diff --git a/.gitignore b/.gitignore index 6e09e097456e9..a32d408493d94 100644 --- a/.gitignore +++ b/.gitignore @@ -22,6 +22,7 @@ /lib/ R-unit-tests.log R/unit-tests.out +R/cran-check.out build/*.jar build/apache-maven* build/scala* @@ -72,7 +73,13 @@ metastore/ metastore_db/ sql/hive-thriftserver/test_warehouses warehouse/ +spark-warehouse/ # For R session data .RData .RHistory +.Rhistory +*.Rproj +*.Rproj.* + +.Rproj.user diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f10d7e277eea3..1a8206abe3838 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,7 +6,7 @@ It lists steps that are required before creating a PR. In particular, consider: - Is the change important and ready enough to ask the community to spend time reviewing? - Have you searched for existing, related JIRAs and pull requests? -- Is this a new feature that can stand alone as a package on http://spark-packages.org ? +- Is this a new feature that can stand alone as a [third party project](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) ? - Is the change being proposed clearly explained and motivated? When you contribute code, you affirm that the contribution is your original work and that you diff --git a/LICENSE b/LICENSE index 9714b3b1e4d17..d68609cc28733 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.9.2 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.3 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) @@ -296,3 +296,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) blockUI (http://jquery.malsup.com/block/) (MIT License) RowsGroup (http://datatables.net/license/mit) (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) + (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) diff --git a/NOTICE b/NOTICE index 2a6fe237dcbea..69b513ea3ba3c 100644 --- a/NOTICE +++ b/NOTICE @@ -1,5 +1,5 @@ Apache Spark -Copyright 2014 The Apache Software Foundation. +Copyright 2014 and onwards The Apache Software Foundation. This product includes software developed at The Apache Software Foundation (http://www.apache.org/). @@ -12,7 +12,9 @@ Common Development and Distribution License 1.0 The following components are provided under the Common Development and Distribution License 1.0. See project link for details. (CDDL 1.0) Glassfish Jasper (org.mortbay.jetty:jsp-2.1:6.1.14 - http://jetty.mortbay.org/project/modules/jsp-2.1) + (CDDL 1.0) JAX-RS (https://jax-rs-spec.java.net/) (CDDL 1.0) Servlet Specification 2.5 API (org.mortbay.jetty:servlet-api-2.5:6.1.14 - http://jetty.mortbay.org/project/modules/servlet-api-2.5) + (CDDL 1.0) (GPL2 w/ CPE) javax.annotation API (https://glassfish.java.net/nonav/public/CDDL+GPL.html) (COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) Version 1.0) (GNU General Public Library) Streaming API for XML (javax.xml.stream:stax-api:1.0-2 - no url defined) (Common Development and Distribution License (CDDL) v1.0) JavaBeans Activation Framework (JAF) (javax.activation:activation:1.1 - http://java.sun.com/products/javabeans/jaf/index.jsp) @@ -22,15 +24,10 @@ Common Development and Distribution License 1.1 The following components are provided under the Common Development and Distribution License 1.1. See project link for details. + (CDDL 1.1) (GPL2 w/ CPE) org.glassfish.hk2 (https://hk2.java.net) (CDDL 1.1) (GPL2 w/ CPE) JAXB API bundle for GlassFish V3 (javax.xml.bind:jaxb-api:2.2.2 - https://jaxb.dev.java.net/) (CDDL 1.1) (GPL2 w/ CPE) JAXB RI (com.sun.xml.bind:jaxb-impl:2.2.3-1 - http://jaxb.java.net/) - (CDDL 1.1) (GPL2 w/ CPE) jersey-core (com.sun.jersey:jersey-core:1.8 - https://jersey.dev.java.net/jersey-core/) - (CDDL 1.1) (GPL2 w/ CPE) jersey-core (com.sun.jersey:jersey-core:1.9 - https://jersey.java.net/jersey-core/) - (CDDL 1.1) (GPL2 w/ CPE) jersey-guice (com.sun.jersey.contribs:jersey-guice:1.9 - https://jersey.java.net/jersey-contribs/jersey-guice/) - (CDDL 1.1) (GPL2 w/ CPE) jersey-json (com.sun.jersey:jersey-json:1.8 - https://jersey.dev.java.net/jersey-json/) - (CDDL 1.1) (GPL2 w/ CPE) jersey-json (com.sun.jersey:jersey-json:1.9 - https://jersey.java.net/jersey-json/) - (CDDL 1.1) (GPL2 w/ CPE) jersey-server (com.sun.jersey:jersey-server:1.8 - https://jersey.dev.java.net/jersey-server/) - (CDDL 1.1) (GPL2 w/ CPE) jersey-server (com.sun.jersey:jersey-server:1.9 - https://jersey.java.net/jersey-server/) + (CDDL 1.1) (GPL2 w/ CPE) Jersey 2 (https://jersey.java.net) ======================================================================== Common Public License 1.0 diff --git a/R/.gitignore b/R/.gitignore index 9a5889ba28b2a..c98504ab07781 100644 --- a/R/.gitignore +++ b/R/.gitignore @@ -4,3 +4,5 @@ lib pkg/man pkg/html +SparkR.Rcheck/ +SparkR_*.tar.gz diff --git a/R/DOCUMENTATION.md b/R/DOCUMENTATION.md index 931d01549b265..7314a1fcccda9 100644 --- a/R/DOCUMENTATION.md +++ b/R/DOCUMENTATION.md @@ -1,12 +1,12 @@ # SparkR Documentation -SparkR documentation is generated using in-source comments annotated using using -`roxygen2`. After making changes to the documentation, to generate man pages, +SparkR documentation is generated by using in-source comments and annotated by using +[`roxygen2`](https://cran.r-project.org/web/packages/roxygen2/index.html). After making changes to the documentation and generating man pages, you can run the following from an R console in the SparkR home directory - - library(devtools) - devtools::document(pkg="./pkg", roclets=c("rd")) - +```R +library(devtools) +devtools::document(pkg="./pkg", roclets=c("rd")) +``` You can verify if your changes are good by running R CMD check pkg/ diff --git a/R/README.md b/R/README.md index 810bfc14e977e..932d5272d0b4f 100644 --- a/R/README.md +++ b/R/README.md @@ -1,12 +1,13 @@ # R on Spark SparkR is an R package that provides a light-weight frontend to use Spark from R. + ### Installing sparkR Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. Example: -``` +```bash # where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript export R_HOME=/home/username/R ./install-dev.sh @@ -17,8 +18,9 @@ export R_HOME=/home/username/R #### Build Spark Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-Psparkr` profile to build the R package. For example to use the default Hadoop versions you can run -``` - build/mvn -DskipTests -Psparkr package + +```bash +build/mvn -DskipTests -Psparkr package ``` #### Running sparkR @@ -37,8 +39,8 @@ To set other options like driver memory, executor memory etc. you can pass in th #### Using SparkR from RStudio -If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example -``` +If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example +```R # Set this to where Spark is installed Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory @@ -55,23 +57,25 @@ Once you have made your changes, please include unit tests for them and run exis #### Generating documentation -The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. +The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. Also, you may need to install these [prerequisites](https://github.com/apache/spark/tree/master/docs#prerequisites). See also, `R/DOCUMENTATION.md` ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. To run one of them, use `./bin/spark-submit `. For example: - - ./bin/spark-submit examples/src/main/r/dataframe.R - -You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): - - R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' - ./R/run-tests.sh +```bash +./bin/spark-submit examples/src/main/r/dataframe.R +``` +You can also run the unit tests for SparkR by running. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: +```bash +R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")' +./R/run-tests.sh +``` ### Running on YARN + The `./bin/spark-submit` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run -``` +```bash export YARN_CONF_DIR=/etc/hadoop/conf ./bin/spark-submit --master yarn examples/src/main/r/dataframe.R ``` diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 3f889c0ca3d1e..f67a1c51d1785 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -11,3 +11,23 @@ include Rtools and R in `PATH`. directory in Maven in `PATH`. 4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html). 5. Open a command shell (`cmd`) in the Spark directory and run `mvn -DskipTests -Psparkr package` + +## Unit tests + +To run the SparkR unit tests on Windows, the following steps are required —assuming you are in the Spark root directory and do not have Apache Hadoop installed already: + +1. Create a folder to download Hadoop related files for Windows. For example, `cd ..` and `mkdir hadoop`. + +2. Download the relevant Hadoop bin package from [steveloughran/winutils](https://github.com/steveloughran/winutils). While these are not official ASF artifacts, they are built from the ASF release git hashes by a Hadoop PMC member on a dedicated Windows VM. For further reading, consult [Windows Problems on the Hadoop wiki](https://wiki.apache.org/hadoop/WindowsProblems). + +3. Install the files into `hadoop\bin`; make sure that `winutils.exe` and `hadoop.dll` are present. + +4. Set the environment variable `HADOOP_HOME` to the full path to the newly created `hadoop` directory. + +5. Run unit tests for SparkR by running the command below. You need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first: + + ``` + R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" + .\bin\spark-submit2.cmd --conf spark.hadoop.fs.default.name="file:///" R\pkg\tests\run-all.R + ``` + diff --git a/R/check-cran.sh b/R/check-cran.sh new file mode 100755 index 0000000000000..bb331466ae931 --- /dev/null +++ b/R/check-cran.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -o pipefail +set -e + +FWDIR="$(cd `dirname $0`; pwd)" +pushd $FWDIR > /dev/null + +if [ ! -z "$R_HOME" ] + then + R_SCRIPT_PATH="$R_HOME/bin" + else + # if system wide R_HOME is not found, then exit + if [ ! `command -v R` ]; then + echo "Cannot find 'R_HOME'. Please specify 'R_HOME' or make sure R is properly installed." + exit 1 + fi + R_SCRIPT_PATH="$(dirname $(which R))" +fi +echo "USING R_HOME = $R_HOME" + +# Build the latest docs +$FWDIR/create-docs.sh + +# Build a zip file containing the source package +"$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + +# Run check as-cran. +VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` + +CRAN_CHECK_OPTIONS="--as-cran" + +if [ -n "$NO_TESTS" ] +then + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-tests" +fi + +if [ -n "$NO_MANUAL" ] +then + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual" +fi + +echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" + +"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz + +popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index d2ae160b50021..69ffc5f678c36 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -17,17 +17,26 @@ # limitations under the License. # -# Script to create API docs for SparkR -# This requires `devtools` and `knitr` to be installed on the machine. +# Script to create API docs and vignettes for SparkR +# This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. # After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html +# The vignettes can be found in +# $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html set -o pipefail set -e # Figure out where the script is export FWDIR="$(cd "`dirname "$0"`"; pwd)" +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + +# Required for setting SPARK_SCALA_VERSION +. "${SPARK_HOME}"/bin/load-spark-env.sh + +echo "Using Scala $SPARK_SCALA_VERSION" + pushd $FWDIR # Install the package (this will also generate the Rd files) @@ -43,4 +52,21 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd +# Find Spark jars. +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +# Only create vignettes if Spark JARs exist +if [ -d "$SPARK_JARS_DIR" ]; then + # render creates SparkR vignettes + Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' + + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" +fi + popd diff --git a/R/install-dev.sh b/R/install-dev.sh index befd413c4cd26..ada6303a722b7 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -38,7 +38,12 @@ pushd $FWDIR > /dev/null if [ ! -z "$R_HOME" ] then R_SCRIPT_PATH="$R_HOME/bin" - else + else + # if system wide R_HOME is not found, then exit + if [ ! `command -v R` ]; then + echo "Cannot find 'R_HOME'. Please specify 'R_HOME' or make sure R is properly installed." + exit 1 + fi R_SCRIPT_PATH="$(dirname $(which R))" fi echo "USING R_HOME = $R_HOME" diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore new file mode 100644 index 0000000000000..544d203a6dce6 --- /dev/null +++ b/R/pkg/.Rbuildignore @@ -0,0 +1,5 @@ +^.*\.Rproj$ +^\.Rproj\.user$ +^\.lintr$ +^src-native$ +^html$ diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 7179438efc1d9..0b01ca8dc8b20 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,20 +1,25 @@ Package: SparkR Type: Package -Title: R frontend for Spark -Version: 2.0.0 -Date: 2013-09-09 -Author: The Apache Software Foundation -Maintainer: Shivaram Venkataraman -Imports: - methods +Title: R Frontend for Apache Spark +Version: 2.0.3 +Date: 2016-08-27 +Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), + email = "shivaram@cs.berkeley.edu"), + person("Xiangrui", "Meng", role = "aut", + email = "meng@databricks.com"), + person("Felix", "Cheung", role = "aut", + email = "felixcheung@apache.org"), + person(family = "The Apache Software Foundation", role = c("aut", "cph"))) +URL: http://www.apache.org/ http://spark.apache.org/ +BugReports: https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-ContributingBugReports Depends: R (>= 3.0), - methods, + methods Suggests: testthat, e1071, survival -Description: R frontend for Spark +Description: The SparkR package provides an R frontend for Apache Spark. License: Apache License (== 2.0) Collate: 'schema.R' @@ -26,16 +31,20 @@ Collate: 'pairRDD.R' 'DataFrame.R' 'SQLContext.R' + 'WindowSpec.R' 'backend.R' 'broadcast.R' 'client.R' 'context.R' 'deserialize.R' 'functions.R' + 'install.R' + 'jvm.R' 'mllib.R' 'serialize.R' 'sparkR.R' 'stats.R' 'types.R' 'utils.R' + 'window.R' RoxygenNote: 5.0.1 diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 73f7c595f4437..62c33a7323ede 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -1,15 +1,32 @@ # Imports from base R -importFrom(methods, setGeneric, setMethod, setOldClass) +# Do not include stats:: "rpois", "runif" - causes error at runtime +importFrom("methods", "setGeneric", "setMethod", "setOldClass") +importFrom("methods", "is", "new", "signature", "show") +importFrom("stats", "gaussian", "setNames") +importFrom("utils", "download.file", "object.size", "packageVersion", "untar") # Disable native libraries till we figure out how to package it # See SPARKR-7839 #useDynLib(SparkR, stringHashCode) # S3 methods exported +export("sparkR.session") export("sparkR.init") export("sparkR.stop") +export("sparkR.session.stop") +export("sparkR.conf") +export("sparkR.version") export("print.jobj") +export("sparkR.newJObject") +export("sparkR.callJMethod") +export("sparkR.callJStatic") + +export("install.spark") + +export("sparkRSQL.init", + "sparkRHive.init") + # MLlib integration exportMethods("glm", "spark.glm", @@ -45,8 +62,10 @@ exportMethods("arrange", "corr", "covar_samp", "covar_pop", + "createOrReplaceTempView", "crosstab", "dapply", + "dapplyCollect", "describe", "dim", "distinct", @@ -60,6 +79,8 @@ exportMethods("arrange", "filter", "first", "freqItems", + "gapply", + "gapplyCollect", "group_by", "groupBy", "head", @@ -78,6 +99,7 @@ exportMethods("arrange", "orderBy", "persist", "printSchema", + "randomSplit", "rbind", "registerTempTable", "rename", @@ -98,6 +120,7 @@ exportMethods("arrange", "summary", "take", "transform", + "union", "unionAll", "unique", "unpersist", @@ -108,6 +131,7 @@ exportMethods("arrange", "write.df", "write.jdbc", "write.json", + "write.orc", "write.parquet", "write.text", "write.ml") @@ -184,6 +208,8 @@ exportMethods("%in%", "isNaN", "isNotNull", "isNull", + "is.nan", + "isnan", "kurtosis", "lag", "last", @@ -207,6 +233,7 @@ exportMethods("%in%", "mean", "min", "minute", + "monotonically_increasing_id", "month", "months_between", "n", @@ -216,8 +243,10 @@ exportMethods("%in%", "next_day", "ntile", "otherwise", + "over", "percent_rank", "pmod", + "posexplode", "quarter", "rand", "randn", @@ -246,6 +275,7 @@ exportMethods("%in%", "skewness", "sort_array", "soundex", + "spark_partition_id", "stddev", "stddev_pop", "stddev_samp", @@ -279,9 +309,7 @@ exportMethods("%in%", exportClasses("GroupedData") exportMethods("agg") - -export("sparkRSQL.init", - "sparkRHive.init") +exportMethods("pivot") export("as.DataFrame", "cacheTable", @@ -289,12 +317,14 @@ export("as.DataFrame", "createDataFrame", "createExternalTable", "dropTempTable", + "dropTempView", "jsonFile", "loadDF", "parquetFile", "read.df", "read.jdbc", "read.json", + "read.orc", "read.parquet", "read.text", "spark.lapply", @@ -315,3 +345,21 @@ export("structField", "structType.jobj", "structType.structField", "print.structType") + +exportClasses("WindowSpec") + +export("partitionBy", + "rowsBetween", + "rangeBetween") + +export("windowPartitionBy", + "windowOrderBy") + +S3method(print, jobj) +S3method(print, structField) +S3method(print, structType) +S3method(print, summary.GeneralizedLinearRegressionModel) +S3method(structField, character) +S3method(structField, jobj) +S3method(structType, jobj) +S3method(structType, structField) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 9e30fa0dbf26a..d5c5486ce2dc9 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -23,9 +23,11 @@ NULL setOldClass("jobj") setOldClass("structType") -#' @title S4 class that represents a SparkDataFrame -#' @description DataFrames can be created using functions like \link{createDataFrame}, -#' \link{read.json}, \link{table} etc. +#' S4 class that represents a SparkDataFrame +#' +#' SparkDataFrames can be created using functions like \link{createDataFrame}, +#' \link{read.json}, \link{table} etc. +#' #' @family SparkDataFrame functions #' @rdname SparkDataFrame #' @docType class @@ -37,10 +39,10 @@ setOldClass("structType") #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df <- createDataFrame(sqlContext, faithful) +#' sparkR.session() +#' df <- createDataFrame(faithful) #'} +#' @note SparkDataFrame since 2.0.0 setClass("SparkDataFrame", slots = list(env = "environment", sdf = "jobj")) @@ -53,10 +55,10 @@ setMethod("initialize", "SparkDataFrame", function(.Object, sdf, isCached) { .Object }) -#' @rdname SparkDataFrame #' @export #' @param sdf A Java object reference to the backing Scala DataFrame #' @param isCached TRUE if the SparkDataFrame is cached +#' @noRd dataFrame <- function(sdf, isCached = FALSE) { new("SparkDataFrame", sdf, isCached) } @@ -72,15 +74,16 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @family SparkDataFrame functions #' @rdname printSchema #' @name printSchema +#' @aliases printSchema,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' printSchema(df) #'} +#' @note printSchema since 1.4.0 setMethod("printSchema", signature(x = "SparkDataFrame"), function(x) { @@ -97,15 +100,16 @@ setMethod("printSchema", #' @family SparkDataFrame functions #' @rdname schema #' @name schema +#' @aliases schema,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' dfSchema <- schema(df) #'} +#' @note schema since 1.4.0 setMethod("schema", signature(x = "SparkDataFrame"), function(x) { @@ -116,20 +120,22 @@ setMethod("schema", #' #' Print the logical and physical Catalyst plans to the console for debugging. #' -#' @param x A SparkDataFrame -#' @param extended Logical. If extended is False, explain() only prints the physical plan. +#' @param x a SparkDataFrame. +#' @param extended Logical. If extended is FALSE, explain() only prints the physical plan. +#' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions +#' @aliases explain,SparkDataFrame-method #' @rdname explain #' @name explain #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' explain(df, TRUE) #'} +#' @note explain since 1.4.0 setMethod("explain", signature(x = "SparkDataFrame"), function(x, extended = FALSE) { @@ -144,7 +150,7 @@ setMethod("explain", #' isLocal #' -#' Returns True if the `collect` and `take` methods can be run locally +#' Returns True if the \code{collect} and \code{take} methods can be run locally #' (without any Spark executors). #' #' @param x A SparkDataFrame @@ -152,15 +158,16 @@ setMethod("explain", #' @family SparkDataFrame functions #' @rdname isLocal #' @name isLocal +#' @aliases isLocal,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' isLocal(df) #'} +#' @note isLocal since 1.4.0 setMethod("isLocal", signature(x = "SparkDataFrame"), function(x) { @@ -171,21 +178,24 @@ setMethod("isLocal", #' #' Print the first numRows rows of a SparkDataFrame #' -#' @param x A SparkDataFrame -#' @param numRows The number of rows to print. Defaults to 20. -#' +#' @param x a SparkDataFrame. +#' @param numRows the number of rows to print. Defaults to 20. +#' @param truncate whether truncate long strings. If \code{TRUE}, strings more than +#' 20 characters will be truncated and all cells will be aligned right. +#' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions +#' @aliases showDF,SparkDataFrame-method #' @rdname showDF #' @name showDF #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' showDF(df) #'} +#' @note showDF since 1.4.0 setMethod("showDF", signature(x = "SparkDataFrame"), function(x, numRows = 20, truncate = TRUE) { @@ -195,22 +205,23 @@ setMethod("showDF", #' show #' -#' Print the SparkDataFrame column names and types +#' Print class and type information of a Spark object. #' -#' @param x A SparkDataFrame +#' @param object a Spark object. Can be a SparkDataFrame, Column, GroupedData, WindowSpec. #' #' @family SparkDataFrame functions #' @rdname show +#' @aliases show,SparkDataFrame-method #' @name show #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) -#' df +#' df <- read.json(path) +#' show(df) #'} +#' @note show(SparkDataFrame) since 1.4.0 setMethod("show", "SparkDataFrame", function(object) { cols <- lapply(dtypes(object), function(l) { @@ -229,15 +240,16 @@ setMethod("show", "SparkDataFrame", #' @family SparkDataFrame functions #' @rdname dtypes #' @name dtypes +#' @aliases dtypes,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' dtypes(df) #'} +#' @note dtypes since 1.4.0 setMethod("dtypes", signature(x = "SparkDataFrame"), function(x) { @@ -246,26 +258,26 @@ setMethod("dtypes", }) }) -#' Column names +#' Column Names of SparkDataFrame #' -#' Return all column names as a list +#' Return all column names as a list. #' -#' @param x A SparkDataFrame +#' @param x a SparkDataFrame. #' #' @family SparkDataFrame functions #' @rdname columns #' @name columns - +#' @aliases columns,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' columns(df) #' colnames(df) #'} +#' @note columns since 1.4.0 setMethod("columns", signature(x = "SparkDataFrame"), function(x) { @@ -276,6 +288,8 @@ setMethod("columns", #' @rdname columns #' @name names +#' @aliases names,SparkDataFrame-method +#' @note names since 1.5.0 setMethod("names", signature(x = "SparkDataFrame"), function(x) { @@ -283,7 +297,9 @@ setMethod("names", }) #' @rdname columns +#' @aliases names<-,SparkDataFrame-method #' @name names<- +#' @note names<- since 1.5.0 setMethod("names<-", signature(x = "SparkDataFrame"), function(x, value) { @@ -294,15 +310,21 @@ setMethod("names<-", }) #' @rdname columns +#' @aliases colnames,SparkDataFrame-method #' @name colnames +#' @note colnames since 1.6.0 setMethod("colnames", signature(x = "SparkDataFrame"), function(x) { columns(x) }) +#' @param value a character vector. Must have the same length as the number +#' of columns in the SparkDataFrame. #' @rdname columns +#' @aliases colnames<-,SparkDataFrame-method #' @name colnames<- +#' @note colnames<- since 1.6.0 setMethod("colnames<-", signature(x = "SparkDataFrame"), function(x, value) { @@ -323,7 +345,7 @@ setMethod("colnames<-", # Check if the column names have . in it if (any(regexec(".", value, fixed = TRUE)[[1]][1] != -1)) { - stop("Colum names cannot contain the '.' symbol.") + stop("Column names cannot contain the '.' symbol.") } sdf <- callJMethod(x@sdf, "toDF", as.list(value)) @@ -337,14 +359,16 @@ setMethod("colnames<-", #' @param x A SparkDataFrame #' @return value A character vector with the column types of the given SparkDataFrame #' @rdname coltypes +#' @aliases coltypes,SparkDataFrame-method #' @name coltypes #' @family SparkDataFrame functions #' @export #' @examples #'\dontrun{ -#' irisDF <- createDataFrame(sqlContext, iris) -#' coltypes(irisDF) +#' irisDF <- createDataFrame(iris) +#' coltypes(irisDF) # get column types #'} +#' @note coltypes since 1.6.0 setMethod("coltypes", signature(x = "SparkDataFrame"), function(x) { @@ -366,7 +390,11 @@ setMethod("coltypes", } if (is.null(type)) { - stop(paste("Unsupported data type: ", x)) + specialtype <- specialtypeshandle(x) + if (is.null(specialtype)) { + stop(paste("Unsupported data type: ", x)) + } + type <- PRIMITIVE_TYPES[[specialtype]] } } type @@ -385,22 +413,22 @@ setMethod("coltypes", #' #' Set the column types of a SparkDataFrame. #' -#' @param x A SparkDataFrame #' @param value A character vector with the target column types for the given #' SparkDataFrame. Column types can be one of integer, numeric/double, character, logical, or NA #' to keep that column as-is. #' @rdname coltypes #' @name coltypes<- +#' @aliases coltypes<-,SparkDataFrame,character-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) -#' coltypes(df) <- c("character", "integer") -#' coltypes(df) <- c(NA, "numeric") +#' df <- read.json(path) +#' coltypes(df) <- c("character", "integer") # set column types +#' coltypes(df) <- c(NA, "numeric") # set column types #'} +#' @note coltypes<- since 1.6.0 setMethod("coltypes<-", signature(x = "SparkDataFrame", value = "character"), function(x, value) { @@ -428,54 +456,86 @@ setMethod("coltypes<-", dataFrame(nx@sdf) }) -#' Register Temporary Table +#' Creates a temporary view using the given name. +#' +#' Creates a new temporary view using a SparkDataFrame in the Spark Session. If a +#' temporary view with the same name already exists, replaces it. +#' +#' @param x A SparkDataFrame +#' @param viewName A character vector containing the name of the table #' -#' Registers a SparkDataFrame as a Temporary Table in the SQLContext +#' @family SparkDataFrame functions +#' @rdname createOrReplaceTempView +#' @name createOrReplaceTempView +#' @aliases createOrReplaceTempView,SparkDataFrame,character-method +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' createOrReplaceTempView(df, "json_df") +#' new_df <- sql("SELECT * FROM json_df") +#'} +#' @note createOrReplaceTempView since 2.0.0 +setMethod("createOrReplaceTempView", + signature(x = "SparkDataFrame", viewName = "character"), + function(x, viewName) { + invisible(callJMethod(x@sdf, "createOrReplaceTempView", viewName)) + }) + +#' (Deprecated) Register Temporary Table #' +#' Registers a SparkDataFrame as a Temporary Table in the SparkSession #' @param x A SparkDataFrame #' @param tableName A character vector containing the name of the table #' #' @family SparkDataFrame functions -#' @rdname registerTempTable +#' @seealso \link{createOrReplaceTempView} +#' @rdname registerTempTable-deprecated #' @name registerTempTable +#' @aliases registerTempTable,SparkDataFrame,character-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' registerTempTable(df, "json_df") -#' new_df <- sql(sqlContext, "SELECT * FROM json_df") +#' new_df <- sql("SELECT * FROM json_df") #'} +#' @note registerTempTable since 1.4.0 setMethod("registerTempTable", signature(x = "SparkDataFrame", tableName = "character"), function(x, tableName) { - invisible(callJMethod(x@sdf, "registerTempTable", tableName)) + .Deprecated("createOrReplaceTempView") + invisible(callJMethod(x@sdf, "createOrReplaceTempView", tableName)) }) #' insertInto #' -#' Insert the contents of a SparkDataFrame into a table registered in the current SQL Context. +#' Insert the contents of a SparkDataFrame into a table registered in the current SparkSession. #' -#' @param x A SparkDataFrame -#' @param tableName A character vector containing the name of the table -#' @param overwrite A logical argument indicating whether or not to overwrite +#' @param x a SparkDataFrame. +#' @param tableName a character vector containing the name of the table. +#' @param overwrite a logical argument indicating whether or not to overwrite. +#' @param ... further arguments to be passed to or from other methods. #' the existing rows in the table. #' #' @family SparkDataFrame functions #' @rdname insertInto #' @name insertInto +#' @aliases insertInto,SparkDataFrame,character-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df <- read.df(sqlContext, path, "parquet") -#' df2 <- read.df(sqlContext, path2, "parquet") -#' registerTempTable(df, "table1") +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' df2 <- read.df(path2, "parquet") +#' createOrReplaceTempView(df, "table1") #' insertInto(df2, "table1", overwrite = TRUE) #'} +#' @note insertInto since 1.4.0 setMethod("insertInto", signature(x = "SparkDataFrame", tableName = "character"), function(x, tableName, overwrite = FALSE) { @@ -492,17 +552,18 @@ setMethod("insertInto", #' @param x A SparkDataFrame #' #' @family SparkDataFrame functions +#' @aliases cache,SparkDataFrame-method #' @rdname cache #' @name cache #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' cache(df) #'} +#' @note cache since 1.4.0 setMethod("cache", signature(x = "SparkDataFrame"), function(x) { @@ -517,20 +578,23 @@ setMethod("cache", #' supported storage levels, refer to #' \url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. #' -#' @param x The SparkDataFrame to persist +#' @param x the SparkDataFrame to persist. +#' @param newLevel storage level chosen for the persistance. See available options in +#' the description. #' #' @family SparkDataFrame functions #' @rdname persist #' @name persist +#' @aliases persist,SparkDataFrame,character-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' persist(df, "MEMORY_AND_DISK") #'} +#' @note persist since 1.4.0 setMethod("persist", signature(x = "SparkDataFrame", newLevel = "character"), function(x, newLevel) { @@ -544,22 +608,24 @@ setMethod("persist", #' Mark this SparkDataFrame as non-persistent, and remove all blocks for it from memory and #' disk. #' -#' @param x The SparkDataFrame to unpersist -#' @param blocking Whether to block until all blocks are deleted +#' @param x the SparkDataFrame to unpersist. +#' @param blocking whether to block until all blocks are deleted. +#' @param ... further arguments to be passed to or from other methods. #' #' @family SparkDataFrame functions #' @rdname unpersist-methods +#' @aliases unpersist,SparkDataFrame-method #' @name unpersist #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} +#' @note unpersist since 1.4.0 setMethod("unpersist", signature(x = "SparkDataFrame"), function(x, blocking = TRUE) { @@ -570,27 +636,56 @@ setMethod("unpersist", #' Repartition #' -#' Return a new SparkDataFrame that has exactly numPartitions partitions. -#' -#' @param x A SparkDataFrame -#' @param numPartitions The number of partitions to use. +#' The following options for repartition are possible: +#' \itemize{ +#' \item{1.} {Return a new SparkDataFrame partitioned by +#' the given columns into \code{numPartitions}.} +#' \item{2.} {Return a new SparkDataFrame that has exactly \code{numPartitions}.} +#' \item{3.} {Return a new SparkDataFrame partitioned by the given column(s), +#' using \code{spark.sql.shuffle.partitions} as number of partitions.} +#'} +#' @param x a SparkDataFrame. +#' @param numPartitions the number of partitions to use. +#' @param col the column by which the partitioning will be performed. +#' @param ... additional column(s) to be used in the partitioning. #' #' @family SparkDataFrame functions #' @rdname repartition #' @name repartition +#' @aliases repartition,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' newDF <- repartition(df, 2L) +#' newDF <- repartition(df, numPartitions = 2L) +#' newDF <- repartition(df, col = df$"col1", df$"col2") +#' newDF <- repartition(df, 3L, col = df$"col1", df$"col2") #'} +#' @note repartition since 1.4.0 setMethod("repartition", - signature(x = "SparkDataFrame", numPartitions = "numeric"), - function(x, numPartitions) { - sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) + signature(x = "SparkDataFrame"), + function(x, numPartitions = NULL, col = NULL, ...) { + if (!is.null(numPartitions) && is.numeric(numPartitions)) { + # number of partitions and columns both are specified + if (!is.null(col) && class(col) == "Column") { + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions), jcol) + } else { + # only number of partitions is specified + sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions)) + } + } else if (!is.null(col) && class(col) == "Column") { + # only columns are specified + cols <- list(col, ...) + jcol <- lapply(cols, function(c) { c@jc }) + sdf <- callJMethod(x@sdf, "repartition", jcol) + } else { + stop("Please, specify the number of partitions and/or a column(s)") + } dataFrame(sdf) }) @@ -601,15 +696,13 @@ setMethod("repartition", #' #' @param x A SparkDataFrame #' @return A StringRRDD of JSON objects -#' @family SparkDataFrame functions -#' @rdname tojson +#' @aliases toJSON,SparkDataFrame-method #' @noRd #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' newRDD <- toJSON(df) #'} setMethod("toJSON", @@ -620,7 +713,7 @@ setMethod("toJSON", RDD(jrdd, serializedMode = "string") }) -#' write.json +#' Save the contents of SparkDataFrame as a JSON file #' #' Save the contents of a SparkDataFrame as a JSON file (one object per line). Files written out #' with this method can be read back in as a SparkDataFrame using read.json(). @@ -631,15 +724,16 @@ setMethod("toJSON", #' @family SparkDataFrame functions #' @rdname write.json #' @name write.json +#' @aliases write.json,SparkDataFrame,character-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' write.json(df, "/tmp/sparkr-tmp/") #'} +#' @note write.json since 1.6.0 setMethod("write.json", signature(x = "SparkDataFrame", path = "character"), function(x, path) { @@ -647,7 +741,35 @@ setMethod("write.json", invisible(callJMethod(write, "json", path)) }) -#' write.parquet +#' Save the contents of SparkDataFrame as an ORC file, preserving the schema. +#' +#' Save the contents of a SparkDataFrame as an ORC file, preserving the schema. Files written out +#' with this method can be read back in as a SparkDataFrame using read.orc(). +#' +#' @param x A SparkDataFrame +#' @param path The directory where the file is saved +#' +#' @family SparkDataFrame functions +#' @aliases write.orc,SparkDataFrame,character-method +#' @rdname write.orc +#' @name write.orc +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' write.orc(df, "/tmp/sparkr-tmp1/") +#' } +#' @note write.orc since 2.0.0 +setMethod("write.orc", + signature(x = "SparkDataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "orc", path)) + }) + +#' Save the contents of SparkDataFrame as a Parquet file, preserving the schema. #' #' Save the contents of a SparkDataFrame as a Parquet file, preserving the schema. Files written out #' with this method can be read back in as a SparkDataFrame using read.parquet(). @@ -658,16 +780,17 @@ setMethod("write.json", #' @family SparkDataFrame functions #' @rdname write.parquet #' @name write.parquet +#' @aliases write.parquet,SparkDataFrame,character-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' write.parquet(df, "/tmp/sparkr-tmp1/") #' saveAsParquetFile(df, "/tmp/sparkr-tmp2/") #'} +#' @note write.parquet since 1.6.0 setMethod("write.parquet", signature(x = "SparkDataFrame", path = "character"), function(x, path) { @@ -677,7 +800,9 @@ setMethod("write.parquet", #' @rdname write.parquet #' @name saveAsParquetFile +#' @aliases saveAsParquetFile,SparkDataFrame,character-method #' @export +#' @note saveAsParquetFile since 1.4.0 setMethod("saveAsParquetFile", signature(x = "SparkDataFrame", path = "character"), function(x, path) { @@ -685,9 +810,9 @@ setMethod("saveAsParquetFile", write.parquet(x, path) }) -#' write.text +#' Save the content of SparkDataFrame in a text file at the specified path. #' -#' Saves the content of the SparkDataFrame in a text file at the specified path. +#' Save the content of the SparkDataFrame in a text file at the specified path. #' The SparkDataFrame must have only one column of string type with the name "value". #' Each row becomes a new line in the output file. #' @@ -695,17 +820,18 @@ setMethod("saveAsParquetFile", #' @param path The directory where the file is saved #' #' @family SparkDataFrame functions +#' @aliases write.text,SparkDataFrame,character-method #' @rdname write.text #' @name write.text #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.txt" -#' df <- read.text(sqlContext, path) +#' df <- read.text(path) #' write.text(df, "/tmp/sparkr-tmp/") #'} +#' @note write.text since 2.0.0 setMethod("write.text", signature(x = "SparkDataFrame", path = "character"), function(x, path) { @@ -720,17 +846,18 @@ setMethod("write.text", #' @param x A SparkDataFrame #' #' @family SparkDataFrame functions +#' @aliases distinct,SparkDataFrame-method #' @rdname distinct #' @name distinct #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' distinctDF <- distinct(df) #'} +#' @note distinct since 1.4.0 setMethod("distinct", signature(x = "SparkDataFrame"), function(x) { @@ -740,6 +867,8 @@ setMethod("distinct", #' @rdname distinct #' @name unique +#' @aliases unique,SparkDataFrame-method +#' @note unique since 1.5.0 setMethod("unique", signature(x = "SparkDataFrame"), function(x) { @@ -756,18 +885,19 @@ setMethod("unique", #' @param seed Randomness seed value #' #' @family SparkDataFrame functions +#' @aliases sample,SparkDataFrame,logical,numeric-method #' @rdname sample #' @name sample #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' collect(sample(df, FALSE, 0.5)) #' collect(sample(df, TRUE, 0.5)) #'} +#' @note sample since 1.4.0 setMethod("sample", signature(x = "SparkDataFrame", withReplacement = "logical", fraction = "numeric"), @@ -784,7 +914,9 @@ setMethod("sample", }) #' @rdname sample +#' @aliases sample_frac,SparkDataFrame,logical,numeric-method #' @name sample_frac +#' @note sample_frac since 1.4.0 setMethod("sample_frac", signature(x = "SparkDataFrame", withReplacement = "logical", fraction = "numeric"), @@ -792,24 +924,22 @@ setMethod("sample_frac", sample(x, withReplacement, fraction, seed) }) -#' nrow -#' #' Returns the number of rows in a SparkDataFrame #' -#' @param x A SparkDataFrame -#' +#' @param x a SparkDataFrame. #' @family SparkDataFrame functions #' @rdname nrow -#' @name count +#' @name nrow +#' @aliases count,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' count(df) #' } +#' @note count since 1.4.0 setMethod("count", signature(x = "SparkDataFrame"), function(x) { @@ -818,6 +948,8 @@ setMethod("count", #' @name nrow #' @rdname nrow +#' @aliases nrow,SparkDataFrame-method +#' @note nrow since 1.5.0 setMethod("nrow", signature(x = "SparkDataFrame"), function(x) { @@ -831,36 +963,40 @@ setMethod("nrow", #' @family SparkDataFrame functions #' @rdname ncol #' @name ncol +#' @aliases ncol,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' ncol(df) #' } +#' @note ncol since 1.5.0 setMethod("ncol", signature(x = "SparkDataFrame"), function(x) { length(columns(x)) }) +#' Returns the dimensions of SparkDataFrame +#' #' Returns the dimensions (number of rows and columns) of a SparkDataFrame #' @param x a SparkDataFrame #' #' @family SparkDataFrame functions #' @rdname dim +#' @aliases dim,SparkDataFrame-method #' @name dim #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' dim(df) #' } +#' @note dim since 1.5.0 setMethod("dim", signature(x = "SparkDataFrame"), function(x) { @@ -869,23 +1005,25 @@ setMethod("dim", #' Collects all the elements of a SparkDataFrame and coerces them into an R data.frame. #' -#' @param x A SparkDataFrame -#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns +#' @param x a SparkDataFrame. +#' @param stringsAsFactors (Optional) a logical indicating whether or not string columns #' should be converted to factors. FALSE by default. +#' @param ... further arguments to be passed to or from other methods. #' #' @family SparkDataFrame functions #' @rdname collect +#' @aliases collect,SparkDataFrame-method #' @name collect #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } +#' @note collect since 1.4.0 setMethod("collect", signature(x = "SparkDataFrame"), function(x, stringsAsFactors = FALSE) { @@ -922,6 +1060,13 @@ setMethod("collect", df[[colIndex]] <- col } else { colType <- dtypes[[colIndex]][[2]] + if (is.null(PRIMITIVE_TYPES[[colType]])) { + specialtype <- specialtypeshandle(colType) + if (!is.null(specialtype)) { + colType <- specialtype + } + } + # Note that "binary" columns behave like complex types. if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) @@ -948,15 +1093,16 @@ setMethod("collect", #' @family SparkDataFrame functions #' @rdname limit #' @name limit +#' @aliases limit,SparkDataFrame,numeric-method #' @export #' @examples #' \dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' limitedDF <- limit(df, 10) #' } +#' @note limit since 1.4.0 setMethod("limit", signature(x = "SparkDataFrame", num = "numeric"), function(x, num) { @@ -964,20 +1110,23 @@ setMethod("limit", dataFrame(res) }) -#' Take the first NUM rows of a SparkDataFrame and return a the results as a data.frame +#' Take the first NUM rows of a SparkDataFrame and return the results as a R data.frame #' +#' @param x a SparkDataFrame. +#' @param num number of rows to take. #' @family SparkDataFrame functions #' @rdname take #' @name take +#' @aliases take,SparkDataFrame,numeric-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' take(df, 2) #' } +#' @note take since 1.4.0 setMethod("take", signature(x = "SparkDataFrame", num = "numeric"), function(x, num) { @@ -987,26 +1136,26 @@ setMethod("take", #' Head #' -#' Return the first NUM rows of a SparkDataFrame as a data.frame. If NUM is NULL, -#' then head() returns the first 6 rows in keeping with the current data.frame -#' convention in R. +#' Return the first \code{num} rows of a SparkDataFrame as a R data.frame. If \code{num} is not +#' specified, then head() returns the first 6 rows as with R data.frame. #' -#' @param x A SparkDataFrame -#' @param num The number of rows to return. Default is 6. -#' @return A data.frame +#' @param x a SparkDataFrame. +#' @param num the number of rows to return. Default is 6. +#' @return A data.frame. #' #' @family SparkDataFrame functions +#' @aliases head,SparkDataFrame-method #' @rdname head #' @name head #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' head(df) #' } +#' @note head since 1.4.0 setMethod("head", signature(x = "SparkDataFrame"), function(x, num = 6L) { @@ -1016,20 +1165,22 @@ setMethod("head", #' Return the first row of a SparkDataFrame #' -#' @param x A SparkDataFrame +#' @param x a SparkDataFrame or a column used in aggregation function. +#' @param ... further arguments to be passed to or from other methods. #' #' @family SparkDataFrame functions +#' @aliases first,SparkDataFrame-method #' @rdname first #' @name first #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' first(df) #' } +#' @note first(SparkDataFrame) since 1.4.0 setMethod("first", signature(x = "SparkDataFrame"), function(x) { @@ -1045,10 +1196,9 @@ setMethod("first", #' @noRd #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' rdd <- toRDD(df) #'} setMethod("toRDD", @@ -1067,10 +1217,11 @@ setMethod("toRDD", #' #' Groups the SparkDataFrame using the specified columns, so we can run aggregation on them. #' -#' @param x a SparkDataFrame -#' @return a GroupedData -#' @seealso GroupedData +#' @param x a SparkDataFrame. +#' @param ... variable(s) (character names(s) or Column(s)) to group on. +#' @return A GroupedData. #' @family SparkDataFrame functions +#' @aliases groupBy,SparkDataFrame-method #' @rdname groupBy #' @name groupBy #' @export @@ -1082,6 +1233,7 @@ setMethod("toRDD", #' # Compute the max age and average salary, grouped by department and gender. #' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") #' } +#' @note groupBy since 1.4.0 setMethod("groupBy", signature(x = "SparkDataFrame"), function(x, ...) { @@ -1097,6 +1249,8 @@ setMethod("groupBy", #' @rdname groupBy #' @name group_by +#' @aliases group_by,SparkDataFrame-method +#' @note group_by since 1.4.0 setMethod("group_by", signature(x = "SparkDataFrame"), function(x, ...) { @@ -1107,49 +1261,71 @@ setMethod("group_by", #' #' Compute aggregates by specifying a list of columns #' -#' @param x a SparkDataFrame #' @family SparkDataFrame functions -#' @rdname agg +#' @aliases agg,SparkDataFrame-method +#' @rdname summarize #' @name agg #' @export +#' @note agg since 1.4.0 setMethod("agg", signature(x = "SparkDataFrame"), function(x, ...) { agg(groupBy(x), ...) }) -#' @rdname agg +#' @rdname summarize #' @name summarize +#' @aliases summarize,SparkDataFrame-method +#' @note summarize since 1.4.0 setMethod("summarize", signature(x = "SparkDataFrame"), function(x, ...) { agg(x, ...) }) +dapplyInternal <- function(x, func, schema) { + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + + sdf <- callJStatic( + "org.apache.spark.sql.api.r.SQLUtils", + "dapply", + x@sdf, + serialize(cleanClosure(func), connection = NULL), + packageNamesArr, + broadcastArr, + if (is.null(schema)) { schema } else { schema$jobj }) + dataFrame(sdf) +} + #' dapply #' -#' Apply a function to each partition of a DataFrame. +#' Apply a function to each partition of a SparkDataFrame. #' #' @param x A SparkDataFrame #' @param func A function to be applied to each partition of the SparkDataFrame. -#' func should have only one parameter, to which a data.frame corresponds +#' func should have only one parameter, to which a R data.frame corresponds #' to each partition will be passed. -#' The output of func should be a data.frame. -#' @param schema The schema of the resulting DataFrame after the function is applied. +#' The output of func should be a R data.frame. +#' @param schema The schema of the resulting SparkDataFrame after the function is applied. #' It must match the output of func. #' @family SparkDataFrame functions #' @rdname dapply +#' @aliases dapply,SparkDataFrame,function,structType-method #' @name dapply +#' @seealso \link{dapplyCollect} #' @export #' @examples #' \dontrun{ -#' df <- createDataFrame (sqlContext, iris) +#' df <- createDataFrame(iris) #' df1 <- dapply(df, function(x) { x }, schema(df)) #' collect(df1) #' #' # filter and add a column -#' df <- createDataFrame ( -#' sqlContext, +#' df <- createDataFrame( #' list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")), #' c("a", "b", "c")) #' schema <- structType(structField("a", "integer"), structField("b", "double"), @@ -1167,24 +1343,237 @@ setMethod("summarize", #' # 1 2 2 2 3 #' # 2 3 3 3 4 #' } +#' @note dapply since 2.0.0 setMethod("dapply", signature(x = "SparkDataFrame", func = "function", schema = "structType"), function(x, func, schema) { - packageNamesArr <- serialize(.sparkREnv[[".packages"]], - connection = NULL) - - broadcastArr <- lapply(ls(.broadcastNames), - function(name) { get(name, .broadcastNames) }) - - sdf <- callJStatic( - "org.apache.spark.sql.api.r.SQLUtils", - "dapply", - x@sdf, - serialize(cleanClosure(func), connection = NULL), - packageNamesArr, - broadcastArr, - schema$jobj) - dataFrame(sdf) + dapplyInternal(x, func, schema) + }) + +#' dapplyCollect +#' +#' Apply a function to each partition of a SparkDataFrame and collect the result back +#' to R as a data.frame. +#' +#' @param x A SparkDataFrame +#' @param func A function to be applied to each partition of the SparkDataFrame. +#' func should have only one parameter, to which a R data.frame corresponds +#' to each partition will be passed. +#' The output of func should be a R data.frame. +#' @family SparkDataFrame functions +#' @rdname dapplyCollect +#' @aliases dapplyCollect,SparkDataFrame,function-method +#' @name dapplyCollect +#' @seealso \link{dapply} +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(iris) +#' ldf <- dapplyCollect(df, function(x) { x }) +#' +#' # filter and add a column +#' df <- createDataFrame( +#' list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")), +#' c("a", "b", "c")) +#' ldf <- dapplyCollect( +#' df, +#' function(x) { +#' y <- x[x[1] > 1, ] +#' y <- cbind(y, y[1] + 1L) +#' }) +#' # the result +#' # a b c d +#' # 2 2 2 3 +#' # 3 3 3 4 +#' } +#' @note dapplyCollect since 2.0.0 +setMethod("dapplyCollect", + signature(x = "SparkDataFrame", func = "function"), + function(x, func) { + df <- dapplyInternal(x, func, NULL) + + content <- callJMethod(df@sdf, "collect") + # content is a list of items of struct type. Each item has a single field + # which is a serialized data.frame corresponds to one partition of the + # SparkDataFrame. + ldfs <- lapply(content, function(x) { unserialize(x[[1]]) }) + ldf <- do.call(rbind, ldfs) + row.names(ldf) <- NULL + ldf + }) + +#' gapply +#' +#' Groups the SparkDataFrame using the specified columns and applies the R function to each +#' group. +#' +#' @param cols grouping columns. +#' @param func a function to be applied to each group partition specified by grouping +#' column of the SparkDataFrame. The function \code{func} takes as argument +#' a key - grouping columns and a data frame - a local R data.frame. +#' The output of \code{func} is a local R data.frame. +#' @param schema the schema of the resulting SparkDataFrame after the function is applied. +#' The schema must match to output of \code{func}. It has to be defined for each +#' output column with preferred output column name and corresponding data type. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions +#' @aliases gapply,SparkDataFrame-method +#' @rdname gapply +#' @name gapply +#' @seealso \link{gapplyCollect} +#' @export +#' @examples +#' +#' \dontrun{ +#' Computes the arithmetic mean of the second column by grouping +#' on the first and third columns. Output the grouping values and the average. +#' +#' df <- createDataFrame ( +#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), +#' c("a", "b", "c", "d")) +#' +#' Here our output contains three columns, the key which is a combination of two +#' columns with data types integer and string and the mean which is a double. +#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' structField("avg", "double")) +#' result <- gapply( +#' df, +#' c("a", "c"), +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' }, schema) +#' +#' We can also group the data and afterwards call gapply on GroupedData. +#' For Example: +#' gdf <- group_by(df, "a", "c") +#' result <- gapply( +#' gdf, +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' }, schema) +#' collect(result) +#' +#' Result +#' ------ +#' a c avg +#' 3 3 3.0 +#' 1 1 1.5 +#' +#' Fits linear models on iris dataset by grouping on the 'Species' column and +#' using 'Sepal_Length' as a target variable, 'Sepal_Width', 'Petal_Length' +#' and 'Petal_Width' as training features. +#' +#' df <- createDataFrame (iris) +#' schema <- structType(structField("(Intercept)", "double"), +#' structField("Sepal_Width", "double"),structField("Petal_Length", "double"), +#' structField("Petal_Width", "double")) +#' df1 <- gapply( +#' df, +#' df$"Species", +#' function(key, x) { +#' m <- suppressWarnings(lm(Sepal_Length ~ +#' Sepal_Width + Petal_Length + Petal_Width, x)) +#' data.frame(t(coef(m))) +#' }, schema) +#' collect(df1) +#' +#' Result +#' --------- +#' Model (Intercept) Sepal_Width Petal_Length Petal_Width +#' 1 0.699883 0.3303370 0.9455356 -0.1697527 +#' 2 1.895540 0.3868576 0.9083370 -0.6792238 +#' 3 2.351890 0.6548350 0.2375602 0.2521257 +#' +#'} +#' @note gapply(SparkDataFrame) since 2.0.0 +setMethod("gapply", + signature(x = "SparkDataFrame"), + function(x, cols, func, schema) { + grouped <- do.call("groupBy", c(x, cols)) + gapply(grouped, func, schema) + }) + +#' gapplyCollect +#' +#' Groups the SparkDataFrame using the specified columns, applies the R function to each +#' group and collects the result back to R as data.frame. +#' +#' @param cols grouping columns. +#' @param func a function to be applied to each group partition specified by grouping +#' column of the SparkDataFrame. The function \code{func} takes as argument +#' a key - grouping columns and a data frame - a local R data.frame. +#' The output of \code{func} is a local R data.frame. +#' @return A data.frame. +#' @family SparkDataFrame functions +#' @aliases gapplyCollect,SparkDataFrame-method +#' @rdname gapplyCollect +#' @name gapplyCollect +#' @seealso \link{gapply} +#' @export +#' @examples +#' +#' \dontrun{ +#' Computes the arithmetic mean of the second column by grouping +#' on the first and third columns. Output the grouping values and the average. +#' +#' df <- createDataFrame ( +#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), +#' c("a", "b", "c", "d")) +#' +#' result <- gapplyCollect( +#' df, +#' c("a", "c"), +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' colnames(y) <- c("key_a", "key_c", "mean_b") +#' y +#' }) +#' +#' We can also group the data and afterwards call gapply on GroupedData. +#' For Example: +#' gdf <- group_by(df, "a", "c") +#' result <- gapplyCollect( +#' gdf, +#' function(key, x) { +#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) +#' colnames(y) <- c("key_a", "key_c", "mean_b") +#' y +#' }) +#' +#' Result +#' ------ +#' key_a key_c mean_b +#' 3 3 3.0 +#' 1 1 1.5 +#' +#' Fits linear models on iris dataset by grouping on the 'Species' column and +#' using 'Sepal_Length' as a target variable, 'Sepal_Width', 'Petal_Length' +#' and 'Petal_Width' as training features. +#' +#' df <- createDataFrame (iris) +#' result <- gapplyCollect( +#' df, +#' df$"Species", +#' function(key, x) { +#' m <- suppressWarnings(lm(Sepal_Length ~ +#' Sepal_Width + Petal_Length + Petal_Width, x)) +#' data.frame(t(coef(m))) +#' }) +#' +#' Result +#'--------- +#' Model X.Intercept. Sepal_Width Petal_Length Petal_Width +#' 1 0.699883 0.3303370 0.9455356 -0.1697527 +#' 2 1.895540 0.3868576 0.9083370 -0.6792238 +#' 3 2.351890 0.6548350 0.2375602 0.2521257 +#' +#'} +#' @note gapplyCollect(SparkDataFrame) since 2.0.0 +setMethod("gapplyCollect", + signature(x = "SparkDataFrame"), + function(x, cols, func) { + grouped <- do.call("groupBy", c(x, cols)) + gapplyCollect(grouped, func) }) ############################## RDD Map Functions ################################## @@ -1261,15 +1650,21 @@ getColumn <- function(x, c) { column(callJMethod(x@sdf, "col", c)) } +#' @param name name of a Column (without being wrapped by \code{""}). #' @rdname select #' @name $ +#' @aliases $,SparkDataFrame-method +#' @note $ since 1.4.0 setMethod("$", signature(x = "SparkDataFrame"), function(x, name) { getColumn(x, name) }) +#' @param value a Column or \code{NULL}. If \code{NULL}, the specified Column is dropped. #' @rdname select #' @name $<- +#' @aliases $<-,SparkDataFrame-method +#' @note $<- since 1.4.0 setMethod("$<-", signature(x = "SparkDataFrame"), function(x, name, value) { stopifnot(class(value) == "Column" || is.null(value)) @@ -1287,6 +1682,8 @@ setClassUnion("numericOrcharacter", c("numeric", "character")) #' @rdname subset #' @name [[ +#' @aliases [[,SparkDataFrame,numericOrcharacter-method +#' @note [[ since 1.4.0 setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), function(x, i) { if (is.numeric(i)) { @@ -1298,6 +1695,8 @@ setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), #' @rdname subset #' @name [ +#' @aliases [,SparkDataFrame-method +#' @note [ since 1.4.0 setMethod("[", signature(x = "SparkDataFrame"), function(x, i, j, ..., drop = F) { # Perform filtering first if needed @@ -1336,20 +1735,22 @@ setMethod("[", signature(x = "SparkDataFrame"), #' Subset #' #' Return subsets of SparkDataFrame according to given conditions -#' @param x A SparkDataFrame -#' @param subset (Optional) A logical expression to filter on rows -#' @param select expression for the single Column or a list of columns to select from the SparkDataFrame +#' @param x a SparkDataFrame. +#' @param i,subset (Optional) a logical expression to filter on rows. +#' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame. #' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column. -#' Otherwise, a SparkDataFrame will always be returned. -#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns +#' Otherwise, a SparkDataFrame will always be returned. +#' @param ... currently not used. +#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns. #' @export #' @family SparkDataFrame functions +#' @aliases subset,SparkDataFrame-method #' @rdname subset #' @name subset #' @family subsetting functions #' @examples #' \dontrun{ -#' # Columns can be selected using `[[` and `[` +#' # Columns can be selected using [[ and [ #' df[[2]] == df[["age"]] #' df[,2] == df[,"age"] #' df[,c("name", "age")] @@ -1362,20 +1763,29 @@ setMethod("[", signature(x = "SparkDataFrame"), #' subset(df, df$age %in% c(19), select = c(1,2)) #' subset(df, select = c(1,2)) #' } +#' @note subset since 1.5.0 setMethod("subset", signature(x = "SparkDataFrame"), function(x, subset, select, drop = F, ...) { - x[subset, select, drop = drop] + if (missing(subset)) { + x[, select, drop = drop, ...] + } else { + x[subset, select, drop = drop, ...] + } }) #' Select #' #' Selects a set of columns with names or Column expressions. -#' @param x A SparkDataFrame -#' @param col A list of columns or single Column or name -#' @return A new SparkDataFrame with selected columns +#' @param x a SparkDataFrame. +#' @param col a list of columns or single Column or name. +#' @param ... additional column(s) if only one column is specified in \code{col}. +#' If more than one column is assigned in \code{col}, \code{...} +#' should be left empty. +#' @return A new SparkDataFrame with selected columns. #' @export #' @family SparkDataFrame functions #' @rdname select +#' @aliases select,SparkDataFrame,character-method #' @name select #' @family subsetting functions #' @examples @@ -1385,9 +1795,10 @@ setMethod("subset", signature(x = "SparkDataFrame"), #' select(df, df$name, df$age + 1) #' select(df, c("col1", "col2")) #' select(df, list(df$name, df$age + 1)) -#' # Similar to R data frames columns can also be selected using `$` +#' # Similar to R data frames columns can also be selected using $ #' df[,df$age] #' } +#' @note select(SparkDataFrame, character) since 1.4.0 setMethod("select", signature(x = "SparkDataFrame", col = "character"), function(x, col, ...) { if (length(col) > 1) { @@ -1402,9 +1813,10 @@ setMethod("select", signature(x = "SparkDataFrame", col = "character"), } }) -#' @family SparkDataFrame functions #' @rdname select #' @export +#' @aliases select,SparkDataFrame,Column-method +#' @note select(SparkDataFrame, Column) since 1.4.0 setMethod("select", signature(x = "SparkDataFrame", col = "Column"), function(x, col, ...) { jcols <- lapply(list(col, ...), function(c) { @@ -1414,9 +1826,10 @@ setMethod("select", signature(x = "SparkDataFrame", col = "Column"), dataFrame(sdf) }) -#' @family SparkDataFrame functions #' @rdname select #' @export +#' @aliases select,SparkDataFrame,list-method +#' @note select(SparkDataFrame, list) since 1.4.0 setMethod("select", signature(x = "SparkDataFrame", col = "list"), function(x, col) { @@ -1440,17 +1853,18 @@ setMethod("select", #' @param ... Additional expressions #' @return A SparkDataFrame #' @family SparkDataFrame functions +#' @aliases selectExpr,SparkDataFrame,character-method #' @rdname selectExpr #' @name selectExpr #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } +#' @note selectExpr since 1.4.0 setMethod("selectExpr", signature(x = "SparkDataFrame", expr = "character"), function(x, expr, ...) { @@ -1464,25 +1878,26 @@ setMethod("selectExpr", #' Return a new SparkDataFrame by adding a column or replacing the existing column #' that has the same name. #' -#' @param x A SparkDataFrame -#' @param colName A column name. -#' @param col A Column expression. +#' @param x a SparkDataFrame. +#' @param colName a column name. +#' @param col a Column expression. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions +#' @aliases withColumn,SparkDataFrame,character,Column-method #' @rdname withColumn #' @name withColumn #' @seealso \link{rename} \link{mutate} #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' # Replace an existing column #' newDF2 <- withColumn(newDF, "newCol", newDF$col1) #' } +#' @note withColumn since 1.4.0 setMethod("withColumn", signature(x = "SparkDataFrame", colName = "character", col = "Column"), function(x, colName, col) { @@ -1494,29 +1909,29 @@ setMethod("withColumn", #' #' Return a new SparkDataFrame with the specified columns added or replaced. #' -#' @param .data A SparkDataFrame -#' @param col a named argument of the form name = col +#' @param .data a SparkDataFrame. +#' @param ... additional column argument(s) each in the form name = col. #' @return A new SparkDataFrame with the new columns added or replaced. #' @family SparkDataFrame functions +#' @aliases mutate,SparkDataFrame-method #' @rdname mutate #' @name mutate #' @seealso \link{rename} \link{withColumn} #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' newDF2 <- transform(df, newCol = df$col1 / 5, newCol2 = df$col1 * 2) #' -#' df <- createDataFrame(sqlContext, -#' list(list("Andy", 30L), list("Justin", 19L)), c("name", "age")) +#' df <- createDataFrame(list(list("Andy", 30L), list("Justin", 19L)), c("name", "age")) #' # Replace the "age" column #' df1 <- mutate(df, age = df$age + 1L) #' } +#' @note mutate since 1.4.0 setMethod("mutate", signature(.data = "SparkDataFrame"), function(.data, ...) { @@ -1572,9 +1987,12 @@ setMethod("mutate", do.call(select, c(x, colList, deDupCols)) }) +#' @param _data a SparkDataFrame. #' @export #' @rdname mutate +#' @aliases transform,SparkDataFrame-method #' @name transform +#' @note transform since 1.5.0 setMethod("transform", signature(`_data` = "SparkDataFrame"), function(`_data`, ...) { @@ -1592,16 +2010,17 @@ setMethod("transform", #' @family SparkDataFrame functions #' @rdname rename #' @name withColumnRenamed +#' @aliases withColumnRenamed,SparkDataFrame,character,character-method #' @seealso \link{mutate} #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } +#' @note withColumnRenamed since 1.4.0 setMethod("withColumnRenamed", signature(x = "SparkDataFrame", existingCol = "character", newCol = "character"), function(x, existingCol, newCol) { @@ -1615,18 +2034,19 @@ setMethod("withColumnRenamed", select(x, cols) }) -#' @param newColPair A named pair of the form new_column_name = existing_column +#' @param ... A named pair of the form new_column_name = existing_column #' @rdname rename #' @name rename +#' @aliases rename,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' newDF <- rename(df, col1 = df$newCol1) #' } +#' @note rename since 1.4.0 setMethod("rename", signature(x = "SparkDataFrame"), function(x, ...) { @@ -1649,31 +2069,32 @@ setMethod("rename", setClassUnion("characterOrColumn", c("character", "Column")) -#' Arrange +#' Arrange Rows by Variables #' #' Sort a SparkDataFrame by the specified column(s). #' -#' @param x A SparkDataFrame to be sorted. -#' @param col A character or Column object vector indicating the fields to sort on -#' @param ... Additional sorting fields -#' @param decreasing A logical argument indicating sorting order for columns when +#' @param x a SparkDataFrame to be sorted. +#' @param col a character or Column object indicating the fields to sort on +#' @param ... additional sorting fields +#' @param decreasing a logical argument indicating sorting order for columns when #' a character vector is specified for col #' @return A SparkDataFrame where all elements are sorted. #' @family SparkDataFrame functions +#' @aliases arrange,SparkDataFrame,Column-method #' @rdname arrange #' @name arrange #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' arrange(df, df$col1) #' arrange(df, asc(df$col1), desc(abs(df$col2))) #' arrange(df, "col1", decreasing = TRUE) #' arrange(df, "col1", "col2", decreasing = c(TRUE, FALSE)) #' } +#' @note arrange(SparkDataFrame, Column) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "Column"), function(x, col, ...) { @@ -1687,7 +2108,9 @@ setMethod("arrange", #' @rdname arrange #' @name arrange +#' @aliases arrange,SparkDataFrame,character-method #' @export +#' @note arrange(SparkDataFrame, character) since 1.4.0 setMethod("arrange", signature(x = "SparkDataFrame", col = "character"), function(x, col, ..., decreasing = FALSE) { @@ -1718,12 +2141,13 @@ setMethod("arrange", }) #' @rdname arrange -#' @name orderBy +#' @aliases orderBy,SparkDataFrame,characterOrColumn-method #' @export +#' @note orderBy(SparkDataFrame, characterOrColumn) since 1.4.0 setMethod("orderBy", signature(x = "SparkDataFrame", col = "characterOrColumn"), - function(x, col) { - arrange(x, col) + function(x, col, ...) { + arrange(x, col, ...) }) #' Filter @@ -1735,19 +2159,20 @@ setMethod("orderBy", #' or a string containing a SQL statement #' @return A SparkDataFrame containing only the rows that meet the condition. #' @family SparkDataFrame functions +#' @aliases filter,SparkDataFrame,characterOrColumn-method #' @rdname filter #' @name filter #' @family subsetting functions #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } +#' @note filter since 1.4.0 setMethod("filter", signature(x = "SparkDataFrame", condition = "characterOrColumn"), function(x, condition) { @@ -1758,9 +2183,10 @@ setMethod("filter", dataFrame(sdf) }) -#' @family SparkDataFrame functions #' @rdname filter #' @name where +#' @aliases where,SparkDataFrame,characterOrColumn-method +#' @note where since 1.4.0 setMethod("where", signature(x = "SparkDataFrame", condition = "characterOrColumn"), function(x, condition) { @@ -1773,27 +2199,41 @@ setMethod("where", #' the subset of columns. #' #' @param x A SparkDataFrame. -#' @param colnames A character vector of column names. +#' @param ... A character vector of column names or string column names. +#' If the first argument contains a character vector, the followings are ignored. #' @return A SparkDataFrame with duplicate rows removed. #' @family SparkDataFrame functions -#' @rdname dropduplicates +#' @aliases dropDuplicates,SparkDataFrame-method +#' @rdname dropDuplicates #' @name dropDuplicates #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' dropDuplicates(df) +#' dropDuplicates(df, "col1", "col2") #' dropDuplicates(df, c("col1", "col2")) #' } +#' @note dropDuplicates since 2.0.0 setMethod("dropDuplicates", signature(x = "SparkDataFrame"), - function(x, colNames = columns(x)) { - stopifnot(class(colNames) == "character") - - sdf <- callJMethod(x@sdf, "dropDuplicates", as.list(colNames)) + function(x, ...) { + cols <- list(...) + if (length(cols) == 0) { + sdf <- callJMethod(x@sdf, "dropDuplicates", as.list(columns(x))) + } else { + if (!all(sapply(cols, function(c) { is.character(c) }))) { + stop("all columns names should be characters") + } + col <- cols[[1]] + if (length(col) > 1) { + sdf <- callJMethod(x@sdf, "dropDuplicates", as.list(col)) + } else { + sdf <- callJMethod(x@sdf, "dropDuplicates", cols) + } + } dataFrame(sdf) }) @@ -1810,20 +2250,21 @@ setMethod("dropDuplicates", #' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". #' @return A SparkDataFrame containing the result of the join operation. #' @family SparkDataFrame functions +#' @aliases join,SparkDataFrame,SparkDataFrame-method #' @rdname join #' @name join #' @seealso \link{merge} #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df1 <- read.json(sqlContext, path) -#' df2 <- read.json(sqlContext, path2) +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") #' } +#' @note join since 1.4.0 setMethod("join", signature(x = "SparkDataFrame", y = "SparkDataFrame"), function(x, y, joinExpr = NULL, joinType = NULL) { @@ -1849,47 +2290,59 @@ setMethod("join", dataFrame(sdf) }) +#' Merges two data frames +#' #' @name merge -#' @title Merges two data frames #' @param x the first data frame to be joined #' @param y the second data frame to be joined #' @param by a character vector specifying the join columns. If by is not #' specified, the common column names in \code{x} and \code{y} will be used. +#' If by or both by.x and by.y are explicitly set to NULL or of length 0, the Cartesian +#' Product of x and y will be returned. #' @param by.x a character vector specifying the joining columns for x. #' @param by.y a character vector specifying the joining columns for y. +#' @param all a boolean value setting \code{all.x} and \code{all.y} +#' if any of them are unset. #' @param all.x a boolean value indicating whether all the rows in x should #' be including in the join #' @param all.y a boolean value indicating whether all the rows in y should #' be including in the join #' @param sort a logical argument indicating whether the resulting columns should be sorted +#' @param suffixes a string vector of length 2 used to make colnames of +#' \code{x} and \code{y} unique. +#' The first element is appended to each colname of \code{x}. +#' The second element is appended to each colname of \code{y}. +#' @param ... additional argument(s) passed to the method. #' @details If all.x and all.y are set to FALSE, a natural join will be returned. If #' all.x is set to TRUE and all.y is set to FALSE, a left outer join will #' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right #' outer join will be returned. If all.x and all.y are set to TRUE, a full #' outer join will be returned. #' @family SparkDataFrame functions +#' @aliases merge,SparkDataFrame,SparkDataFrame-method #' @rdname merge #' @seealso \link{join} #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df1 <- read.json(sqlContext, path) -#' df2 <- read.json(sqlContext, path2) -#' merge(df1, df2) # Performs a Cartesian +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' merge(df1, df2) # Performs an inner join by common columns #' merge(df1, df2, by = "col1") # Performs an inner join based on expression #' merge(df1, df2, by.x = "col1", by.y = "col2", all.y = TRUE) #' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE) #' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE, all.y = TRUE) #' merge(df1, df2, by.x = "col1", by.y = "col2", all = TRUE, sort = FALSE) #' merge(df1, df2, by = "col1", all = TRUE, suffixes = c("-X", "-Y")) +#' merge(df1, df2, by = NULL) # Performs a Cartesian join #' } +#' @note merge since 1.5.0 setMethod("merge", signature(x = "SparkDataFrame", y = "SparkDataFrame"), function(x, y, by = intersect(names(x), names(y)), by.x = by, by.y = by, all = FALSE, all.x = all, all.y = all, - sort = TRUE, suffixes = c("_x", "_y"), ... ) { + sort = TRUE, suffixes = c("_x", "_y"), ...) { if (length(suffixes) != 2) { stop("suffixes must have length 2") @@ -1964,15 +2417,17 @@ setMethod("merge", joinRes }) +#' Creates a list of columns by replacing the intersected ones with aliases #' #' Creates a list of columns by replacing the intersected ones with aliases. #' The name of the alias column is formed by concatanating the original column name and a suffix. #' -#' @param x a SparkDataFrame on which the -#' @param intersectedColNames a list of intersected column names +#' @param x a SparkDataFrame +#' @param intersectedColNames a list of intersected column names of the SparkDataFrame #' @param suffix a suffix for the column name #' @return list of columns #' +#' @note generateAliasesForIntersectedCols since 1.6.0 generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { allColNames <- names(x) # sets alias for making colnames unique in dataframe 'x' @@ -1991,70 +2446,103 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { cols } -#' rbind +#' Return a new SparkDataFrame containing the union of rows #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame -#' and another SparkDataFrame. This is equivalent to `UNION ALL` in SQL. +#' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. #' Note that this does not remove duplicate rows across the two SparkDataFrames. #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame #' @return A SparkDataFrame containing the result of the union. #' @family SparkDataFrame functions -#' @rdname rbind -#' @name unionAll +#' @rdname union +#' @name union +#' @aliases union,SparkDataFrame,SparkDataFrame-method +#' @seealso \link{rbind} #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df1 <- read.json(sqlContext, path) -#' df2 <- read.json(sqlContext, path2) -#' unioned <- unionAll(df, df2) +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) +#' unioned <- union(df, df2) +#' unions <- rbind(df, df2, df3, df4) #' } -setMethod("unionAll", +#' @note union since 2.0.0 +setMethod("union", signature(x = "SparkDataFrame", y = "SparkDataFrame"), function(x, y) { - unioned <- callJMethod(x@sdf, "unionAll", y@sdf) + unioned <- callJMethod(x@sdf, "union", y@sdf) dataFrame(unioned) }) -#' @title Union two or more SparkDataFrames -#' @description Returns a new SparkDataFrame containing rows of all parameters. +#' unionAll is deprecated - use union instead +#' @rdname union +#' @name unionAll +#' @aliases unionAll,SparkDataFrame,SparkDataFrame-method +#' @export +#' @note unionAll since 1.4.0 +setMethod("unionAll", + signature(x = "SparkDataFrame", y = "SparkDataFrame"), + function(x, y) { + .Deprecated("union") + union(x, y) + }) + +#' Union two or more SparkDataFrames #' +#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. +#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' +#' @param x a SparkDataFrame. +#' @param ... additional SparkDataFrame(s). +#' @param deparse.level currently not used (put here to match the signature of +#' the base implementation). +#' @return A SparkDataFrame containing the result of the union. +#' @family SparkDataFrame functions +#' @aliases rbind,SparkDataFrame-method #' @rdname rbind #' @name rbind +#' @seealso \link{union} #' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' unions <- rbind(df, df2, df3, df4) +#' } +#' @note rbind since 1.5.0 setMethod("rbind", signature(... = "SparkDataFrame"), function(x, ..., deparse.level = 1) { if (nargs() == 3) { - unionAll(x, ...) + union(x, ...) } else { - unionAll(x, Recall(..., deparse.level = 1)) + union(x, Recall(..., deparse.level = 1)) } }) #' Intersect #' #' Return a new SparkDataFrame containing rows only in both this SparkDataFrame -#' and another SparkDataFrame. This is equivalent to `INTERSECT` in SQL. +#' and another SparkDataFrame. This is equivalent to \code{INTERSECT} in SQL. #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame #' @return A SparkDataFrame containing the result of the intersect. #' @family SparkDataFrame functions +#' @aliases intersect,SparkDataFrame,SparkDataFrame-method #' @rdname intersect #' @name intersect #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df1 <- read.json(sqlContext, path) -#' df2 <- read.json(sqlContext, path2) +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) #' intersectDF <- intersect(df, df2) #' } +#' @note intersect since 1.4.0 setMethod("intersect", signature(x = "SparkDataFrame", y = "SparkDataFrame"), function(x, y) { @@ -2065,25 +2553,26 @@ setMethod("intersect", #' except #' #' Return a new SparkDataFrame containing rows in this SparkDataFrame -#' but not in another SparkDataFrame. This is equivalent to `EXCEPT` in SQL. +#' but not in another SparkDataFrame. This is equivalent to \code{EXCEPT} in SQL. #' -#' @param x A SparkDataFrame -#' @param y A SparkDataFrame +#' @param x a SparkDataFrame. +#' @param y a SparkDataFrame. #' @return A SparkDataFrame containing the result of the except operation. #' @family SparkDataFrame functions +#' @aliases except,SparkDataFrame,SparkDataFrame-method #' @rdname except #' @name except #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df1 <- read.json(sqlContext, path) -#' df2 <- read.json(sqlContext, path2) +#' sparkR.session() +#' df1 <- read.json(path) +#' df2 <- read.json(path2) #' exceptDF <- except(df, df2) #' } #' @rdname except #' @export +#' @note except since 1.4.0 setMethod("except", signature(x = "SparkDataFrame", y = "SparkDataFrame"), function(x, y) { @@ -2091,52 +2580,48 @@ setMethod("except", dataFrame(excepted) }) -#' Save the contents of the SparkDataFrame to a data source +#' Save the contents of SparkDataFrame to a data source. #' -#' The data source is specified by the `source` and a set of options (...). -#' If `source` is not specified, the default data source configured by +#' The data source is specified by the \code{source} and a set of options (...). +#' If \code{source} is not specified, the default data source configured by #' spark.sql.sources.default will be used. #' -#' Additionally, mode is used to specify the behavior of the save operation when -#' data already exists in the data source. There are four modes: \cr -#' append: Contents of this SparkDataFrame are expected to be appended to existing data. \cr -#' overwrite: Existing data is expected to be overwritten by the contents of this -#' SparkDataFrame. \cr -#' error: An exception is expected to be thrown. \cr -#' ignore: The save operation is expected to not save the contents of the SparkDataFrame -#' and to not change the existing data. \cr +#' Additionally, mode is used to specify the behavior of the save operation when data already +#' exists in the data source. There are four modes: +#' \itemize{ +#' \item append: Contents of this SparkDataFrame are expected to be appended to existing data. +#' \item overwrite: Existing data is expected to be overwritten by the contents of this +#' SparkDataFrame. +#' \item error: An exception is expected to be thrown. +#' \item ignore: The save operation is expected to not save the contents of the SparkDataFrame +#' and to not change the existing data. +#' } #' -#' @param df A SparkDataFrame -#' @param path A name for the table -#' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param df a SparkDataFrame. +#' @param path a name for the table. +#' @param source a name for external data source. +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param ... additional argument(s) passed to the method. #' #' @family SparkDataFrame functions +#' @aliases write.df,SparkDataFrame,character-method #' @rdname write.df #' @name write.df #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' write.df(df, "myfile", "parquet", "overwrite") #' saveDF(df, parquetPath2, "parquet", mode = saveMode, mergeSchema = mergeSchema) #' } +#' @note write.df since 1.4.0 setMethod("write.df", signature(df = "SparkDataFrame", path = "character"), - function(df, path, source = NULL, mode = "error", ...){ + function(df, path, source = NULL, mode = "error", ...) { if (is.null(source)) { - if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) - } else if (exists(".sparkRHivesc", envir = .sparkREnv)) { - sqlContext <- get(".sparkRHivesc", envir = .sparkREnv) - } else { - stop("sparkRHive or sparkRSQL context has to be specified") - } - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") + source <- getDefaultSqlSource() } jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) @@ -2146,24 +2631,25 @@ setMethod("write.df", write <- callJMethod(df@sdf, "write") write <- callJMethod(write, "format", source) write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "options", options) write <- callJMethod(write, "save", path) }) #' @rdname write.df #' @name saveDF +#' @aliases saveDF,SparkDataFrame,character-method #' @export +#' @note saveDF since 1.4.0 setMethod("saveDF", signature(df = "SparkDataFrame", path = "character"), - function(df, path, source = NULL, mode = "error", ...){ + function(df, path, source = NULL, mode = "error", ...) { write.df(df, path, source, mode, ...) }) -#' saveAsTable -#' #' Save the contents of the SparkDataFrame to a data source as a table #' -#' The data source is specified by the `source` and a set of options (...). -#' If `source` is not specified, the default data source configured by +#' The data source is specified by the \code{source} and a set of options (...). +#' If \code{source} is not specified, the default data source configured by #' spark.sql.sources.default will be used. #' #' Additionally, mode is used to specify the behavior of the save operation when @@ -2175,36 +2661,30 @@ setMethod("saveDF", #' ignore: The save operation is expected to not save the contents of the SparkDataFrame #' and to not change the existing data. \cr #' -#' @param df A SparkDataFrame -#' @param tableName A name for the table -#' @param source A name for external data source -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param df a SparkDataFrame. +#' @param tableName a name for the table. +#' @param source a name for external data source. +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default). +#' @param ... additional option(s) passed to the method. #' #' @family SparkDataFrame functions +#' @aliases saveAsTable,SparkDataFrame,character-method #' @rdname saveAsTable #' @name saveAsTable #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' saveAsTable(df, "myfile") #' } +#' @note saveAsTable since 1.4.0 setMethod("saveAsTable", signature(df = "SparkDataFrame", tableName = "character"), - function(df, tableName, source = NULL, mode="error", ...){ + function(df, tableName, source = NULL, mode="error", ...) { if (is.null(source)) { - if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) - } else if (exists(".sparkRHivesc", envir = .sparkREnv)) { - sqlContext <- get(".sparkRHivesc", envir = .sparkREnv) - } else { - stop("sparkRHive or sparkRSQL context has to be specified") - } - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") + source <- getDefaultSqlSource() } jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) @@ -2221,24 +2701,25 @@ setMethod("saveAsTable", #' Computes statistics for numeric columns. #' If no columns are given, this function computes statistics for all numerical columns. #' -#' @param x A SparkDataFrame to be computed. -#' @param col A string of name -#' @param ... Additional expressions -#' @return A SparkDataFrame +#' @param x a SparkDataFrame to be computed. +#' @param col a string of name. +#' @param ... additional expressions. +#' @return A SparkDataFrame. #' @family SparkDataFrame functions +#' @aliases describe,SparkDataFrame,character-method describe,SparkDataFrame,ANY-method #' @rdname summary #' @name describe #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) +#' df <- read.json(path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") #' } +#' @note describe(SparkDataFrame, character) since 1.4.0 setMethod("describe", signature(x = "SparkDataFrame", col = "character"), function(x, col, ...) { @@ -2249,16 +2730,20 @@ setMethod("describe", #' @rdname summary #' @name describe +#' @aliases describe,SparkDataFrame-method +#' @note describe(SparkDataFrame) since 1.4.0 setMethod("describe", signature(x = "SparkDataFrame"), function(x) { - colList <- as.list(c(columns(x))) - sdf <- callJMethod(x@sdf, "describe", colList) + sdf <- callJMethod(x@sdf, "describe", list()) dataFrame(sdf) }) +#' @param object a SparkDataFrame to be summarized. #' @rdname summary #' @name summary +#' @aliases summary,SparkDataFrame-method +#' @note summary(SparkDataFrame) since 1.5.0 setMethod("summary", signature(object = "SparkDataFrame"), function(object, ...) { @@ -2266,33 +2751,38 @@ setMethod("summary", }) -#' dropna +#' A set of SparkDataFrame functions working with NA values #' -#' Returns a new SparkDataFrame omitting rows with null values. +#' dropna, na.omit - Returns a new SparkDataFrame omitting rows with null values. #' -#' @param x A SparkDataFrame. +#' @param x a SparkDataFrame. #' @param how "any" or "all". #' if "any", drop a row if it contains any nulls. #' if "all", drop a row only if all its values are null. -#' if minNonNulls is specified, how is ignored. -#' @param minNonNulls If specified, drop rows that have less than -#' minNonNulls non-null values. +#' if \code{minNonNulls} is specified, how is ignored. +#' @param minNonNulls if specified, drop rows that have less than +#' \code{minNonNulls} non-null values. #' This overwrites the how parameter. -#' @param cols Optional list of column names to consider. -#' @return A SparkDataFrame +#' @param cols optional list of column names to consider. In \code{fillna}, +#' columns specified in cols that do not have matching data +#' type are ignored. For example, if value is a character, and +#' subset contains a non-character column, then the non-character +#' column is simply ignored. +#' @return A SparkDataFrame. #' #' @family SparkDataFrame functions #' @rdname nafunctions +#' @aliases dropna,SparkDataFrame-method #' @name dropna #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlCtx, path) +#' df <- read.json(path) #' dropna(df) #' } +#' @note dropna since 1.4.0 setMethod("dropna", signature(x = "SparkDataFrame"), function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { @@ -2310,44 +2800,41 @@ setMethod("dropna", dataFrame(sdf) }) +#' @param object a SparkDataFrame. +#' @param ... further arguments to be passed to or from other methods. #' @rdname nafunctions #' @name na.omit +#' @aliases na.omit,SparkDataFrame-method #' @export +#' @note na.omit since 1.5.0 setMethod("na.omit", signature(object = "SparkDataFrame"), function(object, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { dropna(object, how, minNonNulls, cols) }) -#' fillna +#' fillna - Replace null values. #' -#' Replace null values. -#' -#' @param x A SparkDataFrame. -#' @param value Value to replace null values with. +#' @param value value to replace null values with. #' Should be an integer, numeric, character or named list. #' If the value is a named list, then cols is ignored and #' value must be a mapping from column name (character) to #' replacement value. The replacement value must be an #' integer, numeric or character. -#' @param cols optional list of column names to consider. -#' Columns specified in cols that do not have matching data -#' type are ignored. For example, if value is a character, and -#' subset contains a non-character column, then the non-character -#' column is simply ignored. #' #' @rdname nafunctions #' @name fillna +#' @aliases fillna,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlCtx, path) +#' df <- read.json(path) #' fillna(df, 1) #' fillna(df, list("age" = 20, "name" = "unknown")) #' } +#' @note fillna since 1.4.0 setMethod("fillna", signature(x = "SparkDataFrame"), function(x, value, cols = NULL) { @@ -2392,33 +2879,41 @@ setMethod("fillna", dataFrame(sdf) }) +#' Download data from a SparkDataFrame into a R data.frame +#' #' This function downloads the contents of a SparkDataFrame into an R's data.frame. #' Since data.frames are held in memory, ensure that you have enough memory #' in your system to accommodate the contents. #' -#' @title Download data from a SparkDataFrame into a data.frame -#' @param x a SparkDataFrame -#' @return a data.frame +#' @param x a SparkDataFrame. +#' @param row.names \code{NULL} or a character vector giving the row names for the data frame. +#' @param optional If \code{TRUE}, converting column names is optional. +#' @param ... additional arguments to pass to base::as.data.frame. +#' @return A data.frame. #' @family SparkDataFrame functions +#' @aliases as.data.frame,SparkDataFrame-method #' @rdname as.data.frame #' @examples \dontrun{ #' -#' irisDF <- createDataFrame(sqlContext, iris) +#' irisDF <- createDataFrame(iris) #' df <- as.data.frame(irisDF[irisDF$Species == "setosa", ]) #' } +#' @note as.data.frame since 1.6.0 setMethod("as.data.frame", signature(x = "SparkDataFrame"), function(x, row.names = NULL, optional = FALSE, ...) { as.data.frame(collect(x), row.names, optional, ...) }) +#' Attach SparkDataFrame to R search path +#' #' The specified SparkDataFrame is attached to the R search path. This means that #' the SparkDataFrame is searched by R when evaluating a variable, so columns in #' the SparkDataFrame can be accessed by simply giving their names. #' #' @family SparkDataFrame functions #' @rdname attach -#' @title Attach SparkDataFrame to R search path +#' @aliases attach,SparkDataFrame-method #' @param what (SparkDataFrame) The SparkDataFrame to attach #' @param pos (integer) Specify position in search() where to attach. #' @param name (character) Name to use for the attached SparkDataFrame. Names @@ -2431,6 +2926,7 @@ setMethod("as.data.frame", #' summary(Sepal_Width) #' } #' @seealso \link{detach} +#' @note attach since 1.6.0 setMethod("attach", signature(what = "SparkDataFrame"), function(what, pos = 2, name = deparse(substitute(what)), warn.conflicts = TRUE) { @@ -2438,6 +2934,8 @@ setMethod("attach", attach(newEnv, pos = pos, name = name, warn.conflicts = warn.conflicts) }) +#' Evaluate a R expression in an environment constructed from a SparkDataFrame +#' #' Evaluate a R expression in an environment constructed from a SparkDataFrame #' with() allows access to columns of a SparkDataFrame by simply referring to #' their name. It appends every column of a SparkDataFrame into a new @@ -2445,7 +2943,8 @@ setMethod("attach", #' environment. #' #' @rdname with -#' @title Evaluate a R expression in an environment constructed from a SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases with,SparkDataFrame-method #' @param data (SparkDataFrame) SparkDataFrame to use for constructing an environment. #' @param expr (expression) Expression to evaluate. #' @param ... arguments to be passed to future methods. @@ -2454,6 +2953,7 @@ setMethod("attach", #' with(irisDf, nrow(Sepal_Width)) #' } #' @seealso \link{attach} +#' @note with since 1.6.0 setMethod("with", signature(data = "SparkDataFrame"), function(data, expr, ...) { @@ -2461,20 +2961,24 @@ setMethod("with", eval(substitute(expr), envir = newEnv, enclos = newEnv) }) +#' Compactly display the structure of a dataset +#' #' Display the structure of a SparkDataFrame, including column names, column types, as well as a #' a small sample of rows. +#' #' @name str -#' @title Compactly display the structure of a dataset #' @rdname str +#' @aliases str,SparkDataFrame-method #' @family SparkDataFrame functions #' @param object a SparkDataFrame #' @examples \dontrun{ #' # Create a SparkDataFrame from the Iris dataset -#' irisDF <- createDataFrame(sqlContext, iris) +#' irisDF <- createDataFrame(iris) #' #' # Show the structure of the SparkDataFrame #' str(irisDF) #' } +#' @note str since 1.6.1 setMethod("str", signature(object = "SparkDataFrame"), function(object) { @@ -2539,24 +3043,26 @@ setMethod("str", #' Returns a new SparkDataFrame with columns dropped. #' This is a no-op if schema doesn't contain column name(s). #' -#' @param x A SparkDataFrame. -#' @param cols A character vector of column names or a Column. -#' @return A SparkDataFrame +#' @param x a SparkDataFrame. +#' @param col a character vector of column names or a Column. +#' @param ... further arguments to be passed to or from other methods. +#' @return A SparkDataFrame. #' #' @family SparkDataFrame functions #' @rdname drop #' @name drop +#' @aliases drop,SparkDataFrame-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlCtx, path) +#' df <- read.json(path) #' drop(df, "col1") #' drop(df, c("col1", "col2")) #' drop(df, df$col1) #' } +#' @note drop since 2.0.0 setMethod("drop", signature(x = "SparkDataFrame"), function(x, col) { @@ -2571,29 +3077,35 @@ setMethod("drop", }) # Expose base::drop +#' @name drop +#' @rdname drop +#' @aliases drop,ANY-method +#' @export setMethod("drop", signature(x = "ANY"), function(x) { base::drop(x) }) +#' Compute histogram statistics for given column +#' #' This function computes a histogram for a given SparkR Column. -#' +#' #' @name histogram -#' @title Histogram #' @param nbins the number of bins (optional). Default value is 10. +#' @param col the column as Character string or a Column to build the histogram from. #' @param df the SparkDataFrame containing the Column to build the histogram from. -#' @param colname the name of the column to build the histogram from. #' @return a data.frame with the histogram statistics, i.e., counts and centroids. #' @rdname histogram +#' @aliases histogram,SparkDataFrame,characterOrColumn-method #' @family SparkDataFrame functions #' @export -#' @examples +#' @examples #' \dontrun{ -#' +#' #' # Create a SparkDataFrame from the Iris dataset -#' irisDF <- createDataFrame(sqlContext, iris) -#' +#' irisDF <- createDataFrame(iris) +#' #' # Compute histogram statistics #' histStats <- histogram(irisDF, irisDF$Sepal_Length, nbins = 12) #' @@ -2603,8 +3115,9 @@ setMethod("drop", #' require(ggplot2) #' plot <- ggplot(histStats, aes(x = centroids, y = counts)) + #' geom_bar(stat = "identity") + -#' xlab("Sepal_Length") + ylab("Frequency") -#' } +#' xlab("Sepal_Length") + ylab("Frequency") +#' } +#' @note histogram since 2.0.0 setMethod("histogram", signature(df = "SparkDataFrame", col = "characterOrColumn"), function(df, col, nbins = 10) { @@ -2644,7 +3157,7 @@ setMethod("histogram", # columns AND all of them have names 100 characters long (which is very unlikely), # AND they run 1 billion histograms, the probability of collision will roughly be # 1 in 4.4 x 10 ^ 96 - colname <- paste(base:::sample(c(letters, LETTERS), + colname <- paste(base::sample(c(letters, LETTERS), size = min(max(nchar(colnames(df))) + 1, 100), replace = TRUE), collapse = "") @@ -2696,40 +3209,82 @@ setMethod("histogram", return(histStats) }) -#' Saves the content of the SparkDataFrame to an external database table via JDBC +#' Save the content of SparkDataFrame to an external database table via JDBC. #' -#' Additional JDBC database connection properties can be set (...) +#' Save the content of the SparkDataFrame to an external database table via JDBC. Additional JDBC +#' database connection properties can be set (...) #' #' Also, mode is used to specify the behavior of the save operation when -#' data already exists in the data source. There are four modes: \cr -#' append: Contents of this SparkDataFrame are expected to be appended to existing data. \cr -#' overwrite: Existing data is expected to be overwritten by the contents of this -#' SparkDataFrame. \cr -#' error: An exception is expected to be thrown. \cr -#' ignore: The save operation is expected to not save the contents of the SparkDataFrame -#' and to not change the existing data. \cr +#' data already exists in the data source. There are four modes: +#' \itemize{ +#' \item append: Contents of this SparkDataFrame are expected to be appended to existing data. +#' \item overwrite: Existing data is expected to be overwritten by the contents of this +#' SparkDataFrame. +#' \item error: An exception is expected to be thrown. +#' \item ignore: The save operation is expected to not save the contents of the SparkDataFrame +#' and to not change the existing data. +#' } #' -#' @param x A SparkDataFrame -#' @param url JDBC database url of the form `jdbc:subprotocol:subname` -#' @param tableName The name of the table in the external database -#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @param x a SparkDataFrame. +#' @param url JDBC database url of the form \code{jdbc:subprotocol:subname}. +#' @param tableName yhe name of the table in the external database. +#' @param mode one of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default). +#' @param ... additional JDBC database connection properties. #' @family SparkDataFrame functions #' @rdname write.jdbc #' @name write.jdbc +#' @aliases write.jdbc,SparkDataFrame,character,character-method #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' jdbcUrl <- "jdbc:mysql://localhost:3306/databasename" #' write.jdbc(df, jdbcUrl, "table", user = "username", password = "password") #' } +#' @note write.jdbc since 2.0.0 setMethod("write.jdbc", signature(x = "SparkDataFrame", url = "character", tableName = "character"), - function(x, url, tableName, mode = "error", ...){ + function(x, url, tableName, mode = "error", ...) { jmode <- convertToJSaveMode(mode) jprops <- varargsToJProperties(...) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) invisible(callJMethod(write, "jdbc", url, tableName, jprops)) }) + +#' randomSplit +#' +#' Return a list of randomly split dataframes with the provided weights. +#' +#' @param x A SparkDataFrame +#' @param weights A vector of weights for splits, will be normalized if they don't sum to 1 +#' @param seed A seed to use for random split +#' +#' @family SparkDataFrame functions +#' @aliases randomSplit,SparkDataFrame,numeric-method +#' @rdname randomSplit +#' @name randomSplit +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createDataFrame(data.frame(id = 1:1000)) +#' df_list <- randomSplit(df, c(2, 3, 5), 0) +#' # df_list contains 3 SparkDataFrames with each having about 200, 300 and 500 rows respectively +#' sapply(df_list, count) +#' } +#' @note randomSplit since 2.0.0 +setMethod("randomSplit", + signature(x = "SparkDataFrame", weights = "numeric"), + function(x, weights, seed) { + if (!all(sapply(weights, function(c) { c >= 0 }))) { + stop("all weight values should not be negative") + } + normalized_list <- as.list(weights / sum(weights)) + if (!missing(seed)) { + sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list, as.integer(seed)) + } else { + sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list) + } + sapply(sdfs, dataFrame) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 34d29ddbfdd52..6cd0704003f1a 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -19,9 +19,11 @@ setOldClass("jobj") -#' @title S4 class that represents an RDD -#' @description RDD can be created using functions like +#' S4 class that represents an RDD +#' +#' RDD can be created using functions like #' \code{parallelize}, \code{textFile} etc. +#' #' @rdname RDD #' @seealso parallelize, textFile #' @slot env An R environment that stores bookkeeping states of the RDD @@ -65,7 +67,7 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, .Object }) -setMethod("show", "RDD", +setMethod("showRDD", "RDD", function(object) { cat(paste(callJMethod(getJRDD(object), "toString"), "\n", sep = "")) }) @@ -213,7 +215,7 @@ setValidity("RDD", #' @rdname cache-methods #' @aliases cache,RDD-method #' @noRd -setMethod("cache", +setMethod("cacheRDD", signature(x = "RDD"), function(x) { callJMethod(getJRDD(x), "cache") @@ -233,12 +235,12 @@ setMethod("cache", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 2L) -#' persist(rdd, "MEMORY_AND_DISK") +#' persistRDD(rdd, "MEMORY_AND_DISK") #'} #' @rdname persist #' @aliases persist,RDD-method #' @noRd -setMethod("persist", +setMethod("persistRDD", signature(x = "RDD", newLevel = "character"), function(x, newLevel = "MEMORY_ONLY") { callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) @@ -257,12 +259,12 @@ setMethod("persist", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 2L) #' cache(rdd) # rdd@@env$isCached == TRUE -#' unpersist(rdd) # rdd@@env$isCached == FALSE +#' unpersistRDD(rdd) # rdd@@env$isCached == FALSE #'} #' @rdname unpersist-methods #' @aliases unpersist,RDD-method #' @noRd -setMethod("unpersist", +setMethod("unpersistRDD", signature(x = "RDD"), function(x) { callJMethod(getJRDD(x), "unpersist") @@ -343,13 +345,13 @@ setMethod("numPartitions", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 2L) -#' collect(rdd) # list from 1 to 10 +#' collectRDD(rdd) # list from 1 to 10 #' collectPartition(rdd, 0L) # list from 1 to 5 #'} #' @rdname collect-methods #' @aliases collect,RDD-method #' @noRd -setMethod("collect", +setMethod("collectRDD", signature(x = "RDD"), function(x, flatten = TRUE) { # Assumes a pairwise RDD is backed by a JavaPairRDD. @@ -395,7 +397,7 @@ setMethod("collectPartition", setMethod("collectAsMap", signature(x = "RDD"), function(x) { - pairList <- collect(x) + pairList <- collectRDD(x) map <- new.env() lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) }) as.list(map) @@ -409,30 +411,30 @@ setMethod("collectAsMap", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' count(rdd) # 10 +#' countRDD(rdd) # 10 #' length(rdd) # Same as count #'} #' @rdname count #' @aliases count,RDD-method #' @noRd -setMethod("count", +setMethod("countRDD", signature(x = "RDD"), function(x) { countPartition <- function(part) { as.integer(length(part)) } valsRDD <- lapplyPartition(x, countPartition) - vals <- collect(valsRDD) + vals <- collectRDD(valsRDD) sum(as.integer(vals)) }) #' Return the number of elements in the RDD #' @rdname count #' @noRd -setMethod("length", +setMethod("lengthRDD", signature(x = "RDD"), function(x) { - count(x) + countRDD(x) }) #' Return the count of each unique value in this RDD as a list of @@ -458,7 +460,7 @@ setMethod("countByValue", signature(x = "RDD"), function(x) { ones <- lapply(x, function(item) { list(item, 1L) }) - collect(reduceByKey(ones, `+`, getNumPartitions(x))) + collectRDD(reduceByKey(ones, `+`, getNumPartitions(x))) }) #' Apply a function to all elements @@ -477,7 +479,7 @@ setMethod("countByValue", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' multiplyByTwo <- lapply(rdd, function(x) { x * 2 }) -#' collect(multiplyByTwo) # 2,4,6... +#' collectRDD(multiplyByTwo) # 2,4,6... #'} setMethod("lapply", signature(X = "RDD", FUN = "function"), @@ -497,9 +499,9 @@ setMethod("map", lapply(X, FUN) }) -#' Flatten results after apply a function to all elements +#' Flatten results after applying a function to all elements #' -#' This function return a new RDD by first applying a function to all +#' This function returns a new RDD by first applying a function to all #' elements of this RDD, and then flattening the results. #' #' @param X The RDD to apply the transformation. @@ -510,7 +512,7 @@ setMethod("map", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) }) -#' collect(multiplyByTwo) # 2,20,4,40,6,60... +#' collectRDD(multiplyByTwo) # 2,20,4,40,6,60... #'} #' @rdname flatMap #' @aliases flatMap,RDD,function-method @@ -539,7 +541,7 @@ setMethod("flatMap", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) }) -#' collect(partitionSum) # 15, 40 +#' collectRDD(partitionSum) # 15, 40 #'} #' @rdname lapplyPartition #' @aliases lapplyPartition,RDD,function-method @@ -574,7 +576,7 @@ setMethod("mapPartitions", #' rdd <- parallelize(sc, 1:10, 5L) #' prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { #' partIndex * Reduce("+", part) }) -#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 +#' collectRDD(prod, flatten = FALSE) # 0, 7, 22, 45, 76 #'} #' @rdname lapplyPartitionsWithIndex #' @aliases lapplyPartitionsWithIndex,RDD,function-method @@ -605,7 +607,7 @@ setMethod("mapPartitionsWithIndex", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) +#' unlist(collectRDD(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) #'} # nolint end #' @rdname filterRDD @@ -654,7 +656,7 @@ setMethod("reduce", Reduce(func, part) } - partitionList <- collect(lapplyPartition(x, reducePartition), + partitionList <- collectRDD(lapplyPartition(x, reducePartition), flatten = FALSE) Reduce(func, partitionList) }) @@ -713,7 +715,7 @@ setMethod("sumRDD", reduce(x, "+") }) -#' Applies a function to all elements in an RDD, and force evaluation. +#' Applies a function to all elements in an RDD, and forces evaluation. #' #' @param x The RDD to apply the function #' @param func The function to be applied. @@ -734,10 +736,10 @@ setMethod("foreach", lapply(x, func) NULL } - invisible(collect(mapPartitions(x, partition.func))) + invisible(collectRDD(mapPartitions(x, partition.func))) }) -#' Applies a function to each partition in an RDD, and force evaluation. +#' Applies a function to each partition in an RDD, and forces evaluation. #' #' @examples #'\dontrun{ @@ -751,7 +753,7 @@ setMethod("foreach", setMethod("foreachPartition", signature(x = "RDD", func = "function"), function(x, func) { - invisible(collect(mapPartitions(x, func))) + invisible(collectRDD(mapPartitions(x, func))) }) #' Take elements from an RDD. @@ -766,13 +768,13 @@ setMethod("foreachPartition", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' take(rdd, 2L) # list(1, 2) +#' takeRDD(rdd, 2L) # list(1, 2) #'} # nolint end #' @rdname take #' @aliases take,RDD,numeric-method #' @noRd -setMethod("take", +setMethod("takeRDD", signature(x = "RDD", num = "numeric"), function(x, num) { resList <- list() @@ -815,13 +817,13 @@ setMethod("take", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' first(rdd) +#' firstRDD(rdd) #' } #' @noRd -setMethod("first", +setMethod("firstRDD", signature(x = "RDD"), function(x) { - take(x, 1)[[1]] + takeRDD(x, 1)[[1]] }) #' Removes the duplicates from RDD. @@ -836,13 +838,13 @@ setMethod("first", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, c(1,2,2,3,3,3)) -#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) +#' sort(unlist(collectRDD(distinctRDD(rdd)))) # c(1, 2, 3) #'} # nolint end #' @rdname distinct #' @aliases distinct,RDD-method #' @noRd -setMethod("distinct", +setMethod("distinctRDD", signature(x = "RDD"), function(x, numPartitions = SparkR:::getNumPartitions(x)) { identical.mapped <- lapply(x, function(x) { list(x, NULL) }) @@ -866,8 +868,8 @@ setMethod("distinct", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements -#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates +#' collectRDD(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements +#' collectRDD(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates #'} #' @rdname sampleRDD #' @aliases sampleRDD,RDD @@ -885,17 +887,17 @@ setMethod("sampleRDD", # Discards some random values to ensure each partition has a # different random seed. - runif(partIndex) + stats::runif(partIndex) for (elem in part) { if (withReplacement) { - count <- rpois(1, fraction) + count <- stats::rpois(1, fraction) if (count > 0) { res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { - if (runif(1) < fraction) { + if (stats::runif(1) < fraction) { len <- len + 1 res[[len]] <- elem } @@ -940,7 +942,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", fraction <- 0.0 total <- 0 multiplier <- 3.0 - initialCount <- count(x) + initialCount <- countRDD(x) maxSelected <- 0 MAXINT <- .Machine$integer.max @@ -962,16 +964,16 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", } set.seed(seed) - samples <- collect(sampleRDD(x, withReplacement, fraction, - as.integer(ceiling(runif(1, + samples <- collectRDD(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(stats::runif(1, -MAXINT, MAXINT))))) # If the first sample didn't turn out large enough, keep trying to # take samples; this shouldn't happen often because we use a big # multiplier for thei initial size while (length(samples) < total) - samples <- collect(sampleRDD(x, withReplacement, fraction, - as.integer(ceiling(runif(1, + samples <- collectRDD(sampleRDD(x, withReplacement, fraction, + as.integer(ceiling(stats::runif(1, -MAXINT, MAXINT))))) @@ -988,7 +990,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3)) -#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) +#' collectRDD(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) #'} # nolint end #' @rdname keyBy @@ -1017,15 +1019,19 @@ setMethod("keyBy", #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L) #' getNumPartitions(rdd) # 4 -#' getNumPartitions(repartition(rdd, 2L)) # 2 +#' getNumPartitions(repartitionRDD(rdd, 2L)) # 2 #'} #' @rdname repartition #' @aliases repartition,RDD #' @noRd -setMethod("repartition", - signature(x = "RDD", numPartitions = "numeric"), +setMethod("repartitionRDD", + signature(x = "RDD"), function(x, numPartitions) { - coalesce(x, numPartitions, TRUE) + if (!is.null(numPartitions) && is.numeric(numPartitions)) { + coalesce(x, numPartitions, TRUE) + } else { + stop("Please, specify the number of partitions") + } }) #' Return a new RDD that is reduced into numPartitions partitions. @@ -1058,7 +1064,7 @@ setMethod("coalesce", }) } shuffled <- lapplyPartitionsWithIndex(x, func) - repartitioned <- partitionBy(shuffled, numPartitions) + repartitioned <- partitionByRDD(shuffled, numPartitions) values(repartitioned) } else { jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle) @@ -1129,7 +1135,7 @@ setMethod("saveAsTextFile", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(3, 2, 1)) -#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#' collectRDD(sortBy(rdd, function(x) { x })) # list (1, 2, 3) #'} # nolint end #' @rdname sortBy @@ -1298,7 +1304,7 @@ setMethod("aggregateRDD", Reduce(seqOp, part, zeroValue) } - partitionList <- collect(lapplyPartition(x, partitionFunc), + partitionList <- collectRDD(lapplyPartition(x, partitionFunc), flatten = FALSE) Reduce(combOp, partitionList, zeroValue) }) @@ -1316,7 +1322,7 @@ setMethod("aggregateRDD", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) -#' collect(pipeRDD(rdd, "more") +#' pipeRDD(rdd, "more") #' Output: c("1", "2", ..., "10") #'} #' @aliases pipeRDD,RDD,character-method @@ -1391,7 +1397,7 @@ setMethod("setName", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -#' collect(zipWithUniqueId(rdd)) +#' collectRDD(zipWithUniqueId(rdd)) #' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) #'} # nolint end @@ -1434,7 +1440,7 @@ setMethod("zipWithUniqueId", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) -#' collect(zipWithIndex(rdd)) +#' collectRDD(zipWithIndex(rdd)) #' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) #'} # nolint end @@ -1446,7 +1452,7 @@ setMethod("zipWithIndex", function(x) { n <- getNumPartitions(x) if (n > 1) { - nums <- collect(lapplyPartition(x, + nums <- collectRDD(lapplyPartition(x, function(part) { list(length(part)) })) @@ -1482,7 +1488,7 @@ setMethod("zipWithIndex", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, as.list(1:4), 2L) -#' collect(glom(rdd)) +#' collectRDD(glom(rdd)) #' # list(list(1, 2), list(3, 4)) #'} # nolint end @@ -1550,7 +1556,7 @@ setMethod("unionRDD", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, 0:4) #' rdd2 <- parallelize(sc, 1000:1004) -#' collect(zipRDD(rdd1, rdd2)) +#' collectRDD(zipRDD(rdd1, rdd2)) #' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) #'} # nolint end @@ -1622,7 +1628,7 @@ setMethod("cartesian", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) #' rdd2 <- parallelize(sc, list(2, 4)) -#' collect(subtract(rdd1, rdd2)) +#' collectRDD(subtract(rdd1, rdd2)) #' # list(1, 1, 3) #'} # nolint end @@ -1656,7 +1662,7 @@ setMethod("subtract", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) #' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) -#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +#' collectRDD(sortBy(intersection(rdd1, rdd2), function(x) { x })) #' # list(1, 2, 3) #'} # nolint end @@ -1693,7 +1699,7 @@ setMethod("intersection", #' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 #' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 #' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 -#' collect(zipPartitions(rdd1, rdd2, rdd3, +#' collectRDD(zipPartitions(rdd1, rdd2, rdd3, #' func = function(x, y, z) { list(list(x, y, z))} )) #' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) #'} diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3824e0a99557c..ce531c3f88863 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -37,7 +37,49 @@ getInternalType <- function(x) { stop(paste("Unsupported type for SparkDataFrame:", class(x)))) } +#' Temporary function to reroute old S3 Method call to new +#' This function is specifically implemented to remove SQLContext from the parameter list. +#' It determines the target to route the call by checking the parent of this callsite (say 'func'). +#' The target should be called 'func.default'. +#' We need to check the class of x to ensure it is SQLContext/HiveContext before dispatching. +#' @param newFuncSig name of the function the user should call instead in the deprecation message +#' @param x the first parameter of the original call +#' @param ... the rest of parameter to pass along +#' @return whatever the target returns +#' @noRd +dispatchFunc <- function(newFuncSig, x, ...) { + # When called with SparkR::createDataFrame, sys.call()[[1]] returns c(::, SparkR, createDataFrame) + callsite <- as.character(sys.call(sys.parent())[[1]]) + funcName <- callsite[[length(callsite)]] + f <- get(paste0(funcName, ".default")) + # Strip sqlContext from list of parameters and then pass the rest along. + contextNames <- c("org.apache.spark.sql.SQLContext", + "org.apache.spark.sql.hive.HiveContext", + "org.apache.spark.sql.hive.test.TestHiveContext", + "org.apache.spark.sql.SparkSession") + if (missing(x) && length(list(...)) == 0) { + f() + } else if (class(x) == "jobj" && + any(grepl(paste(contextNames, collapse = "|"), getClassName.jobj(x)))) { + .Deprecated(newFuncSig, old = paste0(funcName, "(sqlContext...)")) + f(...) + } else { + f(x, ...) + } +} + +#' return the SparkSession +#' @noRd +getSparkSession <- function() { + if (exists(".sparkRsession", envir = .sparkREnv)) { + get(".sparkRsession", envir = .sparkREnv) + } else { + stop("SparkSession not initialized") + } +} + #' infer the SQL type +#' @noRd infer_type <- function(x) { if (is.null(x)) { stop("can not infer type from NULL") @@ -70,28 +112,100 @@ infer_type <- function(x) { } } +#' Get Runtime Config from the current active SparkSession +#' +#' Get Runtime Config from the current active SparkSession. +#' To change SparkSession Runtime Config, please see \code{sparkR.session()}. +#' +#' @param key (optional) The key of the config to get, if omitted, all config is returned +#' @param defaultValue (optional) The default value of the config to return if they config is not +#' set, if omitted, the call fails if the config key is not set +#' @return a list of config values with keys as their names +#' @rdname sparkR.conf +#' @name sparkR.conf +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' allConfigs <- sparkR.conf() +#' masterValue <- unlist(sparkR.conf("spark.master")) +#' namedConfig <- sparkR.conf("spark.executor.memory", "0g") +#' } +#' @note sparkR.conf since 2.0.0 +sparkR.conf <- function(key, defaultValue) { + sparkSession <- getSparkSession() + if (missing(key)) { + m <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getSessionConf", sparkSession) + as.list(m, all.names = TRUE, sorted = TRUE) + } else { + conf <- callJMethod(sparkSession, "conf") + value <- if (missing(defaultValue)) { + tryCatch(callJMethod(conf, "get", key), + error = function(e) { + if (any(grep("java.util.NoSuchElementException", as.character(e)))) { + stop(paste0("Config '", key, "' is not set")) + } else { + stop(paste0("Unknown error: ", as.character(e))) + } + }) + } else { + callJMethod(conf, "get", key, defaultValue) + } + l <- setNames(list(value), key) + l + } +} + +#' Get version of Spark on which this application is running +#' +#' Get version of Spark on which this application is running. +#' +#' @return a character string of the Spark version +#' @rdname sparkR.version +#' @name sparkR.version +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' version <- sparkR.version() +#' } +#' @note sparkR.version since 2.0.1 +sparkR.version <- function() { + sparkSession <- getSparkSession() + callJMethod(sparkSession, "version") +} + +getDefaultSqlSource <- function() { + l <- sparkR.conf("spark.sql.sources.default", "org.apache.spark.sql.parquet") + l[["spark.sql.sources.default"]] +} + #' Create a SparkDataFrame #' #' Converts R data.frame or list into SparkDataFrame. #' -#' @param sqlContext A SQLContext -#' @param data An RDD or list or data.frame -#' @param schema a list of column names or named list (StructType), optional -#' @return a SparkDataFrame +#' @param data an RDD or list or data.frame. +#' @param schema a list of column names or named list (StructType), optional. +#' @return A SparkDataFrame. #' @rdname createDataFrame #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df1 <- as.DataFrame(sqlContext, iris) -#' df2 <- as.DataFrame(sqlContext, list(3,4,5,6)) -#' df3 <- createDataFrame(sqlContext, iris) +#' sparkR.session() +#' df1 <- as.DataFrame(iris) +#' df2 <- as.DataFrame(list(3,4,5,6)) +#' df3 <- createDataFrame(iris) #' } - +#' @name createDataFrame +#' @method createDataFrame default +#' @note createDataFrame since 1.4.0 # TODO(davies): support sampling and infer type from NA -createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { +createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { + sparkSession <- getSparkSession() + if (is.data.frame(data)) { + # Convert data into a list of rows. Each row is a list. + # get the names of columns, they will be put into RDD if (is.null(schema)) { schema <- names(data) @@ -116,8 +230,9 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 args <- list(FUN = list, SIMPLIFY = FALSE, USE.NAMES = FALSE) data <- do.call(mapply, append(args, data)) } + if (is.list(data)) { - sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) rdd <- parallelize(sc, data) } else if (inherits(data, "RDD")) { rdd <- data @@ -126,7 +241,7 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 } if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) { - row <- first(rdd) + row <- firstRDD(rdd) names <- if (is.null(schema)) { names(row) } else { @@ -160,15 +275,30 @@ createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0 jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schema$jobj, sqlContext) + srdd, schema$jobj, sparkSession) dataFrame(sdf) } +createDataFrame <- function(x, ...) { + dispatchFunc("createDataFrame(data, schema = NULL)", x, ...) +} + +#' @param samplingRatio Currently not used. #' @rdname createDataFrame #' @aliases createDataFrame #' @export -as.DataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { - createDataFrame(sqlContext, data, schema, samplingRatio) +#' @method as.DataFrame default +#' @note as.DataFrame since 1.6.0 +as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { + createDataFrame(data, schema) +} + +#' @param ... additional argument(s). +#' @rdname createDataFrame +#' @aliases as.DataFrame +#' @export +as.DataFrame <- function(data, ...) { + dispatchFunc("as.DataFrame(data, schema = NULL)", data, ...) } #' toDF @@ -181,8 +311,7 @@ as.DataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { #' @noRd #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) #' df <- toDF(rdd) #'} @@ -190,14 +319,7 @@ setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), function(x, ...) { - sqlContext <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { - get(".sparkRHivesc", envir = .sparkREnv) - } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - get(".sparkRSQLsc", envir = .sparkREnv) - } else { - stop("no SQL context available") - } - createDataFrame(sqlContext, x, ...) + createDataFrame(x, ...) }) #' Create a SparkDataFrame from a JSON file. @@ -205,36 +327,46 @@ setMethod("toDF", signature(x = "RDD"), #' Loads a JSON file (one object per line), returning the result as a SparkDataFrame #' It goes through the entire dataset once to determine the schema. #' -#' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.json -#' @name read.json #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) -#' df <- jsonFile(sqlContext, path) +#' df <- read.json(path) +#' df <- jsonFile(path) #' } -read.json <- function(sqlContext, path) { +#' @name read.json +#' @method read.json default +#' @note read.json since 1.6.0 +read.json.default <- function(path) { + sparkSession <- getSparkSession() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } +read.json <- function(x, ...) { + dispatchFunc("read.json(path)", x, ...) +} + #' @rdname read.json #' @name jsonFile #' @export -jsonFile <- function(sqlContext, path) { +#' @method jsonFile default +#' @note jsonFile since 1.4.0 +jsonFile.default <- function(path) { .Deprecated("read.json") - read.json(sqlContext, path) + read.json(path) } +jsonFile <- function(x, ...) { + dispatchFunc("jsonFile(path)", x, ...) +} #' JSON RDD #' @@ -248,12 +380,12 @@ jsonFile <- function(sqlContext, path) { #' @noRd #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' rdd <- texFile(sc, "path/to/json") #' df <- jsonRDD(sqlContext, rdd) #'} +# TODO: remove - this method is no longer exported # TODO: support schema jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { .Deprecated("read.json") @@ -268,88 +400,132 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { } } +#' Create a SparkDataFrame from an ORC file. +#' +#' Loads an ORC file, returning the result as a SparkDataFrame. +#' +#' @param path Path of file to read. +#' @return SparkDataFrame +#' @rdname read.orc +#' @export +#' @name read.orc +#' @note read.orc since 2.0.0 +read.orc <- function(path) { + sparkSession <- getSparkSession() + # Allow the user to have a more flexible definiton of the ORC file path + path <- suppressWarnings(normalizePath(path)) + read <- callJMethod(sparkSession, "read") + sdf <- callJMethod(read, "orc", path) + dataFrame(sdf) +} + #' Create a SparkDataFrame from a Parquet file. #' #' Loads a Parquet file, returning the result as a SparkDataFrame. #' -#' @param sqlContext SQLContext to use -#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @param path path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.parquet -#' @name read.parquet #' @export -read.parquet <- function(sqlContext, path) { - # Allow the user to have a more flexible definiton of the text file path +#' @name read.parquet +#' @method read.parquet default +#' @note read.parquet since 1.6.0 +read.parquet.default <- function(path) { + sparkSession <- getSparkSession() + # Allow the user to have a more flexible definiton of the Parquet file path paths <- as.list(suppressWarnings(normalizePath(path))) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") sdf <- callJMethod(read, "parquet", paths) dataFrame(sdf) } +read.parquet <- function(x, ...) { + dispatchFunc("read.parquet(...)", x, ...) +} + +#' @param ... argument(s) passed to the method. #' @rdname read.parquet #' @name parquetFile #' @export -# TODO: Implement saveasParquetFile and write examples for both -parquetFile <- function(sqlContext, ...) { +#' @method parquetFile default +#' @note parquetFile since 1.4.0 +parquetFile.default <- function(...) { .Deprecated("read.parquet") - read.parquet(sqlContext, unlist(list(...))) + read.parquet(unlist(list(...))) +} + +parquetFile <- function(x, ...) { + dispatchFunc("parquetFile(...)", x, ...) } #' Create a SparkDataFrame from a text file. #' -#' Loads a text file and returns a SparkDataFrame with a single string column named "value". +#' Loads text files and returns a SparkDataFrame whose schema starts with +#' a string column named "value", and followed by partitioned columns if +#' there are any. +#' #' Each line in the text file is a new row in the resulting SparkDataFrame. #' -#' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return SparkDataFrame #' @rdname read.text -#' @name read.text #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.txt" -#' df <- read.text(sqlContext, path) +#' df <- read.text(path) #' } -read.text <- function(sqlContext, path) { +#' @name read.text +#' @method read.text default +#' @note read.text since 1.6.1 +read.text.default <- function(path) { + sparkSession <- getSparkSession() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") sdf <- callJMethod(read, "text", paths) dataFrame(sdf) } +read.text <- function(x, ...) { + dispatchFunc("read.text(path)", x, ...) +} + #' SQL Query #' #' Executes a SQL query using Spark, returning the result as a SparkDataFrame. #' -#' @param sqlContext SQLContext to use #' @param sqlQuery A character vector containing the SQL query #' @return SparkDataFrame +#' @rdname sql #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) -#' registerTempTable(df, "table") -#' new_df <- sql(sqlContext, "SELECT * FROM table") +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' new_df <- sql("SELECT * FROM table") #' } +#' @name sql +#' @method sql default +#' @note sql since 1.4.0 +sql.default <- function(sqlQuery) { + sparkSession <- getSparkSession() + sdf <- callJMethod(sparkSession, "sql", sqlQuery) + dataFrame(sdf) +} -sql <- function(sqlContext, sqlQuery) { - sdf <- callJMethod(sqlContext, "sql", sqlQuery) - dataFrame(sdf) +sql <- function(x, ...) { + dispatchFunc("sql(sqlQuery)", x, ...) } #' Create a SparkDataFrame from a SparkSQL Table #' #' Returns the specified Table as a SparkDataFrame. The Table must have already been registered -#' in the SQLContext. +#' in the SparkSession. #' -#' @param sqlContext SQLContext to use #' @param tableName The SparkSQL Table to convert to a SparkDataFrame. #' @return SparkDataFrame #' @rdname tableToDF @@ -357,16 +533,16 @@ sql <- function(sqlContext, sqlQuery) { #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) -#' registerTempTable(df, "table") -#' new_df <- tableToDF(sqlContext, "table") +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' new_df <- tableToDF("table") #' } - -tableToDF <- function(sqlContext, tableName) { - sdf <- callJMethod(sqlContext, "table", tableName) +#' @note tableToDF since 2.0.0 +tableToDF <- function(tableName) { + sparkSession <- getSparkSession() + sdf <- callJMethod(sparkSession, "table", tableName) dataFrame(sdf) } @@ -374,183 +550,263 @@ tableToDF <- function(sqlContext, tableName) { #' #' Returns a SparkDataFrame containing names of tables in the given database. #' -#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a SparkDataFrame +#' @rdname tables #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' tables(sqlContext, "hive") +#' sparkR.session() +#' tables("hive") #' } - -tables <- function(sqlContext, databaseName = NULL) { - jdf <- if (is.null(databaseName)) { - callJMethod(sqlContext, "tables") - } else { - callJMethod(sqlContext, "tables", databaseName) - } +#' @name tables +#' @method tables default +#' @note tables since 1.4.0 +tables.default <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", sparkSession, databaseName) dataFrame(jdf) } +tables <- function(x, ...) { + dispatchFunc("tables(databaseName = NULL)", x, ...) +} #' Table Names #' #' Returns the names of tables in the given database as an array. #' -#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a list of table names +#' @rdname tableNames #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' tableNames(sqlContext, "hive") +#' sparkR.session() +#' tableNames("hive") #' } - -tableNames <- function(sqlContext, databaseName = NULL) { - if (is.null(databaseName)) { - callJMethod(sqlContext, "tableNames") - } else { - callJMethod(sqlContext, "tableNames", databaseName) - } +#' @name tableNames +#' @method tableNames default +#' @note tableNames since 1.4.0 +tableNames.default <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getTableNames", + sparkSession, + databaseName) } +tableNames <- function(x, ...) { + dispatchFunc("tableNames(databaseName = NULL)", x, ...) +} #' Cache Table #' #' Caches the specified table in-memory. #' -#' @param sqlContext SQLContext to use #' @param tableName The name of the table being cached #' @return SparkDataFrame +#' @rdname cacheTable #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) -#' registerTempTable(df, "table") -#' cacheTable(sqlContext, "table") +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' cacheTable("table") #' } +#' @name cacheTable +#' @method cacheTable default +#' @note cacheTable since 1.4.0 +cacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "cacheTable", tableName) +} -cacheTable <- function(sqlContext, tableName) { - callJMethod(sqlContext, "cacheTable", tableName) +cacheTable <- function(x, ...) { + dispatchFunc("cacheTable(tableName)", x, ...) } #' Uncache Table #' #' Removes the specified table from the in-memory cache. #' -#' @param sqlContext SQLContext to use #' @param tableName The name of the table being uncached #' @return SparkDataFrame +#' @rdname uncacheTable #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' path <- "path/to/file.json" -#' df <- read.json(sqlContext, path) -#' registerTempTable(df, "table") -#' uncacheTable(sqlContext, "table") +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' uncacheTable("table") #' } +#' @name uncacheTable +#' @method uncacheTable default +#' @note uncacheTable since 1.4.0 +uncacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "uncacheTable", tableName) +} -uncacheTable <- function(sqlContext, tableName) { - callJMethod(sqlContext, "uncacheTable", tableName) +uncacheTable <- function(x, ...) { + dispatchFunc("uncacheTable(tableName)", x, ...) } #' Clear Cache #' #' Removes all cached tables from the in-memory cache. #' -#' @param sqlContext SQLContext to use +#' @rdname clearCache +#' @export #' @examples #' \dontrun{ -#' clearCache(sqlContext) +#' clearCache() #' } +#' @name clearCache +#' @method clearCache default +#' @note clearCache since 1.4.0 +clearCache.default <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "clearCache") +} -clearCache <- function(sqlContext) { - callJMethod(sqlContext, "clearCache") +clearCache <- function() { + dispatchFunc("clearCache()") } -#' Drop Temporary Table +#' (Deprecated) Drop Temporary Table #' #' Drops the temporary table with the given table name in the catalog. #' If the table has been cached/persisted before, it's also unpersisted. #' -#' @param sqlContext SQLContext to use #' @param tableName The name of the SparkSQL table to be dropped. +#' @seealso \link{dropTempView} +#' @rdname dropTempTable-deprecated +#' @export #' @examples #' \dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df <- read.df(sqlContext, path, "parquet") -#' registerTempTable(df, "table") -#' dropTempTable(sqlContext, "table") +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempTable("table") #' } - -dropTempTable <- function(sqlContext, tableName) { +#' @name dropTempTable +#' @method dropTempTable default +#' @note dropTempTable since 1.4.0 +dropTempTable.default <- function(tableName) { if (class(tableName) != "character") { stop("tableName must be a string.") } - callJMethod(sqlContext, "dropTempTable", tableName) + dropTempView(tableName) +} + +dropTempTable <- function(x, ...) { + .Deprecated("dropTempView") + dispatchFunc("dropTempView(viewName)", x, ...) +} + +#' Drops the temporary view with the given view name in the catalog. +#' +#' Drops the temporary view with the given view name in the catalog. +#' If the view has been cached before, then it will also be uncached. +#' +#' @param viewName the name of the view to be dropped. +#' @rdname dropTempView +#' @name dropTempView +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempView("table") +#' } +#' @note since 2.0.0 + +dropTempView <- function(viewName) { + sparkSession <- getSparkSession() + if (class(viewName) != "character") { + stop("viewName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "dropTempView", viewName) } #' Load a SparkDataFrame #' #' Returns the dataset in a data source as a SparkDataFrame #' -#' The data source is specified by the `source` and a set of options(...). -#' If `source` is not specified, the default data source configured by -#' "spark.sql.sources.default" will be used. +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. \cr +#' Similar to R read.csv, when \code{source} is "csv", by default, a value of "NA" will be +#' interpreted as NA. #' -#' @param sqlContext SQLContext to use #' @param path The path of files to load #' @param source The name of external data source #' @param schema The data schema defined in structType +#' @param na.strings Default string value for NA when source is "csv" +#' @param ... additional external data source specific named properties. #' @return SparkDataFrame #' @rdname read.df #' @name read.df #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df1 <- read.df(sqlContext, "path/to/file.json", source = "json") +#' sparkR.session() +#' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(sqlContext, mapTypeJsonPath, "json", schema) -#' df3 <- loadDF(sqlContext, "data/test_table", "parquet", mergeSchema = "true") +#' df2 <- read.df(mapTypeJsonPath, "json", schema) +#' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") #' } - -read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { +#' @name read.df +#' @method read.df default +#' @note read.df since 1.4.0 +read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { + sparkSession <- getSparkSession() options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } if (is.null(source)) { - sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") + source <- getDefaultSqlSource() + } + if (source == "csv" && is.null(options[["nullValue"]])) { + options[["nullValue"]] <- na.strings } if (!is.null(schema)) { stopifnot(class(schema) == "structType") - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source, schema$jobj, options) } else { - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "loadDF", sparkSession, source, options) } dataFrame(sdf) } +read.df <- function(x, ...) { + dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...) +} + #' @rdname read.df #' @name loadDF -loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { - read.df(sqlContext, path, source, schema, ...) +#' @method loadDF default +#' @note loadDF since 1.6.0 +loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) { + read.df(path, source, schema, ...) +} + +loadDF <- function(x, ...) { + dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } #' Create an external table @@ -558,76 +814,83 @@ loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { #' Creates an external table based on the dataset in a data source, #' Returns a SparkDataFrame associated with the external table. #' -#' The data source is specified by the `source` and a set of options(...). -#' If `source` is not specified, the default data source configured by +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlContext SQLContext to use -#' @param tableName A name of the table -#' @param path The path of files to load -#' @param source the name of external data source -#' @return SparkDataFrame +#' @param tableName a name of the table. +#' @param path the path of files to load. +#' @param source the name of external data source. +#' @param ... additional argument(s) passed to the method. +#' @return A SparkDataFrame. +#' @rdname createExternalTable #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) -#' df <- sparkRSQL.createExternalTable(sqlContext, "myjson", path="path/to/json", source="json") +#' sparkR.session() +#' df <- createExternalTable("myjson", path="path/to/json", source="json") #' } - -createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { +#' @name createExternalTable +#' @method createExternalTable default +#' @note createExternalTable since 1.4.0 +createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { + sparkSession <- getSparkSession() options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } - sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) + catalog <- callJMethod(sparkSession, "catalog") + sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) dataFrame(sdf) } +createExternalTable <- function(x, ...) { + dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) +} + #' Create a SparkDataFrame representing the database table accessible via JDBC URL #' #' Additional JDBC database connection properties can be set (...) #' #' Only one of partitionColumn or predicates should be set. Partitions of the table will be -#' retrieved in parallel based on the `numPartitions` or by the predicates. +#' retrieved in parallel based on the \code{numPartitions} or by the predicates. #' #' Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash #' your external database systems. #' -#' @param sqlContext SQLContext to use -#' @param url JDBC database url of the form `jdbc:subprotocol:subname` +#' @param url JDBC database url of the form \code{jdbc:subprotocol:subname} #' @param tableName the name of the table in the external database #' @param partitionColumn the name of a column of integral type that will be used for partitioning -#' @param lowerBound the minimum value of `partitionColumn` used to decide partition stride -#' @param upperBound the maximum value of `partitionColumn` used to decide partition stride -#' @param numPartitions the number of partitions, This, along with `lowerBound` (inclusive), -#' `upperBound` (exclusive), form partition strides for generated WHERE -#' clause expressions used to split the column `partitionColumn` evenly. +#' @param lowerBound the minimum value of \code{partitionColumn} used to decide partition stride +#' @param upperBound the maximum value of \code{partitionColumn} used to decide partition stride +#' @param numPartitions the number of partitions, This, along with \code{lowerBound} (inclusive), +#' \code{upperBound} (exclusive), form partition strides for generated WHERE +#' clause expressions used to split the column \code{partitionColumn} evenly. #' This defaults to SparkContext.defaultParallelism when unset. #' @param predicates a list of conditions in the where clause; each one defines one partition +#' @param ... additional JDBC database connection named properties. #' @return SparkDataFrame #' @rdname read.jdbc #' @name read.jdbc #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' jdbcUrl <- "jdbc:mysql://localhost:3306/databasename" -#' df <- read.jdbc(sqlContext, jdbcUrl, "table", predicates = list("field<=123"), user = "username") -#' df2 <- read.jdbc(sqlContext, jdbcUrl, "table2", partitionColumn = "index", lowerBound = 0, +#' df <- read.jdbc(jdbcUrl, "table", predicates = list("field<=123"), user = "username") +#' df2 <- read.jdbc(jdbcUrl, "table2", partitionColumn = "index", lowerBound = 0, #' upperBound = 10000, user = "username", password = "password") #' } - -read.jdbc <- function(sqlContext, url, tableName, +#' @note read.jdbc since 2.0.0 +read.jdbc <- function(url, tableName, partitionColumn = NULL, lowerBound = NULL, upperBound = NULL, numPartitions = 0L, predicates = list(), ...) { jprops <- varargsToJProperties(...) - - read <- callJMethod(sqlContext, "read") + sparkSession <- getSparkSession() + read <- callJMethod(sparkSession, "read") if (!is.null(partitionColumn)) { if (is.null(numPartitions) || numPartitions == 0) { - sc <- callJMethod(sqlContext, "sparkContext") + sc <- callJMethod(sparkSession, "sparkContext") numPartitions <- callJMethod(sc, "defaultParallelism") } else { numPartitions <- numToInt(numPartitions) diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R new file mode 100644 index 0000000000000..4ac83c29c6f7e --- /dev/null +++ b/R/pkg/R/WindowSpec.R @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# WindowSpec.R - WindowSpec class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R column.R +NULL + +#' S4 class that represents a WindowSpec +#' +#' WindowSpec can be created by using windowPartitionBy() or windowOrderBy() +#' +#' @rdname WindowSpec +#' @seealso \link{windowPartitionBy}, \link{windowOrderBy} +#' +#' @param sws A Java object reference to the backing Scala WindowSpec +#' @export +#' @note WindowSpec since 2.0.0 +setClass("WindowSpec", + slots = list(sws = "jobj")) + +setMethod("initialize", "WindowSpec", function(.Object, sws) { + .Object@sws <- sws + .Object +}) + +windowSpec <- function(sws) { + stopifnot(class(sws) == "jobj") + new("WindowSpec", sws) +} + +#' @rdname show +#' @export +#' @note show(WindowSpec) since 2.0.0 +setMethod("show", "WindowSpec", + function(object) { + cat("WindowSpec", callJMethod(object@sws, "toString"), "\n") + }) + +#' partitionBy +#' +#' Defines the partitioning columns in a WindowSpec. +#' +#' @param x a WindowSpec. +#' @param col a column to partition on (desribed by the name or Column). +#' @param ... additional column(s) to partition on. +#' @return A WindowSpec. +#' @rdname partitionBy +#' @name partitionBy +#' @aliases partitionBy,WindowSpec-method +#' @family windowspec_method +#' @export +#' @examples +#' \dontrun{ +#' partitionBy(ws, "col1", "col2") +#' partitionBy(ws, df$col1, df$col2) +#' } +#' @note partitionBy(WindowSpec) since 2.0.0 +setMethod("partitionBy", + signature(x = "WindowSpec"), + function(x, col, ...) { + stopifnot (class(col) %in% c("character", "Column")) + + if (class(col) == "character") { + windowSpec(callJMethod(x@sws, "partitionBy", col, list(...))) + } else { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + windowSpec(callJMethod(x@sws, "partitionBy", jcols)) + } + }) + +#' Ordering Columns in a WindowSpec +#' +#' Defines the ordering columns in a WindowSpec. +#' @param x a WindowSpec +#' @param col a character or Column indicating an ordering column +#' @param ... additional sorting fields +#' @return A WindowSpec. +#' @name orderBy +#' @rdname orderBy +#' @aliases orderBy,WindowSpec,character-method +#' @family windowspec_method +#' @seealso See \link{arrange} for use in sorting a SparkDataFrame +#' @export +#' @examples +#' \dontrun{ +#' orderBy(ws, "col1", "col2") +#' orderBy(ws, df$col1, df$col2) +#' } +#' @note orderBy(WindowSpec, character) since 2.0.0 +setMethod("orderBy", + signature(x = "WindowSpec", col = "character"), + function(x, col, ...) { + windowSpec(callJMethod(x@sws, "orderBy", col, list(...))) + }) + +#' @rdname orderBy +#' @name orderBy +#' @aliases orderBy,WindowSpec,Column-method +#' @export +#' @note orderBy(WindowSpec, Column) since 2.0.0 +setMethod("orderBy", + signature(x = "WindowSpec", col = "Column"), + function(x, col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + windowSpec(callJMethod(x@sws, "orderBy", jcols)) + }) + +#' rowsBetween +#' +#' Defines the frame boundaries, from \code{start} (inclusive) to \code{end} (inclusive). +#' +#' Both \code{start} and \code{end} are relative positions from the current row. For example, +#' "0" means "current row", while "-1" means the row before the current row, and "5" means the +#' fifth row after the current row. +#' +#' @param x a WindowSpec +#' @param start boundary start, inclusive. +#' The frame is unbounded if this is the minimum long value. +#' @param end boundary end, inclusive. +#' The frame is unbounded if this is the maximum long value. +#' @return a WindowSpec +#' @rdname rowsBetween +#' @aliases rowsBetween,WindowSpec,numeric,numeric-method +#' @name rowsBetween +#' @family windowspec_method +#' @export +#' @examples +#' \dontrun{ +#' rowsBetween(ws, 0, 3) +#' } +#' @note rowsBetween since 2.0.0 +setMethod("rowsBetween", + signature(x = "WindowSpec", start = "numeric", end = "numeric"), + function(x, start, end) { + # "start" and "end" should be long, due to serde limitation, + # limit "start" and "end" as integer now + windowSpec(callJMethod(x@sws, "rowsBetween", as.integer(start), as.integer(end))) + }) + +#' rangeBetween +#' +#' Defines the frame boundaries, from \code{start} (inclusive) to \code{end} (inclusive). +#' +#' Both \code{start} and \code{end} are relative from the current row. For example, "0" means +#' "current row", while "-1" means one off before the current row, and "5" means the five off +#' after the current row. +#' +#' @param x a WindowSpec +#' @param start boundary start, inclusive. +#' The frame is unbounded if this is the minimum long value. +#' @param end boundary end, inclusive. +#' The frame is unbounded if this is the maximum long value. +#' @return a WindowSpec +#' @rdname rangeBetween +#' @aliases rangeBetween,WindowSpec,numeric,numeric-method +#' @name rangeBetween +#' @family windowspec_method +#' @export +#' @examples +#' \dontrun{ +#' rangeBetween(ws, 0, 3) +#' } +#' @note rangeBetween since 2.0.0 +setMethod("rangeBetween", + signature(x = "WindowSpec", start = "numeric", end = "numeric"), + function(x, start, end) { + # "start" and "end" should be long, due to serde limitation, + # limit "start" and "end" as integer now + windowSpec(callJMethod(x@sws, "rangeBetween", as.integer(start), as.integer(end))) + }) + +# Note that over is a method of Column class, but it is placed here to +# avoid Roxygen circular-dependency between class Column and WindowSpec. + +#' over +#' +#' Define a windowing column. +#' +#' @param x a Column, usually one returned by window function(s). +#' @param window a WindowSpec object. Can be created by \code{windowPartitionBy} or +#' \code{windowOrderBy} and configured by other WindowSpec methods. +#' @rdname over +#' @name over +#' @aliases over,Column,WindowSpec-method +#' @family colum_func +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # Partition by am (transmission) and order by hp (horsepower) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' +#' # Rank on hp within each partition +#' out <- select(df, over(rank(), ws), df$hp, df$am) +#' +#' # Lag mpg values by 1 row on the partition-and-ordered table +#' out <- select(df, over(lead(df$mpg), ws), df$mpg, df$hp, df$am) +#' } +#' @note over since 2.0.0 +setMethod("over", + signature(x = "Column", window = "WindowSpec"), + function(x, window) { + column(callJMethod(x@jc, "over", window@sws)) + }) diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 6c81492f8b675..03e70bb2cb82e 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -68,7 +68,7 @@ isRemoveMethod <- function(isStatic, objId, methodName) { # methodName - name of method to be invoked invokeJava <- function(isStatic, objId, methodName, ...) { if (!exists(".sparkRCon", .sparkREnv)) { - stop("No connection to backend found. Please re-run sparkR.init") + stop("No connection to backend found. Please re-run sparkR.session()") } # If this isn't a removeJObject call diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R index 38f0eed95e065..398dffc4ab1b4 100644 --- a/R/pkg/R/broadcast.R +++ b/R/pkg/R/broadcast.R @@ -23,9 +23,11 @@ .broadcastValues <- new.env() .broadcastIdToName <- new.env() -# @title S4 class that represents a Broadcast variable -# @description Broadcast variables can be created using the broadcast -# function from a \code{SparkContext}. +# S4 class that represents a Broadcast variable +# +# Broadcast variables can be created using the broadcast +# function from a \code{SparkContext}. +# # @rdname broadcast-class # @seealso broadcast # diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 25e99390a9c89..2d341d836c133 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -38,7 +38,7 @@ determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { sparkSubmitBinName <- "spark-submit" } else { - sparkSubmitBinName <- "spark-submit.cmd" + sparkSubmitBinName <- "spark-submit2.cmd" } sparkSubmitBinName } @@ -69,5 +69,5 @@ launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { } combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") - invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) + invisible(launchScript(sparkSubmitBin, combinedArgs)) } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index a3e09372bb5e9..539d91b0f8797 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -22,20 +22,31 @@ NULL setOldClass("jobj") -#' @title S4 class that represents a SparkDataFrame column -#' @description The column class supports unary, binary operations on SparkDataFrame columns +#' S4 class that represents a SparkDataFrame column +#' +#' The column class supports unary, binary operations on SparkDataFrame columns +#' #' @rdname column #' #' @slot jc reference to JVM SparkDataFrame column #' @export +#' @note Column since 1.4.0 setClass("Column", slots = list(jc = "jobj")) +#' A set of operations working with SparkDataFrame columns +#' @rdname columnfunctions +#' @name columnfunctions +NULL + setMethod("initialize", "Column", function(.Object, jc) { .Object@jc <- jc .Object }) +#' @rdname column +#' @name column +#' @aliases column,jobj-method setMethod("column", signature(x = "jobj"), function(x) { @@ -44,6 +55,9 @@ setMethod("column", #' @rdname show #' @name show +#' @aliases show,Column-method +#' @export +#' @note show(Column) since 1.4.0 setMethod("show", "Column", function(object) { cat("Column", callJMethod(object@jc, "toString"), "\n") @@ -57,7 +71,7 @@ operators <- list( "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") -column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") +column_functions2 <- c("like", "rlike", "getField", "getItem", "contains") createOperator <- function(op) { setMethod(op, @@ -121,10 +135,15 @@ createMethods() #' #' Set a new name for a column #' +#' @param object Column to rename +#' @param data new name to use +#' #' @rdname alias #' @name alias +#' @aliases alias,Column-method #' @family colum_func #' @export +#' @note alias since 1.4.0 setMethod("alias", signature(object = "Column"), function(object, data) { @@ -142,15 +161,56 @@ setMethod("alias", #' @rdname substr #' @name substr #' @family colum_func +#' @aliases substr,Column-method #' -#' @param start starting position -#' @param stop ending position +#' @param x a Column. +#' @param start starting position. +#' @param stop ending position. +#' @note substr since 1.4.0 setMethod("substr", signature(x = "Column"), function(x, start, stop) { jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1)) column(jc) }) +#' startsWith +#' +#' Determines if entries of x start with string (entries of) prefix respectively, +#' where strings are recycled to common lengths. +#' +#' @rdname startsWith +#' @name startsWith +#' @family colum_func +#' @aliases startsWith,Column-method +#' +#' @param x vector of character string whose "starts" are considered +#' @param prefix character vector (often of length one) +#' @note startsWith since 1.4.0 +setMethod("startsWith", signature(x = "Column"), + function(x, prefix) { + jc <- callJMethod(x@jc, "startsWith", as.vector(prefix)) + column(jc) + }) + +#' endsWith +#' +#' Determines if entries of x end with string (entries of) suffix respectively, +#' where strings are recycled to common lengths. +#' +#' @rdname endsWith +#' @name endsWith +#' @family colum_func +#' @aliases endsWith,Column-method +#' +#' @param x vector of character string whose "ends" are considered +#' @param suffix character vector (often of length one) +#' @note endsWith since 1.4.0 +setMethod("endsWith", signature(x = "Column"), + function(x, suffix) { + jc <- callJMethod(x@jc, "endsWith", as.vector(suffix)) + column(jc) + }) + #' between #' #' Test if the column is between the lower bound and upper bound, inclusive. @@ -158,8 +218,11 @@ setMethod("substr", signature(x = "Column"), #' @rdname between #' @name between #' @family colum_func +#' @aliases between,Column-method #' +#' @param x a Column #' @param bounds lower and upper bounds +#' @note between since 1.5.0 setMethod("between", signature(x = "Column"), function(x, bounds) { if (is.vector(bounds) && length(bounds) == 2) { @@ -172,40 +235,45 @@ setMethod("between", signature(x = "Column"), #' Casts the column to a different data type. #' +#' @param x a Column. +#' @param dataType a character object describing the target data type. +#' See +#' \href{https://spark.apache.org/docs/latest/sparkr.html#data-type-mapping-between-r-and-spark}{ +#' Spark Data Types} for available data types. #' @rdname cast #' @name cast #' @family colum_func +#' @aliases cast,Column-method #' #' @examples \dontrun{ #' cast(df$age, "string") -#' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE)) #' } +#' @note cast since 1.4.0 setMethod("cast", signature(x = "Column"), function(x, dataType) { if (is.character(dataType)) { column(callJMethod(x@jc, "cast", dataType)) - } else if (is.list(dataType)) { - json <- tojson(dataType) - jdataType <- callJStatic("org.apache.spark.sql.types.DataType", "fromJson", json) - column(callJMethod(x@jc, "cast", jdataType)) } else { - stop("dataType should be character or list") + stop("dataType should be character") } }) #' Match a column with given values. #' +#' @param x a Column. +#' @param table a collection of values (coercible to list) to compare with. #' @rdname match #' @name %in% -#' @aliases %in% -#' @return a matched values as a result of comparing with given values. +#' @aliases %in%,Column-method +#' @return A matched values as a result of comparing with given values. #' @export #' @examples #' \dontrun{ #' filter(df, "age in (10, 30)") #' where(df, df$age %in% c(10, 30)) #' } +#' @note \%in\% since 1.5.0 setMethod("%in%", signature(x = "Column"), function(x, table) { @@ -216,12 +284,17 @@ setMethod("%in%", #' otherwise #' #' If values in the specified column are null, returns the value. -#' Can be used in conjunction with `when` to specify a default value for expressions. +#' Can be used in conjunction with \code{when} to specify a default value for expressions. #' +#' @param x a Column. +#' @param value value to replace when the corresponding entry in \code{x} is NA. +#' Can be a single value or a Column. #' @rdname otherwise #' @name otherwise #' @family colum_func +#' @aliases otherwise,Column-method #' @export +#' @note otherwise since 1.5.0 setMethod("otherwise", signature(x = "Column", value = "ANY"), function(x, value) { diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 44bca877fd45a..cec108cd86486 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -87,6 +87,10 @@ objectFile <- function(sc, path, minPartitions = NULL) { #' in the list are split into \code{numSlices} slices and distributed to nodes #' in the cluster. #' +#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function +#' will write it to disk and send the file name to JVM. Also to make sure each slice is not +#' larger than that limit, number of slices may be increased. +#' #' @param sc SparkContext to use #' @param coll collection to parallelize #' @param numSlices number of partitions to create in the RDD @@ -120,6 +124,11 @@ parallelize <- function(sc, coll, numSlices = 1) { coll <- as.list(coll) } + sizeLimit <- getMaxAllocationLimit(sc) + objectSize <- object.size(coll) + + # For large objects we make sure the size of each slice is also smaller than sizeLimit + numSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) if (numSlices > length(coll)) numSlices <- length(coll) @@ -130,12 +139,44 @@ parallelize <- function(sc, coll, numSlices = 1) { # 2-tuples of raws serializedSlices <- lapply(slices, serialize, connection = NULL) - jrdd <- callJStatic("org.apache.spark.api.r.RRDD", - "createRDDFromArray", sc, serializedSlices) + # The PRC backend cannot handle arguments larger than 2GB (INT_MAX) + # If serialized data is safely less than that threshold we send it over the PRC channel. + # Otherwise, we write it to a file and send the file name + if (objectSize < sizeLimit) { + jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices) + } else { + fileName <- writeToTempFile(serializedSlices) + jrdd <- tryCatch(callJStatic( + "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)), + finally = { + file.remove(fileName) + }) + } RDD(jrdd, "byte") } +getMaxAllocationLimit <- function(sc) { + conf <- callJMethod(sc, "getConf") + as.numeric( + callJMethod(conf, + "get", + "spark.r.maxAllocationLimit", + toString(.Machine$integer.max / 10) # Default to a safe value: 200MB + )) +} + +writeToTempFile <- function(serializedSlices) { + fileName <- tempfile() + conn <- file(fileName, "wb") + for (slice in serializedSlices) { + writeBin(as.integer(length(slice)), conn, endian = "big") + writeBin(slice, conn, endian = "big") + } + close(conn) + fileName +} + #' Include this specified package on all workers #' #' This function can be used to include a package on all workers before the @@ -173,9 +214,8 @@ includePackage <- function(sc, pkg) { .sparkREnv$.packages <- packages } -#' @title Broadcast a variable to all workers +#' Broadcast a variable to all workers #' -#' @description #' Broadcast a read-only variable to the cluster, returning a \code{Broadcast} #' object for reading it in distributed functions. #' @@ -207,7 +247,7 @@ broadcast <- function(sc, object) { Broadcast(id, object, jBroadcast, objName) } -#' @title Set the checkpoint directory +#' Set the checkpoint directory #' #' Set the directory under which RDDs are going to be checkpointed. The #' directory must be a HDFS path if running on a cluster. @@ -226,45 +266,49 @@ setCheckpointDir <- function(sc, dirName) { invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) } -#' @title Run a function over a list of elements, distributing the computations with Spark. +#' Run a function over a list of elements, distributing the computations with Spark #' -#' @description -#' Applies a function in a manner that is similar to doParallel or lapply to elements of a list. +#' Run a function over a list of elements, distributing the computations with Spark. Applies a +#' function in a manner that is similar to doParallel or lapply to elements of a list. #' The computations are distributed using Spark. It is conceptually the same as the following code: #' lapply(list, func) #' #' Known limitations: -#' - variable scoping and capture: compared to R's rich support for variable resolutions, the -# distributed nature of SparkR limits how variables are resolved at runtime. All the variables -# that are available through lexical scoping are embedded in the closure of the function and -# available as read-only variables within the function. The environment variables should be -# stored into temporary variables outside the function, and not directly accessed within the -# function. +#' \itemize{ +#' \item variable scoping and capture: compared to R's rich support for variable resolutions, +#' the distributed nature of SparkR limits how variables are resolved at runtime. All the +#' variables that are available through lexical scoping are embedded in the closure of the +#' function and available as read-only variables within the function. The environment variables +#' should be stored into temporary variables outside the function, and not directly accessed +#' within the function. #' -#' - loading external packages: In order to use a package, you need to load it inside the -#' closure. For example, if you rely on the MASS module, here is how you would use it: -#'\dontrun{ -#' train <- function(hyperparam) { -#' library(MASS) -#' lm.ridge(“y ~ x+z”, data, lambda=hyperparam) -#' model +#' \item loading external packages: In order to use a package, you need to load it inside the +#' closure. For example, if you rely on the MASS module, here is how you would use it: +#' \preformatted{ +#' train <- function(hyperparam) { +#' library(MASS) +#' lm.ridge("y ~ x+z", data, lambda=hyperparam) +#' model +#' } +#' } #' } -#'} #' #' @rdname spark.lapply -#' @param sc Spark Context to use #' @param list the list of elements #' @param func a function that takes one argument. #' @return a list of results (the exact type being determined by the function) #' @export #' @examples #'\dontrun{ +#' sparkR.session() #' doubled <- spark.lapply(1:10, function(x){2 * x}) #'} -spark.lapply <- function(sc, list, func) { +#' @note spark.lapply since 2.0.0 +spark.lapply <- function(list, func) { + sc <- getSparkContext() rdd <- parallelize(sc, list, length(list)) results <- map(rdd, func) - local <- collect(results) + local <- collectRDD(results) local } @@ -273,14 +317,14 @@ spark.lapply <- function(sc, list, func) { #' Set new log level: "ALL", "DEBUG", "ERROR", "FATAL", "INFO", "OFF", "TRACE", "WARN" #' #' @rdname setLogLevel -#' @param sc Spark Context to use #' @param level New log level #' @export #' @examples #'\dontrun{ -#' setLogLevel(sc, "ERROR") +#' setLogLevel("ERROR") #'} - -setLogLevel <- function(sc, level) { +#' @note setLogLevel since 2.0.0 +setLogLevel <- function(level) { + sc <- getSparkContext() callJMethod(sc, "setLogLevel", level) } diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index ce071b1a848bb..0e99b171cabeb 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -197,6 +197,36 @@ readMultipleObjects <- function(inputCon) { data # this is a list of named lists now } +readMultipleObjectsWithKeys <- function(inputCon) { + # readMultipleObjectsWithKeys will read multiple continuous objects from + # a DataOutputStream. There is no preceding field telling the count + # of the objects, so the number of objects varies, we try to read + # all objects in a loop until the end of the stream. This function + # is for use by gapply. Each group of rows is followed by the grouping + # key for this group which is then followed by next group. + keys <- list() + data <- list() + subData <- list() + while (TRUE) { + # If reaching the end of the stream, type returned should be "". + type <- readType(inputCon) + if (type == "") { + break + } else if (type == "r") { + type <- readType(inputCon) + # A grouping boundary detected + key <- readTypedObject(inputCon, type) + index <- length(data) + 1L + data[[index]] <- subData + keys[[index]] <- key + subData <- list() + } else { + subData[[length(subData) + 1L]] <- readTypedObject(inputCon, type) + } + } + list(keys = keys, data = data) # this is a list of keys and corresponding data +} + readRowList <- function(obj) { # readRowList is meant for use inside an lapply. As a result, it is # necessary to open a standalone connection for the row and consume diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4a0bdf331523e..4d94b4cd05d44 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -23,16 +23,19 @@ NULL #' A new \linkS4class{Column} is created to represent the literal value. #' If the parameter is a \linkS4class{Column}, it is returned unchanged. #' +#' @param x a literal value or a Column. #' @family normal_funcs #' @rdname lit #' @name lit #' @export +#' @aliases lit,ANY-method #' @examples #' \dontrun{ #' lit(df$name) #' select(df, lit("x")) #' select(df, lit("2015-01-01")) #'} +#' @note lit since 1.5.0 setMethod("lit", signature("ANY"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -45,11 +48,15 @@ setMethod("lit", signature("ANY"), #' #' Computes the absolute value. #' +#' @param x Column to compute on. +#' #' @rdname abs #' @name abs #' @family normal_funcs #' @export #' @examples \dontrun{abs(df$c)} +#' @aliases abs,Column-method +#' @note abs since 1.5.0 setMethod("abs", signature(x = "Column"), function(x) { @@ -62,11 +69,15 @@ setMethod("abs", #' Computes the cosine inverse of the given value; the returned angle is in the range #' 0.0 through pi. #' +#' @param x Column to compute on. +#' #' @rdname acos #' @name acos #' @family math_funcs #' @export #' @examples \dontrun{acos(df$c)} +#' @aliases acos,Column-method +#' @note acos since 1.5.0 setMethod("acos", signature(x = "Column"), function(x) { @@ -74,15 +85,18 @@ setMethod("acos", column(jc) }) -#' approxCountDistinct +#' Returns the approximate number of distinct items in a group #' -#' Aggregate function: returns the approximate number of distinct items in a group. +#' Returns the approximate number of distinct items in a group. This is a column +#' aggregate function. #' #' @rdname approxCountDistinct #' @name approxCountDistinct -#' @family agg_funcs +#' @return the approximate number of distinct items in a group. #' @export +#' @aliases approxCountDistinct,Column-method #' @examples \dontrun{approxCountDistinct(df$c)} +#' @note approxCountDistinct(Column) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), function(x) { @@ -95,11 +109,15 @@ setMethod("approxCountDistinct", #' Computes the numeric value of the first character of the string column, and returns the #' result as a int column. #' +#' @param x Column to compute on. +#' #' @rdname ascii #' @name ascii #' @family string_funcs #' @export +#' @aliases ascii,Column-method #' @examples \dontrun{\dontrun{ascii(df$c)}} +#' @note ascii since 1.5.0 setMethod("ascii", signature(x = "Column"), function(x) { @@ -112,11 +130,15 @@ setMethod("ascii", #' Computes the sine inverse of the given value; the returned angle is in the range #' -pi/2 through pi/2. #' +#' @param x Column to compute on. +#' #' @rdname asin #' @name asin #' @family math_funcs #' @export +#' @aliases asin,Column-method #' @examples \dontrun{asin(df$c)} +#' @note asin since 1.5.0 setMethod("asin", signature(x = "Column"), function(x) { @@ -128,11 +150,15 @@ setMethod("asin", #' #' Computes the tangent inverse of the given value. #' +#' @param x Column to compute on. +#' #' @rdname atan #' @name atan #' @family math_funcs #' @export +#' @aliases atan,Column-method #' @examples \dontrun{atan(df$c)} +#' @note atan since 1.5.0 setMethod("atan", signature(x = "Column"), function(x) { @@ -148,7 +174,9 @@ setMethod("atan", #' @name avg #' @family agg_funcs #' @export +#' @aliases avg,Column-method #' @examples \dontrun{avg(df$c)} +#' @note avg since 1.4.0 setMethod("avg", signature(x = "Column"), function(x) { @@ -161,11 +189,15 @@ setMethod("avg", #' Computes the BASE64 encoding of a binary column and returns it as a string column. #' This is the reverse of unbase64. #' +#' @param x Column to compute on. +#' #' @rdname base64 #' @name base64 #' @family string_funcs #' @export +#' @aliases base64,Column-method #' @examples \dontrun{base64(df$c)} +#' @note base64 since 1.5.0 setMethod("base64", signature(x = "Column"), function(x) { @@ -178,11 +210,15 @@ setMethod("base64", #' An expression that returns the string representation of the binary value of the given long #' column. For example, bin("12") returns "1100". #' +#' @param x Column to compute on. +#' #' @rdname bin #' @name bin #' @family math_funcs #' @export +#' @aliases bin,Column-method #' @examples \dontrun{bin(df$c)} +#' @note bin since 1.5.0 setMethod("bin", signature(x = "Column"), function(x) { @@ -194,11 +230,15 @@ setMethod("bin", #' #' Computes bitwise NOT. #' +#' @param x Column to compute on. +#' #' @rdname bitwiseNOT #' @name bitwiseNOT #' @family normal_funcs #' @export +#' @aliases bitwiseNOT,Column-method #' @examples \dontrun{bitwiseNOT(df$c)} +#' @note bitwiseNOT since 1.5.0 setMethod("bitwiseNOT", signature(x = "Column"), function(x) { @@ -210,11 +250,15 @@ setMethod("bitwiseNOT", #' #' Computes the cube-root of the given value. #' +#' @param x Column to compute on. +#' #' @rdname cbrt #' @name cbrt #' @family math_funcs #' @export +#' @aliases cbrt,Column-method #' @examples \dontrun{cbrt(df$c)} +#' @note cbrt since 1.4.0 setMethod("cbrt", signature(x = "Column"), function(x) { @@ -222,15 +266,19 @@ setMethod("cbrt", column(jc) }) -#' ceil +#' Computes the ceiling of the given value #' #' Computes the ceiling of the given value. #' +#' @param x Column to compute on. +#' #' @rdname ceil #' @name ceil #' @family math_funcs #' @export +#' @aliases ceil,Column-method #' @examples \dontrun{ceil(df$c)} +#' @note ceil since 1.5.0 setMethod("ceil", signature(x = "Column"), function(x) { @@ -241,19 +289,24 @@ setMethod("ceil", #' Though scala functions has "col" function, we don't expose it in SparkR #' because we don't want to conflict with the "col" function in the R base #' package and we also have "column" function exported which is an alias of "col". +#' @noRd col <- function(x) { column(callJStatic("org.apache.spark.sql.functions", "col", x)) } -#' column +#' Returns a Column based on the given column name #' #' Returns a Column based on the given column name. +# +#' @param x Character column name. #' -#' @rdname col +#' @rdname column #' @name column #' @family normal_funcs #' @export +#' @aliases column,character-method #' @examples \dontrun{column(df)} +#' @note column since 1.6.0 setMethod("column", signature(x = "character"), function(x) { @@ -263,11 +316,15 @@ setMethod("column", #' #' Computes the Pearson Correlation Coefficient for two Columns. #' +#' @param col2 a (second) Column. +#' #' @rdname corr #' @name corr #' @family math_funcs #' @export +#' @aliases corr,Column-method #' @examples \dontrun{corr(df$c, df$d)} +#' @note corr since 1.6.0 setMethod("corr", signature(x = "Column"), function(x, col2) { stopifnot(class(col2) == "Column") @@ -283,6 +340,7 @@ setMethod("corr", signature(x = "Column"), #' @name cov #' @family math_funcs #' @export +#' @aliases cov,characterOrColumn-method #' @examples #' \dontrun{ #' cov(df$c, df$d) @@ -290,6 +348,7 @@ setMethod("corr", signature(x = "Column"), #' covar_samp(df$c, df$d) #' covar_samp("c", "d") #' } +#' @note cov since 1.6.0 setMethod("cov", signature(x = "characterOrColumn"), function(x, col2) { stopifnot(is(class(col2), "characterOrColumn")) @@ -297,7 +356,12 @@ setMethod("cov", signature(x = "characterOrColumn"), }) #' @rdname cov +#' +#' @param col1 the first Column. +#' @param col2 the second Column. #' @name covar_samp +#' @aliases covar_samp,characterOrColumn,characterOrColumn-method +#' @note covar_samp since 2.0.0 setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), function(col1, col2) { stopifnot(class(col1) == class(col2)) @@ -313,15 +377,20 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO #' #' Compute the population covariance between two expressions. #' +#' @param col1 First column to compute cov_pop. +#' @param col2 Second column to compute cov_pop. +#' #' @rdname covar_pop #' @name covar_pop #' @family math_funcs #' @export +#' @aliases covar_pop,characterOrColumn,characterOrColumn-method #' @examples #' \dontrun{ #' covar_pop(df$c, df$d) #' covar_pop("c", "d") #' } +#' @note covar_pop since 2.0.0 setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOrColumn"), function(col1, col2) { stopifnot(class(col1) == class(col2)) @@ -337,11 +406,15 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr #' #' Computes the cosine of the given value. #' +#' @param x Column to compute on. +#' #' @rdname cos #' @name cos #' @family math_funcs +#' @aliases cos,Column-method #' @export #' @examples \dontrun{cos(df$c)} +#' @note cos since 1.5.0 setMethod("cos", signature(x = "Column"), function(x) { @@ -353,11 +426,15 @@ setMethod("cos", #' #' Computes the hyperbolic cosine of the given value. #' +#' @param x Column to compute on. +#' #' @rdname cosh #' @name cosh #' @family math_funcs +#' @aliases cosh,Column-method #' @export #' @examples \dontrun{cosh(df$c)} +#' @note cosh since 1.5.0 setMethod("cosh", signature(x = "Column"), function(x) { @@ -365,15 +442,18 @@ setMethod("cosh", column(jc) }) -#' count +#' Returns the number of items in a group #' -#' Aggregate function: returns the number of items in a group. +#' This can be used as a column aggregate function with \code{Column} as input, +#' and returns the number of items in a group. #' #' @rdname count #' @name count #' @family agg_funcs +#' @aliases count,Column-method #' @export #' @examples \dontrun{count(df$c)} +#' @note count since 1.4.0 setMethod("count", signature(x = "Column"), function(x) { @@ -386,11 +466,15 @@ setMethod("count", #' Calculates the cyclic redundancy check value (CRC32) of a binary column and #' returns the value as a bigint. #' +#' @param x Column to compute on. +#' #' @rdname crc32 #' @name crc32 #' @family misc_funcs +#' @aliases crc32,Column-method #' @export #' @examples \dontrun{crc32(df$c)} +#' @note crc32 since 1.5.0 setMethod("crc32", signature(x = "Column"), function(x) { @@ -402,11 +486,16 @@ setMethod("crc32", #' #' Calculates the hash code of given columns, and returns the result as a int column. #' +#' @param x Column to compute on. +#' @param ... additional Column(s) to be included. +#' #' @rdname hash #' @name hash #' @family misc_funcs +#' @aliases hash,Column-method #' @export #' @examples \dontrun{hash(df$c)} +#' @note hash since 2.0.0 setMethod("hash", signature(x = "Column"), function(x, ...) { @@ -422,11 +511,15 @@ setMethod("hash", #' #' Extracts the day of the month as an integer from a given date/timestamp/string. #' +#' @param x Column to compute on. +#' #' @rdname dayofmonth #' @name dayofmonth #' @family datetime_funcs +#' @aliases dayofmonth,Column-method #' @export #' @examples \dontrun{dayofmonth(df$c)} +#' @note dayofmonth since 1.5.0 setMethod("dayofmonth", signature(x = "Column"), function(x) { @@ -438,11 +531,15 @@ setMethod("dayofmonth", #' #' Extracts the day of the year as an integer from a given date/timestamp/string. #' +#' @param x Column to compute on. +#' #' @rdname dayofyear #' @name dayofyear #' @family datetime_funcs +#' @aliases dayofyear,Column-method #' @export #' @examples \dontrun{dayofyear(df$c)} +#' @note dayofyear since 1.5.0 setMethod("dayofyear", signature(x = "Column"), function(x) { @@ -455,11 +552,16 @@ setMethod("dayofyear", #' Computes the first argument into a string from a binary using the provided character set #' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). #' +#' @param x Column to compute on. +#' @param charset Character set to use +#' #' @rdname decode #' @name decode #' @family string_funcs +#' @aliases decode,Column,character-method #' @export #' @examples \dontrun{decode(df$c, "UTF-8")} +#' @note decode since 1.6.0 setMethod("decode", signature(x = "Column", charset = "character"), function(x, charset) { @@ -472,11 +574,16 @@ setMethod("decode", #' Computes the first argument into a binary from a string using the provided character set #' (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). #' +#' @param x Column to compute on. +#' @param charset Character set to use +#' #' @rdname encode #' @name encode #' @family string_funcs +#' @aliases encode,Column,character-method #' @export #' @examples \dontrun{encode(df$c, "UTF-8")} +#' @note encode since 1.6.0 setMethod("encode", signature(x = "Column", charset = "character"), function(x, charset) { @@ -488,11 +595,15 @@ setMethod("encode", #' #' Computes the exponential of the given value. #' +#' @param x Column to compute on. +#' #' @rdname exp #' @name exp #' @family math_funcs +#' @aliases exp,Column-method #' @export #' @examples \dontrun{exp(df$c)} +#' @note exp since 1.5.0 setMethod("exp", signature(x = "Column"), function(x) { @@ -504,11 +615,15 @@ setMethod("exp", #' #' Computes the exponential of the given value minus one. #' +#' @param x Column to compute on. +#' #' @rdname expm1 #' @name expm1 +#' @aliases expm1,Column-method #' @family math_funcs #' @export #' @examples \dontrun{expm1(df$c)} +#' @note expm1 since 1.5.0 setMethod("expm1", signature(x = "Column"), function(x) { @@ -520,11 +635,15 @@ setMethod("expm1", #' #' Computes the factorial of the given value. #' +#' @param x Column to compute on. +#' #' @rdname factorial #' @name factorial +#' @aliases factorial,Column-method #' @family math_funcs #' @export #' @examples \dontrun{factorial(df$c)} +#' @note factorial since 1.5.0 setMethod("factorial", signature(x = "Column"), function(x) { @@ -539,8 +658,12 @@ setMethod("factorial", #' The function by default returns the first values it sees. It will return the first non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. #' +#' @param na.rm a logical value indicating whether NA values should be stripped +#' before the computation proceeds. +#' #' @rdname first #' @name first +#' @aliases first,characterOrColumn-method #' @family agg_funcs #' @export #' @examples @@ -548,6 +671,7 @@ setMethod("factorial", #' first(df$c) #' first(df$c, TRUE) #' } +#' @note first(characterOrColumn) since 1.4.0 setMethod("first", signature(x = "characterOrColumn"), function(x, na.rm = FALSE) { @@ -564,11 +688,15 @@ setMethod("first", #' #' Computes the floor of the given value. #' +#' @param x Column to compute on. +#' #' @rdname floor #' @name floor +#' @aliases floor,Column-method #' @family math_funcs #' @export #' @examples \dontrun{floor(df$c)} +#' @note floor since 1.5.0 setMethod("floor", signature(x = "Column"), function(x) { @@ -580,11 +708,15 @@ setMethod("floor", #' #' Computes hex value of the given column. #' +#' @param x Column to compute on. +#' #' @rdname hex #' @name hex #' @family math_funcs +#' @aliases hex,Column-method #' @export #' @examples \dontrun{hex(df$c)} +#' @note hex since 1.5.0 setMethod("hex", signature(x = "Column"), function(x) { @@ -596,11 +728,15 @@ setMethod("hex", #' #' Extracts the hours as an integer from a given date/timestamp/string. #' +#' @param x Column to compute on. +#' #' @rdname hour #' @name hour +#' @aliases hour,Column-method #' @family datetime_funcs #' @export #' @examples \dontrun{hour(df$c)} +#' @note hour since 1.5.0 setMethod("hour", signature(x = "Column"), function(x) { @@ -615,11 +751,15 @@ setMethod("hour", #' #' For example, "hello world" will become "Hello World". #' +#' @param x Column to compute on. +#' #' @rdname initcap #' @name initcap #' @family string_funcs +#' @aliases initcap,Column-method #' @export #' @examples \dontrun{initcap(df$c)} +#' @note initcap since 1.5.0 setMethod("initcap", signature(x = "Column"), function(x) { @@ -631,15 +771,19 @@ setMethod("initcap", #' #' Return true if the column is NaN, alias for \link{isnan} #' +#' @param x Column to compute on. +#' #' @rdname is.nan #' @name is.nan #' @family normal_funcs +#' @aliases is.nan,Column-method #' @export #' @examples #' \dontrun{ #' is.nan(df$c) #' isnan(df$c) #' } +#' @note is.nan since 2.0.0 setMethod("is.nan", signature(x = "Column"), function(x) { @@ -648,6 +792,8 @@ setMethod("is.nan", #' @rdname is.nan #' @name isnan +#' @aliases isnan,Column-method +#' @note isnan since 2.0.0 setMethod("isnan", signature(x = "Column"), function(x) { @@ -659,11 +805,15 @@ setMethod("isnan", #' #' Aggregate function: returns the kurtosis of the values in a group. #' +#' @param x Column to compute on. +#' #' @rdname kurtosis #' @name kurtosis +#' @aliases kurtosis,Column-method #' @family agg_funcs #' @export #' @examples \dontrun{kurtosis(df$c)} +#' @note kurtosis since 1.6.0 setMethod("kurtosis", signature(x = "Column"), function(x) { @@ -678,8 +828,14 @@ setMethod("kurtosis", #' The function by default returns the last values it sees. It will return the last non-missing #' value it sees when na.rm is set to true. If all values are missing, then NA is returned. #' +#' @param x column to compute on. +#' @param na.rm a logical value indicating whether NA values should be stripped +#' before the computation proceeds. +#' @param ... further arguments to be passed to or from other methods. +#' #' @rdname last #' @name last +#' @aliases last,characterOrColumn-method #' @family agg_funcs #' @export #' @examples @@ -687,6 +843,7 @@ setMethod("kurtosis", #' last(df$c) #' last(df$c, TRUE) #' } +#' @note last since 1.4.0 setMethod("last", signature(x = "characterOrColumn"), function(x, na.rm = FALSE) { @@ -705,11 +862,15 @@ setMethod("last", #' For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the #' month in July 2015. #' +#' @param x Column to compute on. +#' #' @rdname last_day #' @name last_day +#' @aliases last_day,Column-method #' @family datetime_funcs #' @export #' @examples \dontrun{last_day(df$c)} +#' @note last_day since 1.5.0 setMethod("last_day", signature(x = "Column"), function(x) { @@ -721,11 +882,15 @@ setMethod("last_day", #' #' Computes the length of a given string or binary column. #' +#' @param x Column to compute on. +#' #' @rdname length #' @name length +#' @aliases length,Column-method #' @family string_funcs #' @export #' @examples \dontrun{length(df$c)} +#' @note length since 1.5.0 setMethod("length", signature(x = "Column"), function(x) { @@ -737,11 +902,15 @@ setMethod("length", #' #' Computes the natural logarithm of the given value. #' +#' @param x Column to compute on. +#' #' @rdname log #' @name log +#' @aliases log,Column-method #' @family math_funcs #' @export #' @examples \dontrun{log(df$c)} +#' @note log since 1.5.0 setMethod("log", signature(x = "Column"), function(x) { @@ -753,11 +922,15 @@ setMethod("log", #' #' Computes the logarithm of the given value in base 10. #' +#' @param x Column to compute on. +#' #' @rdname log10 #' @name log10 #' @family math_funcs +#' @aliases log10,Column-method #' @export #' @examples \dontrun{log10(df$c)} +#' @note log10 since 1.5.0 setMethod("log10", signature(x = "Column"), function(x) { @@ -769,11 +942,15 @@ setMethod("log10", #' #' Computes the natural logarithm of the given value plus one. #' +#' @param x Column to compute on. +#' #' @rdname log1p #' @name log1p #' @family math_funcs +#' @aliases log1p,Column-method #' @export #' @examples \dontrun{log1p(df$c)} +#' @note log1p since 1.5.0 setMethod("log1p", signature(x = "Column"), function(x) { @@ -785,11 +962,15 @@ setMethod("log1p", #' #' Computes the logarithm of the given column in base 2. #' +#' @param x Column to compute on. +#' #' @rdname log2 #' @name log2 #' @family math_funcs +#' @aliases log2,Column-method #' @export #' @examples \dontrun{log2(df$c)} +#' @note log2 since 1.5.0 setMethod("log2", signature(x = "Column"), function(x) { @@ -801,11 +982,15 @@ setMethod("log2", #' #' Converts a string column to lower case. #' +#' @param x Column to compute on. +#' #' @rdname lower #' @name lower #' @family string_funcs +#' @aliases lower,Column-method #' @export #' @examples \dontrun{lower(df$c)} +#' @note lower since 1.4.0 setMethod("lower", signature(x = "Column"), function(x) { @@ -817,11 +1002,15 @@ setMethod("lower", #' #' Trim the spaces from left end for the specified string value. #' +#' @param x Column to compute on. +#' #' @rdname ltrim #' @name ltrim #' @family string_funcs +#' @aliases ltrim,Column-method #' @export #' @examples \dontrun{ltrim(df$c)} +#' @note ltrim since 1.5.0 setMethod("ltrim", signature(x = "Column"), function(x) { @@ -833,11 +1022,15 @@ setMethod("ltrim", #' #' Aggregate function: returns the maximum value of the expression in a group. #' +#' @param x Column to compute on. +#' #' @rdname max #' @name max #' @family agg_funcs +#' @aliases max,Column-method #' @export #' @examples \dontrun{max(df$c)} +#' @note max since 1.5.0 setMethod("max", signature(x = "Column"), function(x) { @@ -850,11 +1043,15 @@ setMethod("max", #' Calculates the MD5 digest of a binary column and returns the value #' as a 32 character hex string. #' +#' @param x Column to compute on. +#' #' @rdname md5 #' @name md5 #' @family misc_funcs +#' @aliases md5,Column-method #' @export #' @examples \dontrun{md5(df$c)} +#' @note md5 since 1.5.0 setMethod("md5", signature(x = "Column"), function(x) { @@ -867,11 +1064,15 @@ setMethod("md5", #' Aggregate function: returns the average of the values in a group. #' Alias for avg. #' +#' @param x Column to compute on. +#' #' @rdname mean #' @name mean #' @family agg_funcs +#' @aliases mean,Column-method #' @export #' @examples \dontrun{mean(df$c)} +#' @note mean since 1.5.0 setMethod("mean", signature(x = "Column"), function(x) { @@ -883,11 +1084,15 @@ setMethod("mean", #' #' Aggregate function: returns the minimum value of the expression in a group. #' +#' @param x Column to compute on. +#' #' @rdname min #' @name min +#' @aliases min,Column-method #' @family agg_funcs #' @export #' @examples \dontrun{min(df$c)} +#' @note min since 1.5.0 setMethod("min", signature(x = "Column"), function(x) { @@ -899,11 +1104,15 @@ setMethod("min", #' #' Extracts the minutes as an integer from a given date/timestamp/string. #' +#' @param x Column to compute on. +#' #' @rdname minute #' @name minute +#' @aliases minute,Column-method #' @family datetime_funcs #' @export #' @examples \dontrun{minute(df$c)} +#' @note minute since 1.5.0 setMethod("minute", signature(x = "Column"), function(x) { @@ -911,15 +1120,47 @@ setMethod("minute", column(jc) }) +#' monotonically_increasing_id +#' +#' Return a column that generates monotonically increasing 64-bit integers. +#' +#' The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. +#' The current implementation puts the partition ID in the upper 31 bits, and the record number +#' within each partition in the lower 33 bits. The assumption is that the SparkDataFrame has +#' less than 1 billion partitions, and each partition has less than 8 billion records. +#' +#' As an example, consider a SparkDataFrame with two partitions, each with 3 records. +#' This expression would return the following IDs: +#' 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. +#' +#' This is equivalent to the MONOTONICALLY_INCREASING_ID function in SQL. +#' +#' @rdname monotonically_increasing_id +#' @aliases monotonically_increasing_id,missing-method +#' @name monotonically_increasing_id +#' @family misc_funcs +#' @export +#' @examples \dontrun{select(df, monotonically_increasing_id())} +setMethod("monotonically_increasing_id", + signature("missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "monotonically_increasing_id") + column(jc) + }) + #' month #' #' Extracts the month as an integer from a given date/timestamp/string. #' +#' @param x Column to compute on. +#' #' @rdname month #' @name month +#' @aliases month,Column-method #' @family datetime_funcs #' @export #' @examples \dontrun{month(df$c)} +#' @note month since 1.5.0 setMethod("month", signature(x = "Column"), function(x) { @@ -931,11 +1172,15 @@ setMethod("month", #' #' Unary minus, i.e. negate the expression. #' +#' @param x Column to compute on. +#' #' @rdname negate #' @name negate #' @family normal_funcs +#' @aliases negate,Column-method #' @export #' @examples \dontrun{negate(df$c)} +#' @note negate since 1.5.0 setMethod("negate", signature(x = "Column"), function(x) { @@ -947,11 +1192,15 @@ setMethod("negate", #' #' Extracts the quarter as an integer from a given date/timestamp/string. #' +#' @param x Column to compute on. +#' #' @rdname quarter #' @name quarter #' @family datetime_funcs +#' @aliases quarter,Column-method #' @export #' @examples \dontrun{quarter(df$c)} +#' @note quarter since 1.5.0 setMethod("quarter", signature(x = "Column"), function(x) { @@ -963,11 +1212,15 @@ setMethod("quarter", #' #' Reverses the string column and returns it as a new string column. #' +#' @param x Column to compute on. +#' #' @rdname reverse #' @name reverse #' @family string_funcs +#' @aliases reverse,Column-method #' @export #' @examples \dontrun{reverse(df$c)} +#' @note reverse since 1.5.0 setMethod("reverse", signature(x = "Column"), function(x) { @@ -980,11 +1233,15 @@ setMethod("reverse", #' Returns the double value that is closest in value to the argument and #' is equal to a mathematical integer. #' +#' @param x Column to compute on. +#' #' @rdname rint #' @name rint #' @family math_funcs +#' @aliases rint,Column-method #' @export #' @examples \dontrun{rint(df$c)} +#' @note rint since 1.5.0 setMethod("rint", signature(x = "Column"), function(x) { @@ -994,13 +1251,17 @@ setMethod("rint", #' round #' -#' Returns the value of the column `e` rounded to 0 decimal places using HALF_UP rounding mode. +#' Returns the value of the column \code{e} rounded to 0 decimal places using HALF_UP rounding mode. +#' +#' @param x Column to compute on. #' #' @rdname round #' @name round #' @family math_funcs +#' @aliases round,Column-method #' @export #' @examples \dontrun{round(df$c)} +#' @note round since 1.5.0 setMethod("round", signature(x = "Column"), function(x) { @@ -1010,16 +1271,23 @@ setMethod("round", #' bround #' -#' Returns the value of the column `e` rounded to `scale` decimal places using HALF_EVEN rounding -#' mode if `scale` >= 0 or at integral part when `scale` < 0. +#' Returns the value of the column \code{e} rounded to \code{scale} decimal places using HALF_EVEN rounding +#' mode if \code{scale} >= 0 or at integer part when \code{scale} < 0. #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' +#' @param x Column to compute on. +#' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, +#' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left +#' of the decimal point when \code{scale} < 0. +#' @param ... further arguments to be passed to or from other methods. #' @rdname bround #' @name bround #' @family math_funcs +#' @aliases bround,Column-method #' @export #' @examples \dontrun{bround(df$c, 0)} +#' @note bround since 2.0.0 setMethod("bround", signature(x = "Column"), function(x, scale = 0) { @@ -1032,11 +1300,15 @@ setMethod("bround", #' #' Trim the spaces from right end for the specified string value. #' +#' @param x Column to compute on. +#' #' @rdname rtrim #' @name rtrim #' @family string_funcs +#' @aliases rtrim,Column-method #' @export #' @examples \dontrun{rtrim(df$c)} +#' @note rtrim since 1.5.0 setMethod("rtrim", signature(x = "Column"), function(x) { @@ -1048,9 +1320,12 @@ setMethod("rtrim", #' #' Aggregate function: alias for \link{stddev_samp} #' +#' @param x Column to compute on. +#' @param na.rm currently not used. #' @rdname sd #' @name sd #' @family agg_funcs +#' @aliases sd,Column-method #' @seealso \link{stddev_pop}, \link{stddev_samp} #' @export #' @examples @@ -1059,6 +1334,7 @@ setMethod("rtrim", #'select(df, stddev(df$age)) #'agg(df, sd(df$age)) #'} +#' @note sd since 1.6.0 setMethod("sd", signature(x = "Column"), function(x) { @@ -1070,11 +1346,15 @@ setMethod("sd", #' #' Extracts the seconds as an integer from a given date/timestamp/string. #' +#' @param x Column to compute on. +#' #' @rdname second #' @name second #' @family datetime_funcs +#' @aliases second,Column-method #' @export #' @examples \dontrun{second(df$c)} +#' @note second since 1.5.0 setMethod("second", signature(x = "Column"), function(x) { @@ -1087,11 +1367,15 @@ setMethod("second", #' Calculates the SHA-1 digest of a binary column and returns the value #' as a 40 character hex string. #' +#' @param x Column to compute on. +#' #' @rdname sha1 #' @name sha1 #' @family misc_funcs +#' @aliases sha1,Column-method #' @export #' @examples \dontrun{sha1(df$c)} +#' @note sha1 since 1.5.0 setMethod("sha1", signature(x = "Column"), function(x) { @@ -1103,11 +1387,15 @@ setMethod("sha1", #' #' Computes the signum of the given value. #' -#' @rdname signum +#' @param x Column to compute on. +#' +#' @rdname sign #' @name signum +#' @aliases signum,Column-method #' @family math_funcs #' @export #' @examples \dontrun{signum(df$c)} +#' @note signum since 1.5.0 setMethod("signum", signature(x = "Column"), function(x) { @@ -1119,11 +1407,15 @@ setMethod("signum", #' #' Computes the sine of the given value. #' +#' @param x Column to compute on. +#' #' @rdname sin #' @name sin #' @family math_funcs +#' @aliases sin,Column-method #' @export #' @examples \dontrun{sin(df$c)} +#' @note sin since 1.5.0 setMethod("sin", signature(x = "Column"), function(x) { @@ -1135,11 +1427,15 @@ setMethod("sin", #' #' Computes the hyperbolic sine of the given value. #' +#' @param x Column to compute on. +#' #' @rdname sinh #' @name sinh #' @family math_funcs +#' @aliases sinh,Column-method #' @export #' @examples \dontrun{sinh(df$c)} +#' @note sinh since 1.5.0 setMethod("sinh", signature(x = "Column"), function(x) { @@ -1151,11 +1447,15 @@ setMethod("sinh", #' #' Aggregate function: returns the skewness of the values in a group. #' +#' @param x Column to compute on. +#' #' @rdname skewness #' @name skewness #' @family agg_funcs +#' @aliases skewness,Column-method #' @export #' @examples \dontrun{skewness(df$c)} +#' @note skewness since 1.6.0 setMethod("skewness", signature(x = "Column"), function(x) { @@ -1167,11 +1467,15 @@ setMethod("skewness", #' #' Return the soundex code for the specified expression. #' +#' @param x Column to compute on. +#' #' @rdname soundex #' @name soundex #' @family string_funcs +#' @aliases soundex,Column-method #' @export #' @examples \dontrun{soundex(df$c)} +#' @note soundex since 1.5.0 setMethod("soundex", signature(x = "Column"), function(x) { @@ -1179,8 +1483,32 @@ setMethod("soundex", column(jc) }) +#' Return the partition ID as a column +#' +#' Return the partition ID of the Spark task as a SparkDataFrame column. +#' Note that this is nondeterministic because it depends on data partitioning and +#' task scheduling. +#' +#' This is equivalent to the SPARK_PARTITION_ID function in SQL. +#' +#' @rdname spark_partition_id +#' @name spark_partition_id +#' @aliases spark_partition_id,missing-method +#' @export +#' @examples +#' \dontrun{select(df, spark_partition_id())} +#' @note spark_partition_id since 2.0.0 +setMethod("spark_partition_id", + signature("missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "spark_partition_id") + column(jc) + }) + #' @rdname sd +#' @aliases stddev,Column-method #' @name stddev +#' @note stddev since 1.6.0 setMethod("stddev", signature(x = "Column"), function(x) { @@ -1192,12 +1520,16 @@ setMethod("stddev", #' #' Aggregate function: returns the population standard deviation of the expression in a group. #' +#' @param x Column to compute on. +#' #' @rdname stddev_pop #' @name stddev_pop #' @family agg_funcs +#' @aliases stddev_pop,Column-method #' @seealso \link{sd}, \link{stddev_samp} #' @export #' @examples \dontrun{stddev_pop(df$c)} +#' @note stddev_pop since 1.6.0 setMethod("stddev_pop", signature(x = "Column"), function(x) { @@ -1209,12 +1541,16 @@ setMethod("stddev_pop", #' #' Aggregate function: returns the unbiased sample standard deviation of the expression in a group. #' +#' @param x Column to compute on. +#' #' @rdname stddev_samp #' @name stddev_samp #' @family agg_funcs +#' @aliases stddev_samp,Column-method #' @seealso \link{stddev_pop}, \link{sd} #' @export #' @examples \dontrun{stddev_samp(df$c)} +#' @note stddev_samp since 1.6.0 setMethod("stddev_samp", signature(x = "Column"), function(x) { @@ -1226,15 +1562,20 @@ setMethod("stddev_samp", #' #' Creates a new struct column that composes multiple input columns. #' +#' @param x a column to compute on. +#' @param ... optional column(s) to be included. +#' #' @rdname struct #' @name struct #' @family normal_funcs +#' @aliases struct,characterOrColumn-method #' @export #' @examples #' \dontrun{ #' struct(df$c, df$d) #' struct("col1", "col2") #' } +#' @note struct since 1.6.0 setMethod("struct", signature(x = "characterOrColumn"), function(x, ...) { @@ -1251,11 +1592,15 @@ setMethod("struct", #' #' Computes the square root of the specified float value. #' +#' @param x Column to compute on. +#' #' @rdname sqrt #' @name sqrt #' @family math_funcs +#' @aliases sqrt,Column-method #' @export #' @examples \dontrun{sqrt(df$c)} +#' @note sqrt since 1.5.0 setMethod("sqrt", signature(x = "Column"), function(x) { @@ -1267,11 +1612,15 @@ setMethod("sqrt", #' #' Aggregate function: returns the sum of all values in the expression. #' +#' @param x Column to compute on. +#' #' @rdname sum #' @name sum #' @family agg_funcs +#' @aliases sum,Column-method #' @export #' @examples \dontrun{sum(df$c)} +#' @note sum since 1.5.0 setMethod("sum", signature(x = "Column"), function(x) { @@ -1283,11 +1632,15 @@ setMethod("sum", #' #' Aggregate function: returns the sum of distinct values in the expression. #' +#' @param x Column to compute on. +#' #' @rdname sumDistinct #' @name sumDistinct #' @family agg_funcs +#' @aliases sumDistinct,Column-method #' @export #' @examples \dontrun{sumDistinct(df$c)} +#' @note sumDistinct since 1.4.0 setMethod("sumDistinct", signature(x = "Column"), function(x) { @@ -1299,11 +1652,15 @@ setMethod("sumDistinct", #' #' Computes the tangent of the given value. #' +#' @param x Column to compute on. +#' #' @rdname tan #' @name tan #' @family math_funcs +#' @aliases tan,Column-method #' @export #' @examples \dontrun{tan(df$c)} +#' @note tan since 1.5.0 setMethod("tan", signature(x = "Column"), function(x) { @@ -1315,11 +1672,15 @@ setMethod("tan", #' #' Computes the hyperbolic tangent of the given value. #' +#' @param x Column to compute on. +#' #' @rdname tanh #' @name tanh #' @family math_funcs +#' @aliases tanh,Column-method #' @export #' @examples \dontrun{tanh(df$c)} +#' @note tanh since 1.5.0 setMethod("tanh", signature(x = "Column"), function(x) { @@ -1331,11 +1692,15 @@ setMethod("tanh", #' #' Converts an angle measured in radians to an approximately equivalent angle measured in degrees. #' +#' @param x Column to compute on. +#' #' @rdname toDegrees #' @name toDegrees #' @family math_funcs +#' @aliases toDegrees,Column-method #' @export #' @examples \dontrun{toDegrees(df$c)} +#' @note toDegrees since 1.4.0 setMethod("toDegrees", signature(x = "Column"), function(x) { @@ -1347,11 +1712,15 @@ setMethod("toDegrees", #' #' Converts an angle measured in degrees to an approximately equivalent angle measured in radians. #' +#' @param x Column to compute on. +#' #' @rdname toRadians #' @name toRadians #' @family math_funcs +#' @aliases toRadians,Column-method #' @export #' @examples \dontrun{toRadians(df$c)} +#' @note toRadians since 1.4.0 setMethod("toRadians", signature(x = "Column"), function(x) { @@ -1363,11 +1732,15 @@ setMethod("toRadians", #' #' Converts the column into DateType. #' +#' @param x Column to compute on. +#' #' @rdname to_date #' @name to_date #' @family datetime_funcs +#' @aliases to_date,Column-method #' @export #' @examples \dontrun{to_date(df$c)} +#' @note to_date since 1.5.0 setMethod("to_date", signature(x = "Column"), function(x) { @@ -1379,11 +1752,15 @@ setMethod("to_date", #' #' Trim the spaces from both ends for the specified string column. #' +#' @param x Column to compute on. +#' #' @rdname trim #' @name trim #' @family string_funcs +#' @aliases trim,Column-method #' @export #' @examples \dontrun{trim(df$c)} +#' @note trim since 1.5.0 setMethod("trim", signature(x = "Column"), function(x) { @@ -1396,11 +1773,15 @@ setMethod("trim", #' Decodes a BASE64 encoded string column and returns it as a binary column. #' This is the reverse of base64. #' +#' @param x Column to compute on. +#' #' @rdname unbase64 #' @name unbase64 #' @family string_funcs +#' @aliases unbase64,Column-method #' @export #' @examples \dontrun{unbase64(df$c)} +#' @note unbase64 since 1.5.0 setMethod("unbase64", signature(x = "Column"), function(x) { @@ -1413,11 +1794,15 @@ setMethod("unbase64", #' Inverse of hex. Interprets each pair of characters as a hexadecimal number #' and converts to the byte representation of number. #' +#' @param x Column to compute on. +#' #' @rdname unhex #' @name unhex #' @family math_funcs +#' @aliases unhex,Column-method #' @export #' @examples \dontrun{unhex(df$c)} +#' @note unhex since 1.5.0 setMethod("unhex", signature(x = "Column"), function(x) { @@ -1429,11 +1814,15 @@ setMethod("unhex", #' #' Converts a string column to upper case. #' +#' @param x Column to compute on. +#' #' @rdname upper #' @name upper #' @family string_funcs +#' @aliases upper,Column-method #' @export #' @examples \dontrun{upper(df$c)} +#' @note upper since 1.4.0 setMethod("upper", signature(x = "Column"), function(x) { @@ -1445,9 +1834,12 @@ setMethod("upper", #' #' Aggregate function: alias for \link{var_samp}. #' +#' @param x a Column to compute on. +#' @param y,na.rm,use currently not used. #' @rdname var #' @name var #' @family agg_funcs +#' @aliases var,Column-method #' @seealso \link{var_pop}, \link{var_samp} #' @export #' @examples @@ -1456,6 +1848,7 @@ setMethod("upper", #'select(df, var_pop(df$age)) #'agg(df, var(df$age)) #'} +#' @note var since 1.6.0 setMethod("var", signature(x = "Column"), function(x) { @@ -1464,7 +1857,9 @@ setMethod("var", }) #' @rdname var +#' @aliases variance,Column-method #' @name variance +#' @note variance since 1.6.0 setMethod("variance", signature(x = "Column"), function(x) { @@ -1476,12 +1871,16 @@ setMethod("variance", #' #' Aggregate function: returns the population variance of the values in a group. #' +#' @param x Column to compute on. +#' #' @rdname var_pop #' @name var_pop #' @family agg_funcs +#' @aliases var_pop,Column-method #' @seealso \link{var}, \link{var_samp} #' @export #' @examples \dontrun{var_pop(df$c)} +#' @note var_pop since 1.5.0 setMethod("var_pop", signature(x = "Column"), function(x) { @@ -1493,12 +1892,16 @@ setMethod("var_pop", #' #' Aggregate function: returns the unbiased variance of the values in a group. #' +#' @param x Column to compute on. +#' #' @rdname var_samp #' @name var_samp +#' @aliases var_samp,Column-method #' @family agg_funcs #' @seealso \link{var_pop}, \link{var} #' @export #' @examples \dontrun{var_samp(df$c)} +#' @note var_samp since 1.6.0 setMethod("var_samp", signature(x = "Column"), function(x) { @@ -1510,11 +1913,15 @@ setMethod("var_samp", #' #' Extracts the week number as an integer from a given date/timestamp/string. #' +#' @param x Column to compute on. +#' #' @rdname weekofyear #' @name weekofyear +#' @aliases weekofyear,Column-method #' @family datetime_funcs #' @export #' @examples \dontrun{weekofyear(df$c)} +#' @note weekofyear since 1.5.0 setMethod("weekofyear", signature(x = "Column"), function(x) { @@ -1526,11 +1933,15 @@ setMethod("weekofyear", #' #' Extracts the year as an integer from a given date/timestamp/string. #' +#' @param x Column to compute on. +#' #' @rdname year #' @name year #' @family datetime_funcs +#' @aliases year,Column-method #' @export #' @examples \dontrun{year(df$c)} +#' @note year since 1.5.0 setMethod("year", signature(x = "Column"), function(x) { @@ -1542,12 +1953,17 @@ setMethod("year", #' #' Returns the angle theta from the conversion of rectangular coordinates (x, y) to #' polar coordinates (r, theta). +# +#' @param x Column to compute on. +#' @param y Column to compute on. #' #' @rdname atan2 #' @name atan2 #' @family math_funcs +#' @aliases atan2,Column-method #' @export #' @examples \dontrun{atan2(df$c, x)} +#' @note atan2 since 1.5.0 setMethod("atan2", signature(y = "Column"), function(y, x) { if (class(x) == "Column") { @@ -1559,13 +1975,18 @@ setMethod("atan2", signature(y = "Column"), #' datediff #' -#' Returns the number of days from `start` to `end`. +#' Returns the number of days from \code{start} to \code{end}. +#' +#' @param x start Column to use. +#' @param y end Column to use. #' #' @rdname datediff #' @name datediff +#' @aliases datediff,Column-method #' @family datetime_funcs #' @export #' @examples \dontrun{datediff(df$c, x)} +#' @note datediff since 1.5.0 setMethod("datediff", signature(y = "Column"), function(y, x) { if (class(x) == "Column") { @@ -1577,13 +1998,18 @@ setMethod("datediff", signature(y = "Column"), #' hypot #' -#' Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. +#' Computes "sqrt(a^2 + b^2)" without intermediate overflow or underflow. +# +#' @param x Column to compute on. +#' @param y Column to compute on. #' #' @rdname hypot #' @name hypot #' @family math_funcs +#' @aliases hypot,Column-method #' @export #' @examples \dontrun{hypot(df$c, x)} +#' @note hypot since 1.4.0 setMethod("hypot", signature(y = "Column"), function(y, x) { if (class(x) == "Column") { @@ -1597,11 +2023,16 @@ setMethod("hypot", signature(y = "Column"), #' #' Computes the Levenshtein distance of the two given string columns. #' +#' @param x Column to compute on. +#' @param y Column to compute on. +#' #' @rdname levenshtein #' @name levenshtein #' @family string_funcs +#' @aliases levenshtein,Column-method #' @export #' @examples \dontrun{levenshtein(df$c, x)} +#' @note levenshtein since 1.5.0 setMethod("levenshtein", signature(y = "Column"), function(y, x) { if (class(x) == "Column") { @@ -1613,13 +2044,18 @@ setMethod("levenshtein", signature(y = "Column"), #' months_between #' -#' Returns number of months between dates `date1` and `date2`. +#' Returns number of months between dates \code{date1} and \code{date2}. +#' +#' @param x start Column to use. +#' @param y end Column to use. #' #' @rdname months_between #' @name months_between #' @family datetime_funcs +#' @aliases months_between,Column-method #' @export #' @examples \dontrun{months_between(df$c, x)} +#' @note months_between since 1.5.0 setMethod("months_between", signature(y = "Column"), function(y, x) { if (class(x) == "Column") { @@ -1632,13 +2068,18 @@ setMethod("months_between", signature(y = "Column"), #' nanvl #' #' Returns col1 if it is not NaN, or col2 if col1 is NaN. -#' hhBoth inputs should be floating point columns (DoubleType or FloatType). +#' Both inputs should be floating point columns (DoubleType or FloatType). +#' +#' @param x first Column. +#' @param y second Column. #' #' @rdname nanvl #' @name nanvl #' @family normal_funcs +#' @aliases nanvl,Column-method #' @export #' @examples \dontrun{nanvl(df$c, x)} +#' @note nanvl since 1.5.0 setMethod("nanvl", signature(y = "Column"), function(y, x) { if (class(x) == "Column") { @@ -1652,12 +2093,17 @@ setMethod("nanvl", signature(y = "Column"), #' #' Returns the positive value of dividend mod divisor. #' +#' @param x divisor Column. +#' @param y dividend Column. +#' #' @rdname pmod #' @name pmod #' @docType methods #' @family math_funcs +#' @aliases pmod,Column-method #' @export #' @examples \dontrun{pmod(df$c, x)} +#' @note pmod since 1.5.0 setMethod("pmod", signature(y = "Column"), function(y, x) { if (class(x) == "Column") { @@ -1668,14 +2114,17 @@ setMethod("pmod", signature(y = "Column"), }) -#' Approx Count Distinct -#' -#' @family agg_funcs #' @rdname approxCountDistinct #' @name approxCountDistinct -#' @return the approximate number of distinct items in a group. +#' +#' @param x Column to compute on. +#' @param rsd maximum estimation error allowed (default = 0.05) +#' @param ... further arguments to be passed to or from other methods. +#' +#' @aliases approxCountDistinct,Column-method #' @export #' @examples \dontrun{approxCountDistinct(df$c, 0.02)} +#' @note approxCountDistinct(Column, numeric) since 1.4.0 setMethod("approxCountDistinct", signature(x = "Column"), function(x, rsd = 0.05) { @@ -1683,14 +2132,19 @@ setMethod("approxCountDistinct", column(jc) }) -#' Count Distinct +#' Count Distinct Values +#' +#' @param x Column to compute on +#' @param ... other columns #' #' @family agg_funcs #' @rdname countDistinct #' @name countDistinct +#' @aliases countDistinct,Column-method #' @return the number of distinct items in a group. #' @export #' @examples \dontrun{countDistinct(df$c)} +#' @note countDistinct since 1.4.0 setMethod("countDistinct", signature(x = "Column"), function(x, ...) { @@ -1708,11 +2162,16 @@ setMethod("countDistinct", #' #' Concatenates multiple input string columns together into a single string column. #' +#' @param x Column to compute on +#' @param ... other columns +#' #' @family string_funcs #' @rdname concat #' @name concat +#' @aliases concat,Column-method #' @export #' @examples \dontrun{concat(df$strings, df$strings2)} +#' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), function(x, ...) { @@ -1729,11 +2188,16 @@ setMethod("concat", #' Returns the greatest value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' +#' @param x Column to compute on +#' @param ... other columns +#' #' @family normal_funcs #' @rdname greatest #' @name greatest +#' @aliases greatest,Column-method #' @export #' @examples \dontrun{greatest(df$c, df$d)} +#' @note greatest since 1.5.0 setMethod("greatest", signature(x = "Column"), function(x, ...) { @@ -1751,11 +2215,16 @@ setMethod("greatest", #' Returns the least value of the list of column names, skipping null values. #' This function takes at least 2 parameters. It will return null if all parameters are null. #' +#' @param x Column to compute on +#' @param ... other columns +#' #' @family normal_funcs #' @rdname least +#' @aliases least,Column-method #' @name least #' @export #' @examples \dontrun{least(df$c, df$d)} +#' @note least since 1.5.0 setMethod("least", signature(x = "Column"), function(x, ...) { @@ -1768,28 +2237,26 @@ setMethod("least", column(jc) }) -#' ceiling -#' -#' Computes the ceiling of the given value. -#' #' @rdname ceil +#' #' @name ceiling +#' @aliases ceiling,Column-method #' @export #' @examples \dontrun{ceiling(df$c)} +#' @note ceiling since 1.5.0 setMethod("ceiling", signature(x = "Column"), function(x) { ceil(x) }) -#' sign +#' @rdname sign #' -#' Computes the signum of the given value. -#' -#' @rdname signum #' @name sign +#' @aliases sign,Column-method #' @export #' @examples \dontrun{sign(df$c)} +#' @note sign since 1.5.0 setMethod("sign", signature(x = "Column"), function(x) { signum(x) @@ -1801,21 +2268,21 @@ setMethod("sign", signature(x = "Column"), #' #' @rdname countDistinct #' @name n_distinct +#' @aliases n_distinct,Column-method #' @export #' @examples \dontrun{n_distinct(df$c)} +#' @note n_distinct since 1.4.0 setMethod("n_distinct", signature(x = "Column"), function(x, ...) { countDistinct(x, ...) }) -#' n -#' -#' Aggregate function: returns the number of items in a group. -#' #' @rdname count #' @name n +#' @aliases n,Column-method #' @export #' @examples \dontrun{n(df$c)} +#' @note n since 1.4.0 setMethod("n", signature(x = "Column"), function(x) { count(x) @@ -1832,11 +2299,16 @@ setMethod("n", signature(x = "Column"), #' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' +#' @param y Column to compute on. +#' @param x date format specification. +#' #' @family datetime_funcs #' @rdname date_format #' @name date_format +#' @aliases date_format,Column,character-method #' @export #' @examples \dontrun{date_format(df$t, 'MM/dd/yyy')} +#' @note date_format since 1.5.0 setMethod("date_format", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) @@ -1847,11 +2319,16 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' #' Assumes given timestamp is UTC and converts to given timezone. #' +#' @param y Column to compute on. +#' @param x time zone to use. +#' #' @family datetime_funcs #' @rdname from_utc_timestamp #' @name from_utc_timestamp +#' @aliases from_utc_timestamp,Column,character-method #' @export #' @examples \dontrun{from_utc_timestamp(df$t, 'PST')} +#' @note from_utc_timestamp since 1.5.0 setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) @@ -1866,11 +2343,15 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' NOTE: The position is not zero based, but 1 based index, returns 0 if substr #' could not be found in str. #' +#' @param y column to check +#' @param x substring to check #' @family string_funcs +#' @aliases instr,Column,character-method #' @rdname instr #' @name instr #' @export #' @examples \dontrun{instr(df$c, 'b')} +#' @note instr since 1.5.0 setMethod("instr", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) @@ -1888,15 +2369,20 @@ setMethod("instr", signature(y = "Column", x = "character"), #' Day of the week parameter is case insensitive, and accepts first three or two characters: #' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". #' +#' @param y Column to compute on. +#' @param x Day of the week string. +#' #' @family datetime_funcs #' @rdname next_day #' @name next_day +#' @aliases next_day,Column,character-method #' @export #' @examples #'\dontrun{ #'next_day(df$d, 'Sun') #'next_day(df$d, 'Sunday') #'} +#' @note next_day since 1.5.0 setMethod("next_day", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) @@ -1907,11 +2393,16 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' #' Assumes given timestamp is in given timezone and converts to UTC. #' +#' @param y Column to compute on +#' @param x timezone to use +#' #' @family datetime_funcs #' @rdname to_utc_timestamp #' @name to_utc_timestamp +#' @aliases to_utc_timestamp,Column,character-method #' @export #' @examples \dontrun{to_utc_timestamp(df$t, 'PST')} +#' @note to_utc_timestamp since 1.5.0 setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) @@ -1922,11 +2413,16 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' #' Returns the date that is numMonths after startDate. #' +#' @param y Column to compute on +#' @param x Number of months to add +#' #' @name add_months #' @family datetime_funcs #' @rdname add_months +#' @aliases add_months,Column,numeric-method #' @export #' @examples \dontrun{add_months(df$d, 1)} +#' @note add_months since 1.5.0 setMethod("add_months", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) @@ -1935,13 +2431,18 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' date_add #' -#' Returns the date that is `days` days after `start` +#' Returns the date that is \code{x} days after +#' +#' @param y Column to compute on +#' @param x Number of days to add #' #' @family datetime_funcs #' @rdname date_add #' @name date_add +#' @aliases date_add,Column,numeric-method #' @export #' @examples \dontrun{date_add(df$d, 1)} +#' @note date_add since 1.5.0 setMethod("date_add", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) @@ -1950,13 +2451,18 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' date_sub #' -#' Returns the date that is `days` days before `start` +#' Returns the date that is \code{x} days before +#' +#' @param y Column to compute on +#' @param x Number of days to substract #' #' @family datetime_funcs #' @rdname date_sub #' @name date_sub +#' @aliases date_sub,Column,numeric-method #' @export #' @examples \dontrun{date_sub(df$d, 1)} +#' @note date_sub since 1.5.0 setMethod("date_sub", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) @@ -1976,8 +2482,10 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' @family string_funcs #' @rdname format_number #' @name format_number +#' @aliases format_number,Column,numeric-method #' @export #' @examples \dontrun{format_number(df$n, 4)} +#' @note format_number since 1.5.0 setMethod("format_number", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -1996,8 +2504,10 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' @family misc_funcs #' @rdname sha2 #' @name sha2 +#' @aliases sha2,Column,numeric-method #' @export #' @examples \dontrun{sha2(df$c, 256)} +#' @note sha2 since 1.5.0 setMethod("sha2", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) @@ -2009,11 +2519,16 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' Shift the given value numBits left. If the given value is a long value, this function #' will return a long value else it will return an integer value. #' +#' @param y column to compute on. +#' @param x number of bits to shift. +#' #' @family math_funcs #' @rdname shiftLeft #' @name shiftLeft +#' @aliases shiftLeft,Column,numeric-method #' @export #' @examples \dontrun{shiftLeft(df$c, 1)} +#' @note shiftLeft since 1.5.0 setMethod("shiftLeft", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -2027,11 +2542,16 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' Shift the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' +#' @param y column to compute on. +#' @param x number of bits to shift. +#' #' @family math_funcs #' @rdname shiftRight #' @name shiftRight +#' @aliases shiftRight,Column,numeric-method #' @export #' @examples \dontrun{shiftRight(df$c, 1)} +#' @note shiftRight since 1.5.0 setMethod("shiftRight", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -2045,11 +2565,16 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' Unsigned shift the given value numBits right. If the given value is a long value, #' it will return a long value else it will return an integer value. #' +#' @param y column to compute on. +#' @param x number of bits to shift. +#' #' @family math_funcs #' @rdname shiftRightUnsigned #' @name shiftRightUnsigned +#' @aliases shiftRightUnsigned,Column,numeric-method #' @export #' @examples \dontrun{shiftRightUnsigned(df$c, 1)} +#' @note shiftRightUnsigned since 1.5.0 setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), function(y, x) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -2063,11 +2588,17 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' Concatenates multiple input string columns together into a single string column, #' using the given separator. #' +#' @param x column to concatenate. +#' @param sep separator to use. +#' @param ... other columns to concatenate. +#' #' @family string_funcs #' @rdname concat_ws #' @name concat_ws +#' @aliases concat_ws,character,Column-method #' @export #' @examples \dontrun{concat_ws('-', df$s, df$d)} +#' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { jcols <- lapply(list(x, ...), function(x) { x@jc }) @@ -2079,11 +2610,17 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' #' Convert a number in a string column from one base to another. #' +#' @param x column to convert. +#' @param fromBase base to convert from. +#' @param toBase base to convert to. +#' #' @family math_funcs #' @rdname conv +#' @aliases conv,Column,numeric,numeric-method #' @name conv #' @export #' @examples \dontrun{conv(df$n, 2, 16)} +#' @note conv since 1.5.0 setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), function(x, fromBase, toBase) { fromBase <- as.integer(fromBase) @@ -2099,11 +2636,14 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' Parses the expression string into the column that it represents, similar to #' SparkDataFrame.selectExpr #' +#' @param x an expression character object to be parsed. #' @family normal_funcs #' @rdname expr +#' @aliases expr,character-method #' @name expr #' @export #' @examples \dontrun{expr('length(name)')} +#' @note expr since 1.5.0 setMethod("expr", signature(x = "character"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) @@ -2114,11 +2654,16 @@ setMethod("expr", signature(x = "character"), #' #' Formats the arguments in printf-style and returns the result as a string column. #' +#' @param format a character object of format strings. +#' @param x a Column. +#' @param ... additional Column(s). #' @family string_funcs #' @rdname format_string #' @name format_string +#' @aliases format_string,character,Column-method #' @export #' @examples \dontrun{format_string('%d %s', df$a, df$b)} +#' @note format_string since 1.5.0 setMethod("format_string", signature(format = "character", x = "Column"), function(format, x, ...) { jcols <- lapply(list(x, ...), function(arg) { arg@jc }) @@ -2134,15 +2679,22 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' representing the timestamp of that moment in the current system time zone in the given #' format. #' +#' @param x a Column of unix timestamp. +#' @param format the target format. See +#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' Customizing Formats} for available options. +#' @param ... further arguments to be passed to or from other methods. #' @family datetime_funcs #' @rdname from_unixtime #' @name from_unixtime +#' @aliases from_unixtime,Column-method #' @export #' @examples #'\dontrun{ #'from_unixtime(df$t) #'from_unixtime(df$t, 'yyyy/MM/dd HH') #'} +#' @note from_unixtime since 1.5.0 setMethod("from_unixtime", signature(x = "Column"), function(x, format = "yyyy-MM-dd HH:mm:ss") { jc <- callJStatic("org.apache.spark.sql.functions", @@ -2158,22 +2710,29 @@ setMethod("from_unixtime", signature(x = "Column"), #' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in #' the order of months are not supported. #' -#' The time column must be of TimestampType. -#' -#' Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid -#' interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. -#' If the `slideDuration` is not provided, the windows will be tumbling windows. -#' -#' The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start -#' window intervals. For example, in order to have hourly tumbling windows that start 15 minutes -#' past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. -#' -#' The output column will be a struct called 'window' by default with the nested columns 'start' -#' and 'end'. -#' +#' @param x a time Column. Must be of TimestampType. +#' @param windowDuration a string specifying the width of the window, e.g. '1 second', +#' '1 day 12 hours', '2 minutes'. Valid interval strings are 'week', +#' 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. Note that +#' the duration is a fixed length of time, and does not vary over time +#' according to a calendar. For example, '1 day' always means 86,400,000 +#' milliseconds, not a calendar day. +#' @param slideDuration a string specifying the sliding interval of the window. Same format as +#' \code{windowDuration}. A new window will be generated every +#' \code{slideDuration}. Must be less than or equal to +#' the \code{windowDuration}. This duration is likewise absolute, and does not +#' vary according to a calendar. +#' @param startTime the offset with respect to 1970-01-01 00:00:00 UTC with which to start +#' window intervals. For example, in order to have hourly tumbling windows +#' that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide +#' \code{startTime} as \code{"15 minutes"}. +#' @param ... further arguments to be passed to or from other methods. +#' @return An output column of struct called 'window' by default with the nested columns 'start' +#' and 'end'. #' @family datetime_funcs #' @rdname window #' @name window +#' @aliases window,Column-method #' @export #' @examples #'\dontrun{ @@ -2185,9 +2744,10 @@ setMethod("from_unixtime", signature(x = "Column"), #' # 09:01:15-09:02:15... #' window(df$time, "1 minute", startTime = "15 seconds") #' -#' # Thirty second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... +#' # Thirty-second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... #' window(df$time, "30 seconds", "10 seconds") #'} +#' @note window since 2.0.0 setMethod("window", signature(x = "Column"), function(x, windowDuration, slideDuration = NULL, startTime = NULL) { stopifnot(is.character(windowDuration)) @@ -2220,13 +2780,19 @@ setMethod("window", signature(x = "Column"), #' NOTE: The position is not zero based, but 1 based index, returns 0 if substr #' could not be found in str. #' +#' @param substr a character string to be matched. +#' @param str a Column where matches are sought for each entry. +#' @param pos start position of search. +#' @param ... further arguments to be passed to or from other methods. #' @family string_funcs #' @rdname locate +#' @aliases locate,character,Column-method #' @name locate #' @export #' @examples \dontrun{locate('b', df$c, 1)} +#' @note locate since 1.5.0 setMethod("locate", signature(substr = "character", str = "Column"), - function(substr, str, pos = 0) { + function(substr, str, pos = 1) { jc <- callJStatic("org.apache.spark.sql.functions", "locate", substr, str@jc, as.integer(pos)) @@ -2237,11 +2803,16 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' #' Left-pad the string column with #' +#' @param x the string Column to be left-padded. +#' @param len maximum length of each output result. +#' @param pad a character string to be padded with. #' @family string_funcs #' @rdname lpad +#' @aliases lpad,Column,numeric,character-method #' @name lpad #' @export #' @examples \dontrun{lpad(df$c, 6, '#')} +#' @note lpad since 1.5.0 setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -2254,11 +2825,14 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' #' Generate a random column with i.i.d. samples from U[0.0, 1.0]. #' +#' @param seed a random seed. Can be missing. #' @family normal_funcs #' @rdname rand #' @name rand +#' @aliases rand,missing-method #' @export #' @examples \dontrun{rand()} +#' @note rand since 1.5.0 setMethod("rand", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "rand") @@ -2267,7 +2841,9 @@ setMethod("rand", signature(seed = "missing"), #' @rdname rand #' @name rand +#' @aliases rand,numeric-method #' @export +#' @note rand(numeric) since 1.5.0 setMethod("rand", signature(seed = "numeric"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "rand", as.integer(seed)) @@ -2278,11 +2854,14 @@ setMethod("rand", signature(seed = "numeric"), #' #' Generate a column with i.i.d. samples from the standard normal distribution. #' +#' @param seed a random seed. Can be missing. #' @family normal_funcs #' @rdname randn #' @name randn +#' @aliases randn,missing-method #' @export #' @examples \dontrun{randn()} +#' @note randn since 1.5.0 setMethod("randn", signature(seed = "missing"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "randn") @@ -2291,7 +2870,9 @@ setMethod("randn", signature(seed = "missing"), #' @rdname randn #' @name randn +#' @aliases randn,numeric-method #' @export +#' @note randn(numeric) since 1.5.0 setMethod("randn", signature(seed = "numeric"), function(seed) { jc <- callJStatic("org.apache.spark.sql.functions", "randn", as.integer(seed)) @@ -2300,13 +2881,19 @@ setMethod("randn", signature(seed = "numeric"), #' regexp_extract #' -#' Extract a specific(idx) group identified by a java regex, from the specified string column. +#' Extract a specific \code{idx} group identified by a Java regex, from the specified string column. +#' If the regex did not match, or the specified group did not match, an empty string is returned. #' +#' @param x a string Column. +#' @param pattern a regular expression. +#' @param idx a group index. #' @family string_funcs #' @rdname regexp_extract #' @name regexp_extract +#' @aliases regexp_extract,Column,character,numeric-method #' @export #' @examples \dontrun{regexp_extract(df$c, '(\d+)-(\d+)', 1)} +#' @note regexp_extract since 1.5.0 setMethod("regexp_extract", signature(x = "Column", pattern = "character", idx = "numeric"), function(x, pattern, idx) { @@ -2320,11 +2907,16 @@ setMethod("regexp_extract", #' #' Replace all substrings of the specified string value that match regexp with rep. #' +#' @param x a string Column. +#' @param pattern a regular expression. +#' @param replacement a character string that a matched \code{pattern} is replaced with. #' @family string_funcs #' @rdname regexp_replace #' @name regexp_replace +#' @aliases regexp_replace,Column,character,character-method #' @export #' @examples \dontrun{regexp_replace(df$c, '(\\d+)', '--')} +#' @note regexp_replace since 1.5.0 setMethod("regexp_replace", signature(x = "Column", pattern = "character", replacement = "character"), function(x, pattern, replacement) { @@ -2338,11 +2930,16 @@ setMethod("regexp_replace", #' #' Right-padded with pad to a length of len. #' +#' @param x the string Column to be right-padded. +#' @param len maximum length of each output result. +#' @param pad a character string to be padded with. #' @family string_funcs #' @rdname rpad #' @name rpad +#' @aliases rpad,Column,numeric,character-method #' @export #' @examples \dontrun{rpad(df$c, 6, '#')} +#' @note rpad since 1.5.0 setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), function(x, len, pad) { jc <- callJStatic("org.apache.spark.sql.functions", @@ -2358,8 +2955,14 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' returned. If count is negative, every to the right of the final delimiter (counting from the #' right) is returned. substring_index performs a case-sensitive match when searching for delim. #' +#' @param x a Column. +#' @param delim a delimiter string. +#' @param count number of occurrences of \code{delim} before the substring is returned. +#' A positive number means counting from the left, while negative means +#' counting from the right. #' @family string_funcs #' @rdname substring_index +#' @aliases substring_index,Column,character,numeric-method #' @name substring_index #' @export #' @examples @@ -2367,6 +2970,7 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #'substring_index(df$c, '.', 2) #'substring_index(df$c, '.', -1) #'} +#' @note substring_index since 1.5.0 setMethod("substring_index", signature(x = "Column", delim = "character", count = "numeric"), function(x, delim, count) { @@ -2383,11 +2987,18 @@ setMethod("substring_index", #' The translate will happen when any character in the string matching with the character #' in the matchingString. #' +#' @param x a string Column. +#' @param matchingString a source string where each character will be translated. +#' @param replaceString a target string where each \code{matchingString} character will +#' be replaced by the character in \code{replaceString} +#' at the same location, if any. #' @family string_funcs #' @rdname translate #' @name translate +#' @aliases translate,Column,character,character-method #' @export #' @examples \dontrun{translate(df$c, 'rnlt', '123')} +#' @note translate since 1.5.0 setMethod("translate", signature(x = "Column", matchingString = "character", replaceString = "character"), function(x, matchingString, replaceString) { @@ -2403,6 +3014,7 @@ setMethod("translate", #' @family datetime_funcs #' @rdname unix_timestamp #' @name unix_timestamp +#' @aliases unix_timestamp,missing,missing-method #' @export #' @examples #'\dontrun{ @@ -2410,6 +3022,7 @@ setMethod("translate", #'unix_timestamp(df$t) #'unix_timestamp(df$t, 'yyyy-MM-dd HH') #'} +#' @note unix_timestamp since 1.5.0 setMethod("unix_timestamp", signature(x = "missing", format = "missing"), function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") @@ -2418,16 +3031,24 @@ setMethod("unix_timestamp", signature(x = "missing", format = "missing"), #' @rdname unix_timestamp #' @name unix_timestamp +#' @aliases unix_timestamp,Column,missing-method #' @export +#' @note unix_timestamp(Column) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "missing"), function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) column(jc) }) +#' @param x a Column of date, in string, date or timestamp type. +#' @param format the target format. See +#' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' Customizing Formats} for available options. #' @rdname unix_timestamp #' @name unix_timestamp +#' @aliases unix_timestamp,Column,character-method #' @export +#' @note unix_timestamp(Column, character) since 1.5.0 setMethod("unix_timestamp", signature(x = "Column", format = "character"), function(x, format = "yyyy-MM-dd HH:mm:ss") { jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) @@ -2438,12 +3059,16 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' Evaluates a list of conditions and returns one of multiple possible result expressions. #' For unmatched expressions null is returned. #' +#' @param condition the condition to test on. Must be a Column expression. +#' @param value result expression. #' @family normal_funcs #' @rdname when #' @name when +#' @aliases when,Column-method #' @seealso \link{ifelse} #' @export #' @examples \dontrun{when(df$age == 2, df$age + 1)} +#' @note when since 1.5.0 setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { condition <- condition@jc @@ -2457,15 +3082,20 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. #' Otherwise \code{no} is returned for unmatched conditions. #' +#' @param test a Column expression that describes the condition. +#' @param yes return values for \code{TRUE} elements of test. +#' @param no return values for \code{FALSE} elements of test. #' @family normal_funcs #' @rdname ifelse #' @name ifelse +#' @aliases ifelse,Column-method #' @seealso \link{when} #' @export #' @examples \dontrun{ #' ifelse(df$a > 1 & df$b > 2, 0, 1) #' ifelse(df$a > 1, df$a, 1) #' } +#' @note ifelse since 1.5.0 setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), function(test, yes, no) { @@ -2489,15 +3119,21 @@ setMethod("ifelse", #' N = total number of rows in the partition #' cume_dist(x) = number of values before (and including) x / N #' -#' This is equivalent to the CUME_DIST function in SQL. +#' This is equivalent to the \code{CUME_DIST} function in SQL. #' #' @rdname cume_dist #' @name cume_dist #' @family window_funcs +#' @aliases cume_dist,missing-method #' @export -#' @examples \dontrun{cume_dist()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(cume_dist(), ws), df$hp, df$am) +#' } +#' @note cume_dist since 1.6.0 setMethod("cume_dist", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "cume_dist") column(jc) @@ -2511,15 +3147,21 @@ setMethod("cume_dist", #' and had three people tie for second place, you would say that all three were in second #' place and that the next person came in third. #' -#' This is equivalent to the DENSE_RANK function in SQL. +#' This is equivalent to the \code{DENSE_RANK} function in SQL. #' #' @rdname dense_rank #' @name dense_rank #' @family window_funcs +#' @aliases dense_rank,missing-method #' @export -#' @examples \dontrun{dense_rank()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(dense_rank(), ws), df$hp, df$am) +#' } +#' @note dense_rank since 1.6.0 setMethod("dense_rank", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "dense_rank") column(jc) @@ -2527,20 +3169,35 @@ setMethod("dense_rank", #' lag #' -#' Window function: returns the value that is `offset` rows before the current row, and -#' `defaultValue` if there is less than `offset` rows before the current row. For example, -#' an `offset` of one will return the previous row at any given point in the window partition. +#' Window function: returns the value that is \code{offset} rows before the current row, and +#' \code{defaultValue} if there is less than \code{offset} rows before the current row. For example, +#' an \code{offset} of one will return the previous row at any given point in the window partition. #' -#' This is equivalent to the LAG function in SQL. +#' This is equivalent to the \code{LAG} function in SQL. #' +#' @param x the column as a character string or a Column to compute on. +#' @param offset the number of rows back from the current row from which to obtain a value. +#' If not specified, the default is 1. +#' @param defaultValue (optional) default to use when the offset row does not exist. +#' @param ... further arguments to be passed to or from other methods. #' @rdname lag #' @name lag +#' @aliases lag,characterOrColumn-method #' @family window_funcs #' @export -#' @examples \dontrun{lag(df$c)} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # Partition by am (transmission) and order by hp (horsepower) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' +#' # Lag mpg values by 1 row on the partition-and-ordered table +#' out <- select(df, over(lag(df$mpg), ws), df$mpg, df$hp, df$am) +#' } +#' @note lag since 1.6.0 setMethod("lag", signature(x = "characterOrColumn"), - function(x, offset, defaultValue = NULL) { + function(x, offset = 1, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc } else { @@ -2554,20 +3211,36 @@ setMethod("lag", #' lead #' -#' Window function: returns the value that is `offset` rows after the current row, and -#' `null` if there is less than `offset` rows after the current row. For example, -#' an `offset` of one will return the next row at any given point in the window partition. +#' Window function: returns the value that is \code{offset} rows after the current row, and +#' \code{defaultValue} if there is less than \code{offset} rows after the current row. +#' For example, an \code{offset} of one will return the next row at any given point +#' in the window partition. +#' +#' This is equivalent to the \code{LEAD} function in SQL. #' -#' This is equivalent to the LEAD function in SQL. +#' @param x the column as a character string or a Column to compute on. +#' @param offset the number of rows after the current row from which to obtain a value. +#' If not specified, the default is 1. +#' @param defaultValue (optional) default to use when the offset row does not exist. #' #' @rdname lead #' @name lead #' @family window_funcs +#' @aliases lead,characterOrColumn,numeric-method #' @export -#' @examples \dontrun{lead(df$c)} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # Partition by am (transmission) and order by hp (horsepower) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' +#' # Lead mpg values by 1 row on the partition-and-ordered table +#' out <- select(df, over(lead(df$mpg), ws), df$mpg, df$hp, df$am) +#' } +#' @note lead since 1.6.0 setMethod("lead", signature(x = "characterOrColumn", offset = "numeric", defaultValue = "ANY"), - function(x, offset, defaultValue = NULL) { + function(x, offset = 1, defaultValue = NULL) { col <- if (class(x) == "Column") { x@jc } else { @@ -2581,17 +3254,29 @@ setMethod("lead", #' ntile #' -#' Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window -#' partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second +#' Window function: returns the ntile group id (from 1 to n inclusive) in an ordered window +#' partition. For example, if n is 4, the first quarter of the rows will get value 1, the second #' quarter will get 2, the third quarter will get 3, and the last quarter will get 4. #' -#' This is equivalent to the NTILE function in SQL. +#' This is equivalent to the \code{NTILE} function in SQL. +#' +#' @param x Number of ntile groups #' #' @rdname ntile #' @name ntile +#' @aliases ntile,numeric-method #' @family window_funcs #' @export -#' @examples \dontrun{ntile(1)} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # Partition by am (transmission) and order by hp (horsepower) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' +#' # Get ntile group id (1-4) for hp +#' out <- select(df, over(ntile(4), ws), df$hp, df$am) +#' } +#' @note ntile since 1.6.0 setMethod("ntile", signature(x = "numeric"), function(x) { @@ -2612,10 +3297,16 @@ setMethod("ntile", #' @rdname percent_rank #' @name percent_rank #' @family window_funcs +#' @aliases percent_rank,missing-method #' @export -#' @examples \dontrun{percent_rank()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(percent_rank(), ws), df$hp, df$am) +#' } +#' @note percent_rank since 1.6.0 setMethod("percent_rank", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "percent_rank") column(jc) @@ -2635,8 +3326,14 @@ setMethod("percent_rank", #' @rdname rank #' @name rank #' @family window_funcs +#' @aliases rank,missing-method #' @export -#' @examples \dontrun{rank()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(rank(), ws), df$hp, df$am) +#' } +#' @note rank since 1.6.0 setMethod("rank", signature(x = "missing"), function() { @@ -2645,6 +3342,12 @@ setMethod("rank", }) # Expose rank() in the R base package +#' @param x a numeric, complex, character or logical vector. +#' @param ... additional argument(s) passed to the method. +#' @name rank +#' @rdname rank +#' @aliases rank,ANY-method +#' @export setMethod("rank", signature(x = "ANY"), function(x, ...) { @@ -2659,11 +3362,17 @@ setMethod("rank", #' #' @rdname row_number #' @name row_number +#' @aliases row_number,missing-method #' @family window_funcs #' @export -#' @examples \dontrun{row_number()} +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' ws <- orderBy(windowPartitionBy("am"), "hp") +#' out <- select(df, over(row_number(), ws), df$hp, df$am) +#' } +#' @note row_number since 1.6.0 setMethod("row_number", - signature(x = "missing"), + signature("missing"), function() { jc <- callJStatic("org.apache.spark.sql.functions", "row_number") column(jc) @@ -2678,10 +3387,12 @@ setMethod("row_number", #' @param x A Column #' @param value A value to be checked if contained in the column #' @rdname array_contains +#' @aliases array_contains,Column-method #' @name array_contains #' @family collection_funcs #' @export #' @examples \dontrun{array_contains(df$c, 1)} +#' @note array_contains since 1.6.0 setMethod("array_contains", signature(x = "Column", value = "ANY"), function(x, value) { @@ -2693,11 +3404,15 @@ setMethod("array_contains", #' #' Creates a new row for each element in the given array or map column. #' +#' @param x Column to compute on +#' #' @rdname explode #' @name explode #' @family collection_funcs +#' @aliases explode,Column-method #' @export #' @examples \dontrun{explode(df$c)} +#' @note explode since 1.5.0 setMethod("explode", signature(x = "Column"), function(x) { @@ -2709,11 +3424,15 @@ setMethod("explode", #' #' Returns length of array or map. #' +#' @param x Column to compute on +#' #' @rdname size #' @name size +#' @aliases size,Column-method #' @family collection_funcs #' @export #' @examples \dontrun{size(df$c)} +#' @note size since 1.5.0 setMethod("size", signature(x = "Column"), function(x) { @@ -2732,6 +3451,7 @@ setMethod("size", #' FALSE, sorting is in descending order. #' @rdname sort_array #' @name sort_array +#' @aliases sort_array,Column-method #' @family collection_funcs #' @export #' @examples @@ -2739,9 +3459,30 @@ setMethod("size", #' sort_array(df$c) #' sort_array(df$c, FALSE) #' } +#' @note sort_array since 1.6.0 setMethod("sort_array", signature(x = "Column"), function(x, asc = TRUE) { jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc) column(jc) }) + +#' posexplode +#' +#' Creates a new row for each element with position in the given array or map column. +#' +#' @param x Column to compute on +#' +#' @rdname posexplode +#' @name posexplode +#' @family collection_funcs +#' @aliases posexplode,Column-method +#' @export +#' @examples \dontrun{posexplode(df$c)} +#' @note posexplode since 2.1.0 +setMethod("posexplode", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index f936ea6039981..b54a92a3c6ddd 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -23,9 +23,7 @@ setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") }) -# @rdname cache-methods -# @export -setGeneric("cache", function(x) { standardGeneric("cache") }) +setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") }) # @rdname coalesce # @seealso repartition @@ -36,9 +34,7 @@ setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coales # @export setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) -# @rdname collect-methods -# @export -setGeneric("collect", function(x, ...) { standardGeneric("collect") }) +setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) # @rdname collect-methods # @export @@ -51,40 +47,36 @@ setGeneric("collectPartition", standardGeneric("collectPartition") }) -# @rdname count -# @export -setGeneric("count", function(x) { standardGeneric("count") }) +setGeneric("countRDD", function(x) { standardGeneric("countRDD") }) + +setGeneric("lengthRDD", function(x) { standardGeneric("lengthRDD") }) # @rdname countByValue # @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) -# @rdname statfunctions +# @rdname crosstab # @export setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) -# @rdname statfunctions +# @rdname freqItems # @export setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("freqItems") }) -# @rdname statfunctions +# @rdname approxQuantile # @export setGeneric("approxQuantile", function(x, col, probabilities, relativeError) { standardGeneric("approxQuantile") }) -# @rdname distinct -# @export -setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) +setGeneric("distinctRDD", function(x, numPartitions = 1) { standardGeneric("distinctRDD") }) # @rdname filterRDD # @export setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") }) -# @rdname first -# @export -setGeneric("first", function(x, ...) { standardGeneric("first") }) +setGeneric("firstRDD", function(x, ...) { standardGeneric("firstRDD") }) # @rdname flatMap # @export @@ -110,6 +102,8 @@ setGeneric("glom", function(x) { standardGeneric("glom") }) # @export setGeneric("histogram", function(df, col, nbins=10) { standardGeneric("histogram") }) +setGeneric("joinRDD", function(x, y, ...) { standardGeneric("joinRDD") }) + # @rdname keyBy # @export setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") }) @@ -152,22 +146,21 @@ setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") # @export setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") }) -# @rdname persist -# @export -setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) +setGeneric("persistRDD", function(x, newLevel) { standardGeneric("persistRDD") }) # @rdname pipeRDD # @export setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")}) +# @rdname pivot +# @export +setGeneric("pivot", function(x, colname, values = list()) { standardGeneric("pivot") }) + # @rdname reduce # @export setGeneric("reduce", function(x, func) { standardGeneric("reduce") }) -# @rdname repartition -# @seealso coalesce -# @export -setGeneric("repartition", function(x, numPartitions) { standardGeneric("repartition") }) +setGeneric("repartitionRDD", function(x, ...) { standardGeneric("repartitionRDD") }) # @rdname sampleRDD # @export @@ -189,6 +182,8 @@ setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile # @export setGeneric("setName", function(x, name) { standardGeneric("setName") }) +setGeneric("showRDD", function(object, ...) { standardGeneric("showRDD") }) + # @rdname sortBy # @export setGeneric("sortBy", @@ -196,9 +191,7 @@ setGeneric("sortBy", standardGeneric("sortBy") }) -# @rdname take -# @export -setGeneric("take", function(x, num) { standardGeneric("take") }) +setGeneric("takeRDD", function(x, num) { standardGeneric("takeRDD") }) # @rdname takeOrdered # @export @@ -219,9 +212,7 @@ setGeneric("top", function(x, num) { standardGeneric("top") }) # @export setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") }) -# @rdname unpersist-methods -# @export -setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) +setGeneric("unpersistRDD", function(x, ...) { standardGeneric("unpersistRDD") }) # @rdname zipRDD # @export @@ -339,9 +330,7 @@ setGeneric("join", function(x, y, ...) { standardGeneric("join") }) # @export setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") }) -# @rdname partitionBy -# @export -setGeneric("partitionBy", function(x, numPartitions, ...) { standardGeneric("partitionBy") }) +setGeneric("partitionByRDD", function(x, ...) { standardGeneric("partitionByRDD") }) # @rdname reduceByKey # @seealso groupByKey @@ -391,7 +380,10 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #################### SparkDataFrame Methods ######################## -#' @rdname agg +#' @param x a SparkDataFrame or GroupedData. +#' @param ... further arguments to be passed to or from other methods. +#' @return A SparkDataFrame. +#' @rdname summarize #' @export setGeneric("agg", function (x, ...) { standardGeneric("agg") }) @@ -410,6 +402,16 @@ setGeneric("as.data.frame", #' @export setGeneric("attach") +#' @rdname cache +#' @export +setGeneric("cache", function(x) { standardGeneric("cache") }) + +#' @rdname collect +#' @export +setGeneric("collect", function(x, ...) { standardGeneric("collect") }) + +#' @param do.NULL currently not used. +#' @param prefix currently not used. #' @rdname columns #' @export setGeneric("colnames", function(x, do.NULL = TRUE, prefix = "col") { standardGeneric("colnames") }) @@ -426,44 +428,81 @@ setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) #' @export setGeneric("coltypes<-", function(x, value) { standardGeneric("coltypes<-") }) -#' @rdname schema +#' @rdname columns #' @export setGeneric("columns", function(x) {standardGeneric("columns") }) -#' @rdname statfunctions +#' @param x a GroupedData or Column. +#' @rdname count +#' @export +setGeneric("count", function(x) { standardGeneric("count") }) + +#' @rdname cov +#' @param x a Column or a SparkDataFrame. +#' @param ... additional argument(s). If \code{x} is a Column, a Column +#' should be provided. If \code{x} is a SparkDataFrame, two column names should +#' be provided. #' @export setGeneric("cov", function(x, ...) {standardGeneric("cov") }) -#' @rdname statfunctions +#' @rdname corr +#' @param x a Column or a SparkDataFrame. +#' @param ... additional argument(s). If \code{x} is a Column, a Column +#' should be provided. If \code{x} is a SparkDataFrame, two column names should +#' be provided. #' @export setGeneric("corr", function(x, ...) {standardGeneric("corr") }) -#' @rdname statfunctions +#' @rdname cov #' @export setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") }) -#' @rdname statfunctions +#' @rdname covar_pop #' @export setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") }) +#' @rdname createOrReplaceTempView +#' @export +setGeneric("createOrReplaceTempView", + function(x, viewName) { + standardGeneric("createOrReplaceTempView") + }) + #' @rdname dapply #' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) +#' @rdname dapplyCollect +#' @export +setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") }) + +#' @param x a SparkDataFrame or GroupedData. +#' @param ... additional argument(s) passed to the method. +#' @rdname gapply +#' @export +setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) + +#' @param x a SparkDataFrame or GroupedData. +#' @param ... additional argument(s) passed to the method. +#' @rdname gapplyCollect +#' @export +setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") }) + #' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname distinct +#' @export +setGeneric("distinct", function(x) { standardGeneric("distinct") }) + #' @rdname drop #' @export setGeneric("drop", function(x, ...) { standardGeneric("drop") }) -#' @rdname dropduplicates +#' @rdname dropDuplicates #' @export -setGeneric("dropDuplicates", - function(x, colNames = columns(x)) { - standardGeneric("dropDuplicates") - }) +setGeneric("dropDuplicates", function(x, ...) { standardGeneric("dropDuplicates") }) #' @rdname nafunctions #' @export @@ -479,7 +518,7 @@ setGeneric("na.omit", standardGeneric("na.omit") }) -#' @rdname schema +#' @rdname dtypes #' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) @@ -499,6 +538,10 @@ setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) +#' @rdname first +#' @export +setGeneric("first", function(x, ...) { standardGeneric("first") }) + #' @rdname groupBy #' @export setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) @@ -531,21 +574,29 @@ setGeneric("merge") #' @export setGeneric("mutate", function(.data, ...) {standardGeneric("mutate") }) -#' @rdname arrange +#' @rdname orderBy #' @export -setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") }) +setGeneric("orderBy", function(x, col, ...) { standardGeneric("orderBy") }) -#' @rdname schema +#' @rdname persist +#' @export +setGeneric("persist", function(x, newLevel) { standardGeneric("persist") }) + +#' @rdname printSchema #' @export setGeneric("printSchema", function(x) { standardGeneric("printSchema") }) +#' @rdname registerTempTable-deprecated +#' @export +setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) + #' @rdname rename #' @export setGeneric("rename", function(x, ...) { standardGeneric("rename") }) -#' @rdname registerTempTable +#' @rdname repartition #' @export -setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) +setGeneric("repartition", function(x, ...) { standardGeneric("repartition") }) #' @rdname sample #' @export @@ -559,7 +610,7 @@ setGeneric("sample", setGeneric("sample_frac", function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) -#' @rdname statfunctions +#' @rdname sampleBy #' @export setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") }) @@ -572,6 +623,10 @@ setGeneric("saveAsTable", function(df, tableName, source = NULL, mode = "error", #' @export setGeneric("str") +#' @rdname take +#' @export +setGeneric("take", function(x, num) { standardGeneric("take") }) + #' @rdname mutate #' @export setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") }) @@ -598,6 +653,10 @@ setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { #' @export setGeneric("write.json", function(x, path) { standardGeneric("write.json") }) +#' @rdname write.orc +#' @export +setGeneric("write.orc", function(x, path) { standardGeneric("write.orc") }) + #' @rdname write.parquet #' @export setGeneric("write.parquet", function(x, path) { standardGeneric("write.parquet") }) @@ -618,7 +677,7 @@ setGeneric("schema", function(x) { standardGeneric("schema") }) #' @export setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) -#' @rdname select +#' @rdname selectExpr #' @export setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") }) @@ -626,11 +685,11 @@ setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") #' @export setGeneric("showDF", function(x, ...) { standardGeneric("showDF") }) -# @rdname subset -# @export +#' @rdname subset +#' @export setGeneric("subset", function(x, ...) { standardGeneric("subset") }) -#' @rdname agg +#' @rdname summarize #' @export setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) @@ -642,10 +701,18 @@ setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) -#' @rdname rbind +#' @rdname union +#' @export +setGeneric("union", function(x, y) { standardGeneric("union") }) + +#' @rdname union #' @export setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") }) +#' @rdname unpersist-methods +#' @export +setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) + #' @rdname filter #' @export setGeneric("where", function(x, condition) { standardGeneric("where") }) @@ -667,72 +734,103 @@ setGeneric("withColumnRenamed", #' @export setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) +#' @rdname randomSplit +#' @export +setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) + ###################### Column Methods ########################## -#' @rdname column +#' @rdname columnfunctions #' @export setGeneric("asc", function(x) { standardGeneric("asc") }) -#' @rdname column +#' @rdname between #' @export setGeneric("between", function(x, bounds) { standardGeneric("between") }) -#' @rdname column +#' @rdname cast #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) -#' @rdname column +#' @rdname columnfunctions +#' @param x a Column object. +#' @param ... additional argument(s). #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) -#' @rdname column +#' @rdname columnfunctions #' @export setGeneric("desc", function(x) { standardGeneric("desc") }) -#' @rdname column +#' @rdname endsWith #' @export -setGeneric("endsWith", function(x, ...) { standardGeneric("endsWith") }) +setGeneric("endsWith", function(x, suffix) { standardGeneric("endsWith") }) -#' @rdname column +#' @rdname columnfunctions #' @export setGeneric("getField", function(x, ...) { standardGeneric("getField") }) -#' @rdname column +#' @rdname columnfunctions #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) -#' @rdname column +#' @rdname columnfunctions #' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) -#' @rdname column +#' @rdname columnfunctions #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) -#' @rdname column +#' @rdname columnfunctions #' @export setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") }) -#' @rdname column +#' @rdname columnfunctions #' @export setGeneric("like", function(x, ...) { standardGeneric("like") }) -#' @rdname column +#' @rdname columnfunctions #' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) -#' @rdname column +#' @rdname startsWith #' @export -setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) +setGeneric("startsWith", function(x, prefix) { standardGeneric("startsWith") }) -#' @rdname column +#' @rdname when #' @export setGeneric("when", function(condition, value) { standardGeneric("when") }) -#' @rdname column +#' @rdname otherwise #' @export setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) +#' @rdname over +#' @export +setGeneric("over", function(x, window) { standardGeneric("over") }) + +###################### WindowSpec Methods ########################## + +#' @rdname partitionBy +#' @export +setGeneric("partitionBy", function(x, ...) { standardGeneric("partitionBy") }) + +#' @rdname rowsBetween +#' @export +setGeneric("rowsBetween", function(x, start, end) { standardGeneric("rowsBetween") }) + +#' @rdname rangeBetween +#' @export +setGeneric("rangeBetween", function(x, start, end) { standardGeneric("rangeBetween") }) + +#' @rdname windowPartitionBy +#' @export +setGeneric("windowPartitionBy", function(col, ...) { standardGeneric("windowPartitionBy") }) + +#' @rdname windowOrderBy +#' @export +setGeneric("windowOrderBy", function(col, ...) { standardGeneric("windowOrderBy") }) ###################### Expression Function Methods ########################## @@ -752,6 +850,8 @@ setGeneric("array_contains", function(x, value) { standardGeneric("array_contain #' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) +#' @param x Column to compute on or a GroupedData object. +#' @param ... additional argument(s) when \code{x} is a GroupedData object. #' @rdname avg #' @export setGeneric("avg", function(x, ...) { standardGeneric("avg") }) @@ -780,7 +880,7 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) -#' @rdname col +#' @rdname column #' @export setGeneric("column", function(x) { standardGeneric("column") }) @@ -808,9 +908,10 @@ setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @export setGeneric("hash", function(x, ...) { standardGeneric("hash") }) +#' @param x empty. Should be used with no argument. #' @rdname cume_dist #' @export -setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) +setGeneric("cume_dist", function(x = "missing") { standardGeneric("cume_dist") }) #' @rdname datediff #' @export @@ -840,9 +941,10 @@ setGeneric("dayofyear", function(x) { standardGeneric("dayofyear") }) #' @export setGeneric("decode", function(x, charset) { standardGeneric("decode") }) +#' @param x empty. Should be used with no argument. #' @rdname dense_rank #' @export -setGeneric("dense_rank", function(x) { standardGeneric("dense_rank") }) +setGeneric("dense_rank", function(x = "missing") { standardGeneric("dense_rank") }) #' @rdname encode #' @export @@ -956,6 +1058,12 @@ setGeneric("md5", function(x) { standardGeneric("md5") }) #' @export setGeneric("minute", function(x) { standardGeneric("minute") }) +#' @param x empty. Should be used with no argument. +#' @rdname monotonically_increasing_id +#' @export +setGeneric("monotonically_increasing_id", + function(x = "missing") { standardGeneric("monotonically_increasing_id") }) + #' @rdname month #' @export setGeneric("month", function(x) { standardGeneric("month") }) @@ -988,14 +1096,19 @@ setGeneric("ntile", function(x) { standardGeneric("ntile") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) +#' @param x empty. Should be used with no argument. #' @rdname percent_rank #' @export -setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") }) +setGeneric("percent_rank", function(x = "missing") { standardGeneric("percent_rank") }) #' @rdname pmod #' @export setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) +#' @rdname posexplode +#' @export +setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) + #' @rdname quarter #' @export setGeneric("quarter", function(x) { standardGeneric("quarter") }) @@ -1027,11 +1140,12 @@ setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname rint #' @export -setGeneric("rint", function(x, ...) { standardGeneric("rint") }) +setGeneric("rint", function(x) { standardGeneric("rint") }) +#' @param x empty. Should be used with no argument. #' @rdname row_number #' @export -setGeneric("row_number", function(x) { standardGeneric("row_number") }) +setGeneric("row_number", function(x = "missing") { standardGeneric("row_number") }) #' @rdname rpad #' @export @@ -1069,7 +1183,7 @@ setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) #' @export setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) -#' @rdname signum +#' @rdname sign #' @export setGeneric("signum", function(x) { standardGeneric("signum") }) @@ -1089,6 +1203,11 @@ setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) +#' @param x empty. Should be used with no argument. +#' @rdname spark_partition_id +#' @export +setGeneric("spark_partition_id", function(x = "missing") { standardGeneric("spark_partition_id") }) + #' @rdname sd #' @export setGeneric("stddev", function(x) { standardGeneric("stddev") }) @@ -1185,10 +1304,16 @@ setGeneric("year", function(x) { standardGeneric("year") }) #' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) +#' @param x,y For \code{glm}: logical values indicating whether the response vector +#' and model matrix used in the fitting process should be returned as +#' components of the returned value. +#' @inheritParams stats::glm #' @rdname glm #' @export setGeneric("glm") +#' @param object a fitted ML model object. +#' @param ... additional argument(s) passed to the method. #' @rdname predict #' @export setGeneric("predict", function(object, ...) { standardGeneric("predict") }) @@ -1211,8 +1336,11 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s #' @rdname spark.survreg #' @export -setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) +setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +#' @param object a fitted ML model object. +#' @param path the directory where the model is saved. +#' @param ... additional argument(s) passed to the method. #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 08f4a490c883e..17f5283abead1 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -22,13 +22,16 @@ NULL setOldClass("jobj") -#' @title S4 class that represents a GroupedData -#' @description GroupedDatas can be created using groupBy() on a SparkDataFrame +#' S4 class that represents a GroupedData +#' +#' GroupedDatas can be created using groupBy() on a SparkDataFrame +#' #' @rdname GroupedData #' @seealso groupBy #' #' @param sgd A Java object reference to the backing Scala GroupedData #' @export +#' @note GroupedData since 1.4.0 setClass("GroupedData", slots = list(sgd = "jobj")) @@ -44,6 +47,9 @@ groupedData <- function(sgd) { #' @rdname show +#' @aliases show,GroupedData-method +#' @export +#' @note show(GroupedData) since 1.4.0 setMethod("show", "GroupedData", function(object) { cat("GroupedData\n") @@ -51,17 +57,18 @@ setMethod("show", "GroupedData", #' Count #' -#' Count the number of rows for each group. +#' Count the number of rows for each group when we have \code{GroupedData} input. #' The resulting SparkDataFrame will also contain the grouping columns. #' -#' @param x a GroupedData -#' @return a SparkDataFrame -#' @rdname agg +#' @return A SparkDataFrame. +#' @rdname count +#' @aliases count,GroupedData-method #' @export #' @examples #' \dontrun{ #' count(groupBy(df, "name")) #' } +#' @note count since 1.4.0 setMethod("count", signature(x = "GroupedData"), function(x) { @@ -76,17 +83,18 @@ setMethod("count", #' df2 <- agg(df, = ) #' df2 <- agg(df, newColName = aggFunction(column)) #' -#' @param x a GroupedData -#' @return a SparkDataFrame #' @rdname summarize +#' @aliases agg,GroupedData-method #' @name agg #' @family agg_funcs +#' @export #' @examples #' \dontrun{ #' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)' #' df3 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum #' df4 <- summarize(df, ageSum = max(df$age)) #' } +#' @note agg since 1.4.0 setMethod("agg", signature(x = "GroupedData"), function(x, ...) { @@ -114,6 +122,8 @@ setMethod("agg", #' @rdname summarize #' @name summarize +#' @aliases summarize,GroupedData-method +#' @note summarize since 1.4.0 setMethod("summarize", signature(x = "GroupedData"), function(x, ...) { @@ -126,6 +136,50 @@ methods <- c("avg", "max", "mean", "min", "sum") # These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", # "variance", "var_samp", "var_pop" +#' Pivot a column of the GroupedData and perform the specified aggregation. +#' +#' Pivot a column of the GroupedData and perform the specified aggregation. +#' There are two versions of pivot function: one that requires the caller to specify the list +#' of distinct values to pivot on, and one that does not. The latter is more concise but less +#' efficient, because Spark needs to first compute the list of distinct values internally. +#' +#' @param x a GroupedData object +#' @param colname A column name +#' @param values A value or a list/vector of distinct values for the output columns. +#' @return GroupedData object +#' @rdname pivot +#' @aliases pivot,GroupedData,character-method +#' @name pivot +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(data.frame( +#' earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000), +#' course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"), +#' period = c("1H", "1H", "2H", "2H", "1H", "1H", "2H", "2H"), +#' year = c(2015, 2015, 2015, 2015, 2016, 2016, 2016, 2016) +#' )) +#' group_sum <- sum(pivot(groupBy(df, "year"), "course"), "earnings") +#' group_min <- min(pivot(groupBy(df, "year"), "course", "R"), "earnings") +#' group_max <- max(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings") +#' group_mean <- mean(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings") +#' } +#' @note pivot since 2.0.0 +setMethod("pivot", + signature(x = "GroupedData", colname = "character"), + function(x, colname, values = list()){ + stopifnot(length(colname) == 1) + if (length(values) == 0) { + result <- callJMethod(x@sgd, "pivot", colname) + } else { + if (length(values) > length(unique(values))) { + stop("Values are not unique") + } + result <- callJMethod(x@sgd, "pivot", colname, as.list(values)) + } + groupedData(result) + }) + createMethod <- function(name) { setMethod(name, signature(x = "GroupedData"), @@ -142,3 +196,54 @@ createMethods <- function() { } createMethods() + +#' gapply +#' +#' @rdname gapply +#' @aliases gapply,GroupedData-method +#' @name gapply +#' @export +#' @note gapply(GroupedData) since 2.0.0 +setMethod("gapply", + signature(x = "GroupedData"), + function(x, func, schema) { + if (is.null(schema)) stop("schema cannot be NULL") + gapplyInternal(x, func, schema) + }) + +#' gapplyCollect +#' +#' @rdname gapplyCollect +#' @aliases gapplyCollect,GroupedData-method +#' @name gapplyCollect +#' @export +#' @note gapplyCollect(GroupedData) since 2.0.0 +setMethod("gapplyCollect", + signature(x = "GroupedData"), + function(x, func) { + gdf <- gapplyInternal(x, func, NULL) + content <- callJMethod(gdf@sdf, "collect") + # content is a list of items of struct type. Each item has a single field + # which is a serialized data.frame corresponds to one group of the + # SparkDataFrame. + ldfs <- lapply(content, function(x) { unserialize(x[[1]]) }) + ldf <- do.call(rbind, ldfs) + row.names(ldf) <- NULL + ldf + }) + +gapplyInternal <- function(x, func, schema) { + packageNamesArr <- serialize(.sparkREnv[[".packages"]], + connection = NULL) + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) + sdf <- callJStatic( + "org.apache.spark.sql.api.r.SQLUtils", + "gapply", + x@sgd, + serialize(cleanClosure(func), connection = NULL), + packageNamesArr, + broadcastArr, + if (class(schema) == "structType") { schema$jobj } else { NULL }) + dataFrame(sdf) +} diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R new file mode 100644 index 0000000000000..69b0a523b84e4 --- /dev/null +++ b/R/pkg/R/install.R @@ -0,0 +1,257 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Functions to install Spark in case the user directly downloads SparkR +# from CRAN. + +#' Download and Install Apache Spark to a Local Directory +#' +#' \code{install.spark} downloads and installs Spark to a local directory if +#' it is not found. The Spark version we use is the same as the SparkR version. +#' Users can specify a desired Hadoop version, the remote mirror site, and +#' the directory where the package is installed locally. +#' +#' The full url of remote file is inferred from \code{mirrorUrl} and \code{hadoopVersion}. +#' \code{mirrorUrl} specifies the remote path to a Spark folder. It is followed by a subfolder +#' named after the Spark version (that corresponds to SparkR), and then the tar filename. +#' The filename is composed of four parts, i.e. [Spark version]-bin-[Hadoop version].tgz. +#' For example, the full path for a Spark 2.0.0 package for Hadoop 2.7 from +#' \code{http://apache.osuosl.org} has path: +#' \code{http://apache.osuosl.org/spark/spark-2.0.0/spark-2.0.0-bin-hadoop2.7.tgz}. +#' For \code{hadoopVersion = "without"}, [Hadoop version] in the filename is then +#' \code{without-hadoop}. +#' +#' @param hadoopVersion Version of Hadoop to install. Default is \code{"2.7"}. It can take other +#' version number in the format of "x.y" where x and y are integer. +#' If \code{hadoopVersion = "without"}, "Hadoop free" build is installed. +#' See +#' \href{http://spark.apache.org/docs/latest/hadoop-provided.html}{ +#' "Hadoop Free" Build} for more information. +#' Other patched version names can also be used, e.g. \code{"cdh4"} +#' @param mirrorUrl base URL of the repositories to use. The directory layout should follow +#' \href{http://www.apache.org/dyn/closer.lua/spark/}{Apache mirrors}. +#' @param localDir a local directory where Spark is installed. The directory contains +#' version-specific folders of Spark packages. Default is path to +#' the cache directory: +#' \itemize{ +#' \item Mac OS X: \file{~/Library/Caches/spark} +#' \item Unix: \env{$XDG_CACHE_HOME} if defined, otherwise \file{~/.cache/spark} +#' \item Windows: \file{\%LOCALAPPDATA\%\\spark\\spark\\Cache}. +#' } +#' @param overwrite If \code{TRUE}, download and overwrite the existing tar file in localDir +#' and force re-install Spark (in case the local directory or file is corrupted) +#' @return \code{install.spark} returns the local directory where Spark is found or installed +#' @rdname install.spark +#' @name install.spark +#' @aliases install.spark +#' @export +#' @examples +#'\dontrun{ +#' install.spark() +#'} +#' @note install.spark since 2.1.0 +#' @seealso See available Hadoop versions: +#' \href{http://spark.apache.org/downloads.html}{Apache Spark} +install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, + localDir = NULL, overwrite = FALSE) { + version <- paste0("spark-", packageVersion("SparkR")) + hadoopVersion <- tolower(hadoopVersion) + hadoopVersionName <- hadoopVersionName(hadoopVersion) + packageName <- paste(version, "bin", hadoopVersionName, sep = "-") + localDir <- ifelse(is.null(localDir), sparkCachePath(), + normalizePath(localDir, mustWork = FALSE)) + + if (is.na(file.info(localDir)$isdir)) { + dir.create(localDir, recursive = TRUE) + } + + packageLocalDir <- file.path(localDir, packageName) + + if (overwrite) { + message(paste0("Overwrite = TRUE: download and overwrite the tar file", + "and Spark package directory if they exist.")) + } + + # can use dir.exists(packageLocalDir) under R 3.2.0 or later + if (!is.na(file.info(packageLocalDir)$isdir) && !overwrite) { + fmt <- "%s for Hadoop %s found, with SPARK_HOME set to %s" + msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), + packageLocalDir) + message(msg) + Sys.setenv(SPARK_HOME = packageLocalDir) + return(invisible(packageLocalDir)) + } else { + message("Spark not found in the cache directory. Installation will start.") + } + + packageLocalPath <- paste0(packageLocalDir, ".tgz") + tarExists <- file.exists(packageLocalPath) + + if (tarExists && !overwrite) { + message("tar file found.") + } else { + robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + } + + message(sprintf("Installing to %s", localDir)) + untar(tarfile = packageLocalPath, exdir = localDir) + if (!tarExists || overwrite) { + unlink(packageLocalPath) + } + message("DONE.") + Sys.setenv(SPARK_HOME = packageLocalDir) + message(paste("SPARK_HOME set to", packageLocalDir)) + invisible(packageLocalDir) +} + +robustDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { + # step 1: use user-provided url + if (!is.null(mirrorUrl)) { + msg <- sprintf("Use user-provided mirror site: %s.", mirrorUrl) + message(msg) + success <- directDownloadTar(mirrorUrl, version, hadoopVersion, + packageName, packageLocalPath) + if (success) { + return() + } else { + message(paste0("Unable to download from mirrorUrl: ", mirrorUrl)) + } + } else { + message("MirrorUrl not provided.") + } + + # step 2: use url suggested from apache website + message("Looking for preferred site from apache website...") + mirrorUrl <- getPreferredMirror(version, packageName) + if (!is.null(mirrorUrl)) { + success <- directDownloadTar(mirrorUrl, version, hadoopVersion, + packageName, packageLocalPath) + if (success) return() + } else { + message("Unable to find preferred mirror site.") + } + + # step 3: use backup option + message("To use backup site...") + mirrorUrl <- defaultMirrorUrl() + success <- directDownloadTar(mirrorUrl, version, hadoopVersion, + packageName, packageLocalPath) + if (success) { + return(packageLocalPath) + } else { + msg <- sprintf(paste("Unable to download Spark %s for Hadoop %s.", + "Please check network connection, Hadoop version,", + "or provide other mirror sites."), + version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion)) + stop(msg) + } +} + +getPreferredMirror <- function(version, packageName) { + jsonUrl <- paste0("http://www.apache.org/dyn/closer.cgi?path=", + file.path("spark", version, packageName), + ".tgz&as_json=1") + textLines <- readLines(jsonUrl, warn = FALSE) + rowNum <- grep("\"preferred\"", textLines) + linePreferred <- textLines[rowNum] + matchInfo <- regexpr("\"[A-Za-z][A-Za-z0-9+-.]*://.+\"", linePreferred) + if (matchInfo != -1) { + startPos <- matchInfo + 1 + endPos <- matchInfo + attr(matchInfo, "match.length") - 2 + mirrorPreferred <- base::substr(linePreferred, startPos, endPos) + mirrorPreferred <- paste0(mirrorPreferred, "spark") + message(sprintf("Preferred mirror site found: %s", mirrorPreferred)) + } else { + mirrorPreferred <- NULL + } + mirrorPreferred +} + +directDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { + packageRemotePath <- paste0( + file.path(mirrorUrl, version, packageName), ".tgz") + fmt <- "Downloading %s for Hadoop %s from:\n- %s" + msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), + packageRemotePath) + message(msg) + + isFail <- tryCatch(download.file(packageRemotePath, packageLocalPath), + error = function(e) { + message(sprintf("Fetch failed from %s", mirrorUrl)) + print(e) + TRUE + }) + !isFail +} + +defaultMirrorUrl <- function() { + "http://www-us.apache.org/dist/spark" +} + +hadoopVersionName <- function(hadoopVersion) { + if (hadoopVersion == "without") { + "without-hadoop" + } else if (grepl("^[0-9]+\\.[0-9]+$", hadoopVersion, perl = TRUE)) { + paste0("hadoop", hadoopVersion) + } else { + hadoopVersion + } +} + +# The implementation refers to appdirs package: https://pypi.python.org/pypi/appdirs and +# adapt to Spark context +sparkCachePath <- function() { + if (.Platform$OS.type == "windows") { + winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) + if (is.na(winAppPath)) { + msg <- paste("%LOCALAPPDATA% not found.", + "Please define the environment variable", + "or restart and enter an installation path in localDir.") + stop(msg) + } else { + path <- file.path(winAppPath, "spark", "spark", "Cache") + } + } else if (.Platform$OS.type == "unix") { + if (Sys.info()["sysname"] == "Darwin") { + path <- file.path(Sys.getenv("HOME"), "Library/Caches", "spark") + } else { + path <- file.path( + Sys.getenv("XDG_CACHE_HOME", file.path(Sys.getenv("HOME"), ".cache")), "spark") + } + } else { + stop(sprintf("Unknown OS: %s", .Platform$OS.type)) + } + normalizePath(path, mustWork = FALSE) +} + + +installInstruction <- function(mode) { + if (mode == "remote") { + paste0("Connecting to a remote Spark master. ", + "Please make sure Spark package is also installed in this machine.\n", + "- If there is one, set the path in sparkHome parameter or ", + "environment variable SPARK_HOME.\n", + "- If not, you may run install.spark function to do the job. ", + "Please make sure the Spark and the Hadoop versions ", + "match the versions on the cluster. ", + "SparkR package is compatible with Spark ", packageVersion("SparkR"), ".", + "If you need further help, ", + "contact the administrators of the cluster.") + } else { + stop(paste0("No instruction found for ", mode, " mode.")) + } +} diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R index 0838a7bb35e0d..4905e1fe5c61f 100644 --- a/R/pkg/R/jobj.R +++ b/R/pkg/R/jobj.R @@ -71,12 +71,17 @@ jobj <- function(objId) { #' #' @param x The JVM object reference #' @param ... further arguments passed to or from other methods +#' @note print.jobj since 1.4.0 print.jobj <- function(x, ...) { - cls <- callJMethod(x, "getClass") - name <- callJMethod(cls, "getName") + name <- getClassName.jobj(x) cat("Java ref type", name, "id", x$id, "\n", sep = " ") } +getClassName.jobj <- function(x) { + cls <- callJMethod(x, "getClass") + callJMethod(cls, "getName") +} + cleanup.jobj <- function(jobj) { if (isValidJobj(jobj)) { objId <- jobj$id diff --git a/R/pkg/R/jvm.R b/R/pkg/R/jvm.R new file mode 100644 index 0000000000000..bb5c77544a3da --- /dev/null +++ b/R/pkg/R/jvm.R @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Methods to directly access the JVM running the SparkR backend. + +#' Call Java Methods +#' +#' Call a Java method in the JVM running the Spark driver. The return +#' values are automatically converted to R objects for simple objects. Other +#' values are returned as "jobj" which are references to objects on JVM. +#' +#' @details +#' This is a low level function to access the JVM directly and should only be used +#' for advanced use cases. The arguments and return values that are primitive R +#' types (like integer, numeric, character, lists) are automatically translated to/from +#' Java types (like Integer, Double, String, Array). A full list can be found in +#' serialize.R and deserialize.R in the Apache Spark code base. +#' +#' @param x object to invoke the method on. Should be a "jobj" created by newJObject. +#' @param methodName method name to call. +#' @param ... parameters to pass to the Java method. +#' @return the return value of the Java method. Either returned as a R object +#' if it can be deserialized or returned as a "jobj". See details section for more. +#' @export +#' @seealso \link{sparkR.callJStatic}, \link{sparkR.newJObject} +#' @rdname sparkR.callJMethod +#' @examples +#' \dontrun{ +#' sparkR.session() # Need to have a Spark JVM running before calling newJObject +#' # Create a Java ArrayList and populate it +#' jarray <- sparkR.newJObject("java.util.ArrayList") +#' sparkR.callJMethod(jarray, "add", 42L) +#' sparkR.callJMethod(jarray, "get", 0L) # Will print 42 +#' } +#' @note sparkR.callJMethod since 2.0.1 +sparkR.callJMethod <- function(x, methodName, ...) { + callJMethod(x, methodName, ...) +} + +#' Call Static Java Methods +#' +#' Call a static method in the JVM running the Spark driver. The return +#' value is automatically converted to R objects for simple objects. Other +#' values are returned as "jobj" which are references to objects on JVM. +#' +#' @details +#' This is a low level function to access the JVM directly and should only be used +#' for advanced use cases. The arguments and return values that are primitive R +#' types (like integer, numeric, character, lists) are automatically translated to/from +#' Java types (like Integer, Double, String, Array). A full list can be found in +#' serialize.R and deserialize.R in the Apache Spark code base. +#' +#' @param x fully qualified Java class name that contains the static method to invoke. +#' @param methodName name of static method to invoke. +#' @param ... parameters to pass to the Java method. +#' @return the return value of the Java method. Either returned as a R object +#' if it can be deserialized or returned as a "jobj". See details section for more. +#' @export +#' @seealso \link{sparkR.callJMethod}, \link{sparkR.newJObject} +#' @rdname sparkR.callJStatic +#' @examples +#' \dontrun{ +#' sparkR.session() # Need to have a Spark JVM running before calling callJStatic +#' sparkR.callJStatic("java.lang.System", "currentTimeMillis") +#' sparkR.callJStatic("java.lang.System", "getProperty", "java.home") +#' } +#' @note sparkR.callJStatic since 2.0.1 +sparkR.callJStatic <- function(x, methodName, ...) { + callJStatic(x, methodName, ...) +} + +#' Create Java Objects +#' +#' Create a new Java object in the JVM running the Spark driver. The return +#' value is automatically converted to an R object for simple objects. Other +#' values are returned as a "jobj" which is a reference to an object on JVM. +#' +#' @details +#' This is a low level function to access the JVM directly and should only be used +#' for advanced use cases. The arguments and return values that are primitive R +#' types (like integer, numeric, character, lists) are automatically translated to/from +#' Java types (like Integer, Double, String, Array). A full list can be found in +#' serialize.R and deserialize.R in the Apache Spark code base. +#' +#' @param x fully qualified Java class name. +#' @param ... arguments to be passed to the constructor. +#' @return the object created. Either returned as a R object +#' if it can be deserialized or returned as a "jobj". See details section for more. +#' @export +#' @seealso \link{sparkR.callJMethod}, \link{sparkR.callJStatic} +#' @rdname sparkR.newJObject +#' @examples +#' \dontrun{ +#' sparkR.session() # Need to have a Spark JVM running before calling newJObject +#' # Create a Java ArrayList and populate it +#' jarray <- sparkR.newJObject("java.util.ArrayList") +#' sparkR.callJMethod(jarray, "add", 42L) +#' sparkR.callJMethod(jarray, "get", 0L) # Will print 42 +#' } +#' @note sparkR.newJObject since 2.0.1 +sparkR.newJObject <- function(x, ...) { + newJObject(x, ...) +} diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index f46681149d5ac..b33a16a7cef97 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -25,117 +25,161 @@ # - a set of methods that reflect the arguments of the other languages supported by Spark. These # methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc. -#' @title S4 class that represents a generalized linear model +#' S4 class that represents a generalized linear model +#' #' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper #' @export +#' @note GeneralizedLinearRegressionModel since 2.0.0 setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) -#' @title S4 class that represents a NaiveBayesModel +#' S4 class that represents a NaiveBayesModel +#' #' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper #' @export +#' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) -#' @title S4 class that represents a AFTSurvivalRegressionModel +#' S4 class that represents a AFTSurvivalRegressionModel +#' #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper #' @export +#' @note AFTSurvivalRegressionModel since 2.0.0 setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) -#' @title S4 class that represents a KMeansModel +#' S4 class that represents a KMeansModel +#' #' @param jobj a Java object reference to the backing Scala KMeansModel #' @export +#' @note KMeansModel since 2.0.0 setClass("KMeansModel", representation(jobj = "jobj")) -#' Fits a generalized linear model +#' Saves the MLlib model to the input path #' -#' Fits a generalized linear model against a Spark DataFrame. +#' Saves the MLlib model to the input path. For more information, see the specific +#' MLlib model below. +#' @rdname write.ml +#' @name write.ml +#' @export +#' @seealso \link{spark.glm}, \link{glm} +#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{read.ml} +NULL + +#' Makes predictions from a MLlib model #' -#' @param data SparkDataFrame for training. -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' Makes predictions from a MLlib model. For more information, see the specific +#' MLlib model below. +#' @rdname predict +#' @name predict +#' @export +#' @seealso \link{spark.glm}, \link{glm} +#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg} +NULL + +#' Generalized Linear Models +#' +#' Fits generalized linear model against a Spark DataFrame. +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param family A description of the error distribution and link function to be used in the model. +#' @param family a description of the error distribution and link function to be used in the model. #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param epsilon Positive convergence tolerance of iterations. -#' @param maxit Integer giving the maximal number of IRLS iterations. -#' @return a fitted generalized linear model +#' @param tol positive convergence tolerance of iterations. +#' @param maxIter integer giving the maximal number of IRLS iterations. +#' @param ... additional arguments passed to the method. +#' @aliases spark.glm,SparkDataFrame,formula-method +#' @return \code{spark.glm} returns a fitted generalized linear model #' @rdname spark.glm +#' @name spark.glm #' @export #' @examples #' \dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' data(iris) -#' df <- createDataFrame(sqlContext, iris) -#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family="gaussian") +#' df <- createDataFrame(iris) +#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family = "gaussian") #' summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "Sepal_Length", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) #' } -setMethod( - "spark.glm", - signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, family = gaussian, epsilon = 1e-06, maxit = 25) { - if (is.character(family)) { - family <- get(family, mode = "function", envir = parent.frame()) - } - if (is.function(family)) { - family <- family() - } - if (is.null(family$family)) { - print(family) - stop("'family' not recognized") - } - - formula <- paste(deparse(formula), collapse = "") - - jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", - "fit", formula, data@sdf, family$family, family$link, - epsilon, as.integer(maxit)) - return(new("GeneralizedLinearRegressionModel", jobj = jobj)) -}) - -#' Fits a generalized linear model (R-compliant). +#' @note spark.glm since 2.0.0 +#' @seealso \link{glm}, \link{read.ml} +setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25) { + if (is.character(family)) { + family <- get(family, mode = "function", envir = parent.frame()) + } + if (is.function(family)) { + family <- family() + } + if (is.null(family$family)) { + print(family) + stop("'family' not recognized") + } + + formula <- paste(deparse(formula), collapse = "") + + jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", + "fit", formula, data@sdf, family$family, family$link, + tol, as.integer(maxIter)) + return(new("GeneralizedLinearRegressionModel", jobj = jobj)) + }) + +#' Generalized Linear Models (R-compliant) #' #' Fits a generalized linear model, similarly to R's glm(). -#' -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param data SparkDataFrame for training. -#' @param family A description of the error distribution and link function to be used in the model. +#' @param data a SparkDataFrame or R's glm data for training. +#' @param family a description of the error distribution and link function to be used in the model. #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param epsilon Positive convergence tolerance of iterations. -#' @param maxit Integer giving the maximal number of IRLS iterations. -#' @return a fitted generalized linear model +#' @param epsilon positive convergence tolerance of iterations. +#' @param maxit integer giving the maximal number of IRLS iterations. +#' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export #' @examples #' \dontrun{ -#' sc <- sparkR.init() -#' sqlContext <- sparkRSQL.init(sc) +#' sparkR.session() #' data(iris) -#' df <- createDataFrame(sqlContext, iris) -#' model <- glm(Sepal_Length ~ Sepal_Width, df, family="gaussian") +#' df <- createDataFrame(iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df, family = "gaussian") #' summary(model) #' } +#' @note glm since 1.5.0 +#' @seealso \link{spark.glm} setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) { - spark.glm(data, formula, family, epsilon, maxit) + function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25) { + spark.glm(data, formula, family, tol = epsilon, maxIter = maxit) }) -#' Get the summary of a generalized linear model -#' -#' Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). +# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). + +#' @param object a fitted generalized linear model. +#' @return \code{summary} returns a summary object of the fitted model, a list of components +#' including at least the coefficients, null/residual deviance, null/residual degrees +#' of freedom, AIC and number of iterations IRLS takes. #' -#' @param object A fitted generalized linear model -#' @return coefficients the model's coefficients, intercept -#' @rdname summary +#' @rdname spark.glm #' @export -#' @examples -#' \dontrun{ -#' model <- glm(y ~ x, trainingData) -#' summary(model) -#' } +#' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), function(object, ...) { jobj <- object@jobj @@ -166,11 +210,12 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), return(ans) }) -#' Print the summary of GeneralizedLinearRegressionModel -#' -#' @rdname print -#' @name print.summary.GeneralizedLinearRegressionModel +# Prints the summary of GeneralizedLinearRegressionModel + +#' @rdname spark.glm +#' @param x summary object of fitted generalized linear model returned by \code{summary} function #' @export +#' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { if (x$is.loaded) { cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n") @@ -197,63 +242,42 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { invisible(x) } -#' Make predictions from a generalized linear model -#' -#' Makes predictions from a generalized linear model produced by glm() or spark.glm(), -#' similarly to R's predict(). -#' -#' @param object A fitted generalized linear model -#' @param newData SparkDataFrame for testing -#' @return SparkDataFrame containing predicted labels in a column named "prediction" -#' @rdname predict +# Makes predictions from a generalized linear model produced by glm() or spark.glm(), +# similarly to R's predict(). + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named +#' "prediction" +#' @rdname spark.glm #' @export -#' @examples -#' \dontrun{ -#' model <- glm(y ~ x, trainingData) -#' predicted <- predict(model, testData) -#' showDF(predicted) -#' } +#' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), function(object, newData) { return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) -#' Make predictions from a naive Bayes model -#' -#' Makes predictions from a model produced by spark.naiveBayes(), -#' similarly to R package e1071's predict. -#' -#' @param object A fitted naive Bayes model -#' @param newData SparkDataFrame for testing -#' @return SparkDataFrame containing predicted labels in a column named "prediction" -#' @rdname predict +# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(), +# similarly to R package e1071's predict. + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.naiveBayes #' @export -#' @examples -#' \dontrun{ -#' model <- spark.naiveBayes(trainingData, y ~ x) -#' predicted <- predict(model, testData) -#' showDF(predicted) -#'} +#' @note predict(NaiveBayesModel) since 2.0.0 setMethod("predict", signature(object = "NaiveBayesModel"), function(object, newData) { return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) -#' Get the summary of a naive Bayes model -#' -#' Returns the summary of a naive Bayes model produced by spark.naiveBayes(), -#' similarly to R's summary(). -#' -#' @param object A fitted MLlib model -#' @return a list containing 'apriori', the label distribution, and 'tables', conditional -# probabilities given the target label -#' @rdname summary +# Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes} + +#' @param object a naive Bayes model fitted by \code{spark.naiveBayes}. +#' @return \code{summary} returns a list containing \code{apriori}, the label distribution, and +#' \code{tables}, conditional probabilities given the target label. +#' @rdname spark.naiveBayes #' @export -#' @examples -#' \dontrun{ -#' model <- spark.naiveBayes(trainingData, y ~ x) -#' summary(model) -#'} +#' @note summary(NaiveBayesModel) since 2.0.0 setMethod("summary", signature(object = "NaiveBayesModel"), function(object, ...) { jobj <- object@jobj @@ -269,40 +293,66 @@ setMethod("summary", signature(object = "NaiveBayesModel"), return(list(apriori = apriori, tables = tables)) }) -#' Fit a k-means model +#' K-Means Clustering Model #' -#' Fit a k-means model, similarly to R's kmeans(). +#' Fits a k-means clustering model against a Spark DataFrame, similarly to R's kmeans(). +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. #' -#' @param data SparkDataFrame for training -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' Note that the response variable of formula is empty in spark.kmeans. -#' @param k Number of centers -#' @param maxIter Maximum iteration number -#' @param initMode The initialization algorithm choosen to fit the model -#' @return A fitted k-means model +#' @param k number of centers. +#' @param maxIter maximum iteration number. +#' @param initMode the initialization algorithm choosen to fit the model. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.kmeans} returns a fitted k-means model. #' @rdname spark.kmeans +#' @aliases spark.kmeans,SparkDataFrame,formula-method +#' @name spark.kmeans #' @export #' @examples #' \dontrun{ -#' model <- spark.kmeans(data, ~ ., k=2, initMode="random") +#' sparkR.session() +#' data(iris) +#' df <- createDataFrame(iris) +#' model <- spark.kmeans(df, Sepal_Length ~ Sepal_Width, k = 4, initMode = "random") +#' summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "Sepal_Length", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) #' } +#' @note spark.kmeans since 2.0.0 +#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, k, maxIter = 10, initMode = c("random", "k-means||")) { + function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) { formula <- paste(deparse(formula), collapse = "") initMode <- match.arg(initMode) jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula, as.integer(k), as.integer(maxIter), initMode) return(new("KMeansModel", jobj = jobj)) - }) + }) #' Get fitted result from a k-means model #' #' Get fitted result from a k-means model, similarly to R's fitted(). #' Note: A saved-loaded model does not support this method. #' -#' @param object A fitted k-means model -#' @return SparkDataFrame containing fitted values +#' @param object a fitted k-means model. +#' @param method type of fitted results, \code{"centers"} for cluster centers +#' or \code{"classes"} for assigned classes. +#' @param ... additional argument(s) passed to the method. +#' @return \code{fitted} returns a SparkDataFrame containing fitted values. #' @rdname fitted #' @export #' @examples @@ -311,6 +361,7 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula" #' fitted.model <- fitted(model) #' showDF(fitted.model) #'} +#' @note fitted since 2.0.0 setMethod("fitted", signature(object = "KMeansModel"), function(object, method = c("centers", "classes"), ...) { method <- match.arg(method) @@ -323,20 +374,13 @@ setMethod("fitted", signature(object = "KMeansModel"), } }) -#' Get the summary of a k-means model -#' -#' Returns the summary of a k-means model produced by spark.kmeans(), -#' similarly to R's summary(). -#' -#' @param object a fitted k-means model -#' @return the model's coefficients, size and cluster -#' @rdname summary +# Get the summary of a k-means model + +#' @param object a fitted k-means model. +#' @return \code{summary} returns the model's coefficients, size and cluster. +#' @rdname spark.kmeans #' @export -#' @examples -#' \dontrun{ -#' model <- spark.kmeans(trainingData, ~ ., 2) -#' summary(model) -#' } +#' @note summary(KMeansModel) since 2.0.0 setMethod("summary", signature(object = "KMeansModel"), function(object, ...) { jobj <- object@jobj @@ -357,68 +401,75 @@ setMethod("summary", signature(object = "KMeansModel"), cluster = cluster, is.loaded = is.loaded)) }) -#' Make predictions from a k-means model -#' -#' Make predictions from a model produced by spark.kmeans(). -#' -#' @param object A fitted k-means model -#' @param newData SparkDataFrame for testing -#' @return SparkDataFrame containing predicted labels in a column named "prediction" -#' @rdname predict +# Predicted values based on a k-means model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on a k-means model. +#' @rdname spark.kmeans #' @export -#' @examples -#' \dontrun{ -#' model <- spark.kmeans(trainingData, ~ ., 2) -#' predicted <- predict(model, testData) -#' showDF(predicted) -#' } +#' @note predict(KMeansModel) since 2.0.0 setMethod("predict", signature(object = "KMeansModel"), function(object, newData) { return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) -#' Fit a Bernoulli naive Bayes model +#' Naive Bayes Models #' -#' Fit a Bernoulli naive Bayes model on a Spark DataFrame (only categorical data is supported). +#' \code{spark.naiveBayes} fits a Bernoulli naive Bayes model against a SparkDataFrame. +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' Only categorical data is supported. #' -#' @param data SparkDataFrame for training -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param laplace Smoothing parameter -#' @return a fitted naive Bayes model +#' @param smoothing smoothing parameter. +#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. +#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. #' @rdname spark.naiveBayes -#' @seealso e1071: \url{https://cran.r-project.org/web/packages/e1071/} +#' @aliases spark.naiveBayes,SparkDataFrame,formula-method +#' @name spark.naiveBayes +#' @seealso e1071: \url{https://cran.r-project.org/package=e1071} #' @export #' @examples #' \dontrun{ -#' df <- createDataFrame(sqlContext, infert) -#' model <- spark.naiveBayes(df, education ~ ., laplace = 0) -#'} -setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, laplace = 0, ...) { - formula <- paste(deparse(formula), collapse = "") - jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", - formula, data@sdf, laplace) - return(new("NaiveBayesModel", jobj = jobj)) - }) - -#' Save the Bernoulli naive Bayes model to the input path. -#' -#' @param object A fitted Bernoulli naive Bayes model -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. +#' data <- as.data.frame(UCBAdmissions) +#' df <- createDataFrame(data) #' -#' @rdname write.ml -#' @name write.ml -#' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(sqlContext, infert) -#' model <- spark.naiveBayes(df, education ~ ., laplace = 0) +#' # fit a Bernoulli naive Bayes model +#' model <- spark.naiveBayes(df, Admit ~ Gender + Dept, smoothing = 0) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model #' path <- "path/to/model" #' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) #' } +#' @note spark.naiveBayes since 2.0.0 +setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, smoothing = 1.0, ...) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", + formula, data@sdf, smoothing) + return(new("NaiveBayesModel", jobj = jobj)) + }) + +# Saves the Bernoulli naive Bayes model to the input path. + +#' @param path the directory where the model is saved +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.naiveBayes +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(NaiveBayesModel, character) since 2.0.0 setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), function(object, path, overwrite = FALSE) { writer <- callJMethod(object@jobj, "write") @@ -428,22 +479,15 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), invisible(callJMethod(writer, "save", path)) }) -#' Save the AFT survival regression model to the input path. -#' -#' @param object A fitted AFT survival regression model -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +# Saves the AFT survival regression model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. -#' -#' @rdname write.ml -#' @name write.ml +#' @rdname spark.survreg #' @export -#' @examples -#' \dontrun{ -#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx) -#' path <- "path/to/model" -#' write.ml(model, path) -#' } +#' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 +#' @seealso \link{read.ml} setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { writer <- callJMethod(object@jobj, "write") @@ -453,22 +497,15 @@ setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "c invisible(callJMethod(writer, "save", path)) }) -#' Save the generalized linear model to the input path. -#' -#' @param object A fitted generalized linear model -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +# Saves the generalized linear model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' -#' @rdname write.ml -#' @name write.ml +#' @rdname spark.glm #' @export -#' @examples -#' \dontrun{ -#' model <- glm(y ~ x, trainingData) -#' path <- "path/to/model" -#' write.ml(model, path) -#' } +#' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), function(object, path, overwrite = FALSE) { writer <- callJMethod(object@jobj, "write") @@ -478,22 +515,15 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat invisible(callJMethod(writer, "save", path)) }) -#' Save the k-means model to the input path. -#' -#' @param object A fitted k-means model -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +# Save fitted MLlib model to the input path + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE #' which means throw exception if the output path exists. #' -#' @rdname write.ml -#' @name write.ml +#' @rdname spark.kmeans #' @export -#' @examples -#' \dontrun{ -#' model <- spark.kmeans(trainingData, ~ ., k = 2) -#' path <- "path/to/model" -#' write.ml(model, path) -#' } +#' @note write.ml(KMeansModel, character) since 2.0.0 setMethod("write.ml", signature(object = "KMeansModel", path = "character"), function(object, path, overwrite = FALSE) { writer <- callJMethod(object@jobj, "write") @@ -505,16 +535,18 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"), #' Load a fitted MLlib model from the input path. #' -#' @param path Path of the model to read. -#' @return a fitted MLlib model +#' @param path path of the model to read. +#' @return A fitted MLlib model. #' @rdname read.ml #' @name read.ml #' @export +#' @seealso \link{write.ml} #' @examples #' \dontrun{ #' path <- "path/to/model" #' model <- read.ml(path) #' } +#' @note read.ml since 2.0.0 read.ml <- function(path) { path <- suppressWarnings(normalizePath(path)) jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) @@ -531,48 +563,59 @@ read.ml <- function(path) { } } -#' Fit an accelerated failure time (AFT) survival regression model. +#' Accelerated Failure Time (AFT) Survival Regression Model #' -#' Fit an accelerated failure time (AFT) survival regression model on a Spark DataFrame. +#' \code{spark.survreg} fits an accelerated failure time (AFT) survival regression model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted AFT model, +#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. #' -#' @param data SparkDataFrame for training. -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. -#' @return a fitted AFT survival regression model +#' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg -#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/} +#' @seealso survival: \url{https://cran.r-project.org/package=survival} #' @export #' @examples #' \dontrun{ -#' df <- createDataFrame(sqlContext, ovarian) +#' df <- createDataFrame(ovarian) #' model <- spark.survreg(df, Surv(futime, fustat) ~ ecog_ps + rx) +#' +#' # get a summary of the model +#' summary(model) +#' +#' # make predictions +#' predicted <- predict(model, df) +#' showDF(predicted) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) #' } +#' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, ...) { + function(data, formula) { formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", "fit", formula, data@sdf) return(new("AFTSurvivalRegressionModel", jobj = jobj)) }) +# Returns a summary of the AFT survival regression model produced by spark.survreg, +# similarly to R's summary(). -#' Get the summary of an AFT survival regression model -#' -#' Returns the summary of an AFT survival regression model produced by spark.survreg(), -#' similarly to R's summary(). -#' -#' @param object a fitted AFT survival regression model -#' @return coefficients the model's coefficients, intercept and log(scale). -#' @rdname summary +#' @param object a fitted AFT survival regression model. +#' @return \code{summary} returns a list containing the model's coefficients, +#' intercept and log(scale) +#' @rdname spark.survreg #' @export -#' @examples -#' \dontrun{ -#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx) -#' summary(model) -#' } +#' @note summary(AFTSurvivalRegressionModel) since 2.0.0 setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), - function(object, ...) { + function(object) { jobj <- object@jobj features <- callJMethod(jobj, "rFeatures") coefficients <- callJMethod(jobj, "rCoefficients") @@ -582,22 +625,15 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), return(list(coefficients = coefficients)) }) -#' Make predictions from an AFT survival regression model -#' -#' Make predictions from a model produced by spark.survreg(), -#' similarly to R package survival's predict. -#' -#' @param object A fitted AFT survival regression model -#' @param newData SparkDataFrame for testing -#' @return SparkDataFrame containing predicted labels in a column named "prediction" -#' @rdname predict +# Makes predictions from an AFT survival regression model or a model produced by +# spark.survreg, similarly to R package survival's predict. + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values +#' on the original scale of the data (mean predicted value at scale = 1.0). +#' @rdname spark.survreg #' @export -#' @examples -#' \dontrun{ -#' model <- spark.survreg(trainingData, Surv(futime, fustat) ~ ecog_ps + rx) -#' predicted <- predict(model, testData) -#' showDF(predicted) -#' } +#' @note predict(AFTSurvivalRegressionModel) since 2.0.0 setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), function(object, newData) { return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 4075ef4377acf..4dee3245f9b75 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -49,7 +49,7 @@ setMethod("lookup", lapply(filtered, function(i) { i[[2]] }) } valsRDD <- lapplyPartition(x, partitionFunc) - collect(valsRDD) + collectRDD(valsRDD) }) #' Count the number of elements for each key, and return the result to the @@ -85,7 +85,7 @@ setMethod("countByKey", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -#' collect(keys(rdd)) # list(1, 3) +#' collectRDD(keys(rdd)) # list(1, 3) #'} # nolint end #' @rdname keys @@ -108,7 +108,7 @@ setMethod("keys", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) -#' collect(values(rdd)) # list(2, 4) +#' collectRDD(values(rdd)) # list(2, 4) #'} # nolint end #' @rdname values @@ -135,7 +135,7 @@ setMethod("values", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' makePairs <- lapply(rdd, function(x) { list(x, x) }) -#' collect(mapValues(makePairs, function(x) { x * 2) }) +#' collectRDD(mapValues(makePairs, function(x) { x * 2) }) #' Output: list(list(1,2), list(2,4), list(3,6), ...) #'} #' @rdname mapValues @@ -162,7 +162,7 @@ setMethod("mapValues", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) -#' collect(flatMapValues(rdd, function(x) { x })) +#' collectRDD(flatMapValues(rdd, function(x) { x })) #' Output: list(list(1,1), list(1,2), list(2,3), list(2,4)) #'} #' @rdname flatMapValues @@ -198,15 +198,17 @@ setMethod("flatMapValues", #' sc <- sparkR.init() #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) #' rdd <- parallelize(sc, pairs) -#' parts <- partitionBy(rdd, 2L) +#' parts <- partitionByRDD(rdd, 2L) #' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4) #'} #' @rdname partitionBy #' @aliases partitionBy,RDD,integer-method #' @noRd -setMethod("partitionBy", - signature(x = "RDD", numPartitions = "numeric"), +setMethod("partitionByRDD", + signature(x = "RDD"), function(x, numPartitions, partitionFunc = hashCode) { + stopifnot(is.numeric(numPartitions)) + partitionFunc <- cleanClosure(partitionFunc) serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL) @@ -259,7 +261,7 @@ setMethod("partitionBy", #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) #' rdd <- parallelize(sc, pairs) #' parts <- groupByKey(rdd, 2L) -#' grouped <- collect(parts) +#' grouped <- collectRDD(parts) #' grouped[[1]] # Should be a list(1, list(2, 4)) #'} #' @rdname groupByKey @@ -268,7 +270,7 @@ setMethod("partitionBy", setMethod("groupByKey", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { - shuffled <- partitionBy(x, numPartitions) + shuffled <- partitionByRDD(x, numPartitions) groupVals <- function(part) { vals <- new.env() keys <- new.env() @@ -319,7 +321,7 @@ setMethod("groupByKey", #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) #' rdd <- parallelize(sc, pairs) #' parts <- reduceByKey(rdd, "+", 2L) -#' reduced <- collect(parts) +#' reduced <- collectRDD(parts) #' reduced[[1]] # Should be a list(1, 6) #'} #' @rdname reduceByKey @@ -340,7 +342,7 @@ setMethod("reduceByKey", convertEnvsToList(keys, vals) } locallyReduced <- lapplyPartition(x, reduceVals) - shuffled <- partitionBy(locallyReduced, numToInt(numPartitions)) + shuffled <- partitionByRDD(locallyReduced, numToInt(numPartitions)) lapplyPartition(shuffled, reduceVals) }) @@ -428,7 +430,7 @@ setMethod("reduceByKeyLocally", #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) #' rdd <- parallelize(sc, pairs) #' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L) -#' combined <- collect(parts) +#' combined <- collectRDD(parts) #' combined[[1]] # Should be a list(1, 6) #'} # nolint end @@ -451,7 +453,7 @@ setMethod("combineByKey", convertEnvsToList(keys, combiners) } locallyCombined <- lapplyPartition(x, combineLocally) - shuffled <- partitionBy(locallyCombined, numToInt(numPartitions)) + shuffled <- partitionByRDD(locallyCombined, numToInt(numPartitions)) mergeAfterShuffle <- function(part) { combiners <- new.env() keys <- new.env() @@ -561,13 +563,13 @@ setMethod("foldByKey", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) #' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) +#' joinRDD(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) #'} # nolint end #' @rdname join-methods #' @aliases join,RDD,RDD-method #' @noRd -setMethod("join", +setMethod("joinRDD", signature(x = "RDD", y = "RDD"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) @@ -770,7 +772,7 @@ setMethod("cogroup", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) -#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) +#' collectRDD(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) #'} # nolint end #' @rdname sortByKey @@ -782,12 +784,12 @@ setMethod("sortByKey", rangeBounds <- list() if (numPartitions > 1) { - rddSize <- count(x) + rddSize <- countRDD(x) # constant from Spark's RangePartitioner maxSampleSize <- numPartitions * 20 fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) - samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) + samples <- collectRDD(keys(sampleRDD(x, FALSE, fraction, 1L))) # Note: the built-in R sort() function only works on atomic vectors samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) @@ -820,7 +822,7 @@ setMethod("sortByKey", sortKeyValueList(part, decreasing = !ascending) } - newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) + newRDD <- partitionByRDD(x, numPartitions, rangePartitionFunc) lapplyPartition(newRDD, partitionFunc) }) @@ -839,7 +841,7 @@ setMethod("sortByKey", #' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), #' list("b", 5), list("a", 2))) #' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) -#' collect(subtractByKey(rdd1, rdd2)) +#' collectRDD(subtractByKey(rdd1, rdd2)) #' # list(list("b", 4), list("b", 5)) #'} # nolint end @@ -915,19 +917,19 @@ setMethod("sampleByKey", len <- 0 # mixing because the initial seeds are close to each other - runif(10) + stats::runif(10) for (elem in part) { if (elem[[1]] %in% names(fractions)) { frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))]) if (withReplacement) { - count <- rpois(1, frac) + count <- stats::rpois(1, frac) if (count > 0) { res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { - if (runif(1) < frac) { + if (stats::runif(1) < frac) { len <- len + 1 res[[len]] <- elem } diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 039aa008b3edd..cb5bdb90175bf 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -26,26 +26,34 @@ #' @param x a structField object (created with the field() function) #' @param ... additional structField objects #' @return a structType object +#' @rdname structType #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) -#' schema <- structType(structField("a", "integer"), structField("b", "string")) -#' df <- createDataFrame(sqlCtx, rdd, schema) +#' schema <- structType(structField("a", "integer"), structField("c", "string"), +#' structField("avg", "double")) +#' df1 <- gapply(df, list("a", "c"), +#' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, +#' schema) #' } +#' @note structType since 1.4.0 structType <- function(x, ...) { UseMethod("structType", x) } -structType.jobj <- function(x) { +#' @rdname structType +#' @method structType jobj +#' @export +structType.jobj <- function(x, ...) { obj <- structure(list(), class = "structType") obj$jobj <- x obj$fields <- function() { lapply(callJMethod(obj$jobj, "fields"), structField) } obj } +#' @rdname structType +#' @method structType structField +#' @export structType.structField <- function(x, ...) { fields <- list(x, ...) if (!all(sapply(fields, inherits, "structField"))) { @@ -67,6 +75,7 @@ structType.structField <- function(x, ...) { #' #' @param x A StructType object #' @param ... further arguments passed to or from other methods +#' @note print.structType since 1.4.0 print.structType <- function(x, ...) { cat("StructType\n", sapply(x$fields(), @@ -83,27 +92,30 @@ print.structType <- function(x, ...) { #' #' Create a structField object that contains the metadata for a single field in a schema. #' -#' @param x The name of the field -#' @param type The data type of the field -#' @param nullable A logical vector indicating whether or not the field is nullable -#' @return a structField object +#' @param x the name of the field. +#' @param ... additional argument(s) passed to the method. +#' @return A structField object. +#' @rdname structField #' @export #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) -#' field1 <- structField("a", "integer", TRUE) -#' field2 <- structField("b", "string", TRUE) -#' schema <- structType(field1, field2) -#' df <- createDataFrame(sqlCtx, rdd, schema) +#' field1 <- structField("a", "integer") +#' field2 <- structField("c", "string") +#' field3 <- structField("avg", "double") +#' schema <- structType(field1, field2, field3) +#' df1 <- gapply(df, list("a", "c"), +#' function(key, x) { y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) }, +#' schema) #' } - +#' @note structField since 1.4.0 structField <- function(x, ...) { UseMethod("structField", x) } -structField.jobj <- function(x) { +#' @rdname structField +#' @method structField jobj +#' @export +structField.jobj <- function(x, ...) { obj <- structure(list(), class = "structField") obj$jobj <- x obj$name <- function() { callJMethod(x, "name") } @@ -174,7 +186,11 @@ checkType <- function(type) { stop(paste("Unsupported type for SparkDataframe:", type)) } -structField.character <- function(x, type, nullable = TRUE) { +#' @param type The data type of the field +#' @param nullable A logical vector indicating whether or not the field is nullable +#' @rdname structField +#' @export +structField.character <- function(x, type, nullable = TRUE, ...) { if (class(x) != "character") { stop("Field name must be a string.") } @@ -202,6 +218,7 @@ structField.character <- function(x, type, nullable = TRUE) { #' #' @param x A StructField object #' @param ... further arguments passed to or from other methods +#' @note print.structField since 1.4.0 print.structField <- function(x, ...) { cat("StructField(name = \"", x$name(), "\", type = \"", x$dataType.toString(), diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index c187869fdf121..cc6d591bb2f4c 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -28,10 +28,16 @@ connExists <- function(env) { }) } -#' Stop the Spark context. +#' Stop the Spark Session and Spark Context #' -#' Also terminates the backend this R session is connected to -sparkR.stop <- function() { +#' Stop the Spark Session and Spark Context. +#' +#' Also terminates the backend this R session is connected to. +#' @rdname sparkR.session.stop +#' @name sparkR.session.stop +#' @export +#' @note sparkR.session.stop since 2.0.0 +sparkR.session.stop <- function() { env <- .sparkREnv if (exists(".sparkRCon", envir = env)) { if (exists(".sparkRjsc", envir = env)) { @@ -39,12 +45,8 @@ sparkR.stop <- function() { callJMethod(sc, "stop") rm(".sparkRjsc", envir = env) - if (exists(".sparkRSQLsc", envir = env)) { - rm(".sparkRSQLsc", envir = env) - } - - if (exists(".sparkRHivesc", envir = env)) { - rm(".sparkRHivesc", envir = env) + if (exists(".sparkRsession", envir = env)) { + rm(".sparkRsession", envir = env) } } @@ -80,11 +82,17 @@ sparkR.stop <- function() { clearJobjs() } -#' Initialize a new Spark Context. +#' @rdname sparkR.session.stop +#' @name sparkR.stop +#' @export +#' @note sparkR.stop since 1.4.0 +sparkR.stop <- function() { + sparkR.session.stop() +} + +#' (Deprecated) Initialize a new Spark Context #' -#' This function initializes a new SparkContext. For details on how to initialize -#' and use SparkR, refer to SparkR programming guide at -#' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparkcontext-sqlcontext}. +#' This function initializes a new SparkContext. #' #' @param master The Spark master URL #' @param appName Application name to register with cluster manager @@ -92,7 +100,9 @@ sparkR.stop <- function() { #' @param sparkEnvir Named list of environment variables to set on worker nodes #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors #' @param sparkJars Character vector of jar files to pass to the worker nodes -#' @param sparkPackages Character vector of packages from spark-packages.org +#' @param sparkPackages Character vector of package coordinates +#' @seealso \link{sparkR.session} +#' @rdname sparkR.init-deprecated #' @export #' @examples #'\dontrun{ @@ -103,10 +113,9 @@ sparkR.stop <- function() { #' list(spark.executor.memory="4g"), #' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"), #' c("one.jar", "two.jar", "three.jar"), -#' c("com.databricks:spark-avro_2.10:2.0.1", -#' "com.databricks:spark-csv_2.10:1.3.0")) +#' c("com.databricks:spark-avro_2.10:2.0.1")) #'} - +#' @note sparkR.init since 1.4.0 sparkR.init <- function( master = "", appName = "SparkR", @@ -115,20 +124,41 @@ sparkR.init <- function( sparkExecutorEnv = list(), sparkJars = "", sparkPackages = "") { + .Deprecated("sparkR.session") + sparkR.sparkContext(master, + appName, + sparkHome, + convertNamedListToEnv(sparkEnvir), + convertNamedListToEnv(sparkExecutorEnv), + sparkJars, + sparkPackages) +} + +# Internal function to handle creating the SparkContext. +sparkR.sparkContext <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkEnvirMap = new.env(), + sparkExecutorEnvMap = new.env(), + sparkJars = "", + sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { cat(paste("Re-using existing Spark Context.", - "Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")) + "Call sparkR.session.stop() or restart R to create a new Spark Context\n")) return(get(".sparkRjsc", envir = .sparkREnv)) } jars <- processSparkJars(sparkJars) packages <- processSparkPackages(sparkPackages) - sparkEnvirMap <- convertNamedListToEnv(sparkEnvir) - existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "") if (existingPort != "") { + if (length(packages) != 0) { + warning(paste("sparkPackages has no effect when using spark-submit or sparkR shell", + " please use the --packages commandline instead", sep = ",")) + } backendPort <- existingPort } else { path <- tempfile(pattern = "backend_port") @@ -184,7 +214,6 @@ sparkR.init <- function( sparkHome <- suppressWarnings(normalizePath(sparkHome)) } - sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv) if (is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) { sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:", Sys.getenv("LD_LIBRARY_PATH")) @@ -226,116 +255,235 @@ sparkR.init <- function( sc } -#' Initialize a new SQLContext. +#' (Deprecated) Initialize a new SQLContext #' #' This function creates a SparkContext from an existing JavaSparkContext and #' then uses it to initialize a new SQLContext #' +#' Starting SparkR 2.0, a SparkSession is initialized and returned instead. +#' This API is deprecated and kept for backward compatibility only. +#' #' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @seealso \link{sparkR.session} +#' @rdname sparkRSQL.init-deprecated #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRSQL.init(sc) #'} - +#' @note sparkRSQL.init since 1.4.0 sparkRSQL.init <- function(jsc = NULL) { - if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - return(get(".sparkRSQLsc", envir = .sparkREnv)) - } + .Deprecated("sparkR.session") - # If jsc is NULL, create a Spark Context - sc <- if (is.null(jsc)) { - sparkR.init() - } else { - jsc + if (exists(".sparkRsession", envir = .sparkREnv)) { + return(get(".sparkRsession", envir = .sparkREnv)) } - sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "createSQLContext", - sc) - assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) - sqlContext + # Default to without Hive support for backward compatibility. + sparkR.session(enableHiveSupport = FALSE) } -#' Initialize a new HiveContext. +#' (Deprecated) Initialize a new HiveContext #' #' This function creates a HiveContext from an existing JavaSparkContext #' +#' Starting SparkR 2.0, a SparkSession is initialized and returned instead. +#' This API is deprecated and kept for backward compatibility only. +#' #' @param jsc The existing JavaSparkContext created with SparkR.init() +#' @seealso \link{sparkR.session} +#' @rdname sparkRHive.init-deprecated #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() #' sqlContext <- sparkRHive.init(sc) #'} - +#' @note sparkRHive.init since 1.4.0 sparkRHive.init <- function(jsc = NULL) { - if (exists(".sparkRHivesc", envir = .sparkREnv)) { - return(get(".sparkRHivesc", envir = .sparkREnv)) + .Deprecated("sparkR.session") + + if (exists(".sparkRsession", envir = .sparkREnv)) { + return(get(".sparkRsession", envir = .sparkREnv)) } - # If jsc is NULL, create a Spark Context - sc <- if (is.null(jsc)) { - sparkR.init() - } else { - jsc + # Default to without Hive support for backward compatibility. + sparkR.session(enableHiveSupport = TRUE) +} + +#' Get the existing SparkSession or initialize a new SparkSession. +#' +#' SparkSession is the entry point into SparkR. \code{sparkR.session} gets the existing +#' SparkSession or initializes a new SparkSession. +#' Additional Spark properties can be set in \code{...}, and these named parameters take priority +#' over values in \code{master}, \code{appName}, named lists of \code{sparkConfig}. +#' +#' For details on how to initialize and use SparkR, refer to SparkR programming guide at +#' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}. +#' +#' @param master the Spark master URL. +#' @param appName application name to register with cluster manager. +#' @param sparkHome Spark Home directory. +#' @param sparkConfig named list of Spark configuration to set on worker nodes. +#' @param sparkJars character vector of jar files to pass to the worker nodes. +#' @param sparkPackages character vector of package coordinates +#' @param enableHiveSupport enable support for Hive, fallback if not built with Hive support; once +#' set, this cannot be turned off on an existing session +#' @param ... named Spark properties passed to the method. +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.json(path) +#' +#' sparkR.session("local[2]", "SparkR", "/home/spark") +#' sparkR.session("yarn-client", "SparkR", "/home/spark", +#' list(spark.executor.memory="4g"), +#' c("one.jar", "two.jar", "three.jar"), +#' c("com.databricks:spark-avro_2.10:2.0.1")) +#' sparkR.session(spark.master = "yarn-client", spark.executor.memory = "4g") +#'} +#' @note sparkR.session since 2.0.0 +sparkR.session <- function( + master = "", + appName = "SparkR", + sparkHome = Sys.getenv("SPARK_HOME"), + sparkConfig = list(), + sparkJars = "", + sparkPackages = "", + enableHiveSupport = TRUE, + ...) { + + sparkConfigMap <- convertNamedListToEnv(sparkConfig) + namedParams <- list(...) + if (length(namedParams) > 0) { + paramMap <- convertNamedListToEnv(namedParams) + # Override for certain named parameters + if (exists("spark.master", envir = paramMap)) { + master <- paramMap[["spark.master"]] + } + if (exists("spark.app.name", envir = paramMap)) { + appName <- paramMap[["spark.app.name"]] + } + overrideEnvs(sparkConfigMap, paramMap) } - ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.HiveContext", ssc) - }, - error = function(err) { - stop("Spark SQL is not built with Hive support") - }) + if (!exists(".sparkRjsc", envir = .sparkREnv)) { + retHome <- sparkCheckInstall(sparkHome, master) + if (!is.null(retHome)) sparkHome <- retHome + sparkExecutorEnvMap <- new.env() + sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorEnvMap, + sparkJars, sparkPackages) + stopifnot(exists(".sparkRjsc", envir = .sparkREnv)) + } - assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) - hiveCtx + if (exists(".sparkRsession", envir = .sparkREnv)) { + sparkSession <- get(".sparkRsession", envir = .sparkREnv) + # Apply config to Spark Context and Spark Session if already there + # Cannot change enableHiveSupport + callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "setSparkContextSessionConf", + sparkSession, + sparkConfigMap) + } else { + jsc <- get(".sparkRjsc", envir = .sparkREnv) + sparkSession <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getOrCreateSparkSession", + jsc, + sparkConfigMap, + enableHiveSupport) + assign(".sparkRsession", sparkSession, envir = .sparkREnv) + } + sparkSession } #' Assigns a group ID to all the jobs started by this thread until the group ID is set to a #' different value or cleared. #' -#' @param sc existing spark context -#' @param groupid the ID to be assigned to job groups -#' @param description description for the job group ID -#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation +#' @param groupId the ID to be assigned to job groups. +#' @param description description for the job group ID. +#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation. +#' @rdname setJobGroup +#' @name setJobGroup #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' setJobGroup(sc, "myJobGroup", "My job group description", TRUE) +#' sparkR.session() +#' setJobGroup("myJobGroup", "My job group description", TRUE) #'} +#' @note setJobGroup since 1.5.0 +#' @method setJobGroup default +setJobGroup.default <- function(groupId, description, interruptOnCancel) { + sc <- getSparkContext() + callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) +} setJobGroup <- function(sc, groupId, description, interruptOnCancel) { - callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) + if (class(sc) == "jobj" && any(grepl("JavaSparkContext", getClassName.jobj(sc)))) { + .Deprecated("setJobGroup(groupId, description, interruptOnCancel)", + old = "setJobGroup(sc, groupId, description, interruptOnCancel)") + setJobGroup.default(groupId, description, interruptOnCancel) + } else { + # Parameter order is shifted + groupIdToUse <- sc + descriptionToUse <- groupId + interruptOnCancelToUse <- description + setJobGroup.default(groupIdToUse, descriptionToUse, interruptOnCancelToUse) + } } #' Clear current job group ID and its description #' -#' @param sc existing spark context +#' @rdname clearJobGroup +#' @name clearJobGroup #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' clearJobGroup(sc) +#' sparkR.session() +#' clearJobGroup() #'} +#' @note clearJobGroup since 1.5.0 +#' @method clearJobGroup default +clearJobGroup.default <- function() { + sc <- getSparkContext() + callJMethod(sc, "clearJobGroup") +} clearJobGroup <- function(sc) { - callJMethod(sc, "clearJobGroup") + if (!missing(sc) && + class(sc) == "jobj" && + any(grepl("JavaSparkContext", getClassName.jobj(sc)))) { + .Deprecated("clearJobGroup()", old = "clearJobGroup(sc)") + } + clearJobGroup.default() } + #' Cancel active jobs for the specified group #' -#' @param sc existing spark context #' @param groupId the ID of job group to be cancelled +#' @rdname cancelJobGroup +#' @name cancelJobGroup #' @examples #'\dontrun{ -#' sc <- sparkR.init() -#' cancelJobGroup(sc, "myJobGroup") +#' sparkR.session() +#' cancelJobGroup("myJobGroup") #'} +#' @note cancelJobGroup since 1.5.0 +#' @method cancelJobGroup default +cancelJobGroup.default <- function(groupId) { + sc <- getSparkContext() + callJMethod(sc, "cancelJobGroup", groupId) +} cancelJobGroup <- function(sc, groupId) { - callJMethod(sc, "cancelJobGroup", groupId) + if (class(sc) == "jobj" && any(grepl("JavaSparkContext", getClassName.jobj(sc)))) { + .Deprecated("cancelJobGroup(groupId)", old = "cancelJobGroup(sc, groupId)") + cancelJobGroup.default(groupId) + } else { + # Parameter order is shifted + groupIdToUse <- sc + cancelJobGroup.default(groupIdToUse) + } } sparkConfToSubmitOps <- new.env() @@ -343,6 +491,10 @@ sparkConfToSubmitOps[["spark.driver.memory"]] <- "--driver-memory" sparkConfToSubmitOps[["spark.driver.extraClassPath"]] <- "--driver-class-path" sparkConfToSubmitOps[["spark.driver.extraJavaOptions"]] <- "--driver-java-options" sparkConfToSubmitOps[["spark.driver.extraLibraryPath"]] <- "--driver-library-path" +sparkConfToSubmitOps[["spark.master"]] <- "--master" +sparkConfToSubmitOps[["spark.yarn.keytab"]] <- "--keytab" +sparkConfToSubmitOps[["spark.yarn.principal"]] <- "--principal" + # Utility function that returns Spark Submit arguments as a string # @@ -386,3 +538,35 @@ processSparkPackages <- function(packages) { } splittedPackages } + +# Utility function that checks and install Spark to local folder if not found +# +# Installation will not be triggered if it's called from sparkR shell +# or if the master url is not local +# +# @param sparkHome directory to find Spark package. +# @param master the Spark master URL, used to check local or remote mode. +# @return NULL if no need to update sparkHome, and new sparkHome otherwise. +sparkCheckInstall <- function(sparkHome, master) { + if (!isSparkRShell()) { + if (!is.na(file.info(sparkHome)$isdir)) { + msg <- paste0("Spark package found in SPARK_HOME: ", sparkHome) + message(msg) + NULL + } else { + if (!nzchar(master) || isMasterLocal(master)) { + msg <- paste0("Spark not found in SPARK_HOME: ", + sparkHome) + message(msg) + packageLocalDir <- install.spark() + packageLocalDir + } else { + msg <- paste0("Spark not found in SPARK_HOME: ", + sparkHome, "\n", installInstruction("remote")) + stop(msg) + } + } + } else { + NULL + } +} diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 879b664421316..dcd7198f41ea7 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -19,27 +19,31 @@ setOldClass("jobj") -#' crosstab +#' Computes a pair-wise frequency table of the given columns #' #' Computes a pair-wise frequency table of the given columns. Also known as a contingency #' table. The number of distinct values for each column should be less than 1e4. At most 1e6 #' non-zero pair frequencies will be returned. #' +#' @param x a SparkDataFrame #' @param col1 name of the first column. Distinct items will make the first item of each row. #' @param col2 name of the second column. Distinct items will make the column names of the output. #' @return a local R data.frame representing the contingency table. The first column of each row -#' will be the distinct values of `col1` and the column names will be the distinct values -#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no -#' occurrences will have zero as their counts. +#' will be the distinct values of \code{col1} and the column names will be the distinct values +#' of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". Pairs +#' that have no occurrences will have zero as their counts. #' -#' @rdname statfunctions +#' @rdname crosstab #' @name crosstab +#' @aliases crosstab,SparkDataFrame,character,character-method +#' @family stat functions #' @export #' @examples #' \dontrun{ -#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' df <- read.json("/path/to/file.json") #' ct <- crosstab(df, "title", "gender") #' } +#' @note crosstab since 1.5.0 setMethod("crosstab", signature(x = "SparkDataFrame", col1 = "character", col2 = "character"), function(x, col1, col2) { @@ -48,62 +52,63 @@ setMethod("crosstab", collect(dataFrame(sct)) }) -#' cov -#' #' Calculate the sample covariance of two numerical columns of a SparkDataFrame. #' -#' @param x A SparkDataFrame -#' @param col1 the name of the first column -#' @param col2 the name of the second column -#' @return the covariance of the two columns. +#' @param colName1 the name of the first column +#' @param colName2 the name of the second column +#' @return The covariance of the two columns. #' -#' @rdname statfunctions +#' @rdname cov #' @name cov +#' @aliases cov,SparkDataFrame-method +#' @family stat functions #' @export #' @examples #'\dontrun{ -#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' df <- read.json("/path/to/file.json") #' cov <- cov(df, "title", "gender") #' } +#' @note cov since 1.6.0 setMethod("cov", signature(x = "SparkDataFrame"), - function(x, col1, col2) { - stopifnot(class(col1) == "character" && class(col2) == "character") + function(x, colName1, colName2) { + stopifnot(class(colName1) == "character" && class(colName2) == "character") statFunctions <- callJMethod(x@sdf, "stat") - callJMethod(statFunctions, "cov", col1, col2) + callJMethod(statFunctions, "cov", colName1, colName2) }) -#' corr -#' #' Calculates the correlation of two columns of a SparkDataFrame. #' Currently only supports the Pearson Correlation Coefficient. #' For Spearman Correlation, consider using RDD methods found in MLlib's Statistics. #' -#' @param x A SparkDataFrame -#' @param col1 the name of the first column -#' @param col2 the name of the second column +#' @param colName1 the name of the first column +#' @param colName2 the name of the second column #' @param method Optional. A character specifying the method for calculating the correlation. #' only "pearson" is allowed now. #' @return The Pearson Correlation Coefficient as a Double. #' -#' @rdname statfunctions +#' @rdname corr #' @name corr +#' @aliases corr,SparkDataFrame-method +#' @family stat functions #' @export #' @examples #'\dontrun{ -#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' df <- read.json("/path/to/file.json") #' corr <- corr(df, "title", "gender") #' corr <- corr(df, "title", "gender", method = "pearson") #' } +#' @note corr since 1.6.0 setMethod("corr", signature(x = "SparkDataFrame"), - function(x, col1, col2, method = "pearson") { - stopifnot(class(col1) == "character" && class(col2) == "character") + function(x, colName1, colName2, method = "pearson") { + stopifnot(class(colName1) == "character" && class(colName2) == "character") statFunctions <- callJMethod(x@sdf, "stat") - callJMethod(statFunctions, "corr", col1, col2, method) + callJMethod(statFunctions, "corr", colName1, colName2, method) }) -#' freqItems + +#' Finding frequent items for columns, possibly with false positives #' #' Finding frequent items for columns, possibly with false positives. #' Using the frequent element count algorithm described in @@ -111,18 +116,21 @@ setMethod("corr", #' #' @param x A SparkDataFrame. #' @param cols A vector column names to search frequent items in. -#' @param support (Optional) The minimum frequency for an item to be considered `frequent`. +#' @param support (Optional) The minimum frequency for an item to be considered \code{frequent}. #' Should be greater than 1e-4. Default support = 0.01. #' @return a local R data.frame with the frequent items in each column #' -#' @rdname statfunctions +#' @rdname freqItems #' @name freqItems +#' @aliases freqItems,SparkDataFrame,character-method +#' @family stat functions #' @export #' @examples #' \dontrun{ -#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' df <- read.json("/path/to/file.json") #' fi = freqItems(df, c("title", "gender")) #' } +#' @note freqItems since 1.6.0 setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), function(x, cols, support = 0.01) { statFunctions <- callJMethod(x@sdf, "stat") @@ -130,14 +138,13 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), collect(dataFrame(sct)) }) -#' approxQuantile +#' Calculates the approximate quantiles of a numerical column of a SparkDataFrame #' #' Calculates the approximate quantiles of a numerical column of a SparkDataFrame. -#' #' The result of this algorithm has the following deterministic bound: -#' If the SparkDataFrame has N elements and if we request the quantile at probability `p` up to -#' error `err`, then the algorithm will return a sample `x` from the SparkDataFrame so that the -#' *exact* rank of `x` is close to (p * N). More precisely, +#' If the SparkDataFrame has N elements and if we request the quantile at probability p up to +#' error err, then the algorithm will return a sample x from the SparkDataFrame so that the +#' *exact* rank of x is close to (p * N). More precisely, #' floor((p - err) * N) <= rank(x) <= ceil((p + err) * N). #' This method implements a variation of the Greenwald-Khanna algorithm (with some speed #' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 @@ -152,14 +159,17 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' Note that values greater than 1 are accepted but give the same result as 1. #' @return The approximate quantiles at the given probabilities. #' -#' @rdname statfunctions +#' @rdname approxQuantile #' @name approxQuantile +#' @aliases approxQuantile,SparkDataFrame,character,numeric,numeric-method +#' @family stat functions #' @export #' @examples #' \dontrun{ -#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' df <- read.json("/path/to/file.json") #' quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) #' } +#' @note approxQuantile since 2.0.0 setMethod("approxQuantile", signature(x = "SparkDataFrame", col = "character", probabilities = "numeric", relativeError = "numeric"), @@ -169,9 +179,10 @@ setMethod("approxQuantile", as.list(probabilities), relativeError) }) -#' sampleBy +#' Returns a stratified sample without replacement #' -#' Returns a stratified sample without replacement based on the fraction given on each stratum. +#' Returns a stratified sample without replacement based on the fraction given on each +#' stratum. #' #' @param x A SparkDataFrame #' @param col column that defines strata @@ -180,14 +191,17 @@ setMethod("approxQuantile", #' @param seed random seed #' @return A new SparkDataFrame that represents the stratified sample #' -#' @rdname statfunctions +#' @rdname sampleBy +#' @aliases sampleBy,SparkDataFrame,character,list,numeric-method #' @name sampleBy +#' @family stat functions #' @export #' @examples #'\dontrun{ -#' df <- jsonFile(sqlContext, "/path/to/file.json") +#' df <- read.json("/path/to/file.json") #' sample <- sampleBy(df, "key", fractions, 36) #' } +#' @note sampleBy since 1.6.0 setMethod("sampleBy", signature(x = "SparkDataFrame", col = "character", fractions = "list", seed = "numeric"), diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index ad048b1cd1795..abca703617c7b 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -67,3 +67,19 @@ rToSQLTypes <- as.environment(list( "double" = "double", "character" = "string", "logical" = "boolean")) + +# Helper function of coverting decimal type. When backend returns column type in the +# format of decimal(,) (e.g., decimal(10, 0)), this function coverts the column type +# as double type. This function converts backend returned types that are not the key +# of PRIMITIVE_TYPES, but should be treated as PRIMITIVE_TYPES. +# @param A type returned from the JVM backend. +# @return A type is the key of the PRIMITIVE_TYPES. +specialtypeshandle <- function(type) { + returntype <- NULL + m <- regexec("^decimal(.+)$", type) + matchedStrings <- regmatches(type, m) + if (length(matchedStrings[[1]]) >= 2) { + returntype <- "double" + } + returntype +} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index bf67e231d56f5..248c57532b6cf 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -110,9 +110,12 @@ isRDD <- function(name, env) { #' @return the hash code as an integer #' @export #' @examples +#'\dontrun{ #' hashCode(1L) # 1 #' hashCode(1.0) # 1072693248 #' hashCode("1") # 49 +#'} +#' @note hashCode since 1.4.0 hashCode <- function(key) { if (class(key) == "integer") { as.integer(key[[1]]) @@ -123,20 +126,16 @@ hashCode <- function(key) { as.integer(bitwXor(intBits[2], intBits[1])) } else if (class(key) == "character") { # TODO: SPARK-7839 means we might not have the native library available - if (is.loaded("stringHashCode")) { - .Call("stringHashCode", key) + n <- nchar(key) + if (n == 0) { + 0L } else { - n <- nchar(key) - if (n == 0) { - 0L - } else { - asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) }) - hashC <- 0 - for (k in 1:length(asciiVals)) { - hashC <- mult31AndAdd(hashC, asciiVals[k]) - } - as.integer(hashC) + asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) }) + hashC <- 0 + for (k in 1:length(asciiVals)) { + hashC <- mult31AndAdd(hashC, asciiVals[k]) } + as.integer(hashC) } } else { warning(paste("Could not hash object, returning 0", sep = "")) @@ -157,8 +156,11 @@ wrapInt <- function(value) { # Multiply `val` by 31 and add `addVal` to the result. Ensures that # integer-overflows are handled at every step. +# +# TODO: this function does not handle integer overflow well mult31AndAdd <- function(val, addVal) { vec <- c(bitwShiftL(val, c(4, 3, 2, 1, 0)), addVal) + vec[is.na(vec)] <- 0 Reduce(function(a, b) { wrapInt(as.numeric(a) + as.numeric(b)) }, @@ -312,6 +314,15 @@ convertEnvsToList <- function(keys, vals) { }) } +# Utility function to merge 2 environments with the second overriding values in the first +# env1 is changed in place +overrideEnvs <- function(env1, env2) { + lapply(ls(env2), + function(name) { + env1[[name]] <- env2[[name]] + }) +} + # Utility function to capture the varargs into environment object varargsToEnv <- function(...) { # Based on http://stackoverflow.com/a/3057419/4577954 @@ -486,7 +497,7 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { # checkedFunc An environment of function objects examined during cleanClosure. It can be # considered as a "name"-to-"list of functions" mapping. # return value -# a new version of func that has an correct environment (closure). +# a new version of func that has a correct environment (closure). cleanClosure <- function(func, checkedFuncs = new.env()) { if (is.function(func)) { newEnv <- new.env(parent = .GlobalEnv) @@ -661,3 +672,43 @@ varargsToJProperties <- function(...) { } props } + +launchScript <- function(script, combinedArgs, capture = FALSE) { + if (.Platform$OS.type == "windows") { + scriptWithArgs <- paste(script, combinedArgs, sep = " ") + shell(scriptWithArgs, translate = TRUE, wait = capture, intern = capture) # nolint + } else { + system2(script, combinedArgs, wait = capture, stdout = capture) + } +} + +getSparkContext <- function() { + if (!exists(".sparkRjsc", envir = .sparkREnv)) { + stop("SparkR has not been initialized. Please call sparkR.session()") + } + sc <- get(".sparkRjsc", envir = .sparkREnv) + sc +} + +isMasterLocal <- function(master) { + grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE) +} + +isSparkRShell <- function() { + grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) +} + +# rbind a list of rows with raw (binary) columns +# +# @param inputData a list of rows, with each row a list +# @return data.frame with raw columns as lists +rbindRaws <- function(inputData){ + row1 <- inputData[[1]] + rawcolumns <- ("raw" == sapply(row1, class)) + + listmatrix <- do.call(rbind, inputData) + # A dataframe with all list columns + out <- as.data.frame(listmatrix) + out[!rawcolumns] <- lapply(out[!rawcolumns], unlist) + out +} diff --git a/R/pkg/R/window.R b/R/pkg/R/window.R new file mode 100644 index 0000000000000..0799d841e5dc9 --- /dev/null +++ b/R/pkg/R/window.R @@ -0,0 +1,116 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# window.R - Utility functions for defining window in DataFrames + +#' windowPartitionBy +#' +#' Creates a WindowSpec with the partitioning defined. +#' +#' @param col A column name or Column by which rows are partitioned to +#' windows. +#' @param ... Optional column names or Columns in addition to col, by +#' which rows are partitioned to windows. +#' +#' @rdname windowPartitionBy +#' @name windowPartitionBy +#' @aliases windowPartitionBy,character-method +#' @export +#' @examples +#' \dontrun{ +#' ws <- orderBy(windowPartitionBy("key1", "key2"), "key3") +#' df1 <- select(df, over(lead("value", 1), ws)) +#' +#' ws <- orderBy(windowPartitionBy(df$key1, df$key2), df$key3) +#' df1 <- select(df, over(lead("value", 1), ws)) +#' } +#' @note windowPartitionBy(character) since 2.0.0 +setMethod("windowPartitionBy", + signature(col = "character"), + function(col, ...) { + windowSpec( + callJStatic("org.apache.spark.sql.expressions.Window", + "partitionBy", + col, + list(...))) + }) + +#' @rdname windowPartitionBy +#' @name windowPartitionBy +#' @aliases windowPartitionBy,Column-method +#' @export +#' @note windowPartitionBy(Column) since 2.0.0 +setMethod("windowPartitionBy", + signature(col = "Column"), + function(col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + windowSpec( + callJStatic("org.apache.spark.sql.expressions.Window", + "partitionBy", + jcols)) + }) + +#' windowOrderBy +#' +#' Creates a WindowSpec with the ordering defined. +#' +#' @param col A column name or Column by which rows are ordered within +#' windows. +#' @param ... Optional column names or Columns in addition to col, by +#' which rows are ordered within windows. +#' +#' @rdname windowOrderBy +#' @name windowOrderBy +#' @aliases windowOrderBy,character-method +#' @export +#' @examples +#' \dontrun{ +#' ws <- windowOrderBy("key1", "key2") +#' df1 <- select(df, over(lead("value", 1), ws)) +#' +#' ws <- windowOrderBy(df$key1, df$key2) +#' df1 <- select(df, over(lead("value", 1), ws)) +#' } +#' @note windowOrderBy(character) since 2.0.0 +setMethod("windowOrderBy", + signature(col = "character"), + function(col, ...) { + windowSpec( + callJStatic("org.apache.spark.sql.expressions.Window", + "orderBy", + col, + list(...))) + }) + +#' @rdname windowOrderBy +#' @name windowOrderBy +#' @aliases windowOrderBy,Column-method +#' @export +#' @note windowOrderBy(Column) since 2.0.0 +setMethod("windowOrderBy", + signature(col = "Column"), + function(col, ...) { + jcols <- lapply(list(col, ...), function(c) { + c@jc + }) + windowSpec( + callJStatic("org.apache.spark.sql.expressions.Window", + "orderBy", + jcols)) + }) diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 90a3761e41f82..8a8111a8c5419 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -18,17 +18,17 @@ .First <- function() { home <- Sys.getenv("SPARK_HOME") .libPaths(c(file.path(home, "R", "lib"), .libPaths())) - Sys.setenv(NOAWT=1) + Sys.setenv(NOAWT = 1) # Make sure SparkR package is the last loaded one old <- getOption("defaultPackages") options(defaultPackages = c(old, "SparkR")) - sc <- SparkR::sparkR.init() - assign("sc", sc, envir=.GlobalEnv) - sqlContext <- SparkR::sparkRSQL.init(sc) + spark <- SparkR::sparkR.session() + assign("spark", spark, envir = .GlobalEnv) + sc <- SparkR:::callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", spark) + assign("sc", sc, envir = .GlobalEnv) sparkVer <- SparkR:::callJMethod(sc, "version") - assign("sqlContext", sqlContext, envir=.GlobalEnv) cat("\n Welcome to") cat("\n") cat(" ____ __", "\n") @@ -43,5 +43,5 @@ cat(" /_/", "\n") cat("\n") - cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") + cat("\n SparkSession available as 'spark'.\n") } diff --git a/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar deleted file mode 100644 index 1d5c2af631aa3..0000000000000 Binary files a/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar and /dev/null differ diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R index d68bb20950b00..c9615c8d4faf6 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/inst/tests/testthat/jarTest.R @@ -16,17 +16,17 @@ # library(SparkR) -sc <- sparkR.init() +sc <- sparkR.session() -helloTest <- SparkR:::callJStatic("sparkR.test.hello", +helloTest <- SparkR:::callJStatic("sparkrtest.DummyClass", "helloWorld", "Dave") +stopifnot(identical(helloTest, "Hello Dave")) -basicFunction <- SparkR:::callJStatic("sparkR.test.basicFunction", +basicFunction <- SparkR:::callJStatic("sparkrtest.DummyClass", "addStuff", 2L, 2L) +stopifnot(basicFunction == 4L) -sparkR.stop() -output <- c(helloTest, basicFunction) -writeLines(output) +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R index c26b28b78dee8..4bc935c79eb0f 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -17,13 +17,13 @@ library(SparkR) library(sparkPackageTest) -sc <- sparkR.init() +sparkR.session() run1 <- myfunc(5L) run2 <- myfunc(-4L) -sparkR.stop() +sparkR.session.stop() if (run1 != 6) quit(save = "no", status = 1) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index dddce54d70443..96fb6dda26450 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -17,7 +17,7 @@ context("SerDe functionality") -sc <- sparkR.init() +sparkSession <- sparkR.session() test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", 1L) diff --git a/R/pkg/inst/tests/testthat/test_includeJAR.R b/R/pkg/inst/tests/testthat/test_Windows.R similarity index 54% rename from R/pkg/inst/tests/testthat/test_includeJAR.R rename to R/pkg/inst/tests/testthat/test_Windows.R index f89aa8e507fd5..8813e18a1fa4d 100644 --- a/R/pkg/inst/tests/testthat/test_includeJAR.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -14,24 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -context("include an external JAR in SparkContext") - -runScript <- function() { - sparkHome <- Sys.getenv("SPARK_HOME") - sparkTestJarPath <- "R/lib/SparkR/test_support/sparktestjar_2.10-1.0.jar" - jarPath <- paste("--jars", shQuote(file.path(sparkHome, sparkTestJarPath))) - scriptPath <- file.path(sparkHome, "R/lib/SparkR/tests/testthat/jarTest.R") - submitPath <- file.path(sparkHome, "bin/spark-submit") - res <- system2(command = submitPath, - args = c(jarPath, scriptPath), - stdout = TRUE) - tail(res, 2) -} +context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { - testOutput <- runScript() - helloTest <- testOutput[1] - expect_equal(helloTest, "Hello, Dave") - basicFunction <- testOutput[2] - expect_equal(basicFunction, "4") + if (.Platform$OS.type != "windows") { + skip("This test is only for Windows, skipped") + } + testOutput <- launchScript("ECHO", "a/b/c", capture = TRUE) + abcPath <- testOutput[1] + expect_equal(abcPath, "a\\b\\c") }) diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index 976a7558a816d..f7a0510711da9 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -18,7 +18,8 @@ context("functions on binary files") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") @@ -30,7 +31,7 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { rdd <- textFile(sc, fileName1, 1) saveAsObjectFile(rdd, fileName2) rdd <- objectFile(sc, fileName2) - expect_equal(collect(rdd), as.list(mockFile)) + expect_equal(collectRDD(rdd), as.list(mockFile)) unlink(fileName1) unlink(fileName2, recursive = TRUE) @@ -43,7 +44,7 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { rdd <- parallelize(sc, l, 1) saveAsObjectFile(rdd, fileName) rdd <- objectFile(sc, fileName) - expect_equal(collect(rdd), l) + expect_equal(collectRDD(rdd), l) unlink(fileName, recursive = TRUE) }) @@ -63,7 +64,7 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", saveAsObjectFile(counts, fileName2) counts <- objectFile(sc, fileName2) - output <- collect(counts) + output <- collectRDD(counts) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) @@ -82,7 +83,7 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_equal(count(rdd), 2) + expect_equal(countRDD(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 7bad4d2a7e106..b780b9458545c 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -18,7 +18,8 @@ context("binary functions") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data nums <- 1:10 @@ -28,7 +29,7 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { - actual <- collect(unionRDD(rdd, rdd)) + actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -36,13 +37,13 @@ test_that("union on two RDDs", { text.rdd <- textFile(sc, fileName) union.rdd <- unionRDD(rdd, text.rdd) - actual <- collect(union.rdd) + actual <- collectRDD(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) expect_equal(getSerializedMode(union.rdd), "byte") rdd <- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) - actual <- collect(union.rdd) + actual <- collectRDD(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) expect_equal(getSerializedMode(union.rdd), "byte") @@ -53,14 +54,14 @@ test_that("cogroup on two RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) - actual <- collect(cogroup.rdd) + actual <- collectRDD(cogroup.rdd) expect_equal(actual, list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list())))) rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4))) rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) - actual <- collect(cogroup.rdd) + actual <- collectRDD(cogroup.rdd) expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3)))) expect_equal(sortKeyValueList(actual), @@ -71,7 +72,7 @@ test_that("zipPartitions() on RDDs", { rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 - actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + actual <- collectRDD(zipPartitions(rdd1, rdd2, rdd3, func = function(x, y, z) { list(list(x, y, z))} )) expect_equal(actual, list(list(1, c(1, 2), c(1, 2, 3)), list(2, c(3, 4), c(4, 5, 6)))) @@ -81,19 +82,19 @@ test_that("zipPartitions() on RDDs", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName, 1) - actual <- collect(zipPartitions(rdd, rdd, + actual <- collectRDD(zipPartitions(rdd, rdd, func = function(x, y) { list(paste(x, y, sep = "\n")) })) expected <- list(paste(mockFile, mockFile, sep = "\n")) expect_equal(actual, expected) rdd1 <- parallelize(sc, 0:1, 1) - actual <- collect(zipPartitions(rdd1, rdd, + actual <- collectRDD(zipPartitions(rdd1, rdd, func = function(x, y) { list(x + nchar(y)) })) expected <- list(0:1 + nchar(mockFile)) expect_equal(actual, expected) rdd <- map(rdd, function(x) { x }) - actual <- collect(zipPartitions(rdd, rdd1, + actual <- collectRDD(zipPartitions(rdd, rdd1, func = function(x, y) { list(y + nchar(x)) })) expect_equal(actual, expected) diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 8be6efc3dbed3..064249a57aed4 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -18,7 +18,8 @@ context("broadcast variables") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data nums <- 1:2 @@ -31,7 +32,7 @@ test_that("using broadcast variable", { useBroadcast <- function(x) { sum(SparkR:::value(randomMatBr) * x) } - actual <- collect(lapply(rrdd, useBroadcast)) + actual <- collectRDD(lapply(rrdd, useBroadcast)) expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) expect_equal(actual, expected) }) @@ -42,7 +43,7 @@ test_that("without using broadcast variable", { useBroadcast <- function(x) { sum(randomMat * x) } - actual <- collect(lapply(rrdd, useBroadcast)) + actual <- collectRDD(lapply(rrdd, useBroadcast)) expected <- list(sum(randomMat) * 1, sum(randomMat) * 2) expect_equal(actual, expected) }) diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/inst/tests/testthat/test_client.R index a0664f32f31c1..0cf25fe1dbf39 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/inst/tests/testthat/test_client.R @@ -32,14 +32,12 @@ test_that("no package specified doesn't add packages flag", { }) test_that("multiple packages don't produce a warning", { - expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) + expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", - c("com.databricks:spark-avro_2.10:2.0.1", - "com.databricks:spark-csv_2.10:1.3.0")) + c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") - expect_match(args, - "--packages com.databricks:spark-avro_2.10:2.0.1,com.databricks:spark-csv_2.10:1.3.0") + expect_match(args, "--packages com.databricks:spark-avro_2.10:2.0.1") }) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index ca04342cd5124..66640c4b08459 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -19,16 +19,26 @@ context("test functions in sparkR.R") test_that("Check masked functions", { # Check that we are not masking any new function from base, stats, testthat unexpectedly + # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it + # hard for users to use base R functions. Please check when in doubt. + namesOfMaskedCompletely <- c("cov", "filter", "sample") + namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", + "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", + "summary", "transform", "drop", "window", "as.data.frame", "union") + if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { + namesOfMasked <- c("endsWith", "startsWith", namesOfMasked) + } masked <- conflicts(detail = TRUE)$`package:SparkR` expect_true("describe" %in% masked) # only when with testthat.. func <- lapply(masked, function(x) { capture.output(showMethods(x))[[1]] }) funcSparkROrEmpty <- grepl("\\(package SparkR\\)$|^$", func) maskedBySparkR <- masked[funcSparkROrEmpty] - namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", - "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", - "summary", "transform", "drop", "window", "as.data.frame") expect_equal(length(maskedBySparkR), length(namesOfMasked)) - expect_equal(sort(maskedBySparkR), sort(namesOfMasked)) + # make the 2 lists the same length so expect_equal will print their content + l <- max(length(maskedBySparkR), length(namesOfMasked)) + length(maskedBySparkR) <- l + length(namesOfMasked) <- l + expect_equal(sort(maskedBySparkR, na.last = TRUE), sort(namesOfMasked, na.last = TRUE)) # above are those reported as masked when `library(SparkR)` # note that many of these methods are still callable without base:: or stats:: prefix # there should be a test for each of these, except followings, which are currently "broken" @@ -36,38 +46,43 @@ test_that("Check masked functions", { any(grepl("=\"ANY\"", capture.output(showMethods(x)[-1]))) })) maskedCompletely <- masked[!funcHasAny] - namesOfMaskedCompletely <- c("cov", "filter", "sample") expect_equal(length(maskedCompletely), length(namesOfMaskedCompletely)) - expect_equal(sort(maskedCompletely), sort(namesOfMaskedCompletely)) + l <- max(length(maskedCompletely), length(namesOfMaskedCompletely)) + length(maskedCompletely) <- l + length(namesOfMaskedCompletely) <- l + expect_equal(sort(maskedCompletely, na.last = TRUE), + sort(namesOfMaskedCompletely, na.last = TRUE)) }) test_that("repeatedly starting and stopping SparkR", { for (i in 1:4) { - sc <- sparkR.init() + sc <- suppressWarnings(sparkR.init()) rdd <- parallelize(sc, 1:20, 2L) - expect_equal(count(rdd), 20) - sparkR.stop() + expect_equal(countRDD(rdd), 20) + suppressWarnings(sparkR.stop()) } }) -test_that("repeatedly starting and stopping SparkR SQL", { - for (i in 1:4) { - sc <- sparkR.init() - sqlContext <- sparkRSQL.init(sc) - df <- createDataFrame(sqlContext, data.frame(a = 1:20)) - expect_equal(count(df), 20) - sparkR.stop() - } -}) +# Does not work consistently even with Hive off +# nolint start +# test_that("repeatedly starting and stopping SparkR", { +# for (i in 1:4) { +# sparkR.session(enableHiveSupport = FALSE) +# df <- createDataFrame(data.frame(dummy=1:i)) +# expect_equal(count(df), i) +# sparkR.session.stop() +# Sys.sleep(5) # Need more time to shutdown Hive metastore +# } +# }) +# nolint end test_that("rdd GC across sparkR.stop", { - sparkR.stop() - sc <- sparkR.init() # sc should get id 0 + sc <- sparkR.sparkContext() # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 - sparkR.stop() + sparkR.session.stop() - sc <- sparkR.init() # sc should get id 0 again + sc <- sparkR.sparkContext() # sc should get id 0 again # GC rdd1 before creating rdd3 and rdd2 after rm(rdd1) @@ -79,20 +94,27 @@ test_that("rdd GC across sparkR.stop", { rm(rdd2) gc() - count(rdd3) - count(rdd4) + countRDD(rdd3) + countRDD(rdd4) + sparkR.session.stop() }) test_that("job group functions can be called", { - sc <- sparkR.init() - setJobGroup(sc, "groupId", "job description", TRUE) - cancelJobGroup(sc, "groupId") - clearJobGroup(sc) + sc <- sparkR.sparkContext() + setJobGroup("groupId", "job description", TRUE) + cancelJobGroup("groupId") + clearJobGroup() + + suppressWarnings(setJobGroup(sc, "groupId", "job description", TRUE)) + suppressWarnings(cancelJobGroup(sc, "groupId")) + suppressWarnings(clearJobGroup(sc)) + sparkR.session.stop() }) test_that("utility function can be called", { - sc <- sparkR.init() - setLogLevel(sc, "ERROR") + sparkR.sparkContext() + setLogLevel("ERROR") + sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { @@ -125,25 +147,26 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli test_that("sparkJars sparkPackages as comma-separated strings", { expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) - expect_equal(jars, c("a", "b")) + expect_equal(lapply(jars, basename), list("a", "b")) jars <- suppressWarnings(processSparkJars(" abc ,, def ")) - expect_equal(jars, c("abc", "def")) + expect_equal(lapply(jars, basename), list("abc", "def")) jars <- suppressWarnings(processSparkJars(c(" abc ,, def ", "", "xyz", " ", "a,b"))) - expect_equal(jars, c("abc", "def", "xyz", "a", "b")) + expect_equal(lapply(jars, basename), list("abc", "def", "xyz", "a", "b")) p <- processSparkPackages(c("ghi", "lmn")) expect_equal(p, c("ghi", "lmn")) # check normalizePath f <- dir()[[1]] - expect_that(processSparkJars(f), not(gives_warning())) + expect_warning(processSparkJars(f), NA) expect_match(processSparkJars(f), f) }) test_that("spark.lapply should perform simple transforms", { - sc <- sparkR.init() - doubled <- spark.lapply(sc, 1:10, function(x) { 2 * x }) + sc <- sparkR.sparkContext() + doubled <- spark.lapply(1:10, function(x) { 2 * x }) expect_equal(doubled, as.list(2 * 1:10)) + sparkR.session.stop() }) diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 8152b448d0870..025eb9b9fc9d6 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -18,7 +18,8 @@ context("include R packages") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data nums <- 1:2 @@ -36,7 +37,7 @@ test_that("include inside function", { } data <- lapplyPartition(rdd, generateData) - actual <- collect(data) + actual <- collectRDD(data) } }) @@ -52,6 +53,6 @@ test_that("use include package", { includePackage(sc, plyr) data <- lapplyPartition(rdd, generateData) - actual <- collect(data) + actual <- collectRDD(data) } }) diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/inst/tests/testthat/test_jvm_api.R new file mode 100644 index 0000000000000..7348c893d0af3 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_jvm_api.R @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("JVM API") + +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +test_that("Create and call methods on object", { + jarr <- sparkR.newJObject("java.util.ArrayList") + # Add an element to the array + sparkR.callJMethod(jarr, "add", 1L) + # Check if get returns the same element + expect_equal(sparkR.callJMethod(jarr, "get", 0L), 1L) +}) + +test_that("Call static methods", { + # Convert a boolean to a string + strTrue <- sparkR.callJStatic("java.lang.String", "valueOf", TRUE) + expect_equal(strTrue, "true") +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 37d87aa8a0469..753da81760971 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -20,13 +20,10 @@ library(testthat) context("MLlib functions") # Tests for MLlib functions in SparkR - -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) +sparkSession <- sparkR.session() test_that("formula of spark.glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm model <- spark.glm(training, Sepal_Width ~ . - Species + 0) @@ -41,7 +38,7 @@ test_that("formula of spark.glm", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) # glm should work with long formula - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) training$LongLongLongLongLongName <- training$Sepal_Width training$VeryLongLongLongLonLongName <- training$Sepal_Length training$AnotherLongLongLongLongName <- training$Species @@ -53,7 +50,7 @@ test_that("formula of spark.glm", { }) test_that("spark.glm and predict", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) prediction <- predict(model, training) @@ -80,7 +77,7 @@ test_that("spark.glm and predict", { test_that("spark.glm summary", { # gaussian family - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species)) rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) @@ -99,7 +96,7 @@ test_that("spark.glm summary", { expect_equal(stats$aic, rStats$aic) # binomial family - df <- suppressWarnings(createDataFrame(sqlContext, iris)) + df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width, family = binomial(link = "logit"))) @@ -128,7 +125,7 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -157,7 +154,7 @@ test_that("spark.glm save/load", { test_that("formula of glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) vals <- collect(select(predict(model, training), "prediction")) @@ -171,7 +168,7 @@ test_that("formula of glm", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) # glm should work with long formula - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) training$LongLongLongLongLongName <- training$Sepal_Width training$VeryLongLongLongLonLongName <- training$Sepal_Length training$AnotherLongLongLongLongName <- training$Species @@ -183,7 +180,7 @@ test_that("formula of glm", { }) test_that("glm and predict", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) prediction <- predict(model, training) @@ -210,7 +207,7 @@ test_that("glm and predict", { test_that("glm summary", { # gaussian family - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) @@ -229,7 +226,7 @@ test_that("glm summary", { expect_equal(stats$aic, rStats$aic) # binomial family - df <- suppressWarnings(createDataFrame(sqlContext, iris)) + df <- suppressWarnings(createDataFrame(iris)) training <- df[df$Species %in% c("versicolor", "virginica"), ] stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = binomial(link = "logit"))) @@ -258,7 +255,7 @@ test_that("glm summary", { }) test_that("glm save/load", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) @@ -287,11 +284,11 @@ test_that("glm save/load", { test_that("spark.kmeans", { newIris <- iris newIris$Species <- NULL - training <- suppressWarnings(createDataFrame(sqlContext, newIris)) + training <- suppressWarnings(createDataFrame(newIris)) take(training, 1) - model <- spark.kmeans(data = training, ~ ., k = 2) + model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random") sample <- take(select(predict(model, training), "prediction"), 1) expect_equal(typeof(sample$prediction), "integer") expect_equal(sample$prediction, 1) @@ -365,8 +362,8 @@ test_that("spark.naiveBayes", { t <- as.data.frame(Titanic) t1 <- t[t$Freq > 0, -5] - df <- suppressWarnings(createDataFrame(sqlContext, t1)) - m <- spark.naiveBayes(df, Survived ~ .) + df <- suppressWarnings(createDataFrame(t1)) + m <- spark.naiveBayes(df, Survived ~ ., smoothing = 0.0) s <- summary(m) expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) expect_equal(sum(s$apriori), 1) @@ -420,7 +417,7 @@ test_that("spark.survreg", { # data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) - df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex")) + df <- createDataFrame(data, c("time", "status", "x", "sex")) model <- spark.survreg(df, Surv(time, status) ~ x + sex) stats <- summary(model) coefs <- as.vector(stats$coefficients[, 1]) @@ -450,9 +447,9 @@ test_that("spark.survreg", { if (requireNamespace("survival", quietly = TRUE)) { rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) - expect_that( + expect_error( model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), - not(throws_error())) + NA) expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) } }) diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 2552127cc547f..1b230554f7a0e 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -33,7 +33,8 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -jsc <- sparkR.init() +sparkSession <- sparkR.session() +jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests @@ -66,22 +67,22 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { test_that("collect(), following a parallelize(), gives back the original collections", { numVectorRDD <- parallelize(jsc, numVector, 10) - expect_equal(collect(numVectorRDD), as.list(numVector)) + expect_equal(collectRDD(numVectorRDD), as.list(numVector)) numListRDD <- parallelize(jsc, numList, 1) numListRDD2 <- parallelize(jsc, numList, 4) - expect_equal(collect(numListRDD), as.list(numList)) - expect_equal(collect(numListRDD2), as.list(numList)) + expect_equal(collectRDD(numListRDD), as.list(numList)) + expect_equal(collectRDD(numListRDD2), as.list(numList)) strVectorRDD <- parallelize(jsc, strVector, 2) strVectorRDD2 <- parallelize(jsc, strVector, 3) - expect_equal(collect(strVectorRDD), as.list(strVector)) - expect_equal(collect(strVectorRDD2), as.list(strVector)) + expect_equal(collectRDD(strVectorRDD), as.list(strVector)) + expect_equal(collectRDD(strVectorRDD2), as.list(strVector)) strListRDD <- parallelize(jsc, strList, 4) strListRDD2 <- parallelize(jsc, strList, 1) - expect_equal(collect(strListRDD), as.list(strList)) - expect_equal(collect(strListRDD2), as.list(strList)) + expect_equal(collectRDD(strListRDD), as.list(strList)) + expect_equal(collectRDD(strListRDD2), as.list(strList)) }) test_that("regression: collect() following a parallelize() does not drop elements", { @@ -89,7 +90,7 @@ test_that("regression: collect() following a parallelize() does not drop element collLen <- 10 numPart <- 6 expected <- runif(collLen) - actual <- collect(parallelize(jsc, expected, numPart)) + actual <- collectRDD(parallelize(jsc, expected, numPart)) expect_equal(actual, as.list(expected)) }) @@ -98,12 +99,12 @@ test_that("parallelize() and collect() work for lists of pairs (pairwise data)", numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) numPairsRDDD3 <- parallelize(jsc, numPairs, 3) - expect_equal(collect(numPairsRDDD1), numPairs) - expect_equal(collect(numPairsRDDD2), numPairs) - expect_equal(collect(numPairsRDDD3), numPairs) + expect_equal(collectRDD(numPairsRDDD1), numPairs) + expect_equal(collectRDD(numPairsRDDD2), numPairs) + expect_equal(collectRDD(numPairsRDDD3), numPairs) # can also leave out the parameter name, if the params are supplied in order strPairsRDDD1 <- parallelize(jsc, strPairs, 1) strPairsRDDD2 <- parallelize(jsc, strPairs, 2) - expect_equal(collect(strPairsRDDD1), strPairs) - expect_equal(collect(strPairsRDDD2), strPairs) + expect_equal(collectRDD(strPairsRDDD1), strPairs) + expect_equal(collectRDD(strPairsRDDD2), strPairs) }) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index b6c8e1dc6c1b7..d38a763bab8c6 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -18,7 +18,8 @@ context("basic RDD functions") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data nums <- 1:10 @@ -33,14 +34,14 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_equal(first(rdd), 1) + expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_equal(first(newrdd), 2) + expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - expect_equal(count(rdd), 10) - expect_equal(length(rdd), 10) + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { @@ -56,40 +57,40 @@ test_that("count by values and keys", { test_that("lapply on RDD", { multiples <- lapply(rdd, function(x) { 2 * x }) - actual <- collect(multiples) + actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) - actual <- collect(sums) + actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) - actual <- collect(sums) + actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { flat <- flatMap(intRdd, function(x) { list(x, x) }) - actual <- collect(flat) + actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) - actual <- collect(filtered.rdd) + actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) - actual <- collect(filtered.rdd) + actual <- collectRDD(filtered.rdd) expect_equal(actual, list(list(1L, -1))) # Filter out all elements. filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) - actual <- collect(filtered.rdd) + actual <- collectRDD(filtered.rdd) expect_equal(actual, list()) }) @@ -109,7 +110,7 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { part <- as.list(unlist(part) * partIndex + i) }) rdd2 <- lapply(rdd2, function(x) x + x) - actual <- collect(rdd2) + actual <- collectRDD(rdd2) expected <- list(24, 24, 24, 24, 24, 168, 170, 172, 174, 176) expect_equal(actual, expected) @@ -125,20 +126,20 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp part <- as.list(unlist(part) * partIndex) }) - cache(rdd2) + cacheRDD(rdd2) expect_true(rdd2@env$isCached) rdd2 <- lapply(rdd2, function(x) x) expect_false(rdd2@env$isCached) - unpersist(rdd2) + unpersistRDD(rdd2) expect_false(rdd2@env$isCached) - persist(rdd2, "MEMORY_AND_DISK") + persistRDD(rdd2, "MEMORY_AND_DISK") expect_true(rdd2@env$isCached) rdd2 <- lapply(rdd2, function(x) x) expect_false(rdd2@env$isCached) - unpersist(rdd2) + unpersistRDD(rdd2) expect_false(rdd2@env$isCached) tempDir <- tempfile(pattern = "checkpoint") @@ -151,7 +152,7 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp expect_false(rdd2@env$isCheckpointed) # make sure the data is collectable - collect(rdd2) + collectRDD(rdd2) unlink(tempDir) }) @@ -168,21 +169,21 @@ test_that("reduce on RDD", { test_that("lapply with dependency", { fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) - actual <- collect(multiples) + actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 5)) }) test_that("lapplyPartitionsWithIndex on RDDs", { func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } - actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) + actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } mkTup <- function(partIndex, part) { list(partIndex, part) } - actual <- collect(lapplyPartitionsWithIndex( - partitionBy(pairsRDD, 2L, partitionByParity), + actual <- collectRDD(lapplyPartitionsWithIndex( + partitionByRDD(pairsRDD, 2L, partitionByParity), mkTup), FALSE) expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))), @@ -190,7 +191,7 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { - expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) + expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { @@ -237,7 +238,7 @@ test_that("takeSample() on RDDs", { test_that("mapValues() on pairwise RDDs", { multiples <- mapValues(intRdd, function(x) { x * 2 }) - actual <- collect(multiples) + actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { list(x[[1]], x[[2]] * 2) }) @@ -246,11 +247,11 @@ test_that("mapValues() on pairwise RDDs", { test_that("flatMapValues() on pairwise RDDs", { l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) - actual <- collect(flatMapValues(l, function(x) { x })) + actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) # Generate x to x+1 for every value - actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) + actual <- collectRDD(flatMapValues(intRdd, function(x) { x: (x + 1) })) expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) @@ -272,8 +273,8 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { test_that("distinct() on RDDs", { nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) - uniques <- distinct(rdd.rep2) - actual <- sort(unlist(collect(uniques))) + uniques <- distinctRDD(rdd.rep2) + actual <- sort(unlist(collectRDD(uniques))) expect_equal(actual, nums) }) @@ -295,7 +296,7 @@ test_that("sumRDD() on RDDs", { test_that("keyBy on RDDs", { func <- function(x) { x * x } keys <- keyBy(rdd, func) - actual <- collect(keys) + actual <- collectRDD(keys) expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) }) @@ -303,12 +304,12 @@ test_that("repartition/coalesce on RDDs", { rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition - r1 <- repartition(rdd, 2) + r1 <- repartitionRDD(rdd, 2) expect_equal(getNumPartitions(r1), 2L) count <- length(collectPartition(r1, 0L)) expect_true(count >= 8 && count <= 12) - r2 <- repartition(rdd, 6) + r2 <- repartitionRDD(rdd, 6) expect_equal(getNumPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) expect_true(count >= 0 && count <= 4) @@ -322,12 +323,12 @@ test_that("repartition/coalesce on RDDs", { test_that("sortBy() on RDDs", { sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) - actual <- collect(sortedRdd) + actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) sortedRdd2 <- sortBy(rdd2, function(x) { x * x }) - actual <- collect(sortedRdd2) + actual <- collectRDD(sortedRdd2) expect_equal(actual, as.list(nums)) }) @@ -379,13 +380,13 @@ test_that("aggregateRDD() on RDDs", { test_that("zipWithUniqueId() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) - actual <- collect(zipWithUniqueId(rdd)) + actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) expect_equal(actual, expected) rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) - actual <- collect(zipWithUniqueId(rdd)) + actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) @@ -393,13 +394,13 @@ test_that("zipWithUniqueId() on RDDs", { test_that("zipWithIndex() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) - actual <- collect(zipWithIndex(rdd)) + actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) - actual <- collect(zipWithIndex(rdd)) + actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) @@ -407,35 +408,35 @@ test_that("zipWithIndex() on RDDs", { test_that("glom() on RDD", { rdd <- parallelize(sc, as.list(1:4), 2L) - actual <- collect(glom(rdd)) + actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { keys <- keys(intRdd) - actual <- collect(keys) + actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { values <- values(intRdd) - actual <- collect(values) + actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { - actual <- collect(pipeRDD(rdd, "more")) + actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) - actual <- collect(pipeRDD(trailed.rdd, "sort")) + actual <- collectRDD(pipeRDD(trailed.rdd, "sort")) expected <- list("", "1", "2", "3") expect_equal(actual, expected) rev.nums <- 9:0 rev.rdd <- parallelize(sc, rev.nums, 2L) - actual <- collect(pipeRDD(rev.rdd, "sort")) + actual <- collectRDD(pipeRDD(rev.rdd, "sort")) expected <- as.list(as.character(c(5:9, 0:4))) expect_equal(actual, expected) }) @@ -443,7 +444,7 @@ test_that("pipeRDD() on RDDs", { test_that("zipRDD() on RDDs", { rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) - actual <- collect(zipRDD(rdd1, rdd2)) + actual <- collectRDD(zipRDD(rdd1, rdd2)) expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) @@ -452,17 +453,17 @@ test_that("zipRDD() on RDDs", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName, 1) - actual <- collect(zipRDD(rdd, rdd)) + actual <- collectRDD(zipRDD(rdd, rdd)) expected <- lapply(mockFile, function(x) { list(x, x) }) expect_equal(actual, expected) rdd1 <- parallelize(sc, 0:1, 1) - actual <- collect(zipRDD(rdd1, rdd)) + actual <- collectRDD(zipRDD(rdd1, rdd)) expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) }) expect_equal(actual, expected) rdd1 <- map(rdd, function(x) { x }) - actual <- collect(zipRDD(rdd, rdd1)) + actual <- collectRDD(zipRDD(rdd, rdd1)) expected <- lapply(mockFile, function(x) { list(x, x) }) expect_equal(actual, expected) @@ -471,7 +472,7 @@ test_that("zipRDD() on RDDs", { test_that("cartesian() on RDDs", { rdd <- parallelize(sc, 1:3) - actual <- collect(cartesian(rdd, rdd)) + actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), list( list(1, 1), list(1, 2), list(1, 3), @@ -480,7 +481,7 @@ test_that("cartesian() on RDDs", { # test case where one RDD is empty emptyRdd <- parallelize(sc, list()) - actual <- collect(cartesian(rdd, emptyRdd)) + actual <- collectRDD(cartesian(rdd, emptyRdd)) expect_equal(actual, list()) mockFile <- c("Spark is pretty.", "Spark is awesome.") @@ -488,7 +489,7 @@ test_that("cartesian() on RDDs", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - actual <- collect(cartesian(rdd, rdd)) + actual <- collectRDD(cartesian(rdd, rdd)) expected <- list( list("Spark is awesome.", "Spark is pretty."), list("Spark is awesome.", "Spark is awesome."), @@ -497,7 +498,7 @@ test_that("cartesian() on RDDs", { expect_equal(sortKeyValueList(actual), expected) rdd1 <- parallelize(sc, 0:1) - actual <- collect(cartesian(rdd1, rdd)) + actual <- collectRDD(cartesian(rdd1, rdd)) expect_equal(sortKeyValueList(actual), list( list(0, "Spark is pretty."), @@ -506,7 +507,7 @@ test_that("cartesian() on RDDs", { list(1, "Spark is awesome."))) rdd1 <- map(rdd, function(x) { x }) - actual <- collect(cartesian(rdd, rdd1)) + actual <- collectRDD(cartesian(rdd, rdd1)) expect_equal(sortKeyValueList(actual), expected) unlink(fileName) @@ -517,24 +518,24 @@ test_that("subtract() on RDDs", { rdd1 <- parallelize(sc, l) # subtract by itself - actual <- collect(subtract(rdd1, rdd1)) + actual <- collectRDD(subtract(rdd1, rdd1)) expect_equal(actual, list()) # subtract by an empty RDD rdd2 <- parallelize(sc, list()) - actual <- collect(subtract(rdd1, rdd2)) + actual <- collectRDD(subtract(rdd1, rdd2)) expect_equal(as.list(sort(as.vector(actual, mode = "integer"))), l) rdd2 <- parallelize(sc, list(2, 4)) - actual <- collect(subtract(rdd1, rdd2)) + actual <- collectRDD(subtract(rdd1, rdd2)) expect_equal(as.list(sort(as.vector(actual, mode = "integer"))), list(1, 1, 3)) l <- list("a", "a", "b", "b", "c", "d") rdd1 <- parallelize(sc, l) rdd2 <- parallelize(sc, list("b", "d")) - actual <- collect(subtract(rdd1, rdd2)) + actual <- collectRDD(subtract(rdd1, rdd2)) expect_equal(as.list(sort(as.vector(actual, mode = "character"))), list("a", "a", "c")) }) @@ -545,17 +546,17 @@ test_that("subtractByKey() on pairwise RDDs", { rdd1 <- parallelize(sc, l) # subtractByKey by itself - actual <- collect(subtractByKey(rdd1, rdd1)) + actual <- collectRDD(subtractByKey(rdd1, rdd1)) expect_equal(actual, list()) # subtractByKey by an empty RDD rdd2 <- parallelize(sc, list()) - actual <- collect(subtractByKey(rdd1, rdd2)) + actual <- collectRDD(subtractByKey(rdd1, rdd2)) expect_equal(sortKeyValueList(actual), sortKeyValueList(l)) rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) - actual <- collect(subtractByKey(rdd1, rdd2)) + actual <- collectRDD(subtractByKey(rdd1, rdd2)) expect_equal(actual, list(list("b", 4), list("b", 5))) @@ -563,76 +564,76 @@ test_that("subtractByKey() on pairwise RDDs", { list(2, 5), list(1, 2)) rdd1 <- parallelize(sc, l) rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1))) - actual <- collect(subtractByKey(rdd1, rdd2)) + actual <- collectRDD(subtractByKey(rdd1, rdd2)) expect_equal(actual, list(list(2, 4), list(2, 5))) }) test_that("intersection() on RDDs", { # intersection with self - actual <- collect(intersection(rdd, rdd)) + actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) # intersection with an empty RDD emptyRdd <- parallelize(sc, list()) - actual <- collect(intersection(rdd, emptyRdd)) + actual <- collectRDD(intersection(rdd, emptyRdd)) expect_equal(actual, list()) rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) - actual <- collect(intersection(rdd1, rdd2)) + actual <- collectRDD(intersection(rdd1, rdd2)) expect_equal(sort(as.integer(actual)), 1:3) }) test_that("join() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) - actual <- collect(join(rdd1, rdd2, 2L)) + actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3))))) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4))) rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3))) - actual <- collect(join(rdd1, rdd2, 2L)) + actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3))))) rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) - actual <- collect(join(rdd1, rdd2, 2L)) + actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) expect_equal(actual, list()) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) - actual <- collect(join(rdd1, rdd2, 2L)) + actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) expect_equal(actual, list()) }) test_that("leftOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) - actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4))) rdd2 <- parallelize(sc, list(list("a", 2), list("a", 3))) - actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) - actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) - actual <- collect(leftOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -641,26 +642,26 @@ test_that("leftOuterJoin() on pairwise RDDs", { test_that("rightOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) - actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3))) rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4))) - actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) - actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) - actual <- collect(rightOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) @@ -668,14 +669,14 @@ test_that("rightOuterJoin() on pairwise RDDs", { test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) - actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) rdd1 <- parallelize(sc, list(list("a", 2), list("a", 3), list("c", 1))) rdd2 <- parallelize(sc, list(list("a", 1), list("b", 4))) - actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL))) expect_equal(sortKeyValueList(actual), @@ -683,14 +684,14 @@ test_that("fullOuterJoin() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1, 1), list(2, 2))) rdd2 <- parallelize(sc, list(list(3, 3), list(4, 4))) - actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4))))) rdd1 <- parallelize(sc, list(list("a", 1), list("b", 2))) rdd2 <- parallelize(sc, list(list("c", 3), list("d", 4))) - actual <- collect(fullOuterJoin(rdd1, rdd2, 2L)) + actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) @@ -699,21 +700,21 @@ test_that("fullOuterJoin() on pairwise RDDs", { test_that("sortByKey() on pairwise RDDs", { numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) - actual <- collect(sortedRdd) + actual <- collectRDD(sortedRdd) numPairs <- lapply(nums, function(x) { list (x, x) }) expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) sortedRdd2 <- sortByKey(numPairsRdd2) - actual <- collect(sortedRdd2) + actual <- collectRDD(sortedRdd2) expect_equal(actual, numPairs) # sort by string keys l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5)) rdd3 <- parallelize(sc, l, 2L) sortedRdd3 <- sortByKey(rdd3) - actual <- collect(sortedRdd3) + actual <- collectRDD(sortedRdd3) expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) # test on the boundary cases @@ -721,27 +722,27 @@ test_that("sortByKey() on pairwise RDDs", { # boundary case 1: the RDD to be sorted has only 1 partition rdd4 <- parallelize(sc, l, 1L) sortedRdd4 <- sortByKey(rdd4) - actual <- collect(sortedRdd4) + actual <- collectRDD(sortedRdd4) expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) # boundary case 2: the sorted RDD has only 1 partition rdd5 <- parallelize(sc, l, 2L) sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L) - actual <- collect(sortedRdd5) + actual <- collectRDD(sortedRdd5) expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) # boundary case 3: the RDD to be sorted has only 1 element l2 <- list(list("a", 1)) rdd6 <- parallelize(sc, l2, 2L) sortedRdd6 <- sortByKey(rdd6) - actual <- collect(sortedRdd6) + actual <- collectRDD(sortedRdd6) expect_equal(actual, l2) # boundary case 4: the RDD to be sorted has 0 element l3 <- list() rdd7 <- parallelize(sc, l3, 2L) sortedRdd7 <- sortByKey(rdd7) - actual <- collect(sortedRdd7) + actual <- collectRDD(sortedRdd7) expect_equal(actual, l3) }) @@ -765,7 +766,7 @@ test_that("collectAsMap() on a pairwise RDD", { test_that("show()", { rdd <- parallelize(sc, list(1:10)) - expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") + expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index d3d0f8a24d01c..07f3b02df6649 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -18,7 +18,8 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) @@ -38,7 +39,7 @@ strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { grouped <- groupByKey(intRdd, 2L) - actual <- collect(grouped) + actual <- collectRDD(grouped) expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -47,7 +48,7 @@ test_that("groupByKey for integers", { test_that("groupByKey for doubles", { grouped <- groupByKey(doubleRdd, 2L) - actual <- collect(grouped) + actual <- collectRDD(grouped) expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -56,7 +57,7 @@ test_that("groupByKey for doubles", { test_that("reduceByKey for ints", { reduced <- reduceByKey(intRdd, "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -64,7 +65,7 @@ test_that("reduceByKey for ints", { test_that("reduceByKey for doubles", { reduced <- reduceByKey(doubleRdd, "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list(1.5, 199), list(2.5, 101)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -73,7 +74,7 @@ test_that("reduceByKey for doubles", { test_that("combineByKey for ints", { reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -81,7 +82,7 @@ test_that("combineByKey for ints", { test_that("combineByKey for doubles", { reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list(1.5, 199), list(2.5, 101)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -93,7 +94,7 @@ test_that("combineByKey for characters", { list("other", 3L), list("max", 4L)), 2L) reduced <- combineByKey(stringKeyRDD, function(x) { x }, "+", "+", 2L) - actual <- collect(reduced) + actual <- collectRDD(reduced) expected <- list(list("max", 5L), list("min", 2L), list("other", 3L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -108,7 +109,7 @@ test_that("aggregateByKey", { combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) - actual <- collect(aggregatedRDD) + actual <- collectRDD(aggregatedRDD) expected <- list(list(1, list(3, 2)), list(2, list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -121,7 +122,7 @@ test_that("aggregateByKey", { combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) - actual <- collect(aggregatedRDD) + actual <- collectRDD(aggregatedRDD) expected <- list(list("a", list(3, 2)), list("b", list(7, 2))) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -131,7 +132,7 @@ test_that("foldByKey", { # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list(list(2L, 101), list(1L, 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -139,7 +140,7 @@ test_that("foldByKey", { # test foldByKey for double keys folded <- foldByKey(doubleRdd, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list(list(1.5, 199), list(2.5, 101)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -150,7 +151,7 @@ test_that("foldByKey", { stringKeyRDD <- parallelize(sc, stringKeyPairs) folded <- foldByKey(stringKeyRDD, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list(list("b", 101), list("a", 199)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -158,14 +159,14 @@ test_that("foldByKey", { # test foldByKey for empty pair RDD rdd <- parallelize(sc, list()) folded <- foldByKey(rdd, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list() expect_equal(actual, expected) # test foldByKey for RDD with only 1 pair rdd <- parallelize(sc, list(list(1, 1))) folded <- foldByKey(rdd, 0, "+", 2L) - actual <- collect(folded) + actual <- collectRDD(folded) expected <- list(list(1, 1)) expect_equal(actual, expected) }) @@ -174,7 +175,7 @@ test_that("partitionBy() partitions data correctly", { # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } - resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) + resultRDD <- partitionByRDD(numPairsRdd, 2L, partitionByMagnitude) expected_first <- list(list(1, 100), list(2, 200)) # key less than 3 expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key greater than or equal 3 @@ -190,7 +191,7 @@ test_that("partitionBy works with dependencies", { partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } # Partition by parity - resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity) + resultRDD <- partitionByRDD(numPairsRdd, numPartitions = 2L, partitionByParity) # keys even; 100 %% 2 == 0 expected_first <- list(list(2, 200), list(4, -1)) @@ -207,7 +208,7 @@ test_that("test partitionBy with string keys", { words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) - resultRDD <- partitionBy(wordCount, 2L) + resultRDD <- partitionByRDD(wordCount, 2L) expected_first <- list(list("Dexter", 1), list("Dexter", 1)) expected_second <- list(list("and", 1), list("and", 1)) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 5cf9dc405b169..ef6cab1e1c92d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -32,17 +32,43 @@ markUtf8 <- function(s) { s } -# Tests for SparkSQL functions in SparkR +setHiveContext <- function(sc) { + if (exists(".testHiveSession", envir = .sparkREnv)) { + hiveSession <- get(".testHiveSession", envir = .sparkREnv) + } else { + # initialize once and reuse + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc, FALSE) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + hiveSession <- callJMethod(hiveCtx, "sparkSession") + } + previousSession <- get(".sparkRsession", envir = .sparkREnv) + assign(".sparkRsession", hiveSession, envir = .sparkREnv) + assign(".prevSparkRsession", previousSession, envir = .sparkREnv) + hiveSession +} -sc <- sparkR.init() +unsetHiveContext <- function() { + previousSession <- get(".prevSparkRsession", envir = .sparkREnv) + assign(".sparkRsession", previousSession, envir = .sparkREnv) + remove(".prevSparkRsession", envir = .sparkREnv) +} -sqlContext <- sparkRSQL.init(sc) +# Tests for SparkSQL functions in SparkR + +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"Justin\", \"age\":19}") jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") +orcPath <- tempfile(pattern = "sparkr-test", fileext = ".orc") writeLines(mockLines, jsonPath) # For test nafunctions, like dropna(), fillna(),... @@ -63,7 +89,16 @@ complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) test_that("calling sparkRSQL.init returns existing SQL context", { - expect_equal(sparkRSQL.init(sc), sqlContext) + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) +}) + +test_that("calling sparkRSQL.init returns existing SparkSession", { + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) +}) + +test_that("calling sparkR.session returns existing SparkSession", { + expect_equal(sparkR.session(), sparkSession) }) test_that("infer types and check types", { @@ -99,8 +134,8 @@ test_that("structType and structField", { test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) - dfAsDF <- as.DataFrame(sqlContext, rdd, list("a", "b")) + df <- createDataFrame(rdd, list("a", "b")) + dfAsDF <- as.DataFrame(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") expect_is(dfAsDF, "SparkDataFrame") expect_equal(count(df), 10) @@ -116,8 +151,8 @@ test_that("create DataFrame from RDD", { expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) expect_equal(dtypes(dfAsDF), list(c("a", "int"), c("b", "string"))) - df <- createDataFrame(sqlContext, rdd) - dfAsDF <- as.DataFrame(sqlContext, rdd) + df <- createDataFrame(rdd) + dfAsDF <- as.DataFrame(rdd) expect_is(df, "SparkDataFrame") expect_is(dfAsDF, "SparkDataFrame") expect_equal(columns(df), c("_1", "_2")) @@ -125,13 +160,13 @@ test_that("create DataFrame from RDD", { schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) - df <- createDataFrame(sqlContext, rdd, schema) + df <- createDataFrame(rdd, schema) expect_is(df, "SparkDataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- createDataFrame(sqlContext, rdd) + df <- createDataFrame(rdd) expect_is(df, "SparkDataFrame") expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) @@ -139,9 +174,9 @@ test_that("create DataFrame from RDD", { schema <- structType(structField("name", "string"), structField("age", "integer"), structField("height", "float")) - df <- read.df(sqlContext, jsonPathNa, "json", schema) - df2 <- createDataFrame(sqlContext, toRDD(df), schema) - df2AsDF <- as.DataFrame(sqlContext, toRDD(df), schema) + df <- read.df(jsonPathNa, "json", schema) + df2 <- createDataFrame(toRDD(df), schema) + df2AsDF <- as.DataFrame(toRDD(df), schema) expect_equal(columns(df2), c("name", "age", "height")) expect_equal(columns(df2AsDF), c("name", "age", "height")) expect_equal(dtypes(df2), list(c("name", "string"), c("age", "int"), c("height", "float"))) @@ -154,7 +189,7 @@ test_that("create DataFrame from RDD", { localDF <- data.frame(name = c("John", "Smith", "Sarah"), age = c(19L, 23L, 18L), height = c(176.5, 181.4, 173.7)) - df <- createDataFrame(sqlContext, localDF, schema) + df <- createDataFrame(localDF, schema) expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) expect_equal(columns(df), c("name", "age", "height")) @@ -162,55 +197,109 @@ test_that("create DataFrame from RDD", { expect_equal(as.list(collect(where(df, df$name == "John"))), list(name = "John", age = 19L, height = 176.5)) - ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") - df <- read.df(hiveCtx, jsonPathNa, "json", schema) + setHiveContext(sc) + sql("CREATE TABLE people (name string, age double, height float)") + df <- read.df(jsonPathNa, "json", schema) invisible(insertInto(df, "people")) - expect_equal(collect(sql(hiveCtx, "SELECT age from people WHERE name = 'Bob'"))$age, + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) - expect_equal(collect(sql(hiveCtx, "SELECT height from people WHERE name ='Bob'"))$height, + expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, c(176.5)) + unsetHiveContext() +}) + +test_that("createDataFrame uses files for large objects", { + # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value + conf <- callJMethod(sparkSession, "conf") + callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") + df <- suppressWarnings(createDataFrame(iris)) + + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10)) + expect_equal(dim(df), dim(iris)) +}) + +test_that("read/write csv as DataFrame", { + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "NA,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + # default "header" is false, inferSchema to handle "year" as "int" + df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") + expect_equal(count(df), 4) + expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) + expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), + sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) + + # since "year" is "int", let's skip the NA values + withoutna <- na.omit(df, how = "any", cols = "year") + expect_equal(count(withoutna), 3) + + unlink(csvPath) + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "Empty,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") + expect_equal(count(df2), 4) + withoutna2 <- na.omit(df2, how = "any", cols = "year") + expect_equal(count(withoutna2), 3) + expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) + + # writing csv file + csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") + write.df(df2, path = csvPath2, "csv", header = "true") + df3 <- read.df(csvPath2, "csv", header = "true") + expect_equal(nrow(df3), nrow(df2)) + expect_equal(colnames(df3), colnames(df2)) + csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) + expect_equal(colnames(df3), colnames(csv)) + + unlink(csvPath) + unlink(csvPath2) }) test_that("convert NAs to null type in DataFrames", { rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L))) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) + df <- createDataFrame(rdd, list("a", "b")) expect_true(is.na(collect(df)[2, "a"])) expect_equal(collect(df)[2, "b"], 4L) l <- data.frame(x = 1L, y = c(1L, NA_integer_, 3L)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(df)[2, "x"], 1L) expect_true(is.na(collect(df)[2, "y"])) rdd <- parallelize(sc, list(list(1, 2), list(NA, 4))) - df <- createDataFrame(sqlContext, rdd, list("a", "b")) + df <- createDataFrame(rdd, list("a", "b")) expect_true(is.na(collect(df)[2, "a"])) expect_equal(collect(df)[2, "b"], 4) l <- data.frame(x = 1, y = c(1, NA_real_, 3)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(df)[2, "x"], 1) expect_true(is.na(collect(df)[2, "y"])) l <- list("a", "b", NA, "d") - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_true(is.na(collect(df)[3, "_1"])) expect_equal(collect(df)[4, "_1"], "d") l <- list("a", "b", NA_character_, "d") - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_true(is.na(collect(df)[3, "_1"])) expect_equal(collect(df)[4, "_1"], "d") l <- list(TRUE, FALSE, NA, TRUE) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_true(is.na(collect(df)[3, "_1"])) expect_equal(collect(df)[4, "_1"], TRUE) }) @@ -244,46 +333,59 @@ test_that("toDF", { test_that("create DataFrame from list or data.frame", { l <- list(list(1, 2), list(3, 4)) - df <- createDataFrame(sqlContext, l, c("a", "b")) + df <- createDataFrame(l, c("a", "b")) expect_equal(columns(df), c("a", "b")) l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(columns(df), c("a", "b")) a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) - df <- createDataFrame(sqlContext, ldf) + df <- createDataFrame(ldf) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) expect_equal(count(df), 3) ldf2 <- collect(df) expect_equal(ldf$a, ldf2$a) - irisdf <- suppressWarnings(createDataFrame(sqlContext, iris)) + irisdf <- suppressWarnings(createDataFrame(iris)) iris_collected <- collect(irisdf) expect_equivalent(iris_collected[, -5], iris[, -5]) expect_equal(iris_collected$Species, as.character(iris$Species)) - mtcarsdf <- createDataFrame(sqlContext, mtcars) + mtcarsdf <- createDataFrame(mtcars) expect_equivalent(collect(mtcarsdf), mtcars) bytes <- as.raw(c(1, 2, 3)) - df <- createDataFrame(sqlContext, list(list(bytes))) + df <- createDataFrame(list(list(bytes))) expect_equal(collect(df)[[1]][[1]], bytes) }) test_that("create DataFrame with different data types", { l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), f = as.POSIXct("2015-03-15 12:13:14.056")) - df <- createDataFrame(sqlContext, list(l)) + df <- createDataFrame(list(l)) expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), c("d", "string"), c("e", "date"), c("f", "timestamp"))) expect_equal(count(df), 1) expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) +test_that("SPARK-17811: can create DataFrame containing NA as date and time", { + df <- data.frame( + id = 1:2, + time = c(as.POSIXlt("2016-01-10"), NA), + date = c(as.Date("2016-10-01"), NA)) + + DF <- collect(createDataFrame(df)) + expect_true(is.na(DF$date[2])) + expect_equal(DF$date[1], as.Date("2016-10-01")) + expect_true(is.na(DF$time[2])) + expect_equal(DF$time[1], as.POSIXlt("2016-01-10")) +}) + test_that("create DataFrame with complex types", { e <- new.env() assign("n", 3L, envir = e) @@ -291,7 +393,7 @@ test_that("create DataFrame with complex types", { s <- listToStruct(list(a = "aa", b = 3L)) l <- list(as.list(1:10), list("a", "b"), e, s) - df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) + df <- createDataFrame(list(l), c("a", "b", "c", "d")) expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), c("c", "map"), @@ -318,7 +420,7 @@ test_that("create DataFrame from a data.frame with complex types", { ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) - sdf <- createDataFrame(sqlContext, ldf) + sdf <- createDataFrame(ldf) collected <- collect(sdf) expect_identical(ldf[, 1, FALSE], collected[, 1, FALSE]) @@ -334,7 +436,7 @@ writeLines(mockLinesMapType, mapTypeJsonPath) test_that("Collect DataFrame with complex types", { # ArrayType - df <- read.json(sqlContext, complexTypeJsonPath) + df <- read.json(complexTypeJsonPath) ldf <- collect(df) expect_equal(nrow(ldf), 3) expect_equal(ncol(ldf), 3) @@ -346,7 +448,7 @@ test_that("Collect DataFrame with complex types", { # MapType schema <- structType(structField("name", "string"), structField("info", "map")) - df <- read.df(sqlContext, mapTypeJsonPath, "json", schema) + df <- read.df(mapTypeJsonPath, "json", schema) expect_equal(dtypes(df), list(c("name", "string"), c("info", "map"))) ldf <- collect(df) @@ -360,7 +462,7 @@ test_that("Collect DataFrame with complex types", { expect_equal(bob$height, 176.5) # StructType - df <- read.json(sqlContext, mapTypeJsonPath) + df <- read.json(mapTypeJsonPath) expect_equal(dtypes(df), list(c("info", "struct"), c("name", "string"))) ldf <- collect(df) @@ -376,7 +478,7 @@ test_that("Collect DataFrame with complex types", { test_that("read/write json files", { # Test read.df - df <- read.df(sqlContext, jsonPath, "json") + df <- read.df(jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) @@ -384,17 +486,17 @@ test_that("read/write json files", { schema <- structType(structField("name", type = "string"), structField("age", type = "double")) - df1 <- read.df(sqlContext, jsonPath, "json", schema) + df1 <- read.df(jsonPath, "json", schema) expect_is(df1, "SparkDataFrame") expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) # Test loadDF - df2 <- loadDF(sqlContext, jsonPath, "json", schema) + df2 <- loadDF(jsonPath, "json", schema) expect_is(df2, "SparkDataFrame") expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) # Test read.json - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) @@ -407,11 +509,11 @@ test_that("read/write json files", { write.json(df, jsonPath3) # Test read.json()/jsonFile() works with multiple input paths - jsonDF1 <- read.json(sqlContext, c(jsonPath2, jsonPath3)) + jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) expect_is(jsonDF1, "SparkDataFrame") expect_equal(count(jsonDF1), 6) # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(sqlContext, c(jsonPath2, jsonPath3))) + jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) expect_is(jsonDF2, "SparkDataFrame") expect_equal(count(jsonDF2), 6) @@ -420,8 +522,9 @@ test_that("read/write json files", { }) test_that("jsonRDD() on a RDD with json string", { + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) - expect_equal(count(rdd), 3) + expect_equal(countRDD(rdd), 3) df <- suppressWarnings(jsonRDD(sqlContext, rdd)) expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) @@ -432,89 +535,109 @@ test_that("jsonRDD() on a RDD with json string", { expect_equal(count(df), 6) }) -test_that("test cache, uncache and clearCache", { - df <- read.json(sqlContext, jsonPath) - registerTempTable(df, "table1") - cacheTable(sqlContext, "table1") - uncacheTable(sqlContext, "table1") - clearCache(sqlContext) - dropTempTable(sqlContext, "table1") -}) - test_that("test tableNames and tables", { - df <- read.json(sqlContext, jsonPath) - registerTempTable(df, "table1") - expect_equal(length(tableNames(sqlContext)), 1) - df <- tables(sqlContext) - expect_equal(count(df), 1) - dropTempTable(sqlContext, "table1") -}) - -test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- read.json(sqlContext, jsonPath) - registerTempTable(df, "table1") - newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") + df <- read.json(jsonPath) + createOrReplaceTempView(df, "table1") + expect_equal(length(tableNames()), 1) + tables <- tables() + expect_equal(count(tables), 1) + + suppressWarnings(registerTempTable(df, "table2")) + tables <- tables() + expect_equal(count(tables), 2) + suppressWarnings(dropTempTable("table1")) + dropTempView("table2") + + tables <- tables() + expect_equal(count(tables), 0) +}) + +test_that( + "createOrReplaceTempView() results in a queryable table and sql() results in a new DataFrame", { + df <- read.json(jsonPath) + createOrReplaceTempView(df, "table1") + newdf <- sql("SELECT * FROM table1 where name = 'Michael'") expect_is(newdf, "SparkDataFrame") expect_equal(count(newdf), 1) - dropTempTable(sqlContext, "table1") + dropTempView("table1") + + createOrReplaceTempView(df, "dfView") + sqlCast <- collect(sql("select cast('2' as decimal) as x from dfView limit 1")) + out <- capture.output(sqlCast) + expect_true(is.data.frame(sqlCast)) + expect_equal(names(sqlCast)[1], "x") + expect_equal(nrow(sqlCast), 1) + expect_equal(ncol(sqlCast), 1) + expect_equal(out[1], " x") + expect_equal(out[2], "1 2") + dropTempView("dfView") +}) + +test_that("test cache, uncache and clearCache", { + df <- read.json(jsonPath) + createOrReplaceTempView(df, "table1") + cacheTable("table1") + uncacheTable("table1") + clearCache() + dropTempView("table1") }) test_that("insertInto() on a registered table", { - df <- read.df(sqlContext, jsonPath, "json") + df <- read.df(jsonPath, "json") write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(sqlContext, parquetPath, "parquet") + dfParquet <- read.df(parquetPath, "parquet") lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") writeLines(lines, jsonPath2) - df2 <- read.df(sqlContext, jsonPath2, "json") + df2 <- read.df(jsonPath2, "json") write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") + dfParquet2 <- read.df(parquetPath2, "parquet") - registerTempTable(dfParquet, "table1") + createOrReplaceTempView(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_equal(count(sql(sqlContext, "select * from table1")), 5) - expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") - dropTempTable(sqlContext, "table1") + expect_equal(count(sql("select * from table1")), 5) + expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") + dropTempView("table1") - registerTempTable(dfParquet, "table1") + createOrReplaceTempView(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_equal(count(sql(sqlContext, "select * from table1")), 2) - expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") - dropTempTable(sqlContext, "table1") + expect_equal(count(sql("select * from table1")), 2) + expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") + dropTempView("table1") unlink(jsonPath2) unlink(parquetPath2) }) test_that("tableToDF() returns a new DataFrame", { - df <- read.json(sqlContext, jsonPath) - registerTempTable(df, "table1") - tabledf <- tableToDF(sqlContext, "table1") + df <- read.json(jsonPath) + createOrReplaceTempView(df, "table1") + tabledf <- tableToDF("table1") expect_is(tabledf, "SparkDataFrame") expect_equal(count(tabledf), 3) - tabledf2 <- tableToDF(sqlContext, "table1") + tabledf2 <- tableToDF("table1") expect_equal(count(tabledf2), 3) - dropTempTable(sqlContext, "table1") + dropTempView("table1") }) test_that("toRDD() returns an RRDD", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") - expect_equal(count(testRDD), 3) + expect_equal(countRDD(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) expect_is(unioned, "RDD") expect_equal(getSerializedMode(unioned), "byte") - expect_equal(collect(unioned)[[2]]$name, "Andy") + expect_equal(collectRDD(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -530,48 +653,48 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) expect_is(unionByte, "RDD") expect_equal(getSerializedMode(unionByte), "byte") - expect_equal(collect(unionByte)[[1]], 1) - expect_equal(collect(unionByte)[[12]]$name, "Andy") + expect_equal(collectRDD(unionByte)[[1]], 1) + expect_equal(collectRDD(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) expect_is(unionString, "RDD") expect_equal(getSerializedMode(unionString), "byte") - expect_equal(collect(unionString)[[1]], "Michael") - expect_equal(collect(unionString)[[5]]$name, "Andy") + expect_equal(collectRDD(unionString)[[1]], "Michael") + expect_equal(collectRDD(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) expect_is(objectIn, "RDD") expect_equal(getSerializedMode(objectIn), "byte") - expect_equal(collect(objectIn)[[2]]$age, 30) + expect_equal(collectRDD(objectIn)[[2]]$age, 30) }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 row }) expect_is(testRDD, "RDD") - collected <- collect(testRDD) + collected <- collectRDD(testRDD) expect_equal(collected[[1]]$name, "Michael") expect_equal(collected[[2]]$newCol, 35) }) test_that("collect() returns a data.frame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) expect_equal(names(rdf)[1], "age") @@ -587,20 +710,20 @@ test_that("collect() returns a data.frame", { expect_equal(ncol(rdf), 2) # collect() correctly handles multiple columns with same name - df <- createDataFrame(sqlContext, list(list(1, 2)), schema = c("name", "name")) + df <- createDataFrame(list(list(1, 2)), schema = c("name", "name")) ldf <- collect(df) expect_equal(names(ldf), c("name", "name")) }) test_that("limit() returns DataFrame with the correct number of rows", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) dfLimited <- limit(df, 2) expect_is(dfLimited, "SparkDataFrame") expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_equal(nrow(collect(df)), nrow(take(df, 10))) expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) @@ -614,7 +737,7 @@ test_that("collect() support Unicode characters", { jsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPath) - df <- read.df(sqlContext, jsonPath, "json") + df <- read.df(jsonPath, "json") rdf <- collect(df) expect_true(is.data.frame(rdf)) expect_equal(rdf$name[1], markUtf8("안녕하세요")) @@ -622,12 +745,12 @@ test_that("collect() support Unicode characters", { expect_equal(rdf$name[3], markUtf8("こんにちは")) expect_equal(rdf$name[4], markUtf8("Xin chào")) - df1 <- createDataFrame(sqlContext, rdf) + df1 <- createDataFrame(rdf) expect_equal(collect(where(df1, df1$name == markUtf8("您好")))$name, markUtf8("您好")) }) test_that("multiple pipeline transformations result in an RDD with the correct values", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 row @@ -637,14 +760,14 @@ test_that("multiple pipeline transformations result in an RDD with the correct v row }) expect_is(second, "RDD") - expect_equal(count(second), 3) - expect_equal(collect(second)[[2]]$age, 35) - expect_true(collect(second)[[2]]$testCol) - expect_false(collect(second)[[3]]$testCol) + expect_equal(countRDD(second), 3) + expect_equal(collectRDD(second)[[2]]$age, 35) + expect_true(collectRDD(second)[[2]]$testCol) + expect_false(collectRDD(second)[[3]]$testCol) }) test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_false(df@env$isCached) cache(df) expect_true(df@env$isCached) @@ -663,7 +786,7 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testSchema <- schema(df) expect_equal(length(testSchema$fields()), 2) expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") @@ -684,7 +807,7 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form }) test_that("names() colnames() set the column names", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) names(df) <- c("col1", "col2") expect_equal(colnames(df)[2], "col2") @@ -692,14 +815,14 @@ test_that("names() colnames() set the column names", { expect_equal(names(df)[1], "col3") expect_error(colnames(df) <- c("sepal.length", "sepal_width"), - "Colum names cannot contain the '.' symbol.") + "Column names cannot contain the '.' symbol.") expect_error(colnames(df) <- c(1, 2), "Invalid column names.") expect_error(colnames(df) <- c("a"), "Column names must have the same length as the number of columns in the dataset.") expect_error(colnames(df) <- c("1", NA), "Column names cannot be NA.") # Note: if this test is broken, remove check for "." character on colnames<- method - irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) + irisDF <- suppressWarnings(createDataFrame(iris)) expect_equal(names(irisDF)[1], "Sepal_Length") # Test base::colnames base::names @@ -715,7 +838,7 @@ test_that("names() colnames() set the column names", { }) test_that("head() and first() return the correct data", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testHead <- head(df) expect_equal(nrow(testHead), 3) expect_equal(ncol(testHead), 2) @@ -748,7 +871,7 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { jsonPathWithDup <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPathWithDup) - df <- read.json(sqlContext, jsonPathWithDup) + df <- read.json(jsonPathWithDup) uniques <- distinct(df) expect_is(uniques, "SparkDataFrame") expect_equal(count(uniques), 3) @@ -759,7 +882,6 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { # Test dropDuplicates() df <- createDataFrame( - sqlContext, list( list(2, 1, 2), list(1, 1, 1), list(1, 2, 1), list(2, 1, 2), @@ -785,6 +907,14 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { result[order(result$key, result$value1, result$value2), ], expected) + result <- collect(dropDuplicates(df, "key", "value1")) + expected <- rbind.data.frame( + c(1, 1, 1), c(1, 2, 1), c(2, 1, 2), c(2, 2, 2)) + names(expected) <- c("key", "value1", "value2") + expect_equivalent( + result[order(result$key, result$value1, result$value2), ], + expected) + result <- collect(dropDuplicates(df, "key")) expected <- rbind.data.frame( c(1, 1, 1), c(2, 1, 2)) @@ -795,7 +925,7 @@ test_that("distinct(), unique() and dropDuplicates() on DataFrames", { }) test_that("sample on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_is(sampled, "SparkDataFrame") @@ -817,7 +947,7 @@ test_that("sample on a DataFrame", { }) test_that("select operators", { - df <- select(read.json(sqlContext, jsonPath), "name", "age") + df <- select(read.json(jsonPath), "name", "age") expect_is(df$name, "Column") expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") @@ -846,7 +976,7 @@ test_that("select operators", { }) test_that("select with column", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) df1 <- select(df, "name") expect_equal(columns(df1), c("name")) expect_equal(count(df1), 3) @@ -869,7 +999,7 @@ test_that("select with column", { }) test_that("drop column", { - df <- select(read.json(sqlContext, jsonPath), "name", "age") + df <- select(read.json(jsonPath), "name", "age") df1 <- drop(df, "name") expect_equal(columns(df1), c("age")) @@ -891,7 +1021,7 @@ test_that("drop column", { test_that("subsetting", { # read.json returns columns in random order - df <- select(read.json(sqlContext, jsonPath), "name", "age") + df <- select(read.json(jsonPath), "name", "age") filtered <- df[df$age > 20, ] expect_equal(count(filtered), 1) expect_equal(columns(filtered), c("name", "age")) @@ -928,7 +1058,7 @@ test_that("subsetting", { }) test_that("selectExpr() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) selected <- selectExpr(df, "age * 2") expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) @@ -939,12 +1069,12 @@ test_that("selectExpr() on a DataFrame", { }) test_that("expr() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) }) test_that("column calculation", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) @@ -953,40 +1083,35 @@ test_that("column calculation", { }) test_that("test HiveContext", { - ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - df <- createExternalTable(hiveCtx, "json", jsonPath, "json") + setHiveContext(sc) + df <- createExternalTable("json", jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) - df2 <- sql(hiveCtx, "select * from json") + df2 <- sql("select * from json") expect_is(df2, "SparkDataFrame") expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) - df3 <- sql(hiveCtx, "select * from json2") + df3 <- sql("select * from json2") expect_is(df3, "SparkDataFrame") expect_equal(count(df3), 3) unlink(jsonPath2) hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) - df4 <- sql(hiveCtx, "select * from hivetestbl") + df4 <- sql("select * from hivetestbl") expect_is(df4, "SparkDataFrame") expect_equal(count(df4), 3) unlink(hivetestDataPath) parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") invisible(saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)) - df5 <- sql(hiveCtx, "select * from parquetest") + df5 <- sql("select * from parquetest") expect_is(df5, "SparkDataFrame") expect_equal(count(df5), 3) unlink(parquetDataPath) + unsetHiveContext() }) test_that("column operators", { @@ -1005,8 +1130,8 @@ test_that("column functions", { c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) - c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) - c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + monotonically_increasing_id() c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c) c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) @@ -1017,6 +1142,7 @@ test_that("column functions", { c16 <- is.nan(c) + isnan(c) + isNaN(c) c17 <- cov(c, c1) + cov("c", "c1") + covar_samp(c, c1) + covar_samp("c", "c1") c18 <- covar_pop(c, c1) + covar_pop("c", "c1") + c19 <- spark_partition_id() # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1025,7 +1151,7 @@ test_that("column functions", { expect_equal(class(rank())[[1]], "Column") expect_equal(rank(1:3), as.numeric(c(1:3))) - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) expect_equal(collect(df2)[[2, 1]], TRUE) expect_equal(collect(df2)[[2, 2]], FALSE) @@ -1044,11 +1170,11 @@ test_that("column functions", { expect_true(abs(collect(select(df, stddev(df$age)))[1, 1] - 7.778175) < 1e-6) expect_equal(collect(select(df, var_pop(df$age)))[1, 1], 30.25) - df5 <- createDataFrame(sqlContext, list(list(a = "010101"))) + df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") # Test array_contains() and sort_array() - df <- createDataFrame(sqlContext, list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) + df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) @@ -1061,8 +1187,7 @@ test_that("column functions", { expect_equal(length(lag(ldeaths, 12)), 72) # Test struct() - df <- createDataFrame(sqlContext, - list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) result <- collect(select(df, struct("a", "c"))) expected <- data.frame(row.names = 1:2) @@ -1078,15 +1203,14 @@ test_that("column functions", { # Test encode(), decode() bytes <- as.raw(c(0xe5, 0xa4, 0xa7, 0xe5, 0x8d, 0x83, 0xe4, 0xb8, 0x96, 0xe7, 0x95, 0x8c)) - df <- createDataFrame(sqlContext, - list(list(markUtf8("大千世界"), "utf-8", bytes)), + df <- createDataFrame(list(list(markUtf8("大千世界"), "utf-8", bytes)), schema = c("a", "b", "c")) result <- collect(select(df, encode(df$a, "utf-8"), decode(df$c, "utf-8"))) expect_equal(result[[1]][[1]], bytes) expect_equal(result[[2]], markUtf8("大千世界")) # Test first(), last() - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_equal(collect(select(df, first(df$age)))[[1]], NA) expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30) expect_equal(collect(select(df, first("age")))[[1]], NA) @@ -1097,7 +1221,7 @@ test_that("column functions", { expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19) # Test bround() - df <- createDataFrame(sqlContext, data.frame(x = c(2.5, 3.5))) + df <- createDataFrame(data.frame(x = c(2.5, 3.5))) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) }) @@ -1109,7 +1233,7 @@ test_that("column binary mathfunctions", { "{\"a\":4, \"b\":8}") jsonPathWithDup <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPathWithDup) - df <- read.json(sqlContext, jsonPathWithDup) + df <- read.json(jsonPathWithDup) expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) @@ -1130,10 +1254,17 @@ test_that("column binary mathfunctions", { }) test_that("string operators", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_equal(count(where(df, like(df$name, "A%"))), 1) expect_equal(count(where(df, startsWith(df$name, "A"))), 1) + expect_true(first(select(df, startsWith(df$name, "M")))[[1]]) + expect_false(first(select(df, startsWith(df$name, "m")))[[1]]) + expect_true(first(select(df, endsWith(df$name, "el")))[[1]]) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") + if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { + expect_true(startsWith("Hello World", "Hello")) + expect_false(endsWith("Hello World", "a")) + } expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") expect_equal(collect(select(df, concat(df$name, lit(":"), df$age)))[[2, 1]], "Andy:30") expect_equal(collect(select(df, concat_ws(":", df$name)))[[2, 1]], "Andy") @@ -1150,14 +1281,14 @@ test_that("string operators", { expect_equal(collect(select(df, regexp_replace(df$name, "(n.y)", "ydn")))[2, 1], "Aydn") l2 <- list(list(a = "aaads")) - df2 <- createDataFrame(sqlContext, l2) + df2 <- createDataFrame(l2) expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) - expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) + expect_equal(collect(select(df2, locate("aa", df2$a, 2)))[1, 1], 2) expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") # nolint expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") # nolint l3 <- list(list(a = "a.b.c.d")) - df3 <- createDataFrame(sqlContext, l3) + df3 <- createDataFrame(l3) expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") @@ -1169,7 +1300,7 @@ test_that("date functions on a DataFrame", { l <- list(list(a = 1L, b = as.Date("2012-12-13")), list(a = 2L, b = as.Date("2013-12-14")), list(a = 3L, b = as.Date("2014-12-15"))) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) @@ -1189,19 +1320,19 @@ test_that("date functions on a DataFrame", { l2 <- list(list(a = 1L, b = as.POSIXlt("2012-12-13 12:34:00", tz = "UTC")), list(a = 2L, b = as.POSIXlt("2014-12-15 01:24:34", tz = "UTC"))) - df2 <- createDataFrame(sqlContext, l2) + df2 <- createDataFrame(l2) expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) - expect_more_than(collect(select(df2, unix_timestamp()))[1, 1], 0) - expect_more_than(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) - expect_more_than(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + expect_gt(collect(select(df2, unix_timestamp()))[1, 1], 0) + expect_gt(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) + expect_gt(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) l3 <- list(list(a = 1000), list(a = -1000)) - df3 <- createDataFrame(sqlContext, l3) + df3 <- createDataFrame(l3) result31 <- collect(select(df3, from_unixtime(df3$a))) expect_equal(grep("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", result31[, 1], perl = TRUE), c(1, 2)) @@ -1212,13 +1343,13 @@ test_that("date functions on a DataFrame", { test_that("greatest() and least() on a DataFrame", { l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(select(df, greatest(df$a, df$b)))[, 1], c(2, 4)) expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3)) }) test_that("time windowing (window()) with all inputs", { - df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df <- createDataFrame(data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) df$window <- window(df$t, "5 seconds", "5 seconds", "0 seconds") local <- collect(df)$v # Not checking time windows because of possible time zone issues. Just checking that the function @@ -1227,7 +1358,7 @@ test_that("time windowing (window()) with all inputs", { }) test_that("time windowing (window()) with slide duration", { - df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df <- createDataFrame(data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) df$window <- window(df$t, "5 seconds", "2 seconds") local <- collect(df)$v # Not checking time windows because of possible time zone issues. Just checking that the function @@ -1236,7 +1367,7 @@ test_that("time windowing (window()) with slide duration", { }) test_that("time windowing (window()) with start time", { - df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df <- createDataFrame(data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) df$window <- window(df$t, "5 seconds", startTime = "2 seconds") local <- collect(df)$v # Not checking time windows because of possible time zone issues. Just checking that the function @@ -1245,7 +1376,7 @@ test_that("time windowing (window()) with start time", { }) test_that("time windowing (window()) with just window duration", { - df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df <- createDataFrame(data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) df$window <- window(df$t, "5 seconds") local <- collect(df)$v # Not checking time windows because of possible time zone issues. Just checking that the function @@ -1255,7 +1386,7 @@ test_that("time windowing (window()) with just window duration", { test_that("when(), otherwise() and ifelse() on a DataFrame", { l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, 1)))[, 1], c(NA, 1)) expect_equal(collect(select(df, otherwise(when(df$a > 1, 1), 0)))[, 1], c(0, 1)) expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) @@ -1263,14 +1394,14 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { test_that("when(), otherwise() and ifelse() with column on a DataFrame", { l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) - df <- createDataFrame(sqlContext, l) + df <- createDataFrame(l) expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, lit(1))))[, 1], c(NA, 1)) expect_equal(collect(select(df, otherwise(when(df$a > 1, lit(1)), lit(0))))[, 1], c(0, 1)) expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, lit(0), lit(1))))[, 1], c(1, 0)) }) test_that("group by, agg functions", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) @@ -1315,7 +1446,7 @@ test_that("group by, agg functions", { "{\"name\":\"ID2\", \"value\": \"-3\"}") jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines2, jsonPath2) - gd2 <- groupBy(read.json(sqlContext, jsonPath2), "name") + gd2 <- groupBy(read.json(jsonPath2), "name") df6 <- agg(gd2, value = "sum") df6_local <- collect(df6) expect_equal(42, df6_local[df6_local$name == "ID1", ][1, 2]) @@ -1332,7 +1463,7 @@ test_that("group by, agg functions", { "{\"name\":\"Justin\", \"age\":1}") jsonPath3 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines3, jsonPath3) - df8 <- read.json(sqlContext, jsonPath3) + df8 <- read.json(jsonPath3) gd3 <- groupBy(df8, "name") gd3_local <- collect(sum(gd3)) expect_equal(60, gd3_local[gd3_local$name == "Andy", ][1, 2]) @@ -1350,8 +1481,33 @@ test_that("group by, agg functions", { unlink(jsonPath3) }) +test_that("pivot GroupedData column", { + df <- createDataFrame(data.frame( + earnings = c(10000, 10000, 11000, 15000, 12000, 20000, 21000, 22000), + course = c("R", "Python", "R", "Python", "R", "Python", "R", "Python"), + year = c(2013, 2013, 2014, 2014, 2015, 2015, 2016, 2016) + )) + sum1 <- collect(sum(pivot(groupBy(df, "year"), "course"), "earnings")) + sum2 <- collect(sum(pivot(groupBy(df, "year"), "course", c("Python", "R")), "earnings")) + sum3 <- collect(sum(pivot(groupBy(df, "year"), "course", list("Python", "R")), "earnings")) + sum4 <- collect(sum(pivot(groupBy(df, "year"), "course", "R"), "earnings")) + + correct_answer <- data.frame( + year = c(2013, 2014, 2015, 2016), + Python = c(10000, 15000, 20000, 22000), + R = c(10000, 11000, 12000, 21000) + ) + expect_equal(sum1, correct_answer) + expect_equal(sum2, correct_answer) + expect_equal(sum3, correct_answer) + expect_equal(sum4, correct_answer[, c("year", "R")]) + + expect_error(collect(sum(pivot(groupBy(df, "year"), "course", c("R", "R")), "earnings"))) + expect_error(collect(sum(pivot(groupBy(df, "year"), "course", list("R", "R")), "earnings"))) +}) + test_that("arrange() and orderBy() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) sorted <- arrange(df, df$age) expect_equal(collect(sorted)[1, 2], "Michael") @@ -1377,7 +1533,7 @@ test_that("arrange() and orderBy() on a DataFrame", { }) test_that("filter() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) filtered <- filter(df, "age > 20") expect_equal(count(filtered), 1) expect_equal(collect(filtered)$name, "Andy") @@ -1400,7 +1556,7 @@ test_that("filter() on a DataFrame", { }) test_that("join() and merge() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", @@ -1408,7 +1564,7 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"test\": \"yes\"}") jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines2, jsonPath2) - df2 <- read.json(sqlContext, jsonPath2) + df2 <- read.json(jsonPath2) joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) @@ -1483,7 +1639,7 @@ test_that("join() and merge() on a DataFrame", { "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}") jsonPath3 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLines3, jsonPath3) - df3 <- read.json(sqlContext, jsonPath3) + df3 <- read.json(jsonPath3) expect_error(merge(df, df3), paste("The following column name: name_y occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.", sep = "")) @@ -1493,16 +1649,15 @@ test_that("join() and merge() on a DataFrame", { }) test_that("toJSON() returns an RDD of the correct values", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) testRDD <- toJSON(df) expect_is(testRDD, "RDD") expect_equal(getSerializedMode(testRDD), "string") - expect_equal(collect(testRDD)[[1]], mockLines[1]) + expect_equal(collectRDD(testRDD)[[1]], mockLines[1]) }) test_that("showDF()", { - df <- read.json(sqlContext, jsonPath) - s <- capture.output(showDF(df)) + df <- read.json(jsonPath) expected <- paste("+----+-------+\n", "| age| name|\n", "+----+-------+\n", @@ -1510,28 +1665,29 @@ test_that("showDF()", { "| 30| Andy|\n", "| 19| Justin|\n", "+----+-------+\n", sep = "") - expect_output(s, expected) + expect_output(showDF(df), expected) }) test_that("isLocal()", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_false(isLocal(df)) }) -test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) +test_that("union(), rbind(), except(), and intersect() on a DataFrame", { + df <- read.json(jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(lines, jsonPath2) - df2 <- read.df(sqlContext, jsonPath2, "json") + df2 <- read.df(jsonPath2, "json") - unioned <- arrange(unionAll(df, df2), df$age) + unioned <- arrange(union(df, df2), df$age) expect_is(unioned, "SparkDataFrame") expect_equal(count(unioned), 6) expect_equal(first(unioned)$name, "Michael") + expect_equal(count(arrange(suppressWarnings(unionAll(df, df2)), df$age)), 6) unioned2 <- arrange(rbind(unioned, df, df2), df$age) expect_is(unioned2, "SparkDataFrame") @@ -1548,6 +1704,9 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { expect_equal(count(intersected), 1) expect_equal(first(intersected)$name, "Andy") + # Test base::union is working + expect_equal(union(c(1:3), c(3:5)), c(1:5)) + # Test base::rbind is working expect_equal(length(rbind(1:4, c = 2, a = 10, 10, deparse.level = 0)), 16) @@ -1558,7 +1717,7 @@ test_that("unionAll(), rbind(), except(), and intersect() on a DataFrame", { }) test_that("withColumn() and withColumnRenamed()", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1575,7 +1734,7 @@ test_that("withColumn() and withColumnRenamed()", { }) test_that("mutate(), transform(), rename() and names()", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_equal(length(columns(newDF)), 3) expect_equal(columns(newDF)[3], "newAge") @@ -1622,11 +1781,32 @@ test_that("mutate(), transform(), rename() and names()", { detach(airquality) }) +test_that("read/write ORC files", { + setHiveContext(sc) + df <- read.df(jsonPath, "json") + + # Test write.df and read.df + write.df(df, orcPath, "orc", mode = "overwrite") + df2 <- read.df(orcPath, "orc") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df), count(df2)) + + # Test write.orc and read.orc + orcPath2 <- tempfile(pattern = "orcPath2", fileext = ".orc") + write.orc(df, orcPath2) + orcDF <- read.orc(orcPath2) + expect_is(orcDF, "SparkDataFrame") + expect_equal(count(orcDF), count(df)) + + unlink(orcPath2) + unsetHiveContext() +}) + test_that("read/write Parquet files", { - df <- read.df(sqlContext, jsonPath, "json") + df <- read.df(jsonPath, "json") # Test write.df and read.df write.df(df, parquetPath, "parquet", mode = "overwrite") - df2 <- read.df(sqlContext, parquetPath, "parquet") + df2 <- read.df(parquetPath, "parquet") expect_is(df2, "SparkDataFrame") expect_equal(count(df2), 3) @@ -1635,10 +1815,10 @@ test_that("read/write Parquet files", { write.parquet(df, parquetPath2) parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") suppressWarnings(saveAsParquetFile(df, parquetPath3)) - parquetDF <- read.parquet(sqlContext, c(parquetPath2, parquetPath3)) + parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) expect_is(parquetDF, "SparkDataFrame") expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(sqlContext, parquetPath2, parquetPath3)) + parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) expect_is(parquetDF2, "SparkDataFrame") expect_equal(count(parquetDF2), count(df) * 2) @@ -1655,7 +1835,7 @@ test_that("read/write Parquet files", { test_that("read/write text files", { # Test write.df and read.df - df <- read.df(sqlContext, jsonPath, "text") + df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") expect_equal(colnames(df), c("value")) expect_equal(count(df), 3) @@ -1665,7 +1845,7 @@ test_that("read/write text files", { # Test write.text and read.text textPath2 <- tempfile(pattern = "textPath2", fileext = ".txt") write.text(df, textPath2) - df2 <- read.text(sqlContext, c(textPath, textPath2)) + df2 <- read.text(c(textPath, textPath2)) expect_is(df2, "SparkDataFrame") expect_equal(colnames(df2), c("value")) expect_equal(count(df2), count(df) * 2) @@ -1675,25 +1855,29 @@ test_that("read/write text files", { }) test_that("describe() and summarize() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) stats <- describe(df, "age") expect_equal(collect(stats)[1, "summary"], "count") expect_equal(collect(stats)[2, "age"], "24.5") expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) - expect_equal(collect(stats)[4, "name"], "Andy") + expect_equal(collect(stats)[4, "name"], NULL) expect_equal(collect(stats)[5, "age"], "30") stats2 <- summary(df) - expect_equal(collect(stats2)[4, "name"], "Andy") + expect_equal(collect(stats2)[4, "name"], NULL) expect_equal(collect(stats2)[5, "age"], "30") + # SPARK-16425: SparkR summary() fails on column of type logical + df <- withColumn(df, "boolean", df$age == 30) + summary(df) + # Test base::summary is working expect_equal(length(summary(attenu, digits = 4)), 35) }) test_that("dropna() and na.omit() on a DataFrame", { - df <- read.json(sqlContext, jsonPathNa) + df <- read.json(jsonPathNa) rows <- collect(df) # drop with columns @@ -1779,7 +1963,7 @@ test_that("dropna() and na.omit() on a DataFrame", { }) test_that("fillna() on a DataFrame", { - df <- read.json(sqlContext, jsonPathNa) + df <- read.json(jsonPathNa) rows <- collect(df) # fill with value @@ -1830,7 +2014,7 @@ test_that("crosstab() on a DataFrame", { test_that("cov() and corr() on a DataFrame", { l <- lapply(c(0:9), function(x) { list(x, x * 2.0) }) - df <- createDataFrame(sqlContext, l, c("singles", "doubles")) + df <- createDataFrame(l, c("singles", "doubles")) result <- cov(df, "singles", "doubles") expect_true(abs(result - 55.0 / 3) < 1e-12) @@ -1848,7 +2032,7 @@ test_that("freqItems() on a DataFrame", { rdf <- data.frame(numbers = input, letters = as.character(input), negDoubles = input * -1.0, stringsAsFactors = F) rdf[ input %% 3 == 0, ] <- c(1, "1", -1) - df <- createDataFrame(sqlContext, rdf) + df <- createDataFrame(rdf) multiColResults <- freqItems(df, c("numbers", "letters"), support = 0.1) expect_true(1 %in% multiColResults$numbers[[1]]) expect_true("1" %in% multiColResults$letters[[1]]) @@ -1858,7 +2042,7 @@ test_that("freqItems() on a DataFrame", { l <- lapply(c(0:99), function(i) { if (i %% 2 == 0) { list(1L, -1.0) } else { list(i, i * -1.0) }}) - df <- createDataFrame(sqlContext, l, c("a", "b")) + df <- createDataFrame(l, c("a", "b")) result <- freqItems(df, c("a", "b"), 0.4) expect_identical(result[[1]], list(list(1L, 99L))) expect_identical(result[[2]], list(list(-1, -99))) @@ -1866,7 +2050,7 @@ test_that("freqItems() on a DataFrame", { test_that("sampleBy() on a DataFrame", { l <- lapply(c(0:99), function(i) { as.character(i %% 3) }) - df <- createDataFrame(sqlContext, l, "key") + df <- createDataFrame(l, "key") fractions <- list("0" = 0.1, "1" = 0.2) sample <- sampleBy(df, "key", fractions, 0) result <- collect(orderBy(count(groupBy(sample, "key")), "key")) @@ -1876,19 +2060,19 @@ test_that("sampleBy() on a DataFrame", { test_that("approxQuantile() on a DataFrame", { l <- lapply(c(0:99), function(i) { i }) - df <- createDataFrame(sqlContext, l, "key") + df <- createDataFrame(l, "key") quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) expect_equal(quantiles[[1]], 50) expect_equal(quantiles[[2]], 80) }) test_that("SQL error message is returned from JVM", { - retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) }) -irisDF <- suppressWarnings(createDataFrame(sqlContext, iris)) +irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { expect_equal(as.data.frame(irisDF), collect(irisDF)) @@ -1896,11 +2080,11 @@ test_that("Method as.data.frame as a synonym for collect()", { expect_equal(as.data.frame(irisDF2), collect(irisDF2)) # Make sure as.data.frame in the R base package is not covered - expect_that(as.data.frame(c(1, 2)), not(throws_error())) + expect_error(as.data.frame(c(1, 2)), NA) }) test_that("attach() on a DataFrame", { - df <- read.json(sqlContext, jsonPath) + df <- read.json(jsonPath) expect_error(age) attach(df) expect_is(age, "SparkDataFrame") @@ -1920,7 +2104,7 @@ test_that("attach() on a DataFrame", { }) test_that("with() on a DataFrame", { - df <- suppressWarnings(createDataFrame(sqlContext, iris)) + df <- suppressWarnings(createDataFrame(iris)) expect_error(Sepal_Length) sum1 <- with(df, list(summary(Sepal_Length), summary(Sepal_Width))) expect_equal(collect(sum1[[1]])[1, "Sepal_Length"], "150") @@ -1940,15 +2124,18 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", { structField("c4", "timestamp")) # Test primitive types - DF <- createDataFrame(sqlContext, data, schema) + DF <- createDataFrame(data, schema) expect_equal(coltypes(DF), c("integer", "logical", "POSIXct")) + createOrReplaceTempView(DF, "DFView") + sqlCast <- sql("select cast('2' as decimal) as x from DFView limit 1") + expect_equal(coltypes(sqlCast), "numeric") # Test complex types - x <- createDataFrame(sqlContext, list(list(as.environment( + x <- createDataFrame(list(list(as.environment( list("a" = "b", "c" = "d", "e" = "f"))))) expect_equal(coltypes(x), "map") - df <- selectExpr(read.json(sqlContext, jsonPath), "name", "(age * 1.21) as age") + df <- selectExpr(read.json(jsonPath), "name", "(age * 1.21) as age") expect_equal(dtypes(df), list(c("name", "string"), c("age", "decimal(24,2)"))) df1 <- select(df, cast(df$age, "integer")) @@ -1972,7 +2159,7 @@ test_that("Method str()", { iris2 <- iris colnames(iris2) <- c("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width", "Species") iris2$col <- TRUE - irisDF2 <- createDataFrame(sqlContext, iris2) + irisDF2 <- createDataFrame(iris2) out <- capture.output(str(irisDF2)) expect_equal(length(out), 7) @@ -1985,12 +2172,20 @@ test_that("Method str()", { "setosa\" \"setosa\" \"setosa\" \"setosa\"")) expect_equal(out[7], " $ col : logi TRUE TRUE TRUE TRUE TRUE TRUE") + createOrReplaceTempView(irisDF2, "irisView") + + sqlCast <- sql("select cast('2' as decimal) as x from irisView limit 1") + castStr <- capture.output(str(sqlCast)) + expect_equal(length(castStr), 2) + expect_equal(castStr[1], "'SparkDataFrame': 1 variables:") + expect_equal(castStr[2], " $ x: num 2") + # A random dataset with many columns. This test is to check str limits # the number of columns. Therefore, it will suffice to check for the # number of returned rows x <- runif(200, 1, 10) df <- data.frame(t(as.matrix(data.frame(x, x, x, x, x, x, x, x, x)))) - DF <- createDataFrame(sqlContext, df) + DF <- createDataFrame(df) out <- capture.output(str(DF)) expect_equal(length(out), 103) @@ -2040,13 +2235,12 @@ test_that("Histogram", { histogram(irisDF, "Sepal_Width", 12)$counts), T) # Test when there are zero counts - df <- as.DataFrame(sqlContext, data.frame(x = c(1, 2, 3, 4, 100))) + df <- as.DataFrame(data.frame(x = c(1, 2, 3, 4, 100))) expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1)) }) -test_that("dapply() on a DataFrame", { - df <- createDataFrame ( - sqlContext, +test_that("dapply() and dapplyCollect() on a DataFrame", { + df <- createDataFrame( list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")), c("a", "b", "c")) ldf <- collect(df) @@ -2054,6 +2248,8 @@ test_that("dapply() on a DataFrame", { result <- collect(df1) expect_identical(ldf, result) + result <- dapplyCollect(df, function(x) { x }) + expect_identical(ldf, result) # Filter and add a column schema <- structType(structField("a", "integer"), structField("b", "double"), @@ -2071,6 +2267,16 @@ test_that("dapply() on a DataFrame", { rownames(expected) <- NULL expect_identical(expected, result) + result <- dapplyCollect( + df, + function(x) { + y <- x[x$a > 1, ] + y <- cbind(y, y$a + 1L) + }) + expected1 <- expected + names(expected1) <- names(result) + expect_identical(expected1, result) + # Remove the added column df2 <- dapply( df1, @@ -2081,8 +2287,280 @@ test_that("dapply() on a DataFrame", { result <- collect(df2) expected <- expected[, c("a", "b", "c")] expect_identical(expected, result) + + result <- dapplyCollect( + df1, + function(x) { + x[, c("a", "b", "c")] + }) + expect_identical(expected, result) +}) + +test_that("dapplyCollect() on DataFrame with a binary column", { + + df <- data.frame(key = 1:3) + df$bytes <- lapply(df$key, serialize, connection = NULL) + + df_spark <- createDataFrame(df) + + result1 <- collect(df_spark) + expect_identical(df, result1) + + result2 <- dapplyCollect(df_spark, function(x) x) + expect_identical(df, result2) + + # A data.frame with a single column of bytes + scb <- subset(df, select = "bytes") + scb_spark <- createDataFrame(scb) + result <- dapplyCollect(scb_spark, function(x) x) + expect_identical(scb, result) + +}) + +test_that("repartition by columns on DataFrame", { + df <- createDataFrame( + list(list(1L, 1, "1", 0.1), list(1L, 2, "2", 0.2), list(3L, 3, "3", 0.3)), + c("a", "b", "c", "d")) + + # no column and number of partitions specified + retError <- tryCatch(repartition(df), error = function(e) e) + expect_equal(grepl + ("Please, specify the number of partitions and/or a column\\(s\\)", retError), TRUE) + + # repartition by column and number of partitions + actual <- repartition(df, 3L, col = df$"a") + + # since we cannot access the number of partitions from dataframe, checking + # that at least the dimensions are identical + expect_identical(dim(df), dim(actual)) + + # repartition by number of partitions + actual <- repartition(df, 13L) + expect_identical(dim(df), dim(actual)) + + # a test case with a column and dapply + schema <- structType(structField("a", "integer"), structField("avg", "double")) + df <- repartition(df, col = df$"a") + df1 <- dapply( + df, + function(x) { + y <- (data.frame(x$a[1], mean(x$b))) + }, + schema) + + # Number of partitions is equal to 2 + expect_equal(nrow(df1), 2) +}) + +test_that("gapply() and gapplyCollect() on a DataFrame", { + df <- createDataFrame ( + list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), + c("a", "b", "c", "d")) + expected <- collect(df) + df1 <- gapply(df, "a", function(key, x) { x }, schema(df)) + actual <- collect(df1) + expect_identical(actual, expected) + + df1Collect <- gapplyCollect(df, list("a"), function(key, x) { x }) + expect_identical(df1Collect, expected) + + # Computes the sum of second column by grouping on the first and third columns + # and checks if the sum is larger than 2 + schema <- structType(structField("a", "integer"), structField("e", "boolean")) + df2 <- gapply( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + }, + schema) + actual <- collect(df2)$e + expected <- c(TRUE, TRUE) + expect_identical(actual, expected) + + df2Collect <- gapplyCollect( + df, + c(df$"a", df$"c"), + function(key, x) { + y <- data.frame(key[1], sum(x$b) > 2) + colnames(y) <- c("a", "e") + y + }) + actual <- df2Collect$e + expect_identical(actual, expected) + + # Computes the arithmetic mean of the second column by grouping + # on the first and third columns. Output the groupping value and the average. + schema <- structType(structField("a", "integer"), structField("c", "string"), + structField("avg", "double")) + df3 <- gapply( + df, + c("a", "c"), + function(key, x) { + y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) + }, + schema) + actual <- collect(df3) + actual <- actual[order(actual$a), ] + rownames(actual) <- NULL + expected <- collect(select(df, "a", "b", "c")) + expected <- data.frame(aggregate(expected$b, by = list(expected$a, expected$c), FUN = mean)) + colnames(expected) <- c("a", "c", "avg") + expected <- expected[order(expected$a), ] + rownames(expected) <- NULL + expect_identical(actual, expected) + + df3Collect <- gapplyCollect( + df, + c("a", "c"), + function(key, x) { + y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE) + colnames(y) <- c("a", "c", "avg") + y + }) + actual <- df3Collect[order(df3Collect$a), ] + expect_identical(actual$avg, expected$avg) + + irisDF <- suppressWarnings(createDataFrame (iris)) + schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double")) + # Groups by `Sepal_Length` and computes the average for `Sepal_Width` + df4 <- gapply( + cols = "Sepal_Length", + irisDF, + function(key, x) { + y <- data.frame(key, mean(x$Sepal_Width), stringsAsFactors = FALSE) + }, + schema) + actual <- collect(df4) + actual <- actual[order(actual$Sepal_Length), ] + rownames(actual) <- NULL + agg_local_df <- data.frame(aggregate(iris$Sepal.Width, by = list(iris$Sepal.Length), FUN = mean), + stringsAsFactors = FALSE) + colnames(agg_local_df) <- c("Sepal_Length", "Avg") + expected <- agg_local_df[order(agg_local_df$Sepal_Length), ] + rownames(expected) <- NULL + expect_identical(actual, expected) +}) + +test_that("Window functions on a DataFrame", { + df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), + schema = c("key", "value")) + ws <- orderBy(windowPartitionBy("key"), "value") + result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) + names(result) <- c("key", "value") + expected <- data.frame(key = c(1L, NA, 2L, NA), + value = c("1", NA, "2", NA), + stringsAsFactors = FALSE) + expect_equal(result, expected) + + ws <- orderBy(windowPartitionBy(df$key), df$value) + result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) + names(result) <- c("key", "value") + expect_equal(result, expected) + + ws <- partitionBy(windowOrderBy("value"), "key") + result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) + names(result) <- c("key", "value") + expect_equal(result, expected) + + ws <- partitionBy(windowOrderBy(df$value), df$key) + result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) + names(result) <- c("key", "value") + expect_equal(result, expected) +}) + +test_that("createDataFrame sqlContext parameter backward compatibility", { + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) + a <- 1:3 + b <- c("a", "b", "c") + ldf <- data.frame(a, b) + # Call function with namespace :: operator - SPARK-16538 + df <- suppressWarnings(SparkR::createDataFrame(sqlContext, ldf)) + expect_equal(columns(df), c("a", "b")) + expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) + expect_equal(count(df), 3) + ldf2 <- collect(df) + expect_equal(ldf$a, ldf2$a) + + df2 <- suppressWarnings(createDataFrame(sqlContext, iris)) + expect_equal(count(df2), 150) + expect_equal(ncol(df2), 5) + + df3 <- suppressWarnings(read.df(sqlContext, jsonPath, "json")) + expect_is(df3, "SparkDataFrame") + expect_equal(count(df3), 3) + + before <- suppressWarnings(createDataFrame(sqlContext, iris)) + after <- suppressWarnings(createDataFrame(iris)) + expect_equal(collect(before), collect(after)) + + # more tests for SPARK-16538 + createOrReplaceTempView(df, "table") + SparkR::tables() + SparkR::sql("SELECT 1") + suppressWarnings(SparkR::sql(sqlContext, "SELECT * FROM table")) + suppressWarnings(SparkR::dropTempTable(sqlContext, "table")) +}) + +test_that("randomSplit", { + num <- 4000 + df <- createDataFrame(data.frame(id = 1:num)) + weights <- c(2, 3, 5) + df_list <- randomSplit(df, weights) + expect_equal(length(weights), length(df_list)) + counts <- sapply(df_list, count) + expect_equal(num, sum(counts)) + expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) + + df_list <- randomSplit(df, weights, 0) + expect_equal(length(weights), length(df_list)) + counts <- sapply(df_list, count) + expect_equal(num, sum(counts)) + expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) +}) + +test_that("Setting and getting config on SparkSession", { + # first, set it to a random but known value + conf <- callJMethod(sparkSession, "conf") + property <- paste0("spark.testing.", as.character(runif(1))) + value1 <- as.character(runif(1)) + callJMethod(conf, "set", property, value1) + + # next, change the same property to the new value + value2 <- as.character(runif(1)) + l <- list(value2) + names(l) <- property + sparkR.session(sparkConfig = l) + + newValue <- unlist(sparkR.conf(property, ""), use.names = FALSE) + expect_equal(value2, newValue) + + value <- as.character(runif(1)) + sparkR.session(spark.app.name = "sparkSession test", spark.testing.r.session.r = value) + allconf <- sparkR.conf() + appNameValue <- allconf[["spark.app.name"]] + testValue <- allconf[["spark.testing.r.session.r"]] + expect_equal(appNameValue, "sparkSession test") + expect_equal(testValue, value) + expect_error(sparkR.conf("completely.dummy"), "Config 'completely.dummy' is not set") +}) + +test_that("enableHiveSupport on SparkSession", { + setHiveContext(sc) + unsetHiveContext() + # if we are still here, it must be built with hive + conf <- callJMethod(sparkSession, "conf") + value <- callJMethod(conf, "get", "spark.sql.catalogImplementation", "") + expect_equal(value, "hive") +}) + +test_that("Spark version from SparkSession", { + ver <- callJMethod(sc, "version") + version <- sparkR.version() + expect_equal(ver, version) }) unlink(parquetPath) +unlink(orcPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index c2c724cdc762f..dcf479363b9a8 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -30,37 +30,38 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -jsc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { - numVectorRDD <- parallelize(jsc, numVector, 10) + numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition - expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) + expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) # case: number of elements to take is the same as the size of the first partition - expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11))) + expect_equal(takeRDD(numVectorRDD, 11), as.list(head(numVector, n = 11))) # case: number of elements to take is greater than all elements - expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) - expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) + expect_equal(takeRDD(numVectorRDD, length(numVector)), as.list(numVector)) + expect_equal(takeRDD(numVectorRDD, length(numVector) + 1), as.list(numVector)) - numListRDD <- parallelize(jsc, numList, 1) - numListRDD2 <- parallelize(jsc, numList, 4) - expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) - expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) - expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) - expect_equal(take(numListRDD2, 999), numList) + numListRDD <- parallelize(sc, numList, 1) + numListRDD2 <- parallelize(sc, numList, 4) + expect_equal(takeRDD(numListRDD, 3), takeRDD(numListRDD2, 3)) + expect_equal(takeRDD(numListRDD, 5), takeRDD(numListRDD2, 5)) + expect_equal(takeRDD(numListRDD, 1), as.list(head(numList, n = 1))) + expect_equal(takeRDD(numListRDD2, 999), numList) - strVectorRDD <- parallelize(jsc, strVector, 2) - strVectorRDD2 <- parallelize(jsc, strVector, 3) - expect_equal(take(strVectorRDD, 4), as.list(strVector)) - expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) + strVectorRDD <- parallelize(sc, strVector, 2) + strVectorRDD2 <- parallelize(sc, strVector, 3) + expect_equal(takeRDD(strVectorRDD, 4), as.list(strVector)) + expect_equal(takeRDD(strVectorRDD2, 2), as.list(head(strVector, n = 2))) - strListRDD <- parallelize(jsc, strList, 4) - strListRDD2 <- parallelize(jsc, strList, 1) - expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) - expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) + strListRDD <- parallelize(sc, strList, 4) + strListRDD2 <- parallelize(sc, strList, 1) + expect_equal(takeRDD(strListRDD, 3), as.list(head(strList, n = 3))) + expect_equal(takeRDD(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_equal(length(take(strListRDD, 0)), 0) - expect_equal(length(take(strVectorRDD, 0)), 0) - expect_equal(length(take(numListRDD, 0)), 0) - expect_equal(length(take(numVectorRDD, 0)), 0) + expect_equal(length(takeRDD(strListRDD, 0)), 0) + expect_equal(length(takeRDD(strVectorRDD, 0)), 0) + expect_equal(length(takeRDD(numListRDD, 0)), 0) + expect_equal(length(takeRDD(numVectorRDD, 0)), 0) }) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index e64ef1bb31a3a..ba434a5d4127b 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -18,7 +18,8 @@ context("the textFile() function") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") @@ -28,8 +29,8 @@ test_that("textFile() on a local file returns an RDD", { rdd <- textFile(sc, fileName) expect_is(rdd, "RDD") - expect_true(count(rdd) > 0) - expect_equal(count(rdd), 2) + expect_true(countRDD(rdd) > 0) + expect_equal(countRDD(rdd), 2) unlink(fileName) }) @@ -39,7 +40,7 @@ test_that("textFile() followed by a collect() returns the same content", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_equal(collect(rdd), as.list(mockFile)) + expect_equal(collectRDD(rdd), as.list(mockFile)) unlink(fileName) }) @@ -54,7 +55,7 @@ test_that("textFile() word count works as expected", { wordCount <- lapply(words, function(word) { list(word, 1L) }) counts <- reduceByKey(wordCount, "+", 2L) - output <- collect(counts) + output <- collectRDD(counts) expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1), list("Spark", 2)) expect_equal(sortKeyValueList(output), sortKeyValueList(expected)) @@ -71,7 +72,7 @@ test_that("several transformations on RDD created by textFile()", { # PipelinedRDD initially created from RDD rdd <- lapply(rdd, function(x) paste(x, x)) } - collect(rdd) + collectRDD(rdd) unlink(fileName) }) @@ -84,7 +85,7 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", rdd <- textFile(sc, fileName1, 1L) saveAsTextFile(rdd, fileName2) rdd <- textFile(sc, fileName2) - expect_equal(collect(rdd), as.list(mockFile)) + expect_equal(collectRDD(rdd), as.list(mockFile)) unlink(fileName1) unlink(fileName2) @@ -96,7 +97,7 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { rdd <- parallelize(sc, l, 1L) saveAsTextFile(rdd, fileName) rdd <- textFile(sc, fileName) - expect_equal(collect(rdd), lapply(l, function(x) {toString(x)})) + expect_equal(collectRDD(rdd), lapply(l, function(x) {toString(x)})) unlink(fileName) }) @@ -116,7 +117,7 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { saveAsTextFile(counts, fileName2) rdd <- textFile(sc, fileName2) - output <- collect(rdd) + output <- collectRDD(rdd) expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1), list("is", 2)) expectedStr <- lapply(expected, function(x) { toString(x) }) @@ -133,7 +134,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_equal(count(rdd), 2) + expect_equal(countRDD(rdd), 2) unlink(fileName1) unlink(fileName2) @@ -146,16 +147,16 @@ test_that("Pipelined operations on RDDs created using textFile", { rdd <- textFile(sc, fileName) lengths <- lapply(rdd, function(x) { length(x) }) - expect_equal(collect(lengths), list(1, 1)) + expect_equal(collectRDD(lengths), list(1, 1)) lengthsPipelined <- lapply(lengths, function(x) { x + 10 }) - expect_equal(collect(lengthsPipelined), list(11, 11)) + expect_equal(collectRDD(lengthsPipelined), list(11, 11)) lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 }) - expect_equal(collect(lengths30), list(31, 31)) + expect_equal(collectRDD(lengths30), list(31, 31)) lengths20 <- lapply(lengths, function(x) { x + 20 }) - expect_equal(collect(lengths20), list(21, 21)) + expect_equal(collectRDD(lengths20), list(21, 21)) unlink(fileName) }) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 01694ab5c4f61..dd5001b1599ae 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -18,12 +18,13 @@ context("functions in utils.R") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { # It's hard to manually create a Java List using rJava, since it does not - # support generics well. Instead, we rely on collect() returning a + # support generics well. Instead, we rely on collectRDD() returning a # JList. nums <- as.list(1:10) rdd <- parallelize(sc, nums, 1L) @@ -47,7 +48,7 @@ test_that("serializeToBytes on RDD", { text.rdd <- textFile(sc, fileName) expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) - expect_equal(collect(ser.rdd), as.list(mockFile)) + expect_equal(collectRDD(ser.rdd), as.list(mockFile)) expect_equal(getSerializedMode(ser.rdd), "byte") unlink(fileName) @@ -127,7 +128,7 @@ test_that("cleanClosure on R functions", { env <- environment(newF) expect_equal(ls(env), "t") expect_equal(get("t", envir = env, inherits = FALSE), t) - actual <- collect(lapply(rdd, f)) + actual <- collectRDD(lapply(rdd, f)) expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6))) expect_equal(actual, expected) @@ -164,3 +165,44 @@ test_that("convertToJSaveMode", { expect_error(convertToJSaveMode("foo"), 'mode should be one of "append", "overwrite", "error", "ignore"') #nolint }) + +test_that("hashCode", { + expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) +}) + +test_that("overrideEnvs", { + config <- new.env() + config[["spark.master"]] <- "foo" + config[["config_only"]] <- "ok" + param <- new.env() + param[["spark.master"]] <- "local" + param[["param_only"]] <- "blah" + overrideEnvs(config, param) + expect_equal(config[["spark.master"]], "local") + expect_equal(config[["param_only"]], "blah") + expect_equal(config[["config_only"]], "ok") +}) + +test_that("rbindRaws", { + + # Mixed Column types + r <- serialize(1:5, connection = NULL) + r1 <- serialize(1, connection = NULL) + r2 <- serialize(letters, connection = NULL) + r3 <- serialize(1:10, connection = NULL) + inputData <- list(list(1L, r1, "a", r), list(2L, r2, "b", r), + list(3L, r3, "c", r)) + expected <- data.frame(V1 = 1:3) + expected$V2 <- list(r1, r2, r3) + expected$V3 <- c("a", "b", "c") + expected$V4 <- list(r, r, r) + result <- rbindRaws(inputData) + expect_equal(expected, result) + + # Single binary column + input <- list(list(r1), list(r2), list(r3)) + expected <- subset(expected, select = "V2") + result <- setNames(rbindRaws(input), "V2") + expect_equal(expected, result) + +}) diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index f55beac6c8c07..b92e6be995ca9 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -44,7 +44,7 @@ while (TRUE) { if (inherits(p, "masterProcess")) { close(inputCon) Sys.setenv(SPARKR_WORKER_PORT = port) - source(script) + try(source(script)) # Set SIGUSR1 so that child can exit tools::pskill(Sys.getpid(), tools::SIGUSR1) parallel:::mcexit(0L) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 40cda0c5ef9c1..cfe41ded200c2 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -27,6 +27,61 @@ elapsedSecs <- function() { proc.time()[3] } +compute <- function(mode, partition, serializer, deserializer, key, + colNames, computeFunc, inputData) { + if (mode > 0) { + if (deserializer == "row") { + # Transform the list of rows into a data.frame + # Note that the optional argument stringsAsFactors for rbind is + # available since R 3.2.4. So we set the global option here. + oldOpt <- getOption("stringsAsFactors") + options(stringsAsFactors = FALSE) + + # Handle binary data types + if ("raw" %in% sapply(inputData[[1]], class)) { + inputData <- SparkR:::rbindRaws(inputData) + } else { + inputData <- do.call(rbind.data.frame, inputData) + } + + options(stringsAsFactors = oldOpt) + + names(inputData) <- colNames + } else { + # Check to see if inputData is a valid data.frame + stopifnot(deserializer == "byte") + stopifnot(class(inputData) == "data.frame") + } + + if (mode == 2) { + output <- computeFunc(key, inputData) + } else { + output <- computeFunc(inputData) + } + if (serializer == "row") { + # Transform the result data.frame back to a list of rows + output <- split(output, seq(nrow(output))) + } else { + # Serialize the ouput to a byte array + stopifnot(serializer == "byte") + } + } else { + output <- computeFunc(partition, inputData) + } + return (output) +} + +outputResult <- function(serializer, output, outputCon) { + if (serializer == "byte") { + SparkR:::writeRawSerialize(outputCon, output) + } else if (serializer == "row") { + SparkR:::writeRowSerialize(outputCon, output) + } else { + # write lines one-by-one with flag + lapply(output, function(line) SparkR:::writeString(outputCon, line)) + } +} + # Constants specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L) @@ -79,75 +134,71 @@ if (numBroadcastVars > 0) { # Timing broadcast broadcastElap <- elapsedSecs() +# Initial input timing +inputElap <- broadcastElap # If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int # as number of partitions to create. numPartitions <- SparkR:::readInt(inputCon) -isDataFrame <- as.logical(SparkR:::readInt(inputCon)) +# 0 - RDD mode, 1 - dapply mode, 2 - gapply mode +mode <- SparkR:::readInt(inputCon) -# If isDataFrame, then read column names -if (isDataFrame) { +if (mode > 0) { colNames <- SparkR:::readObject(inputCon) } isEmpty <- SparkR:::readInt(inputCon) +computeInputElapsDiff <- 0 +outputComputeElapsDiff <- 0 if (isEmpty != 0) { - if (numPartitions == -1) { if (deserializer == "byte") { # Now read as many characters as described in funcLen data <- SparkR:::readDeserialize(inputCon) } else if (deserializer == "string") { data <- as.list(readLines(inputCon)) + } else if (deserializer == "row" && mode == 2) { + dataWithKeys <- SparkR:::readMultipleObjectsWithKeys(inputCon) + keys <- dataWithKeys$keys + data <- dataWithKeys$data } else if (deserializer == "row") { data <- SparkR:::readMultipleObjects(inputCon) } + # Timing reading input data for execution inputElap <- elapsedSecs() - - if (isDataFrame) { - if (deserializer == "row") { - # Transform the list of rows into a data.frame - # Note that the optional argument stringsAsFactors for rbind is - # available since R 3.2.4. So we set the global option here. - oldOpt <- getOption("stringsAsFactors") - options(stringsAsFactors = FALSE) - data <- do.call(rbind.data.frame, data) - options(stringsAsFactors = oldOpt) - - names(data) <- colNames - } else { - # Check to see if data is a valid data.frame - stopifnot(deserializer == "byte") - stopifnot(class(data) == "data.frame") - } - output <- computeFunc(data) - if (serializer == "row") { - # Transform the result data.frame back to a list of rows - output <- split(output, seq(nrow(output))) - } else { - # Serialize the ouput to a byte array - stopifnot(serializer == "byte") + if (mode > 0) { + if (mode == 1) { + output <- compute(mode, partition, serializer, deserializer, NULL, + colNames, computeFunc, data) + } else { + # gapply mode + for (i in 1:length(data)) { + # Timing reading input data for execution + inputElap <- elapsedSecs() + output <- compute(mode, partition, serializer, deserializer, keys[[i]], + colNames, computeFunc, data[[i]]) + computeElap <- elapsedSecs() + outputResult(serializer, output, outputCon) + outputElap <- elapsedSecs() + computeInputElapsDiff <- computeInputElapsDiff + (computeElap - inputElap) + outputComputeElapsDiff <- outputComputeElapsDiff + (outputElap - computeElap) + } } } else { - output <- computeFunc(partition, data) + output <- compute(mode, partition, serializer, deserializer, NULL, + colNames, computeFunc, data) } - - # Timing computing - computeElap <- elapsedSecs() - - if (serializer == "byte") { - SparkR:::writeRawSerialize(outputCon, output) - } else if (serializer == "row") { - SparkR:::writeRowSerialize(outputCon, output) - } else { - # write lines one-by-one with flag - lapply(output, function(line) SparkR:::writeString(outputCon, line)) + if (mode != 2) { + # Not a gapply mode + computeElap <- elapsedSecs() + outputResult(serializer, output, outputCon) + outputElap <- elapsedSecs() + computeInputElapsDiff <- computeElap - inputElap + outputComputeElapsDiff <- outputElap - computeElap } - # Timing output - outputElap <- elapsedSecs() } else { if (deserializer == "byte") { # Now read as many characters as described in funcLen @@ -189,11 +240,9 @@ if (isEmpty != 0) { } # Timing output outputElap <- elapsedSecs() + computeInputElapsDiff <- computeElap - inputElap + outputComputeElapsDiff <- outputElap - computeElap } -} else { - inputElap <- broadcastElap - computeElap <- broadcastElap - outputElap <- broadcastElap } # Report timing @@ -202,8 +251,8 @@ SparkR:::writeDouble(outputCon, bootTime) SparkR:::writeDouble(outputCon, initElap - bootElap) # init SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input -SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute -SparkR:::writeDouble(outputCon, outputElap - computeElap) # output +SparkR:::writeDouble(outputCon, computeInputElapsDiff) # compute +SparkR:::writeDouble(outputCon, outputComputeElapsDiff) # output # End of output SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd new file mode 100644 index 0000000000000..babfb71c4af49 --- /dev/null +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -0,0 +1,652 @@ +--- +title: "SparkR - Practical Guide" +output: + html_document: + theme: united + toc: true + toc_depth: 4 + toc_float: true + highlight: textmate +--- + +## Overview + +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). + +## Getting Started + +We begin with an example running on the local machine and provide an overview of the use of SparkR: data ingestion, data processing and machine learning. + +First, let's load and attach the package. +```{r, message=FALSE} +library(SparkR) +``` + +`SparkSession` is the entry point into SparkR which connects your R program to a Spark cluster. You can create a `SparkSession` using `sparkR.session` and pass in options such as the application name, any Spark packages depended on, etc. + +We use default settings in which it runs in local mode. It auto downloads Spark package in the background if no previous installation is found. For more details about setup, see [Spark Session](#SetupSparkSession). + +```{r, message=FALSE, results="hide"} +sparkR.session() +``` + +The operations in SparkR are centered around an R class called `SparkDataFrame`. It is a distributed collection of data organized into named columns, which is conceptually equivalent to a table in a relational database or a data frame in R, but with richer optimizations under the hood. + +`SparkDataFrame` can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing local R data frames. For example, we create a `SparkDataFrame` from a local R data frame, + +```{r} +cars <- cbind(model = rownames(mtcars), mtcars) +carsDF <- createDataFrame(cars) +``` + +We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` function. +```{r} +head(carsDF) +``` + +Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +```{r} +carsSubDF <- select(carsDF, "model", "mpg", "hp") +carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) +head(carsSubDF) +``` + +SparkR can use many common aggregation functions after grouping. + +```{r} +carsGPDF <- summarize(groupBy(carsDF, carsDF$gear), count = n(carsDF$gear)) +head(carsGPDF) +``` + +The results `carsDF` and `carsSubDF` are `SparkDataFrame` objects. To convert back to R `data.frame`, we can use `collect`. **Caution**: This can cause your interactive environment to run out of memory, though, because `collect()` fetches the entire distributed `DataFrame` to your client, which is acting as a Spark driver. +```{r} +carsGP <- collect(carsGPDF) +class(carsGP) +``` + +SparkR supports a number of commonly used machine learning algorithms. Under the hood, SparkR uses MLlib to train the model. Users can call `summary` to print a summary of the fitted model, `predict` to make predictions on new data, and `write.ml`/`read.ml` to save/load fitted models. + +SparkR supports a subset of R formula operators for model fitting, including ‘~’, ‘.’, ‘:’, ‘+’, and ‘-‘. We use linear regression as an example. +```{r} +model <- spark.glm(carsDF, mpg ~ wt + cyl) +``` + +The result matches that returned by R `glm` function applied to the corresponding `data.frame` `mtcars` of `carsDF`. In fact, for Generalized Linear Model, we specifically expose `glm` for `SparkDataFrame` as well so that the above is equivalent to `model <- glm(mpg ~ wt + cyl, data = carsDF)`. + +```{r} +summary(model) +``` + +The model can be saved by `write.ml` and loaded back using `read.ml`. +```{r, eval=FALSE} +write.ml(model, path = "/HOME/tmp/mlModel/glmModel") +``` + +In the end, we can stop Spark Session by running +```{r, eval=FALSE} +sparkR.session.stop() +``` + +## Setup + +### Installation + +Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. + +If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). Alternatively, we provide an easy-to-use function `install.spark` to complete this process. You don't have to call it explicitly. We will check the installation when `sparkR.session` is called and `install.spark` function will be triggered automatically if no installation is found. + +```{r, eval=FALSE} +install.spark() +``` + +If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the Spark installation is. + +```{r, eval=FALSE} +sparkR.session(sparkHome = "/HOME/spark") +``` + +### Spark Session {#SetupSparkSession} + + +In addition to `sparkHome`, many other options can be specified in `sparkR.session`. For a complete list, see [Starting up: SparkSession](http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession) and [SparkR API doc](http://spark.apache.org/docs/latest/api/R/sparkR.session.html). + +In particular, the following Spark driver properties can be set in `sparkConfig`. + +Property Name | Property group | spark-submit equivalent +---------------- | ------------------ | ---------------------- +`spark.driver.memory` | Application Properties | `--driver-memory` +`spark.driver.extraClassPath` | Runtime Environment | `--driver-class-path` +`spark.driver.extraJavaOptions` | Runtime Environment | `--driver-java-options` +`spark.driver.extraLibraryPath` | Runtime Environment | `--driver-library-path` +`spark.yarn.keytab` | Application Properties | `--keytab` +`spark.yarn.principal` | Application Properties | `--principal` + +**For Windows users**: Due to different file prefixes across operating systems, to avoid the issue of potential wrong prefix, a current workaround is to specify `spark.sql.warehouse.dir` when starting the `SparkSession`. + +```{r, eval=FALSE} +spark_warehouse_path <- file.path(path.expand('~'), "spark-warehouse") +sparkR.session(spark.sql.warehouse.dir = spark_warehouse_path) +``` + + +#### Cluster Mode +SparkR can connect to remote Spark clusters. [Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) is a good introduction to different Spark cluster modes. + +When connecting SparkR to a remote Spark cluster, make sure that the Spark version and Hadoop version on the machine match the corresponding versions on the cluster. Current SparkR package is compatible with +```{r, echo=FALSE, tidy = TRUE} +paste("Spark", packageVersion("SparkR")) +``` +It should be used both on the local computer and on the remote cluster. + +To connect, pass the URL of the master node to `sparkR.session`. A complete list can be seen in [Spark Master URLs](http://spark.apache.org/docs/latest/submitting-applications.html#master-urls). +For example, to connect to a local standalone Spark master, we can call + +```{r, eval=FALSE} +sparkR.session(master = "spark://local:7077") +``` + +For YARN cluster, SparkR supports the client mode with the master set as "yarn". +```{r, eval=FALSE} +sparkR.session(master = "yarn") +``` +Yarn cluster mode is not supported in the current version. + +## Data Import + +### Local Data Frame +The simplest way is to convert a local R data frame into a `SparkDataFrame`. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a `SparkDataFrame`. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R. +```{r} +df <- as.DataFrame(faithful) +head(df) +``` + +### Data Sources +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. + +```{r, eval=FALSE} +sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") +``` + +We can see how to use data sources using an example CSV input file. For more information please refer to SparkR [read.df](https://spark.apache.org/docs/latest/api/R/read.df.html) API documentation. +```{r, eval=FALSE} +df <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "NA") +``` + +The data sources API natively supports JSON formatted input files. Note that the file that is used here is not a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +Let's take a look at the first two lines of the raw JSON file used here. + +```{r} +filePath <- paste0(sparkR.conf("spark.home"), + "/examples/src/main/resources/people.json") +readLines(filePath, n = 2L) +``` + +We use `read.df` to read that into a `SparkDataFrame`. + +```{r} +people <- read.df(filePath, "json") +count(people) +head(people) +``` + +SparkR automatically infers the schema from the JSON file. +```{r} +printSchema(people) +``` + +If we want to read multiple JSON files, `read.json` can be used. +```{r} +people <- read.json(paste0(Sys.getenv("SPARK_HOME"), + c("/examples/src/main/resources/people.json", + "/examples/src/main/resources/people.json"))) +count(people) +``` + +The data sources API can also be used to save out `SparkDataFrames` into multiple file formats. For example we can save the `SparkDataFrame` from the previous example to a Parquet file using `write.df`. +```{r, eval=FALSE} +write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite") +``` + +### Hive Tables +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). + +```{r, eval=FALSE} +sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + +txtPath <- paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/kv1.txt") +sqlCMD <- sprintf("LOAD DATA LOCAL INPATH '%s' INTO TABLE src", txtPath) +sql(sqlCMD) + +results <- sql("FROM src SELECT key, value") + +# results is now a SparkDataFrame +head(results) +``` + + +## Data Processing + +**To dplyr users**: SparkR has similar interface as dplyr in data processing. However, some noticeable differences are worth mentioning in the first place. We use `df` to represent a `SparkDataFrame` and `col` to represent the name of column here. + +1. indicate columns. SparkR uses either a character string of the column name or a Column object constructed with `$` to indicate a column. For example, to select `col` in `df`, we can write `select(df, "col")` or `select(df, df$col)`. + +2. describe conditions. In SparkR, the Column object representation can be inserted into the condition directly, or we can use a character string to describe the condition, without referring to the `SparkDataFrame` used. For example, to select rows with value > 1, we can write `filter(df, df$col > 1)` or `filter(df, "col > 1")`. + +Here are more concrete examples. + +dplyr | SparkR +-------- | --------- +`select(mtcars, mpg, hp)` | `select(carsDF, "mpg", "hp")` +`filter(mtcars, mpg > 20, hp > 100)` | `filter(carsDF, carsDF$mpg > 20, carsDF$hp > 100)` + +Other differences will be mentioned in the specific methods. + +We use the `SparkDataFrame` `carsDF` created above. We can get basic information about the `SparkDataFrame`. +```{r} +carsDF +``` + +Print out the schema in tree format. +```{r} +printSchema(carsDF) +``` + +### SparkDataFrame Operations + +#### Selecting rows, columns + +SparkDataFrames support a number of functions to do structured data processing. Here we include some basic examples and a complete list can be found in the [API](https://spark.apache.org/docs/latest/api/R/index.html) docs: + +You can also pass in column name as strings. +```{r} +head(select(carsDF, "mpg")) +``` + +Filter the SparkDataFrame to only retain rows with mpg less than 20 miles/gallon. +```{r} +head(filter(carsDF, carsDF$mpg < 20)) +``` + +#### Grouping, Aggregation + +A common flow of grouping and aggregation is + +1. Use `groupBy` or `group_by` with respect to some grouping variables to create a `GroupedData` object + +2. Feed the `GroupedData` object to `agg` or `summarize` functions, with some provided aggregation functions to compute a number within each group. + +A number of widely used functions are supported to aggregate data after grouping, including `avg`, `countDistinct`, `count`, `first`, `kurtosis`, `last`, `max`, `mean`, `min`, `sd`, `skewness`, `stddev_pop`, `stddev_samp`, `sumDistinct`, `sum`, `var_pop`, `var_samp`, `var`. See the [API doc for `mean`](http://spark.apache.org/docs/latest/api/R/mean.html) and other `agg_funcs` linked there. + +For example we can compute a histogram of the number of cylinders in the `mtcars` dataset as shown below. + +```{r} +numCyl <- summarize(groupBy(carsDF, carsDF$cyl), count = n(carsDF$cyl)) +head(numCyl) +``` + +#### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +```{r} +carsDF_km <- carsDF +carsDF_km$kmpg <- carsDF_km$mpg * 1.61 +head(select(carsDF_km, "model", "mpg", "kmpg")) +``` + + +### Window Functions +A window function is a variation of aggregation function. In simple words, + +* aggregation function: `n` to `1` mapping - returns a single value for a group of entries. Examples include `sum`, `count`, `max`. + +* window function: `n` to `n` mapping - returns one value for each entry in the group, but the value may depend on all the entries of the *group*. Examples include `rank`, `lead`, `lag`. + +Formally, the *group* mentioned above is called the *frame*. Every input row can have a unique frame associated with it and the output of the window function on that row is based on the rows confined in that frame. + +Window functions are often used in conjunction with the following functions: `windowPartitionBy`, `windowOrderBy`, `partitionBy`, `orderBy`, `over`. To illustrate this we next look at an example. + +We still use the `mtcars` dataset. The corresponding `SparkDataFrame` is `carsDF`. Suppose for each number of cylinders, we want to calculate the rank of each car in `mpg` within the group. +```{r} +carsSubDF <- select(carsDF, "model", "mpg", "cyl") +ws <- orderBy(windowPartitionBy("cyl"), "mpg") +carsRank <- withColumn(carsSubDF, "rank", over(rank(), ws)) +head(carsRank, n = 20L) +``` + +We explain in detail the above steps. + +* `windowPartitionBy` creates a window specification object `WindowSpec` that defines the partition. It controls which rows will be in the same partition as the given row. In this case, rows with the same value in `cyl` will be put in the same partition. `orderBy` further defines the ordering - the position a given row is in the partition. The resulting `WindowSpec` is returned as `ws`. + +More window specification methods include `rangeBetween`, which can define boundaries of the frame by value, and `rowsBetween`, which can define the boundaries by row indices. + +* `withColumn` appends a Column called `rank` to the `SparkDataFrame`. `over` returns a windowing column. The first argument is usually a Column returned by window function(s) such as `rank()`, `lead(carsDF$wt)`. That calculates the corresponding values according to the partitioned-and-ordered table. + +### User-Defined Function + +In SparkR, we support several kinds of user-defined functions (UDFs). + +#### Apply by Partition + +`dapply` can apply a function to each partition of a `SparkDataFrame`. The function to be applied to each partition of the `SparkDataFrame` should have only one parameter, a `data.frame` corresponding to a partition, and the output should be a `data.frame` as well. Schema specifies the row format of the resulting a `SparkDataFrame`. It must match to data types of returned value. See [here](#DataTypes) for mapping between R and Spark. + +We convert `mpg` to `kmpg` (kilometers per gallon). `carsSubDF` is a `SparkDataFrame` with a subset of `carsDF` columns. + +```{r} +carsSubDF <- select(carsDF, "model", "mpg") +schema <- structType(structField("model", "string"), structField("mpg", "double"), + structField("kmpg", "double")) +out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) +head(collect(out)) +``` + +Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. + +```{r} +out <- dapplyCollect( + carsSubDF, + function(x) { + x <- cbind(x, "kmpg" = x$mpg * 1.61) + }) +head(out, 3) +``` + +#### Apply by Group +`gapply` can apply a function to each group of a `SparkDataFrame`. The function is to be applied to each group of the `SparkDataFrame` and should have only two parameters: grouping key and R `data.frame` corresponding to that key. The groups are chosen from `SparkDataFrames` column(s). The output of function should be a `data.frame`. Schema specifies the row format of the resulting `SparkDataFrame`. It must represent R function’s output schema on the basis of Spark data types. The column names of the returned `data.frame` are set by user. See [here](#DataTypes) for mapping between R and Spark. + +```{r} +schema <- structType(structField("cyl", "double"), structField("max_mpg", "double")) +result <- gapply( + carsDF, + "cyl", + function(key, x) { + y <- data.frame(key, max(x$mpg)) + }, + schema) +head(arrange(result, "max_mpg", decreasing = TRUE)) +``` + +Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. + +```{r} +result <- gapplyCollect( + carsDF, + "cyl", + function(key, x) { + y <- data.frame(key, max(x$mpg)) + colnames(y) <- c("cyl", "max_mpg") + y + }) +head(result[order(result$max_mpg, decreasing = TRUE), ]) +``` + +#### Distribute Local Functions + +Similar to `lapply` in native R, `spark.lapply` runs a function over a list of elements and distributes the computations with Spark. `spark.lapply` works in a manner that is similar to `doParallel` or `lapply` to elements of a list. The results of all the computations should fit in a single machine. If that is not the case you can do something like `df <- createDataFrame(list)` and then use `dapply`. + +We use `svm` in package `e1071` as an example. We use all default settings except for varying costs of constraints violation. `spark.lapply` can train those different models in parallel. + +```{r} +costs <- exp(seq(from = log(1), to = log(1000), length.out = 5)) +train <- function(cost) { + stopifnot(requireNamespace("e1071", quietly = TRUE)) + model <- e1071::svm(Species ~ ., data = iris, cost = cost) + summary(model) +} +``` + +Return a list of model's summaries. +```{r} +model.summaries <- spark.lapply(costs, train) +``` + +```{r} +class(model.summaries) +``` + + +To avoid lengthy display, we only present the partial result of the second fitted model. You are free to inspect other models as well. +```{r, include=FALSE} +ops <- options() +options(max.print=40) +``` +```{r} +print(model.summaries[[2]]) +``` +```{r, include=FALSE} +options(ops) +``` + + +### SQL Queries +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. + +```{r} +people <- read.df(paste0(sparkR.conf("spark.home"), + "/examples/src/main/resources/people.json"), "json") +``` + +Register this SparkDataFrame as a temporary view. + +```{r} +createOrReplaceTempView(people, "people") +``` + +SQL statements can be run by using the sql method. +```{r} +teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +``` + + +## Machine Learning + +SparkR supports the following machine learning models and algorithms. + +* Generalized Linear Model (GLM) + +* Naive Bayes Model + +* $k$-means Clustering + +* Accelerated Failure Time (AFT) Survival Model + +More will be added in the future. + +### R Formula + +For most above, SparkR supports **R formula operators**, including `~`, `.`, `:`, `+` and `-` for model fitting. This makes it a similar experience as using R functions. + +### Training and Test Sets + +We can easily split `SparkDataFrame` into random training and test sets by the `randomSplit` function. It returns a list of split `SparkDataFrames` with provided `weights`. We use `carsDF` as an example and want to have about $70%$ training data and $30%$ test data. +```{r} +splitDF_list <- randomSplit(carsDF, c(0.7, 0.3), seed = 0) +carsDF_train <- splitDF_list[[1]] +carsDF_test <- splitDF_list[[2]] +``` + +```{r} +count(carsDF_train) +head(carsDF_train) +``` + +```{r} +count(carsDF_test) +head(carsDF_test) +``` + + +### Models and Algorithms + +#### Generalized Linear Model + +The main function is `spark.glm`. The following families and link functions are supported. The default is gaussian. + +Family | Link Function +------ | --------- +gaussian | identity, log, inverse +binomial | logit, probit, cloglog (complementary log-log) +poisson | log, identity, sqrt +gamma | inverse, identity, log + +There are three ways to specify the `family` argument. + +* Family name as a character string, e.g. `family = "gaussian"`. + +* Family function, e.g. `family = binomial`. + +* Result returned by a family function, e.g. `family = poisson(link = log)` + +For more information regarding the families and their link functions, see the Wikipedia page [Generalized Linear Model](https://en.wikipedia.org/wiki/Generalized_linear_model). + +We use the `mtcars` dataset as an illustration. The corresponding `SparkDataFrame` is `carsDF`. After fitting the model, we print out a summary and see the fitted values by making predictions on the original dataset. We can also pass into a new `SparkDataFrame` of same schema to predict on new data. + +```{r} +gaussianGLM <- spark.glm(carsDF, mpg ~ wt + hp) +summary(gaussianGLM) +``` +When doing prediction, a new column called `prediction` will be appended. Let's look at only a subset of columns here. +```{r} +gaussianFitted <- predict(gaussianGLM, carsDF) +head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) +``` + +#### Naive Bayes Model + +Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. + +```{r} +titanic <- as.data.frame(Titanic) +titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) +naiveBayesModel <- spark.naiveBayes(titanicDF, Survived ~ Class + Sex + Age) +summary(naiveBayesModel) +naiveBayesPrediction <- predict(naiveBayesModel, titanicDF) +head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction")) +``` + +#### k-Means Clustering + +`spark.kmeans` fits a $k$-means clustering model against a `SparkDataFrame`. As an unsupervised learning method, we don't need a response variable. Hence, the left hand side of the R formula should be left blank. The clustering is based only on the variables on the right hand side. + +```{r} +kmeansModel <- spark.kmeans(carsDF, ~ mpg + hp + wt, k = 3) +summary(kmeansModel) +kmeansPredictions <- predict(kmeansModel, carsDF) +head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20L) +``` + +#### AFT Survival Model +Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. + +Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. +```{r, warning=FALSE} +library(survival) +ovarianDF <- createDataFrame(ovarian) +aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) +summary(aftModel) +aftPredictions <- predict(aftModel, ovarianDF) +head(aftPredictions) +``` + +### Model Persistence +The following example shows how to save/load an ML model by SparkR. +```{r, warning=FALSE} +irisDF <- createDataFrame(iris) +gaussianGLM <- spark.glm(irisDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") + +# Save and then load a fitted MLlib model +modelPath <- tempfile(pattern = "ml", fileext = ".tmp") +write.ml(gaussianGLM, modelPath) +gaussianGLM2 <- read.ml(modelPath) + +# Check model summary +summary(gaussianGLM2) + +# Check model prediction +gaussianPredictions <- predict(gaussianGLM2, irisDF) +head(gaussianPredictions) + +unlink(modelPath) +``` + + +## Advanced Topics + +### SparkR Object Classes + +There are three main object classes in SparkR you may be working with. + +* `SparkDataFrame`: the central component of SparkR. It is an S4 class representing distributed collection of data organized into named columns, which is conceptually equivalent to a table in a relational database or a data frame in R. It has two slots `sdf` and `env`. + + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + + `env` saves the meta-information of the object such as `isCached`. + +It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + +* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. + +It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. + +This is often an intermediate object with group information and followed up by aggregation operations. + +### Architecture + +A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. + +Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. + +The main method calls of actual computation happen in the Spark JVM of the driver. We have a socket-based SparkR API that allows us to invoke functions on the JVM from R. We use a SparkR JVM backend that listens on a Netty-based socket server. + +Two kinds of RPCs are supported in the SparkR JVM backend: method invocation and creating new objects. Method invocation can be done in two ways. + +* `sparkR.invokeJMethod` takes a reference to an existing Java object and a list of arguments to be passed on to the method. + +* `sparkR.invokeJStatic` takes a class name for static method and a list of arguments to be passed on to the method. + +The arguments are serialized using our custom wire format which is then deserialized on the JVM side. We then use Java reflection to invoke the appropriate method. + +To create objects, `sparkR.newJObject` is used and then similarly the appropriate constructor is invoked with provided arguments. + +Finally, we use a new R class `jobj` that refers to a Java object existing in the backend. These references are tracked on the Java side and are automatically garbage collected when they go out of scope on the R side. + +## Appendix + +### R and Spark Data Types {#DataTypes} + +R | Spark +----------- | ------------- +byte | byte +integer | integer +float | float +double | double +numeric | double +character | string +string | string +binary | binary +raw | binary +logical | boolean +POSIXct | timestamp +POSIXlt | timestamp +Date | date +array | array +list | array +env | map + +## References + +* [Spark Cluster Mode Overview](http://spark.apache.org/docs/latest/cluster-overview.html) + +* [Submitting Spark Applications](http://spark.apache.org/docs/latest/submitting-applications.html) + +* [Machine Learning Library Guide (MLlib)](http://spark.apache.org/docs/latest/ml-guide.html) + +* [SparkR: Scaling R Programs with Spark](https://people.csail.mit.edu/matei/papers/2016/sigmod_sparkr.pdf), Shivaram Venkataraman, Zongheng Yang, Davies Liu, Eric Liang, Hossein Falaki, Xiangrui Meng, Reynold Xin, Ali Ghodsi, Michael Franklin, Ion Stoica, and Matei Zaharia. SIGMOD 2016. June 2016. + +```{r, echo=FALSE} +sparkR.session.stop() +``` diff --git a/R/run-tests.sh b/R/run-tests.sh index 9dcf0ace7d97e..1a1e8ab9ffe18 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -26,6 +26,17 @@ rm -f $LOGFILE SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) +# Also run the documentation tests for CRAN +CRAN_CHECK_LOG_FILE=$FWDIR/cran-check.out +rm -f $CRAN_CHECK_LOG_FILE + +NO_TESTS=1 NO_MANUAL=1 $FWDIR/check-cran.sh 2>&1 | tee -a $CRAN_CHECK_LOG_FILE +FAILED=$((PIPESTATUS[0]||$FAILED)) + +NUM_CRAN_WARNING="$(grep -c WARNING$ $CRAN_CHECK_LOG_FILE)" +NUM_CRAN_ERROR="$(grep -c ERROR$ $CRAN_CHECK_LOG_FILE)" +NUM_CRAN_NOTES="$(grep -c NOTE$ $CRAN_CHECK_LOG_FILE)" + if [[ $FAILED != 0 ]]; then cat $LOGFILE echo -en "\033[31m" # Red @@ -33,7 +44,17 @@ if [[ $FAILED != 0 ]]; then echo -en "\033[0m" # No color exit -1 else - echo -en "\033[32m" # Green - echo "Tests passed." - echo -en "\033[0m" # No color + # We have 2 existing NOTEs for new maintainer, attach() + # We have one more NOTE in Jenkins due to "No repository set" + if [[ $NUM_CRAN_WARNING != 0 || $NUM_CRAN_ERROR != 0 || $NUM_CRAN_NOTES -gt 3 ]]; then + cat $CRAN_CHECK_LOG_FILE + echo -en "\033[31m" # Red + echo "Had CRAN check errors; see logs." + echo -en "\033[0m" # No color + exit -1 + else + echo -en "\033[32m" # Green + echo "Tests passed." + echo -en "\033[0m" # No color + fi fi diff --git a/README.md b/README.md index d5804d1a20b43..c77c429e577cd 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,8 @@ To build Spark and its example programs, run: build/mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) + +You can build Spark using more than one thread by using the -T option with Maven, see ["Parallel builds in Maven 3"](https://cwiki.apache.org/confluence/display/MAVEN/Parallel+builds+in+Maven+3). More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) diff --git a/assembly/pom.xml b/assembly/pom.xml index 75ac9262cbae5..de09fce6a48d3 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml diff --git a/bin/pyspark b/bin/pyspark index d1fe75a08bdac..037645dbd64d7 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -50,9 +50,11 @@ if [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}" fi +WORKS_WITH_IPYTHON=$($DEFAULT_PYTHON -c 'import sys; print(sys.version_info >= (2, 7, 0))') + # Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! WORKS_WITH_IPYTHON ]]; then echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 exit 1 else @@ -63,7 +65,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.2-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.3-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index cb788497ffc79..1217a4f2f97a2 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.9.2-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.3-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/bin/spark-class b/bin/spark-class index b2a36b9846780..377c8d1add3f6 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -64,8 +64,34 @@ fi # The launcher library will print arguments separated by a NULL character, to allow arguments with # characters that would be otherwise interpreted by the shell. Read that in a while loop, populating # an array that will be used to exec the final command. +# +# The exit code of the launcher is appended to the output, so the parent shell removes it from the +# command array and checks the value to see if the launcher succeeded. +build_command() { + "$RUNNER" -Xmx128m -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@" + printf "%d\0" $? +} + CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") -done < <("$RUNNER" -cp "$LAUNCH_CLASSPATH" org.apache.spark.launcher.Main "$@") +done < <(build_command "$@") + +COUNT=${#CMD[@]} +LAST=$((COUNT - 1)) +LAUNCHER_EXIT_CODE=${CMD[$LAST]} + +# Certain JVM failures result in errors being printed to stdout (instead of stderr), which causes +# the code that parses the output of the launcher to get confused. In those cases, check if the +# exit code is an integer, and if it's not, handle it as a special error case. +if ! [[ $LAUNCHER_EXIT_CODE =~ ^[0-9]+$ ]]; then + echo "${CMD[@]}" | head -n-1 1>&2 + exit 1 +fi + +if [ $LAUNCHER_EXIT_CODE != 0 ]; then + exit $LAUNCHER_EXIT_CODE +fi + +CMD=("${CMD[@]:0:$LAST}") exec "${CMD[@]}" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index db680218dc964..869c0b202f7f3 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -55,7 +55,7 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt -"%RUNNER%" -cp "%LAUNCH_CLASSPATH%" org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +"%RUNNER%" -Xmx128m -cp "%LAUNCH_CLASSPATH%" org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i ) diff --git a/build/mvn b/build/mvn index a78b93a685689..c3ab62da36868 100755 --- a/build/mvn +++ b/build/mvn @@ -141,9 +141,13 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then + ZINC_JAVA_HOME= + if [ -n "$JAVA_7_HOME" ]; then + ZINC_JAVA_HOME="env JAVA_HOME=$JAVA_7_HOME" + fi export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} - "${ZINC_BIN}" -start -port ${ZINC_PORT} \ + $ZINC_JAVA_HOME "${ZINC_BIN}" -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ -scala-library "${SCALA_LIBRARY}" &>/dev/null fi diff --git a/build/spark-build-info b/build/spark-build-info new file mode 100755 index 0000000000000..ad0ec67f455cb --- /dev/null +++ b/build/spark-build-info @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script generates the build info for spark and places it into the spark-version-info.properties file. +# Arguments: +# build_tgt_directory - The target directory where properties file would be created. [./core/target/extra-resources] +# spark_version - The current version of spark + +RESOURCE_DIR="$1" +mkdir -p "$RESOURCE_DIR" +SPARK_BUILD_INFO="${RESOURCE_DIR}"/spark-version-info.properties + +echo_build_properties() { + echo version=$1 + echo user=$USER + echo revision=$(git rev-parse HEAD) + echo branch=$(git rev-parse --abbrev-ref HEAD) + echo date=$(date -u +%Y-%m-%dT%H:%M:%SZ) + echo url=$(git config --get remote.origin.url) +} + +echo_build_properties $2 > "$SPARK_BUILD_INFO" diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index bd507c2cb6c4b..2ee104f3eaf39 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -42,6 +42,22 @@ netty-all + + org.fusesource.leveldbjni + leveldbjni-all + 1.8 + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.fasterxml.jackson.core + jackson-annotations + + org.slf4j @@ -66,7 +82,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} org.mockito diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 844eff4f4c701..c20fab83c3460 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -130,7 +130,7 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { if (conf.lazyFileDescriptor()) { - return new LazyFileRegion(file, offset, length); + return new DefaultFileRegion(file, offset, length); } else { FileChannel fileChannel = new FileInputStream(file).getChannel(); return new DefaultFileRegion(fileChannel, offset, length); diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java deleted file mode 100644 index 162cf6da0dffe..0000000000000 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.buffer; - -import java.io.FileInputStream; -import java.io.File; -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; - -import com.google.common.base.Objects; -import io.netty.channel.FileRegion; -import io.netty.util.AbstractReferenceCounted; - -import org.apache.spark.network.util.JavaUtils; - -/** - * A FileRegion implementation that only creates the file descriptor when the region is being - * transferred. This cannot be used with Epoll because there is no native support for it. - * - * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we - * should push this into Netty so the native Epoll transport can support this feature. - */ -public final class LazyFileRegion extends AbstractReferenceCounted implements FileRegion { - - private final File file; - private final long position; - private final long count; - - private FileChannel channel; - - private long numBytesTransferred = 0L; - - /** - * @param file file to transfer. - * @param position start position for the transfer. - * @param count number of bytes to transfer starting from position. - */ - public LazyFileRegion(File file, long position, long count) { - this.file = file; - this.position = position; - this.count = count; - } - - @Override - protected void deallocate() { - JavaUtils.closeQuietly(channel); - } - - @Override - public long position() { - return position; - } - - @Override - public long transfered() { - return numBytesTransferred; - } - - @Override - public long count() { - return count; - } - - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - if (channel == null) { - channel = new FileInputStream(file).getChannel(); - } - - long count = this.count - position; - if (count < 0 || position < 0) { - throw new IllegalArgumentException( - "position out of range: " + position + " (expected: 0 - " + (count - 1) + ')'); - } - - if (count == 0) { - return 0L; - } - - long written = channel.transferTo(this.position + position, count, target); - if (written > 0) { - numBytesTransferred += written; - } - return written; - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("file", file) - .add("position", position) - .add("count", count) - .toString(); - } -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 64a83171e9e90..17ac91dd4cdd3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -43,7 +43,7 @@ import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.StreamChunkId; import org.apache.spark.network.protocol.StreamRequest; -import org.apache.spark.network.util.NettyUtils; +import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow @@ -135,9 +135,10 @@ public void fetchChunk( long streamId, final int chunkIndex, final ChunkReceivedCallback callback) { - final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); - logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); + if (logger.isDebugEnabled()) { + logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); + } final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); handler.addFetchRequest(streamChunkId, callback); @@ -148,11 +149,13 @@ public void fetchChunk( public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr, - timeTaken); + if (logger.isTraceEnabled()) { + logger.trace("Sending request {} to {} took {} ms", streamChunkId, + getRemoteAddress(channel), timeTaken); + } } else { String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - serverAddr, future.cause()); + getRemoteAddress(channel), future.cause()); logger.error(errorMsg, future.cause()); handler.removeFetchRequest(streamChunkId); channel.close(); @@ -173,9 +176,10 @@ public void operationComplete(ChannelFuture future) throws Exception { * @param callback Object to call with the stream data. */ public void stream(final String streamId, final StreamCallback callback) { - final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); - logger.debug("Sending stream request for {} to {}", streamId, serverAddr); + if (logger.isDebugEnabled()) { + logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); + } // Need to synchronize here so that the callback is added to the queue and the RPC is // written to the socket atomically, so that callbacks are called in the right order @@ -188,11 +192,13 @@ public void stream(final String streamId, final StreamCallback callback) { public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request for {} to {} took {} ms", streamId, serverAddr, - timeTaken); + if (logger.isTraceEnabled()) { + logger.trace("Sending request for {} to {} took {} ms", streamId, + getRemoteAddress(channel), timeTaken); + } } else { String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - serverAddr, future.cause()); + getRemoteAddress(channel), future.cause()); logger.error(errorMsg, future.cause()); channel.close(); try { @@ -215,9 +221,10 @@ public void operationComplete(ChannelFuture future) throws Exception { * @return The RPC's id. */ public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { - final String serverAddr = NettyUtils.getRemoteAddress(channel); final long startTime = System.currentTimeMillis(); - logger.trace("Sending RPC to {}", serverAddr); + if (logger.isTraceEnabled()) { + logger.trace("Sending RPC to {}", getRemoteAddress(channel)); + } final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); @@ -228,10 +235,13 @@ public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; - logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken); + if (logger.isTraceEnabled()) { + logger.trace("Sending request {} to {} took {} ms", requestId, + getRemoteAddress(channel), timeTaken); + } } else { String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, - serverAddr, future.cause()); + getRemoteAddress(channel), future.cause()); logger.error(errorMsg, future.cause()); handler.removeRpcRequest(requestId); channel.close(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index a27aaf2b277f7..1c9916baee07c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -195,7 +195,7 @@ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { - logger.debug("Creating new connection to " + address); + logger.debug("Creating new connection to {}", address); Bootstrap bootstrap = new Bootstrap(); bootstrap.group(workerGroup) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 8a69223c88ee4..179667296ec7d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -38,7 +38,7 @@ import org.apache.spark.network.protocol.StreamFailure; import org.apache.spark.network.protocol.StreamResponse; import org.apache.spark.network.server.MessageHandler; -import org.apache.spark.network.util.NettyUtils; +import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; import org.apache.spark.network.util.TransportFrameDecoder; /** @@ -122,7 +122,7 @@ public void channelActive() { @Override public void channelInactive() { if (numOutstandingRequests() > 0) { - String remoteAddress = NettyUtils.getRemoteAddress(channel); + String remoteAddress = getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed")); @@ -132,7 +132,7 @@ public void channelInactive() { @Override public void exceptionCaught(Throwable cause) { if (numOutstandingRequests() > 0) { - String remoteAddress = NettyUtils.getRemoteAddress(channel); + String remoteAddress = getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", numOutstandingRequests(), remoteAddress); failOutstandingRequests(cause); @@ -141,13 +141,12 @@ public void exceptionCaught(Throwable cause) { @Override public void handle(ResponseMessage message) throws Exception { - String remoteAddress = NettyUtils.getRemoteAddress(channel); if (message instanceof ChunkFetchSuccess) { ChunkFetchSuccess resp = (ChunkFetchSuccess) message; ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", - resp.streamChunkId, remoteAddress); + resp.streamChunkId, getRemoteAddress(channel)); resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkId); @@ -159,7 +158,7 @@ public void handle(ResponseMessage message) throws Exception { ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", - resp.streamChunkId, remoteAddress, resp.errorString); + resp.streamChunkId, getRemoteAddress(channel), resp.errorString); } else { outstandingFetches.remove(resp.streamChunkId); listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException( @@ -170,7 +169,7 @@ public void handle(ResponseMessage message) throws Exception { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", - resp.requestId, remoteAddress, resp.body().size()); + resp.requestId, getRemoteAddress(channel), resp.body().size()); } else { outstandingRpcs.remove(resp.requestId); try { @@ -184,7 +183,7 @@ public void handle(ResponseMessage message) throws Exception { RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", - resp.requestId, remoteAddress, resp.errorString); + resp.requestId, getRemoteAddress(channel), resp.errorString); } else { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 074780f2b95ce..f0453186185e1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -39,7 +39,7 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { Message.Type msgType = Message.Type.decode(in); Message decoded = decode(msgType, in); assert decoded.type() == msgType; - logger.trace("Received message " + msgType + ": " + decoded); + logger.trace("Received message {}: {}", msgType, decoded); out.add(decoded); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index f2223379a9d24..884ea7d1152a5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -29,7 +29,7 @@ import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ResponseMessage; -import org.apache.spark.network.util.NettyUtils; +import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** * The single Transport-level Channel handler which is used for delegating requests to the @@ -76,7 +76,7 @@ public TransportClient getClient() { @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { - logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), + logger.warn("Exception in connection from " + getRemoteAddress(ctx.channel()), cause); requestHandler.exceptionCaught(cause); responseHandler.exceptionCaught(cause); @@ -139,7 +139,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { if (responseHandler.numOutstandingRequests() > 0) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); + String address = getRemoteAddress(ctx.channel()); logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + "requests. Assuming connection is dead; please adjust spark.network.timeout if " + "this is wrong.", address, requestTimeoutNs / 1000 / 1000); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index bebe88ec5d503..143c29c6a57cf 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,6 +17,7 @@ package org.apache.spark.network.server; +import java.net.SocketAddress; import java.nio.ByteBuffer; import com.google.common.base.Throwables; @@ -42,7 +43,7 @@ import org.apache.spark.network.protocol.StreamFailure; import org.apache.spark.network.protocol.StreamRequest; import org.apache.spark.network.protocol.StreamResponse; -import org.apache.spark.network.util.NettyUtils; +import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; /** * A handler that processes requests from clients and writes chunk data back. Each handler is @@ -114,9 +115,10 @@ public void handle(RequestMessage request) { } private void processFetchRequest(final ChunkFetchRequest req) { - final String client = NettyUtils.getRemoteAddress(channel); - - logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); + if (logger.isTraceEnabled()) { + logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel), + req.streamChunkId); + } ManagedBuffer buf; try { @@ -124,8 +126,8 @@ private void processFetchRequest(final ChunkFetchRequest req) { streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { - logger.error(String.format( - "Error opening block %s for request from %s", req.streamChunkId, client), e); + logger.error(String.format("Error opening block %s for request from %s", req.streamChunkId, + getRemoteAddress(channel)), e); respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); return; } @@ -134,13 +136,12 @@ private void processFetchRequest(final ChunkFetchRequest req) { } private void processStreamRequest(final StreamRequest req) { - final String client = NettyUtils.getRemoteAddress(channel); ManagedBuffer buf; try { buf = streamManager.openStream(req.streamId); } catch (Exception e) { logger.error(String.format( - "Error opening stream %s for request from %s", req.streamId, client), e); + "Error opening stream %s for request from %s", req.streamId, getRemoteAddress(channel)), e); respond(new StreamFailure(req.streamId, Throwables.getStackTraceAsString(e))); return; } @@ -189,13 +190,13 @@ private void processOneWayMessage(OneWayMessage req) { * it will be logged and the channel closed. */ private void respond(final Encodable result) { - final String remoteAddress = channel.remoteAddress().toString(); + final SocketAddress remoteAddress = channel.remoteAddress(); channel.writeAndFlush(result).addListener( new ChannelFutureListener() { @Override public void operationComplete(ChannelFuture future) throws Exception { if (future.isSuccess()) { - logger.trace(String.format("Sent result %s to client %s", result, remoteAddress)); + logger.trace("Sent result {} to client {}", result, remoteAddress); } else { logger.error(String.format("Error sending result %s to %s; closing connection", result, remoteAddress), future.cause()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index baae235e02205..a67db4f69f086 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -130,7 +130,7 @@ protected void initChannel(SocketChannel ch) throws Exception { channelFuture.syncUninterruptibly(); port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); - logger.debug("Shuffle server started on port :" + port); + logger.debug("Shuffle server started on port: {}", port); } @Override diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/LevelDBProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/LevelDBProvider.java new file mode 100644 index 0000000000000..f96d068cf3d59 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/LevelDBProvider.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.fusesource.leveldbjni.JniDBFactory; +import org.fusesource.leveldbjni.internal.NativeDB; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.Options; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * LevelDB utility class available in the network package. + */ +public class LevelDBProvider { + private static final Logger logger = LoggerFactory.getLogger(LevelDBProvider.class); + + public static DB initLevelDB(File dbFile, StoreVersion version, ObjectMapper mapper) throws + IOException { + DB tmpDb = null; + if (dbFile != null) { + Options options = new Options(); + options.createIfMissing(false); + options.logger(new LevelDBLogger()); + try { + tmpDb = JniDBFactory.factory.open(dbFile, options); + } catch (NativeDB.DBException e) { + if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { + logger.info("Creating state database at " + dbFile); + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(dbFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + } else { + // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new + // one, so we can keep processing new apps + logger.error("error opening leveldb file {}. Creating new file, will not be able to " + + "recover state for existing applications", dbFile, e); + if (dbFile.isDirectory()) { + for (File f : dbFile.listFiles()) { + if (!f.delete()) { + logger.warn("error deleting {}", f.getPath()); + } + } + } + if (!dbFile.delete()) { + logger.warn("error deleting {}", dbFile.getPath()); + } + options.createIfMissing(true); + try { + tmpDb = JniDBFactory.factory.open(dbFile, options); + } catch (NativeDB.DBException dbExc) { + throw new IOException("Unable to create state store", dbExc); + } + + } + } + // if there is a version mismatch, we throw an exception, which means the service is unusable + checkVersion(tmpDb, version, mapper); + } + return tmpDb; + } + + private static class LevelDBLogger implements org.iq80.leveldb.Logger { + private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); + + @Override + public void log(String message) { + LOG.info(message); + } + } + + /** + * Simple major.minor versioning scheme. Any incompatible changes should be across major + * versions. Minor version differences are allowed -- meaning we should be able to read + * dbs that are either earlier *or* later on the minor version. + */ + public static void checkVersion(DB db, StoreVersion newversion, ObjectMapper mapper) throws + IOException { + byte[] bytes = db.get(StoreVersion.KEY); + if (bytes == null) { + storeVersion(db, newversion, mapper); + } else { + StoreVersion version = mapper.readValue(bytes, StoreVersion.class); + if (version.major != newversion.major) { + throw new IOException("cannot read state DB with version " + version + ", incompatible " + + "with current version " + newversion); + } + storeVersion(db, newversion, mapper); + } + } + + public static void storeVersion(DB db, StoreVersion version, ObjectMapper mapper) + throws IOException { + db.put(StoreVersion.KEY, mapper.writeValueAsBytes(version)); + } + + public static class StoreVersion { + + static final byte[] KEY = "StoreVersion".getBytes(StandardCharsets.UTF_8); + + public final int major; + public final int minor; + + @JsonCreator + public StoreVersion(@JsonProperty("major") int major, @JsonProperty("minor") int minor) { + this.major = major; + this.minor = minor; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StoreVersion that = (StoreVersion) o; + + return major == that.major && minor == that.minor; + } + + @Override + public int hashCode() { + int result = major; + result = 31 * result + minor; + return result; + } + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java index 922c37a10efdd..e79eef0325897 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -48,11 +48,27 @@ * use this functionality in both a Guava 11 environment and a Guava >14 environment. */ public final class LimitedInputStream extends FilterInputStream { + private final boolean closeWrappedStream; private long left; private long mark = -1; public LimitedInputStream(InputStream in, long limit) { + this(in, limit, true); + } + + /** + * Create a LimitedInputStream that will read {@code limit} bytes from {@code in}. + *

+ * If {@code closeWrappedStream} is true, this will close {@code in} when it is closed. + * Otherwise, the stream is left open for reading its remaining content. + * + * @param in a {@link InputStream} to read from + * @param limit the number of bytes to read + * @param closeWrappedStream whether to close {@code in} when {@link #close} is called + */ + public LimitedInputStream(InputStream in, long limit, boolean closeWrappedStream) { super(in); + this.closeWrappedStream = closeWrappedStream; Preconditions.checkNotNull(in); Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); left = limit; @@ -102,4 +118,11 @@ public LimitedInputStream(InputStream in, long limit) { left -= skipped; return skipped; } + + @Override + public void close() throws IOException { + if (closeWrappedStream) { + super.close(); + } + } } diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 810ec10ca05b3..b20f9e2aab2c3 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -43,22 +43,6 @@ ${project.version} - - org.fusesource.leveldbjni - leveldbjni-all - 1.8 - - - - com.fasterxml.jackson.core - jackson-databind - - - - com.fasterxml.jackson.core - jackson-annotations - - org.slf4j @@ -80,7 +64,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} log4j diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index f8d03b3b9433a..22fd592a321d2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -35,6 +35,7 @@ import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; import org.apache.spark.network.shuffle.protocol.*; +import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -86,7 +87,11 @@ protected void handleMessage( blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId)); } long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); - logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); + logger.trace("Registered streamId {} with {} buffers for client {} from host {}", + streamId, + msg.blockIds.length, + client.getClientId(), + NettyUtils.getRemoteAddress(client.getChannel())); callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); } else if (msgObj instanceof RegisterExecutor) { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 54e870a9b56a6..e34dc1f5b1de5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -30,17 +30,16 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; import com.google.common.collect.Maps; -import org.fusesource.leveldbjni.JniDBFactory; -import org.fusesource.leveldbjni.internal.NativeDB; import org.iq80.leveldb.DB; import org.iq80.leveldb.DBIterator; -import org.iq80.leveldb.Options; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.LevelDBProvider; +import org.apache.spark.network.util.LevelDBProvider.StoreVersion; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -95,52 +94,10 @@ public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorF Executor directoryCleaner) throws IOException { this.conf = conf; this.registeredExecutorFile = registeredExecutorFile; - if (registeredExecutorFile != null) { - Options options = new Options(); - options.createIfMissing(false); - options.logger(new LevelDBLogger()); - DB tmpDb; - try { - tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); - } catch (NativeDB.DBException e) { - if (e.isNotFound() || e.getMessage().contains(" does not exist ")) { - logger.info("Creating state database at " + registeredExecutorFile); - options.createIfMissing(true); - try { - tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); - } catch (NativeDB.DBException dbExc) { - throw new IOException("Unable to create state store", dbExc); - } - } else { - // the leveldb file seems to be corrupt somehow. Lets just blow it away and create a new - // one, so we can keep processing new apps - logger.error("error opening leveldb file {}. Creating new file, will not be able to " + - "recover state for existing applications", registeredExecutorFile, e); - if (registeredExecutorFile.isDirectory()) { - for (File f : registeredExecutorFile.listFiles()) { - if (!f.delete()) { - logger.warn("error deleting {}", f.getPath()); - } - } - } - if (!registeredExecutorFile.delete()) { - logger.warn("error deleting {}", registeredExecutorFile.getPath()); - } - options.createIfMissing(true); - try { - tmpDb = JniDBFactory.factory.open(registeredExecutorFile, options); - } catch (NativeDB.DBException dbExc) { - throw new IOException("Unable to create state store", dbExc); - } - - } - } - // if there is a version mismatch, we throw an exception, which means the service is unusable - checkVersion(tmpDb); - executors = reloadRegisteredExecutors(tmpDb); - db = tmpDb; + db = LevelDBProvider.initLevelDB(this.registeredExecutorFile, CURRENT_VERSION, mapper); + if (db != null) { + executors = reloadRegisteredExecutors(db); } else { - db = null; executors = Maps.newConcurrentMap(); } this.directoryCleaner = directoryCleaner; @@ -244,7 +201,7 @@ private void deleteExecutorDirs(String[] dirs) { for (String localDir : dirs) { try { JavaUtils.deleteRecursively(new File(localDir)); - logger.debug("Successfully cleaned up directory: " + localDir); + logger.debug("Successfully cleaned up directory: {}", localDir); } catch (Exception e) { logger.error("Failed to delete directory: " + localDir, e); } @@ -368,76 +325,11 @@ static ConcurrentMap reloadRegisteredExecutors(D break; } AppExecId id = parseDbAppExecKey(key); + logger.info("Reloading registered executors: " + id.toString()); ExecutorShuffleInfo shuffleInfo = mapper.readValue(e.getValue(), ExecutorShuffleInfo.class); registeredExecutors.put(id, shuffleInfo); } } return registeredExecutors; } - - private static class LevelDBLogger implements org.iq80.leveldb.Logger { - private static final Logger LOG = LoggerFactory.getLogger(LevelDBLogger.class); - - @Override - public void log(String message) { - LOG.info(message); - } - } - - /** - * Simple major.minor versioning scheme. Any incompatible changes should be across major - * versions. Minor version differences are allowed -- meaning we should be able to read - * dbs that are either earlier *or* later on the minor version. - */ - private static void checkVersion(DB db) throws IOException { - byte[] bytes = db.get(StoreVersion.KEY); - if (bytes == null) { - storeVersion(db); - } else { - StoreVersion version = mapper.readValue(bytes, StoreVersion.class); - if (version.major != CURRENT_VERSION.major) { - throw new IOException("cannot read state DB with version " + version + ", incompatible " + - "with current version " + CURRENT_VERSION); - } - storeVersion(db); - } - } - - private static void storeVersion(DB db) throws IOException { - db.put(StoreVersion.KEY, mapper.writeValueAsBytes(CURRENT_VERSION)); - } - - - public static class StoreVersion { - - static final byte[] KEY = "StoreVersion".getBytes(StandardCharsets.UTF_8); - - public final int major; - public final int minor; - - @JsonCreator public StoreVersion( - @JsonProperty("major") int major, - @JsonProperty("minor") int minor) { - this.major = major; - this.minor = minor; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - StoreVersion that = (StoreVersion) o; - - return major == that.major && minor == that.minor; - } - - @Override - public int hashCode() { - int result = major; - result = 31 * result + minor; - return result; - } - } - } diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index bc83ef24c30ec..06895c67560a6 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -48,7 +48,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} @@ -96,6 +96,13 @@ com.fasterxml.jackson.** + + io.netty + ${spark.shade.packageName}.io.netty + + io.netty.** + + diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 4bc3c1a3c8a64..7bdf73539671c 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -18,15 +18,28 @@ package org.apache.spark.network.yarn; import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.nio.ByteBuffer; import java.util.List; +import java.util.Map; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.permission.FsPermission; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.server.api.*; +import org.apache.spark.network.util.LevelDBProvider; +import org.iq80.leveldb.DB; +import org.iq80.leveldb.DBIterator; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -68,6 +81,22 @@ public class YarnShuffleService extends AuxiliaryService { private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate"; private static final boolean DEFAULT_SPARK_AUTHENTICATE = false; + private static final String RECOVERY_FILE_NAME = "registeredExecutors.ldb"; + private static final String SECRETS_RECOVERY_FILE_NAME = "sparkShuffleRecovery.ldb"; + + // just for testing when you want to find an open port + @VisibleForTesting + static int boundPort = -1; + private static final ObjectMapper mapper = new ObjectMapper(); + private static final String APP_CREDS_KEY_PREFIX = "AppCreds"; + private static final LevelDBProvider.StoreVersion CURRENT_VERSION = new LevelDBProvider + .StoreVersion(1, 0); + + // just for integration tests that want to look at this file -- in general not sensible as + // a static + @VisibleForTesting + static YarnShuffleService instance; + // An entity that manages the shuffle secret per application // This is used only if authentication is enabled private ShuffleSecretManager secretManager; @@ -75,6 +104,8 @@ public class YarnShuffleService extends AuxiliaryService { // The actual server that serves shuffle files private TransportServer shuffleServer = null; + private Configuration _conf = null; + // Handles registering executors and opening shuffle blocks @VisibleForTesting ExternalShuffleBlockHandler blockHandler; @@ -83,14 +114,11 @@ public class YarnShuffleService extends AuxiliaryService { @VisibleForTesting File registeredExecutorFile; - // just for testing when you want to find an open port + // Where to store & reload application secrets for recovering state after an NM restart @VisibleForTesting - static int boundPort = -1; + File secretsFile; - // just for integration tests that want to look at this file -- in general not sensible as - // a static - @VisibleForTesting - static YarnShuffleService instance; + private DB db; public YarnShuffleService() { super("spark_shuffle"); @@ -112,42 +140,86 @@ private boolean isAuthenticationEnabled() { */ @Override protected void serviceInit(Configuration conf) { + _conf = conf; - // In case this NM was killed while there were running spark applications, we need to restore - // lost state for the existing executors. We look for an existing file in the NM's local dirs. - // If we don't find one, then we choose a file to use to save the state next time. Even if - // an application was stopped while the NM was down, we expect yarn to call stopApplication() - // when it comes back - registeredExecutorFile = - findRegisteredExecutorFile(conf.getTrimmedStrings("yarn.nodemanager.local-dirs")); - - TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); - // If authentication is enabled, set up the shuffle server to use a - // special RPC handler that filters out unauthenticated fetch requests - boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); try { - blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); + // In case this NM was killed while there were running spark applications, we need to restore + // lost state for the existing executors. We look for an existing file in the NM's local dirs. + // If we don't find one, then we choose a file to use to save the state next time. Even if + // an application was stopped while the NM was down, we expect yarn to call stopApplication() + // when it comes back + registeredExecutorFile = findRecoveryDb(RECOVERY_FILE_NAME); + + TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); + blockHandler = new ExternalShuffleBlockHandler(transportConf, registeredExecutorFile); + + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); + List bootstraps = Lists.newArrayList(); + if (authEnabled) { + createSecretManager(); + bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); + } + + int port = conf.getInt( + SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); + TransportContext transportContext = new TransportContext(transportConf, blockHandler); + shuffleServer = transportContext.createServer(port, bootstraps); + // the port should normally be fixed, but for tests its useful to find an open port + port = shuffleServer.getPort(); + boundPort = port; + String authEnabledString = authEnabled ? "enabled" : "not enabled"; + logger.info("Started YARN shuffle service for Spark on port {}. " + + "Authentication is {}. Registered executor file is {}", port, authEnabledString, + registeredExecutorFile); } catch (Exception e) { logger.error("Failed to initialize external shuffle service", e); } + } - List bootstraps = Lists.newArrayList(); - if (authEnabled) { - secretManager = new ShuffleSecretManager(); - bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); + private void createSecretManager() throws IOException { + secretManager = new ShuffleSecretManager(); + secretsFile = findRecoveryDb(SECRETS_RECOVERY_FILE_NAME); + + // Make sure this is protected in case its not in the NM recovery dir + FileSystem fs = FileSystem.getLocal(_conf); + fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission((short)0700)); + + db = LevelDBProvider.initLevelDB(secretsFile, CURRENT_VERSION, mapper); + logger.info("Recovery location is: " + secretsFile.getPath()); + if (db != null) { + logger.info("Going to reload spark shuffle data"); + DBIterator itr = db.iterator(); + itr.seek(APP_CREDS_KEY_PREFIX.getBytes(StandardCharsets.UTF_8)); + while (itr.hasNext()) { + Map.Entry e = itr.next(); + String key = new String(e.getKey(), StandardCharsets.UTF_8); + if (!key.startsWith(APP_CREDS_KEY_PREFIX)) { + break; + } + String id = parseDbAppKey(key); + ByteBuffer secret = mapper.readValue(e.getValue(), ByteBuffer.class); + logger.info("Reloading tokens for app: " + id); + secretManager.registerApp(id, secret); + } } + } - int port = conf.getInt( - SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); - TransportContext transportContext = new TransportContext(transportConf, blockHandler); - shuffleServer = transportContext.createServer(port, bootstraps); - // the port should normally be fixed, but for tests its useful to find an open port - port = shuffleServer.getPort(); - boundPort = port; - String authEnabledString = authEnabled ? "enabled" : "not enabled"; - logger.info("Started YARN shuffle service for Spark on port {}. " + - "Authentication is {}. Registered executor file is {}", port, authEnabledString, - registeredExecutorFile); + private static String parseDbAppKey(String s) throws IOException { + if (!s.startsWith(APP_CREDS_KEY_PREFIX)) { + throw new IllegalArgumentException("expected a string starting with " + APP_CREDS_KEY_PREFIX); + } + String json = s.substring(APP_CREDS_KEY_PREFIX.length() + 1); + AppId parsed = mapper.readValue(json, AppId.class); + return parsed.appId; + } + + private static byte[] dbAppKey(AppId appExecId) throws IOException { + // we stick a common prefix on all the keys so we can find them in the DB + String appExecJson = mapper.writeValueAsString(appExecId); + String key = (APP_CREDS_KEY_PREFIX + ";" + appExecJson); + return key.getBytes(StandardCharsets.UTF_8); } @Override @@ -157,6 +229,12 @@ public void initializeApplication(ApplicationInitializationContext context) { ByteBuffer shuffleSecret = context.getApplicationDataForService(); logger.info("Initializing application {}", appId); if (isAuthenticationEnabled()) { + AppId fullId = new AppId(appId); + if (db != null) { + byte[] key = dbAppKey(fullId); + byte[] value = mapper.writeValueAsString(shuffleSecret).getBytes(StandardCharsets.UTF_8); + db.put(key, value); + } secretManager.registerApp(appId, shuffleSecret); } } catch (Exception e) { @@ -170,6 +248,14 @@ public void stopApplication(ApplicationTerminationContext context) { try { logger.info("Stopping application {}", appId); if (isAuthenticationEnabled()) { + AppId fullId = new AppId(appId); + if (db != null) { + try { + db.delete(dbAppKey(fullId)); + } catch (IOException e) { + logger.error("Error deleting {} from executor state db", appId, e); + } + } secretManager.unregisterApp(appId); } blockHandler.applicationRemoved(appId, false /* clean up local dirs */); @@ -190,14 +276,15 @@ public void stopContainer(ContainerTerminationContext context) { logger.info("Stopping container {}", containerId); } - private File findRegisteredExecutorFile(String[] localDirs) { + private File findRecoveryDb(String fileName) { + String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs"); for (String dir: localDirs) { - File f = new File(new Path(dir).toUri().getPath(), "registeredExecutors.ldb"); + File f = new File(new Path(dir).toUri().getPath(), fileName); if (f.exists()) { return f; } } - return new File(new Path(localDirs[0]).toUri().getPath(), "registeredExecutors.ldb"); + return new File(new Path(localDirs[0]).toUri().getPath(), fileName); } /** @@ -212,6 +299,9 @@ protected void serviceStop() { if (blockHandler != null) { blockHandler.close(); } + if (db != null) { + db.close(); + } } catch (Exception e) { logger.error("Exception when stopping service", e); } @@ -222,4 +312,38 @@ protected void serviceStop() { public ByteBuffer getMetaData() { return ByteBuffer.allocate(0); } + + /** + * Simply encodes an application ID. + */ + public static class AppId { + public final String appId; + + @JsonCreator + public AppId(@JsonProperty("appId") String appId) { + this.appId = appId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AppId appExecId = (AppId) o; + return Objects.equal(appId, appExecId.appId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .toString(); + } + } + } diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 8bc1f52798941..d42be3a792778 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -38,7 +38,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 8e702b4fefe8c..43531a256de9a 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,17 +22,17 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml org.apache.spark - spark-test-tags_2.11 + spark-tags_2.11 jar - Spark Project Test Tags + Spark Project Tags http://spark.apache.org/ - test-tags + tags diff --git a/core/src/main/java/org/apache/spark/annotation/AlphaComponent.java b/common/tags/src/main/java/org/apache/spark/annotation/AlphaComponent.java similarity index 100% rename from core/src/main/java/org/apache/spark/annotation/AlphaComponent.java rename to common/tags/src/main/java/org/apache/spark/annotation/AlphaComponent.java diff --git a/core/src/main/java/org/apache/spark/annotation/DeveloperApi.java b/common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java similarity index 100% rename from core/src/main/java/org/apache/spark/annotation/DeveloperApi.java rename to common/tags/src/main/java/org/apache/spark/annotation/DeveloperApi.java diff --git a/core/src/main/java/org/apache/spark/annotation/Experimental.java b/common/tags/src/main/java/org/apache/spark/annotation/Experimental.java similarity index 100% rename from core/src/main/java/org/apache/spark/annotation/Experimental.java rename to common/tags/src/main/java/org/apache/spark/annotation/Experimental.java diff --git a/core/src/main/java/org/apache/spark/annotation/Private.java b/common/tags/src/main/java/org/apache/spark/annotation/Private.java similarity index 100% rename from core/src/main/java/org/apache/spark/annotation/Private.java rename to common/tags/src/main/java/org/apache/spark/annotation/Private.java diff --git a/core/src/main/scala/org/apache/spark/annotation/Since.scala b/common/tags/src/main/scala/org/apache/spark/annotation/Since.scala similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/Since.scala rename to common/tags/src/main/scala/org/apache/spark/annotation/Since.scala diff --git a/core/src/main/scala/org/apache/spark/annotation/package-info.java b/common/tags/src/main/scala/org/apache/spark/annotation/package-info.java similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/package-info.java rename to common/tags/src/main/scala/org/apache/spark/annotation/package-info.java diff --git a/core/src/main/scala/org/apache/spark/annotation/package.scala b/common/tags/src/main/scala/org/apache/spark/annotation/package.scala similarity index 100% rename from core/src/main/scala/org/apache/spark/annotation/package.scala rename to common/tags/src/main/scala/org/apache/spark/annotation/package.scala diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 93b9580f26b86..61f4a38531e68 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -36,6 +36,10 @@ + + org.apache.spark + spark-tags_${scala.binary.version} + com.twitter chill_${scala.binary.version} @@ -59,10 +63,6 @@ - - org.apache.spark - spark-test-tags_${scala.binary.version} - org.mockito mockito-core diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index bdf52f32c6fe1..77c8c398be955 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -155,8 +155,8 @@ public static long reallocateMemory(long address, long oldSize, long newSize) { @SuppressWarnings("unchecked") public static ByteBuffer allocateDirectBuffer(int size) { try { - Class cls = Class.forName("java.nio.DirectByteBuffer"); - Constructor constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE); + Class cls = Class.forName("java.nio.DirectByteBuffer"); + Constructor constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE); constructor.setAccessible(true); Field cleanerField = cls.getDeclaredField("cleaner"); cleanerField.setAccessible(true); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index e3e79471154df..1bc924d424c02 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -51,6 +51,6 @@ public long size() { * Creates a memory block pointing to the memory used by the long array. */ public static MemoryBlock fromLongArray(final long[] array) { - return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); + return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8L); } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 54a54569240c0..ce8351694308e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -465,9 +465,9 @@ public UTF8String trim() { int s = 0; int e = this.numBytes - 1; // skip all of the space (0x20) in the left side - while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++; + while (s < this.numBytes && getByte(s) == 0x20) s++; // skip all of the space (0x20) in the right side - while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--; + while (e >= 0 && getByte(e) == 0x20) e--; if (s > e) { // empty string return UTF8String.fromBytes(new byte[0]); @@ -479,7 +479,7 @@ public UTF8String trim() { public UTF8String trimLeft() { int s = 0; // skip all of the space (0x20) in the left side - while (s < this.numBytes && getByte(s) <= 0x20 && getByte(s) >= 0x00) s++; + while (s < this.numBytes && getByte(s) == 0x20) s++; if (s == this.numBytes) { // empty string return UTF8String.fromBytes(new byte[0]); @@ -491,7 +491,7 @@ public UTF8String trimLeft() { public UTF8String trimRight() { int e = numBytes - 1; // skip all of the space (0x20) in the right side - while (e >= 0 && getByte(e) <= 0x20 && getByte(e) >= 0x00) e--; + while (e >= 0 && getByte(e) == 0x20) e--; if (e < 0) { // empty string diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index d4160ad029eb3..7f03686dcec41 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -232,6 +232,16 @@ public void trims() { assertEquals(fromString("数据砖头"), fromString("数据砖头").trim()); assertEquals(fromString("数据砖头"), fromString("数据砖头").trimLeft()); assertEquals(fromString("数据砖头"), fromString("数据砖头").trimRight()); + + char[] charsLessThan0x20 = new char[10]; + Arrays.fill(charsLessThan0x20, (char)(' ' - 1)); + String stringStartingWithSpace = + new String(charsLessThan0x20) + "hello" + new String(charsLessThan0x20); + assertEquals(fromString(stringStartingWithSpace), fromString(stringStartingWithSpace).trim()); + assertEquals(fromString(stringStartingWithSpace), + fromString(stringStartingWithSpace).trimLeft()); + assertEquals(fromString(stringStartingWithSpace), + fromString(stringStartingWithSpace).trimRight()); } @Test diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 8a6b9e3e4536d..62d4176d00f94 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -98,7 +98,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty } } - val whitespaceChar: Gen[Char] = Gen.choose(0x00, 0x20).map(_.toChar) + val whitespaceChar: Gen[Char] = Gen.const(0x20.toChar) val whitespaceString: Gen[String] = Gen.listOf(whitespaceChar).map(_.mkString) val randomString: Gen[String] = Arbitrary.arbString.arbitrary @@ -107,7 +107,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty def lTrim(s: String): String = { var st = 0 val array: Array[Char] = s.toCharArray - while ((st < s.length) && (array(st) <= ' ')) { + while ((st < s.length) && (array(st) == ' ')) { st += 1 } if (st > 0) s.substring(st, s.length) else s @@ -115,7 +115,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty def rTrim(s: String): String = { var len = s.length val array: Array[Char] = s.toCharArray - while ((len > 0) && (array(len - 1) <= ' ')) { + while ((len > 0) && (array(len - 1) == ' ')) { len -= 1 } if (len < s.length) s.substring(0, len) else s @@ -127,7 +127,7 @@ class UTF8StringPropertyCheckSuite extends FunSuite with GeneratorDrivenProperty whitespaceString ) { (start: String, middle: String, end: String) => val s = start + middle + end - assert(toUTF8(s).trim() === toUTF8(s.trim())) + assert(toUTF8(s).trim() === toUTF8(rTrim(lTrim(s)))) assert(toUTF8(s).trimLeft() === toUTF8(lTrim(s))) assert(toUTF8(s).trimRight() === toUTF8(rTrim(s))) } diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index a031cd6a722f9..c750c72d19880 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -40,13 +40,9 @@ # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) -# - SPARK_YARN_APP_NAME, The name of your application (Default: Spark) -# - SPARK_YARN_QUEUE, The hadoop queue to use for allocation requests (Default: 'default') -# - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. -# - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job. # Options for the daemons used in the standalone deploy mode -# - SPARK_MASTER_IP, to bind the master to a different IP address or hostname +# - SPARK_MASTER_HOST, to bind the master to a different IP address or hostname # - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports for the master # - SPARK_MASTER_OPTS, to set config properties only for the master (e.g. "-Dx=y") # - SPARK_WORKER_CORES, to set the number of cores to use on this machine diff --git a/core/pom.xml b/core/pom.xml index 7349ad35b9595..55cbe898537cb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml @@ -125,12 +125,15 @@ jetty-servlet compile - - org.eclipse.jetty.orbit - javax.servlet - ${orbit.version} + org.eclipse.jetty + jetty-servlets + compile + + + javax.servlet + javax.servlet-api + ${javaxservlet.version} @@ -194,12 +197,24 @@ json4s-jackson_${scala.binary.version} - com.sun.jersey + org.glassfish.jersey.core + jersey-client + + + org.glassfish.jersey.core + jersey-common + + + org.glassfish.jersey.core jersey-server - com.sun.jersey - jersey-core + org.glassfish.jersey.containers + jersey-container-servlet + + + org.glassfish.jersey.containers + jersey-container-servlet-core org.apache.mesos @@ -260,12 +275,11 @@ org.seleniumhq.selenium selenium-java - - - com.google.guava - guava - - + test + + + org.seleniumhq.selenium + selenium-htmlunit-driver test @@ -302,7 +316,7 @@ net.razorvine pyrolite - 4.9 + 4.13 net.razorvine @@ -313,17 +327,49 @@ net.sf.py4j py4j - 0.9.2 + 0.10.3 org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + ${project.basedir}/src/main/resources + + + + ${project.build.directory}/extra-resources + true + + + + org.apache.maven.plugins + maven-antrun-plugin + + + generate-resources + + + + + + + + + + + + run + + + + org.apache.maven.plugins maven-dependency-plugin @@ -344,7 +390,7 @@ true true - guava,jetty-io,jetty-servlet,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server,jetty-security + guava,jetty-io,jetty-servlet,jetty-servlets,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server,jetty-security true diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java index e8d999dd00135..462ca3f6f6d19 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java @@ -22,7 +22,7 @@ /** * Base interface for a function used in Dataset's filter function. * - * If the function returns true, the element is discarded in the returned Dataset. + * If the function returns true, the element is included in the returned Dataset. */ public interface FilterFunction extends Serializable { boolean call(T value) throws Exception; diff --git a/core/src/main/java/org/apache/spark/api/java/function/package.scala b/core/src/main/java/org/apache/spark/api/java/function/package.scala index 0f9bac7164162..e19f12fdac090 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/package.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/package.scala @@ -22,4 +22,4 @@ package org.apache.spark.api.java * these interfaces to pass functions to various Java API methods for Spark. Please visit Spark's * Java programming guide for more details. */ -package object function +package object function diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 840f13b39464c..fc1f3a80239ba 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -23,7 +23,7 @@ import org.apache.spark.unsafe.memory.MemoryBlock; /** - * An memory consumer of TaskMemoryManager, which support spilling. + * A memory consumer of {@link TaskMemoryManager} that supports spilling. * * Note: this only supports allocation / spilling of Tungsten memory. */ @@ -31,15 +31,24 @@ public abstract class MemoryConsumer { protected final TaskMemoryManager taskMemoryManager; private final long pageSize; + private final MemoryMode mode; protected long used; - protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { + protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize, MemoryMode mode) { this.taskMemoryManager = taskMemoryManager; this.pageSize = pageSize; + this.mode = mode; } protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { - this(taskMemoryManager, taskMemoryManager.pageSizeBytes()); + this(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP); + } + + /** + * Returns the memory mode, {@link MemoryMode#ON_HEAP} or {@link MemoryMode#OFF_HEAP}. + */ + public MemoryMode getMode() { + return mode; } /** @@ -132,19 +141,19 @@ protected void freePage(MemoryBlock page) { } /** - * Allocates a heap memory of `size`. + * Allocates memory of `size`. */ - public long acquireOnHeapMemory(long size) { - long granted = taskMemoryManager.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this); + public long acquireMemory(long size) { + long granted = taskMemoryManager.acquireExecutionMemory(size, this); used += granted; return granted; } /** - * Release N bytes of heap memory. + * Release N bytes of memory. */ - public void freeOnHeapMemory(long size) { - taskMemoryManager.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this); + public void freeMemory(long size) { + taskMemoryManager.releaseExecutionMemory(size, this); used -= size; } } diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 2796114fc545a..867c4a1050677 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -76,9 +76,6 @@ public class TaskMemoryManager { /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; - /** Bit mask for the upper 13 bits of a long */ - private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; - /** * Similar to an operating system's page table, this array maps page numbers into base object * pointers, allowing us to translate between the hashtable's internal 64-bit address @@ -132,11 +129,10 @@ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { * * @return number of bytes successfully granted (<= N). */ - public long acquireExecutionMemory( - long required, - MemoryMode mode, - MemoryConsumer consumer) { + public long acquireExecutionMemory(long required, MemoryConsumer consumer) { assert(required >= 0); + assert(consumer != null); + MemoryMode mode = consumer.getMode(); // If we are allocating Tungsten pages off-heap and receive a request to allocate on-heap // memory here, then it may not make sense to spill since that would only end up freeing // off-heap memory. This is subject to change, though, so it may be risky to make this @@ -149,10 +145,10 @@ public long acquireExecutionMemory( if (got < required) { // Call spill() on other consumers to release memory for (MemoryConsumer c: consumers) { - if (c != consumer && c.getUsed() > 0) { + if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { try { long released = c.spill(required - got, consumer); - if (released > 0 && mode == tungstenMemoryMode) { + if (released > 0) { logger.debug("Task {} released {} from {} for {}", taskAttemptId, Utils.bytesToString(released), c, consumer); got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); @@ -170,10 +166,10 @@ public long acquireExecutionMemory( } // call spill() on itself - if (got < required && consumer != null) { + if (got < required) { try { long released = consumer.spill(required - got, consumer); - if (released > 0 && mode == tungstenMemoryMode) { + if (released > 0) { logger.debug("Task {} released {} from itself ({})", taskAttemptId, Utils.bytesToString(released), consumer); got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); @@ -185,10 +181,8 @@ public long acquireExecutionMemory( } } - if (consumer != null) { - consumers.add(consumer); - } - logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); + consumers.add(consumer); + logger.debug("Task {} acquired {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); return got; } } @@ -196,9 +190,9 @@ public long acquireExecutionMemory( /** * Release N bytes of execution memory for a MemoryConsumer. */ - public void releaseExecutionMemory(long size, MemoryMode mode, MemoryConsumer consumer) { + public void releaseExecutionMemory(long size, MemoryConsumer consumer) { logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); - memoryManager.releaseExecutionMemory(size, taskAttemptId, mode); + memoryManager.releaseExecutionMemory(size, taskAttemptId, consumer.getMode()); } /** @@ -241,12 +235,14 @@ public long pageSizeBytes() { * contains fewer bytes than requested, so callers should verify the size of returned pages. */ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { + assert(consumer != null); + assert(consumer.getMode() == tungstenMemoryMode); if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new IllegalArgumentException( "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } - long acquired = acquireExecutionMemory(size, tungstenMemoryMode, consumer); + long acquired = acquireExecutionMemory(size, consumer); if (acquired <= 0) { return null; } @@ -255,7 +251,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { synchronized (this) { pageNumber = allocatedPages.nextClearBit(0); if (pageNumber >= PAGE_TABLE_SIZE) { - releaseExecutionMemory(acquired, tungstenMemoryMode, consumer); + releaseExecutionMemory(acquired, consumer); throw new IllegalStateException( "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); } @@ -299,7 +295,7 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { } long pageSize = page.size(); memoryManager.tungstenMemoryAllocator().free(page); - releaseExecutionMemory(pageSize, tungstenMemoryMode, consumer); + releaseExecutionMemory(pageSize, consumer); } /** @@ -379,7 +375,6 @@ public long getOffsetInPage(long pagePlusOffsetAddress) { */ public long cleanUpAllAllocatedMemory() { synchronized (this) { - Arrays.fill(pageTable, null); for (MemoryConsumer c: consumers) { if (c != null && c.getUsed() > 0) { // In case of failed task, it's normal to see leaked memory @@ -387,16 +382,17 @@ public long cleanUpAllAllocatedMemory() { } } consumers.clear(); - } - for (MemoryBlock page : pageTable) { - if (page != null) { - memoryManager.tungstenMemoryAllocator().free(page); + for (MemoryBlock page : pageTable) { + if (page != null) { + logger.warn("leak a page: " + page + " in task " + taskAttemptId); + memoryManager.tungstenMemoryAllocator().free(page); + } } + Arrays.fill(pageTable, null); } - Arrays.fill(pageTable, null); - // release the memory that is not used by any consumer. + // release the memory that is not used by any consumer (acquired for pages in tungsten mode). memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode); return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); @@ -412,7 +408,7 @@ public long getMemoryConsumptionForThisTask() { /** * Returns Tungsten memory mode */ - public MemoryMode getTungstenMemoryMode(){ + public MemoryMode getTungstenMemoryMode() { return tungstenMemoryMode; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 0e9defe5b4a51..601dd6edfcf7c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -156,8 +156,14 @@ public void write(Iterator> records) throws IOException { File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); File tmp = Utils.tempFileWith(output); - partitionLengths = writePartitionedFile(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + try { + partitionLengths = writePartitionedFile(tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + if (tmp.exists() && !tmp.delete()) { + logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + } + } mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 2be5a16b2d1e4..cf38a04ed7cfb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -72,7 +72,10 @@ final class ShuffleExternalSorter extends MemoryConsumer { private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; - /** Force this sorter to spill when there are this many elements in memory. For testing only */ + /** + * Force this sorter to spill when there are this many elements in memory. The default value is + * 1024 * 1024 * 1024, which allows the maximum size of the pointer array to be 8G. + */ private final long numElementsForSpillThreshold; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ @@ -104,8 +107,9 @@ final class ShuffleExternalSorter extends MemoryConsumer { int numPartitions, SparkConf conf, ShuffleWriteMetrics writeMetrics) { - super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, - memoryManager.pageSizeBytes())); + super(memoryManager, + (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()), + memoryManager.getTungstenMemoryMode()); this.taskMemoryManager = memoryManager; this.blockManager = blockManager; this.taskContext = taskContext; @@ -113,7 +117,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.numElementsForSpillThreshold = - conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); + conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); @@ -371,7 +375,9 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p // for tests assert(inMemSorter != null); - if (inMemSorter.numRecords() > numElementsForSpillThreshold) { + if (inMemSorter.numRecords() >= numElementsForSpillThreshold) { + logger.info("Spilling data because number of spilledRecords crossed the threshold " + + numElementsForSpillThreshold); spill(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 75a0e807d76f5..dc36809d8911f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -22,12 +22,12 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; import org.apache.spark.util.collection.unsafe.sort.RadixSort; final class ShuffleInMemorySorter { - private final Sorter sorter; private static final class SortComparator implements Comparator { @Override public int compare(PackedRecordPointer left, PackedRecordPointer right) { @@ -44,6 +44,9 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { * An array of record pointers and partition ids that have been encoded by * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * records. + * + * Only part of the array will be used to store the pointers, the rest part is preserved as + * temporary buffer for sorting. */ private LongArray array; @@ -54,14 +57,14 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { private final boolean useRadixSort; /** - * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x. + * The position in the pointer array where new records can be inserted. */ - private final int memoryAllocationFactor; + private int pos = 0; /** - * The position in the pointer array where new records can be inserted. + * How many records could be inserted, because part of the array should be left for sorting. */ - private int pos = 0; + private int usableCapacity = 0; private int initialSize; @@ -70,9 +73,14 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { assert (initialSize > 0); this.initialSize = initialSize; this.useRadixSort = useRadixSort; - this.memoryAllocationFactor = useRadixSort ? 2 : 1; this.array = consumer.allocateArray(initialSize); - this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); + this.usableCapacity = getUsableCapacity(); + } + + private int getUsableCapacity() { + // Radix sort requires same amount of used memory as buffer, Tim sort requires + // half of the used memory as buffer. + return (int) (array.size() / (useRadixSort ? 2 : 1.5)); } public void free() { @@ -89,7 +97,8 @@ public int numRecords() { public void reset() { if (consumer != null) { consumer.freeArray(array); - this.array = consumer.allocateArray(initialSize); + array = consumer.allocateArray(initialSize); + usableCapacity = getUsableCapacity(); } pos = 0; } @@ -101,14 +110,15 @@ public void expandPointerArray(LongArray newArray) { array.getBaseOffset(), newArray.getBaseObject(), newArray.getBaseOffset(), - array.size() * (8 / memoryAllocationFactor) + pos * 8L ); consumer.freeArray(array); array = newArray; + usableCapacity = getUsableCapacity(); } public boolean hasSpaceForAnotherRecord() { - return pos < array.size() / memoryAllocationFactor; + return pos < usableCapacity; } public long getMemoryUsage() { @@ -170,6 +180,14 @@ public ShuffleSorterIterator getSortedIterator() { PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); } else { + MemoryBlock unused = new MemoryBlock( + array.getBaseObject(), + array.getBaseOffset() + pos * 8L, + (array.size() - pos) * 8L); + LongArray buffer = new LongArray(unused); + Sorter sorter = + new Sorter<>(new ShuffleSortDataFormat(buffer)); + sorter.sort(array, 0, pos, SORT_COMPARATOR); } return new ShuffleSorterIterator(pos, array, offset); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 8f4e3229976dc..717bdd79d47ef 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -19,14 +19,15 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; final class ShuffleSortDataFormat extends SortDataFormat { - public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); + private final LongArray buffer; - private ShuffleSortDataFormat() { } + ShuffleSortDataFormat(LongArray buffer) { + this.buffer = buffer; + } @Override public PackedRecordPointer getKey(LongArray data, int pos) { @@ -61,17 +62,17 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { Platform.copyMemory( src.getBaseObject(), - src.getBaseOffset() + srcPos * 8, + src.getBaseOffset() + srcPos * 8L, dst.getBaseObject(), - dst.getBaseOffset() + dstPos * 8, - length * 8 + dst.getBaseOffset() + dstPos * 8L, + length * 8L ); } @Override public LongArray allocate(int length) { - // This buffer is used temporary (usually small), so it's fine to allocated from JVM heap. - return new LongArray(MemoryBlock.fromLongArray(new long[length])); + assert (length <= buffer.size()) : + "the buffer is smaller than required: " + buffer.size() + " < " + length; + return buffer; } - } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index daa63d47e6aed..c08a5d424ea6a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -207,15 +207,21 @@ void closeAndWriteOutput() throws IOException { final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); final File tmp = Utils.tempFileWith(output); try { - partitionLengths = mergeSpills(spills, tmp); - } finally { - for (SpillInfo spill : spills) { - if (spill.file.exists() && ! spill.file.delete()) { - logger.error("Error while deleting spill file {}", spill.file.getPath()); + try { + partitionLengths = mergeSpills(spills, tmp); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } } } + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + } finally { + if (tmp.exists() && !tmp.delete()) { + logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + } } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @@ -346,12 +352,19 @@ private long[] mergeSpillsWithFileStream( for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = - new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + InputStream partitionInputStream = null; + boolean innerThrewException = true; + try { + partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, mergedFileOutputStream); + innerThrewException = false; + } finally { + Closeables.close(partitionInputStream, innerThrewException); } - ByteStreams.copy(partitionInputStream, mergedFileOutputStream); } } mergedFileOutputStream.flush(); diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 6807710f9fef1..dc04025692909 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -182,7 +182,7 @@ public BytesToBytesMap( double loadFactor, long pageSizeBytes, boolean enablePerfMetrics) { - super(taskMemoryManager, pageSizeBytes); + super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; this.serializerManager = serializerManager; @@ -221,7 +221,8 @@ public BytesToBytesMap( SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, SparkEnv.get() != null ? SparkEnv.get().serializerManager() : null, initialCapacity, - 0.70, + // In order to re-use the longArray for sorting, the load factor cannot be larger than 0.5. + 0.5, pageSizeBytes, enablePerfMetrics); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java index de92b8db47131..e9571aa8bb052 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -17,7 +17,7 @@ package org.apache.spark.util.collection.unsafe.sort; -final class RecordPointerAndKeyPrefix { +public final class RecordPointerAndKeyPrefix { /** * A pointer to a record; see {@link org.apache.spark.memory.TaskMemoryManager} for a * description of how these addresses are encoded. diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 8b6c96a4c4e63..56d54a18c00ec 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -59,6 +59,13 @@ public final class UnsafeExternalSorter extends MemoryConsumer { /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; + /** + * Force this sorter to spill when there are this many elements in memory. The default value is + * 1024 * 1024 * 1024 / 2 which allows the maximum size of the pointer array to be 8G. + */ + public static final long DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD = 1024 * 1024 * 1024 / 2; + + private final long numElementsForSpillThreshold; /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -76,6 +83,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private long pageCursor = -1; private long peakMemoryUsedBytes = 0; private long totalSpillBytes = 0L; + private long totalSortTimeNanos = 0L; private volatile SpillableIterator readingIterator = null; public static UnsafeExternalSorter createWithExistingInMemorySorter( @@ -87,10 +95,11 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, + long numElementsForSpillThreshold, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparator, prefixComparator, initialSize, - pageSizeBytes, inMemorySorter, false /* ignored */); + numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. sorter.inMemSorter = null; @@ -106,10 +115,11 @@ public static UnsafeExternalSorter create( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, + long numElementsForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, - taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null, - canUseRadixSort); + taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, + numElementsForSpillThreshold, null, canUseRadixSort); } private UnsafeExternalSorter( @@ -121,9 +131,10 @@ private UnsafeExternalSorter( PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, + long numElementsForSpillThreshold, @Nullable UnsafeInMemorySorter existingInMemorySorter, boolean canUseRadixSort) { - super(taskMemoryManager, pageSizeBytes); + super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; this.serializerManager = serializerManager; @@ -131,9 +142,12 @@ private UnsafeExternalSorter( this.recordComparator = recordComparator; this.prefixComparator = prefixComparator; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units - // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; - this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); + // The spill metrics are stored in a new ShuffleWriteMetrics, + // and then discarded (this fixes SPARK-16827). + // TODO: Instead, separate spill metrics should be stored and reported (tracked in SPARK-3577). + this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( @@ -142,6 +156,7 @@ private UnsafeExternalSorter( this.inMemSorter = existingInMemorySorter; } this.peakMemoryUsedBytes = getMemoryUsage(); + this.numElementsForSpillThreshold = numElementsForSpillThreshold; // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator @@ -247,6 +262,17 @@ public long getPeakMemoryUsedBytes() { return peakMemoryUsedBytes; } + /** + * @return the total amount of time spent sorting data (in-memory only). + */ + public long getSortTimeNanos() { + UnsafeInMemorySorter sorter = inMemSorter; + if (sorter != null) { + return sorter.getSortTimeNanos(); + } + return totalSortTimeNanos; + } + /** * Return the total number of bytes that has been spilled into disk so far. */ @@ -360,6 +386,13 @@ private void acquireNewPageIfNecessary(int required) { public void insertRecord(Object recordBase, long recordOffset, int length, long prefix) throws IOException { + assert(inMemSorter != null); + if (inMemSorter.numRecords() >= numElementsForSpillThreshold) { + logger.info("Spilling data because number of spilledRecords crossed the threshold " + + numElementsForSpillThreshold); + spill(); + } + growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. final int required = length + 4; @@ -371,7 +404,6 @@ public void insertRecord(Object recordBase, long recordOffset, int length, long pageCursor += 4; Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; - assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); } @@ -492,7 +524,8 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.getBaseObject() != upstream.getBaseObject()) { + if (!loaded || page.pageNumber != + ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { released += page.size(); freePage(page); } else { @@ -505,6 +538,7 @@ public long spill() throws IOException { // in-memory sorter will not be used after spilling assert(inMemSorter != null); released += inMemSorter.getMemoryUsage(); + totalSortTimeNanos += inMemSorter.getSortTimeNanos(); inMemSorter.free(); inMemSorter = null; taskContext.taskMetrics().incMemoryBytesSpilled(released); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 03973f3c124ee..b51737158098d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -25,6 +25,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.Sorter; /** @@ -69,8 +70,6 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { private final MemoryConsumer consumer; private final TaskMemoryManager memoryManager; @Nullable - private final Sorter sorter; - @Nullable private final Comparator sortComparator; /** @@ -79,14 +78,12 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { @Nullable private final PrefixComparators.RadixSortSupport radixSortSupport; - /** - * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x. - */ - private final int memoryAllocationFactor; - /** * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. + * + * Only part of the array will be used to store the pointers, the rest part is preserved as + * temporary buffer for sorting. */ private LongArray array; @@ -95,8 +92,15 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { */ private int pos = 0; + /** + * How many records could be inserted, because part of the array should be left for sorting. + */ + private int usableCapacity = 0; + private long initialSize; + private long totalSortTimeNanos = 0L; + public UnsafeInMemorySorter( final MemoryConsumer consumer, final TaskMemoryManager memoryManager, @@ -119,7 +123,6 @@ public UnsafeInMemorySorter( this.memoryManager = memoryManager; this.initialSize = array.size(); if (recordComparator != null) { - this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) { this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator; @@ -127,12 +130,17 @@ public UnsafeInMemorySorter( this.radixSortSupport = null; } } else { - this.sorter = null; this.sortComparator = null; this.radixSortSupport = null; } - this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1; this.array = array; + this.usableCapacity = getUsableCapacity(); + } + + private int getUsableCapacity() { + // Radix sort requires same amount of used memory as buffer, Tim sort requires + // half of the used memory as buffer. + return (int) (array.size() / (radixSortSupport != null ? 2 : 1.5)); } /** @@ -148,7 +156,8 @@ public void free() { public void reset() { if (consumer != null) { consumer.freeArray(array); - this.array = consumer.allocateArray(initialSize); + array = consumer.allocateArray(initialSize); + usableCapacity = getUsableCapacity(); } pos = 0; } @@ -160,12 +169,19 @@ public int numRecords() { return pos / 2; } + /** + * @return the total amount of time spent sorting data (in-memory only). + */ + public long getSortTimeNanos() { + return totalSortTimeNanos; + } + public long getMemoryUsage() { return array.size() * 8; } public boolean hasSpaceForAnotherRecord() { - return pos + 1 < (array.size() / memoryAllocationFactor); + return pos + 1 < usableCapacity; } public void expandPointerArray(LongArray newArray) { @@ -177,9 +193,10 @@ public void expandPointerArray(LongArray newArray) { array.getBaseOffset(), newArray.getBaseObject(), newArray.getBaseOffset(), - array.size() * (8 / memoryAllocationFactor)); + pos * 8L); consumer.freeArray(array); array = newArray; + usableCapacity = getUsableCapacity(); } /** @@ -208,6 +225,7 @@ public final class SortedIterator extends UnsafeSorterIterator implements Clonea private long baseOffset; private long keyPrefix; private int recordLength; + private long currentPageNumber; private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; @@ -222,6 +240,7 @@ public SortedIterator clone() { iter.baseOffset = baseOffset; iter.keyPrefix = keyPrefix; iter.recordLength = recordLength; + iter.currentPageNumber = currentPageNumber; return iter; } @@ -239,6 +258,7 @@ public boolean hasNext() { public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); + currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); @@ -252,6 +272,10 @@ public void loadNext() { @Override public long getBaseOffset() { return baseOffset; } + public long getCurrentPageNumber() { + return currentPageNumber; + } + @Override public int getRecordLength() { return recordLength; } @@ -265,16 +289,25 @@ public void loadNext() { */ public SortedIterator getSortedIterator() { int offset = 0; - if (sorter != null) { + long start = System.nanoTime(); + if (sortComparator != null) { if (this.radixSortSupport != null) { // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they // force a full-width sort (and we cannot radix-sort nullable long fields at all). offset = RadixSort.sortKeyPrefixArray( array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { + MemoryBlock unused = new MemoryBlock( + array.getBaseObject(), + array.getBaseOffset() + pos * 8L, + (array.size() - pos) * 8L); + LongArray buffer = new LongArray(unused); + Sorter sorter = + new Sorter<>(new UnsafeSortDataFormat(buffer)); sorter.sort(array, 0, pos / 2, sortComparator); } } + totalSortTimeNanos += System.nanoTime() - start; return new SortedIterator(pos / 2, offset); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index 12fb62fb77f0f..430bf677edbdf 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -19,7 +19,6 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; /** @@ -29,11 +28,14 @@ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ -final class UnsafeSortDataFormat extends SortDataFormat { +public final class UnsafeSortDataFormat + extends SortDataFormat { - public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); + private final LongArray buffer; - private UnsafeSortDataFormat() { } + public UnsafeSortDataFormat(LongArray buffer) { + this.buffer = buffer; + } @Override public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) { @@ -74,17 +76,17 @@ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { Platform.copyMemory( src.getBaseObject(), - src.getBaseOffset() + srcPos * 16, + src.getBaseOffset() + srcPos * 16L, dst.getBaseObject(), - dst.getBaseOffset() + dstPos * 16, - length * 16); + dst.getBaseOffset() + dstPos * 16L, + length * 16L); } @Override public LongArray allocate(int length) { - assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; - // This is used as temporary buffer, it's fine to allocate from JVM heap. - return new LongArray(MemoryBlock.fromLongArray(new long[length * 2])); + assert (length * 2 <= buffer.size()) : + "the buffer is smaller than required: " + buffer.size() + " < " + (length * 2); + return buffer; } } diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index ef89a9a86f093..177120aaa6c1f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -15,6 +15,12 @@ * limitations under the License. */ +var appLimit = -1; + +function setAppLimit(val) { + appLimit = val; +} + // this function works exactly the same as UIUtils.formatDuration function formatDuration(milliseconds) { if (milliseconds < 100) { @@ -54,7 +60,8 @@ function makeIdNumeric(id) { } function formatDate(date) { - return date.split(".")[0].replace("T", " "); + if (date <= 0) return "-"; + else return date.split(".")[0].replace("T", " "); } function getParameterByName(name, searchString) { @@ -110,7 +117,7 @@ $(document).ready(function() { requestedIncomplete = getParameterByName("showIncomplete", searchString); requestedIncomplete = (requestedIncomplete == "true" ? true : false); - $.getJSON("api/v1/applications", function(response,status,jqXHR) { + $.getJSON("api/v1/applications?limit=" + appLimit, function(response,status,jqXHR) { var array = []; var hasMultipleAttempts = false; for (i in response) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png b/core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png index 6c5f0993c43f0..cee28916e8db7 100644 Binary files a/core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png and b/core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png differ diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark_logo.png b/core/src/main/resources/org/apache/spark/ui/static/spark_logo.png deleted file mode 100644 index 4b187347792a6..0000000000000 Binary files a/core/src/main/resources/org/apache/spark/ui/static/spark_logo.png and /dev/null differ diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css index 0f400461c5293..3bf3e8bfa1f31 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css @@ -33,12 +33,15 @@ div#application-timeline, div#job-timeline { height: 55px; } -#task-assignment-timeline div.item.range { - padding: 0px; +#task-assignment-timeline div.vis-item.vis-range { height: 26px; border-width: 0; } +#task-assignment-timeline .vis-item-content { + padding: 0px; +} + .task-assignment-timeline-content { width: 100%; } @@ -83,24 +86,24 @@ rect.getting-result-time-proportion { stroke: #75B0A6; } -.vis.timeline { +.vis-timeline { line-height: 14px; } -.vis.timeline div.content { +.vis-timeline div.vis-item-content { width: 100%; } -.vis.timeline .item.stage { +.vis-timeline .vis-item.stage { cursor: pointer; } -.vis.timeline .item.stage.succeeded { +.vis-timeline .vis-item.stage.succeeded { background-color: #A0DFFF; border-color: #3EC0FF; } -.vis.timeline .item.stage.succeeded.selected { +.vis-timeline .vis-item.stage.succeeded.vis-selected { background-color: #A0DFFF; border-color: #3EC0FF; z-index: auto; @@ -111,12 +114,12 @@ rect.getting-result-time-proportion { stroke: #3EC0FF; } -.vis.timeline .item.stage.failed { +.vis-timeline .vis-item.stage.failed { background-color: #FFA1B0; border-color: #FF4D6D; } -.vis.timeline .item.stage.failed.selected { +.vis-timeline .vis-item.stage.failed.vis-selected { background-color: #FFA1B0; border-color: #FF4D6D; z-index: auto; @@ -127,12 +130,12 @@ rect.getting-result-time-proportion { stroke: #FF4D6D; } -.vis.timeline .item.stage.running { +.vis-timeline .vis-item.stage.running { background-color: #A2FCC0; border-color: #36F572; } -.vis.timeline .item.stage.running.selected { +.vis-timeline .vis-item.stage.running.vis-selected { background-color: #A2FCC0; border-color: #36F572; z-index: auto; @@ -143,20 +146,20 @@ rect.getting-result-time-proportion { stroke: #36F572; } -.vis.timeline .foreground { +.vis-timeline .vis-foreground { cursor: move; } -.vis.timeline .item.job { +.vis-timeline .vis-item.job { cursor: pointer; } -.vis.timeline .item.job.succeeded { +.vis-timeline .vis-item.job.succeeded { background-color: #A0DFFF; border-color: #3EC0FF; } -.vis.timeline .item.job.succeeded.selected { +.vis-timeline .vis-item.job.succeeded.vis-selected { background-color: #A0DFFF; border-color: #3EC0FF; z-index: auto; @@ -167,12 +170,12 @@ rect.getting-result-time-proportion { stroke: #3EC0FF; } -.vis.timeline .item.job.failed { +.vis-timeline .vis-item.job.failed { background-color: #FFA1B0; border-color: #FF4D6D; } -.vis.timeline .item.job.failed.selected { +.vis-timeline .vis-item.job.failed.vis-selected { background-color: #FFA1B0; border-color: #FF4D6D; z-index: auto; @@ -183,12 +186,12 @@ rect.getting-result-time-proportion { stroke: #FF4D6D; } -.vis.timeline .item.job.running { +.vis-timeline .vis-item.job.running { background-color: #A2FCC0; border-color: #36F572; } -.vis.timeline .item.job.running.selected { +.vis-timeline .vis-item.job.running.vis-selected { background-color: #A2FCC0; border-color: #36F572; z-index: auto; @@ -199,7 +202,7 @@ rect.getting-result-time-proportion { stroke: #36F572; } -.vis.timeline .item.executor.added { +.vis-timeline .vis-item.executor.added { background-color: #A0DFFF; border-color: #3EC0FF; } @@ -209,7 +212,7 @@ rect.getting-result-time-proportion { stroke: #3EC0FF; } -.vis.timeline .item.executor.removed { +.vis-timeline .vis-item.executor.removed { background-color: #FFA1B0; border-color: #FF4D6D; } @@ -219,7 +222,7 @@ rect.getting-result-time-proportion { stroke: #FF4D6D; } -.vis.timeline .item.executor.selected { +.vis-timeline .vis-item.executor.vis-selected { background-color: #A2FCC0; border-color: #36F572; z-index: 2; @@ -258,15 +261,15 @@ span.expand-task-assignment-timeline { cursor: pointer; } -.vis.timeline .item.range .content { +.vis-timeline .vis-item.vis-range .vis-item-content { position: unset; } -.vis.timeline .item .tooltip-inner { +.vis-timeline .vis-item .tooltip-inner { max-width: unset !important; } -.vispanel.center { +.vis-panel.vis-center { font-size: 12px; line-height: 12px; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index f4453c71df1ea..a6153ceda75e2 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -15,7 +15,7 @@ * limitations under the License. */ -function drawApplicationTimeline(groupArray, eventObjArray, startTime) { +function drawApplicationTimeline(groupArray, eventObjArray, startTime, offset) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); var container = $("#application-timeline")[0]; @@ -26,7 +26,10 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { editable: false, showCurrentTime: false, min: startTime, - zoomable: false + zoomable: false, + moment: function (date) { + return vis.moment(date).utcOffset(offset); + } }; var applicationTimeline = new vis.Timeline(container); @@ -38,10 +41,10 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { setupExecutorEventAction(); function setupJobEventAction() { - $(".item.range.job.application-timeline-object").each(function() { + $(".vis-item.vis-range.job.application-timeline-object").each(function() { var getSelectorForJobEntry = function(baseElem) { var jobIdText = $($(baseElem).find(".application-timeline-content")[0]).text(); - var jobId = jobIdText.match("\\(Job (\\d+)\\)")[1]; + var jobId = jobIdText.match("\\(Job (\\d+)\\)$")[1]; return "#job-" + jobId; }; @@ -87,7 +90,7 @@ $(function (){ } }); -function drawJobTimeline(groupArray, eventObjArray, startTime) { +function drawJobTimeline(groupArray, eventObjArray, startTime, offset) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); var container = $('#job-timeline')[0]; @@ -99,6 +102,9 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { showCurrentTime: false, min: startTime, zoomable: false, + moment: function (date) { + return vis.moment(date).utcOffset(offset); + } }; var jobTimeline = new vis.Timeline(container); @@ -110,10 +116,10 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { setupExecutorEventAction(); function setupStageEventAction() { - $(".item.range.stage.job-timeline-object").each(function() { + $(".vis-item.vis-range.stage.job-timeline-object").each(function() { var getSelectorForStageEntry = function(baseElem) { var stageIdText = $($(baseElem).find(".job-timeline-content")[0]).text(); - var stageIdAndAttempt = stageIdText.match("\\(Stage (\\d+\\.\\d+)\\)")[1].split("."); + var stageIdAndAttempt = stageIdText.match("\\(Stage (\\d+\\.\\d+)\\)$")[1].split("."); return "#stage-" + stageIdAndAttempt[0] + "-" + stageIdAndAttempt[1]; }; @@ -159,7 +165,7 @@ $(function (){ } }); -function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) { +function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime, offset) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); var container = $("#task-assignment-timeline")[0] @@ -173,7 +179,10 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, ma showCurrentTime: false, min: minLaunchTime, max: maxFinishTime, - zoomable: false + zoomable: false, + moment: function (date) { + return vis.moment(date).utcOffset(offset); + } }; var taskTimeline = new vis.Timeline(container) @@ -224,7 +233,7 @@ $(function (){ }); function setupExecutorEventAction() { - $(".item.box.executor").each(function () { + $(".vis-item.vis-box.executor").each(function () { $(this).hover( function() { $($(this).find(".executor-event-content")[0]).tooltip("show"); diff --git a/core/src/main/resources/org/apache/spark/ui/static/vis.min.css b/core/src/main/resources/org/apache/spark/ui/static/vis.min.css index a390c40d67574..40d182cfde231 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/vis.min.css +++ b/core/src/main/resources/org/apache/spark/ui/static/vis.min.css @@ -1 +1 @@ -.vis .overlay{position:absolute;top:0;left:0;width:100%;height:100%;z-index:10}.vis-active{box-shadow:0 0 10px #86d5f8}.vis [class*=span]{min-height:0;width:auto}.vis.timeline.root{position:relative;border:1px solid #bfbfbf;overflow:hidden;padding:0;margin:0;box-sizing:border-box}.vis.timeline .vispanel{position:absolute;padding:0;margin:0;box-sizing:border-box}.vis.timeline .vispanel.bottom,.vis.timeline .vispanel.center,.vis.timeline .vispanel.left,.vis.timeline .vispanel.right,.vis.timeline .vispanel.top{border:1px #bfbfbf}.vis.timeline .vispanel.center,.vis.timeline .vispanel.left,.vis.timeline .vispanel.right{border-top-style:solid;border-bottom-style:solid;overflow:hidden}.vis.timeline .vispanel.bottom,.vis.timeline .vispanel.center,.vis.timeline .vispanel.top{border-left-style:solid;border-right-style:solid}.vis.timeline .background{overflow:hidden}.vis.timeline .vispanel>.content{position:relative}.vis.timeline .vispanel .shadow{position:absolute;width:100%;height:1px;box-shadow:0 0 10px rgba(0,0,0,.8)}.vis.timeline .vispanel .shadow.top{top:-1px;left:0}.vis.timeline .vispanel .shadow.bottom{bottom:-1px;left:0}.vis.timeline .labelset{position:relative;overflow:hidden;box-sizing:border-box}.vis.timeline .labelset .vlabel{position:relative;left:0;top:0;width:100%;color:#4d4d4d;box-sizing:border-box;border-bottom:1px solid #bfbfbf}.vis.timeline .labelset .vlabel:last-child{border-bottom:none}.vis.timeline .labelset .vlabel .inner{display:inline-block;padding:5px}.vis.timeline .labelset .vlabel .inner.hidden{padding:0}.vis.timeline .itemset{position:relative;padding:0;margin:0;box-sizing:border-box}.vis.timeline .itemset .background,.vis.timeline .itemset .foreground{position:absolute;width:100%;height:100%;overflow:visible}.vis.timeline .axis{position:absolute;width:100%;height:0;left:0;z-index:1}.vis.timeline .foreground .group{position:relative;box-sizing:border-box;border-bottom:1px solid #bfbfbf}.vis.timeline .foreground .group:last-child{border-bottom:none}.vis.timeline .item{position:absolute;color:#1A1A1A;border-color:#97B0F8;border-width:1px;background-color:#D5DDF6;display:inline-block;padding:5px}.vis.timeline .item.selected{border-color:#FFC200;background-color:#FFF785;z-index:2}.vis.timeline .editable .item.selected{cursor:move}.vis.timeline .item.point.selected{background-color:#FFF785}.vis.timeline .item.box{text-align:center;border-style:solid;border-radius:2px}.vis.timeline .item.point{background:0 0}.vis.timeline .item.dot{position:absolute;padding:0;border-width:4px;border-style:solid;border-radius:4px}.vis.timeline .item.range{border-style:solid;border-radius:2px;box-sizing:border-box}.vis.timeline .item.background{overflow:hidden;border:none;background-color:rgba(213,221,246,.4);box-sizing:border-box;padding:0;margin:0}.vis.timeline .item.range .content{position:relative;display:inline-block;max-width:100%;overflow:hidden}.vis.timeline .item.background .content{position:absolute;display:inline-block;overflow:hidden;max-width:100%;margin:5px}.vis.timeline .item.line{padding:0;position:absolute;width:0;border-left-width:1px;border-left-style:solid}.vis.timeline .item .content{white-space:nowrap;overflow:hidden}.vis.timeline .item .delete{background:url(img/timeline/delete.png) top center no-repeat;position:absolute;width:24px;height:24px;top:0;right:-24px;cursor:pointer}.vis.timeline .item.range .drag-left{position:absolute;width:24px;height:100%;top:0;left:-4px;cursor:w-resize}.vis.timeline .item.range .drag-right{position:absolute;width:24px;height:100%;top:0;right:-4px;cursor:e-resize}.vis.timeline .timeaxis{position:relative;overflow:hidden}.vis.timeline .timeaxis.foreground{top:0;left:0;width:100%}.vis.timeline .timeaxis.background{position:absolute;top:0;left:0;width:100%;height:100%}.vis.timeline .timeaxis .text{position:absolute;color:#4d4d4d;padding:3px;white-space:nowrap}.vis.timeline .timeaxis .text.measure{position:absolute;padding-left:0;padding-right:0;margin-left:0;margin-right:0;visibility:hidden}.vis.timeline .timeaxis .grid.vertical{position:absolute;border-left:1px solid}.vis.timeline .timeaxis .grid.minor{border-color:#e5e5e5}.vis.timeline .timeaxis .grid.major{border-color:#bfbfbf}.vis.timeline .currenttime{background-color:#FF7F6E;width:2px;z-index:1}.vis.timeline .customtime{background-color:#6E94FF;width:2px;cursor:move;z-index:1}.vis.timeline .vispanel.background.horizontal .grid.horizontal{position:absolute;width:100%;height:0;border-bottom:1px solid}.vis.timeline .vispanel.background.horizontal .grid.minor{border-color:#e5e5e5}.vis.timeline .vispanel.background.horizontal .grid.major{border-color:#bfbfbf}.vis.timeline .dataaxis .yAxis.major{width:100%;position:absolute;color:#4d4d4d;white-space:nowrap}.vis.timeline .dataaxis .yAxis.major.measure{padding:0;margin:0;border:0;visibility:hidden;width:auto}.vis.timeline .dataaxis .yAxis.minor{position:absolute;width:100%;color:#bebebe;white-space:nowrap}.vis.timeline .dataaxis .yAxis.minor.measure{padding:0;margin:0;border:0;visibility:hidden;width:auto}.vis.timeline .dataaxis .yAxis.title{position:absolute;color:#4d4d4d;white-space:nowrap;bottom:20px;text-align:center}.vis.timeline .dataaxis .yAxis.title.measure{padding:0;margin:0;visibility:hidden;width:auto}.vis.timeline .dataaxis .yAxis.title.left{bottom:0;-webkit-transform-origin:left top;-moz-transform-origin:left top;-ms-transform-origin:left top;-o-transform-origin:left top;transform-origin:left bottom;-webkit-transform:rotate(-90deg);-moz-transform:rotate(-90deg);-ms-transform:rotate(-90deg);-o-transform:rotate(-90deg);transform:rotate(-90deg)}.vis.timeline .dataaxis .yAxis.title.right{bottom:0;-webkit-transform-origin:right bottom;-moz-transform-origin:right bottom;-ms-transform-origin:right bottom;-o-transform-origin:right bottom;transform-origin:right bottom;-webkit-transform:rotate(90deg);-moz-transform:rotate(90deg);-ms-transform:rotate(90deg);-o-transform:rotate(90deg);transform:rotate(90deg)}.vis.timeline .legend{background-color:rgba(247,252,255,.65);padding:5px;border-color:#b3b3b3;border-style:solid;border-width:1px;box-shadow:2px 2px 10px rgba(154,154,154,.55)}.vis.timeline .legendText{white-space:nowrap;display:inline-block}.vis.timeline .graphGroup0{fill:#4f81bd;fill-opacity:0;stroke-width:2px;stroke:#4f81bd}.vis.timeline .graphGroup1{fill:#f79646;fill-opacity:0;stroke-width:2px;stroke:#f79646}.vis.timeline .graphGroup2{fill:#8c51cf;fill-opacity:0;stroke-width:2px;stroke:#8c51cf}.vis.timeline .graphGroup3{fill:#75c841;fill-opacity:0;stroke-width:2px;stroke:#75c841}.vis.timeline .graphGroup4{fill:#ff0100;fill-opacity:0;stroke-width:2px;stroke:#ff0100}.vis.timeline .graphGroup5{fill:#37d8e6;fill-opacity:0;stroke-width:2px;stroke:#37d8e6}.vis.timeline .graphGroup6{fill:#042662;fill-opacity:0;stroke-width:2px;stroke:#042662}.vis.timeline .graphGroup7{fill:#00ff26;fill-opacity:0;stroke-width:2px;stroke:#00ff26}.vis.timeline .graphGroup8{fill:#f0f;fill-opacity:0;stroke-width:2px;stroke:#f0f}.vis.timeline .graphGroup9{fill:#8f3938;fill-opacity:0;stroke-width:2px;stroke:#8f3938}.vis.timeline .fill{fill-opacity:.1;stroke:none}.vis.timeline .bar{fill-opacity:.5;stroke-width:1px}.vis.timeline .point{stroke-width:2px;fill-opacity:1}.vis.timeline .legendBackground{stroke-width:1px;fill-opacity:.9;fill:#fff;stroke:#c2c2c2}.vis.timeline .outline{stroke-width:1px;fill-opacity:1;fill:#fff;stroke:#e5e5e5}.vis.timeline .iconFill{fill-opacity:.3;stroke:none}div.network-manipulationDiv{border-width:0;border-bottom:1px;border-style:solid;border-color:#d6d9d8;background:#fff;background:-moz-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:-webkit-gradient(linear,left top,left bottom,color-stop(0,#fff),color-stop(48%,#fcfcfc),color-stop(50%,#fafafa),color-stop(100%,#fcfcfc));background:-webkit-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:-o-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:-ms-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:linear-gradient(to bottom,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);filter:progid:DXImageTransform.Microsoft.gradient(startColorstr='#ffffff', endColorstr='#fcfcfc', GradientType=0);position:absolute;left:0;top:0;width:100%;height:30px}div.network-manipulation-editMode{position:absolute;left:0;top:0;height:30px;margin-top:20px}div.network-manipulation-closeDiv{position:absolute;right:0;top:0;width:30px;height:30px;background-position:20px 3px;background-repeat:no-repeat;background-image:url(img/network/cross.png);cursor:pointer;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}div.network-manipulation-closeDiv:hover{opacity:.6}span.network-manipulationUI{font-family:verdana;font-size:12px;-moz-border-radius:15px;border-radius:15px;display:inline-block;background-position:0 0;background-repeat:no-repeat;height:24px;margin:-14px 0 0 10px;vertical-align:middle;cursor:pointer;padding:0 8px;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}span.network-manipulationUI:hover{box-shadow:1px 1px 8px rgba(0,0,0,.2)}span.network-manipulationUI:active{box-shadow:1px 1px 8px rgba(0,0,0,.5)}span.network-manipulationUI.back{background-image:url(img/network/backIcon.png)}span.network-manipulationUI.none:hover{box-shadow:1px 1px 8px transparent;cursor:default}span.network-manipulationUI.none:active{box-shadow:1px 1px 8px transparent}span.network-manipulationUI.none{padding:0}span.network-manipulationUI.notification{margin:2px;font-weight:700}span.network-manipulationUI.add{background-image:url(img/network/addNodeIcon.png)}span.network-manipulationUI.edit{background-image:url(img/network/editIcon.png)}span.network-manipulationUI.edit.editmode{background-color:#fcfcfc;border-style:solid;border-width:1px;border-color:#ccc}span.network-manipulationUI.connect{background-image:url(img/network/connectIcon.png)}span.network-manipulationUI.delete{background-image:url(img/network/deleteIcon.png)}span.network-manipulationLabel{margin:0 0 0 23px;line-height:25px}div.network-seperatorLine{display:inline-block;width:1px;height:20px;background-color:#bdbdbd;margin:5px 7px 0 15px}div.network-navigation_wrapper{position:absolute;left:0;top:0;width:100%;height:100%}div.network-navigation{width:34px;height:34px;-moz-border-radius:17px;border-radius:17px;position:absolute;display:inline-block;background-position:2px 2px;background-repeat:no-repeat;cursor:pointer;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}div.network-navigation:hover{box-shadow:0 0 3px 3px rgba(56,207,21,.3)}div.network-navigation:active{box-shadow:0 0 1px 3px rgba(56,207,21,.95)}div.network-navigation.up{background-image:url(img/network/upArrow.png);bottom:50px;left:55px}div.network-navigation.down{background-image:url(img/network/downArrow.png);bottom:10px;left:55px}div.network-navigation.left{background-image:url(img/network/leftArrow.png);bottom:10px;left:15px}div.network-navigation.right{background-image:url(img/network/rightArrow.png);bottom:10px;left:95px}div.network-navigation.zoomIn{background-image:url(img/network/plus.png);bottom:10px;right:15px}div.network-navigation.zoomOut{background-image:url(img/network/minus.png);bottom:10px;right:55px}div.network-navigation.zoomExtends{background-image:url(img/network/zoomExtends.png);bottom:50px;right:15px} \ No newline at end of file +.vis-background,.vis-labelset,.vis-timeline{overflow:hidden}.vis .overlay{position:absolute;top:0;left:0;width:100%;height:100%;z-index:10}.vis-active{box-shadow:0 0 10px #86d5f8}.vis [class*=span]{min-height:0;width:auto}div.vis-configuration{position:relative;display:block;float:left;font-size:12px}div.vis-configuration-wrapper{display:block;width:700px}div.vis-configuration-wrapper::after{clear:both;content:"";display:block}div.vis-configuration.vis-config-option-container{display:block;width:495px;background-color:#fff;border:2px solid #f7f8fa;border-radius:4px;margin-top:20px;left:10px;padding-left:5px}div.vis-configuration.vis-config-button{display:block;width:495px;height:25px;vertical-align:middle;line-height:25px;background-color:#f7f8fa;border:2px solid #ceced0;border-radius:4px;margin-top:20px;left:10px;padding-left:5px;cursor:pointer;margin-bottom:30px}div.vis-configuration.vis-config-button.hover{background-color:#4588e6;border:2px solid #214373;color:#fff}div.vis-configuration.vis-config-item{display:block;float:left;width:495px;height:25px;vertical-align:middle;line-height:25px}div.vis-configuration.vis-config-item.vis-config-s2{left:10px;background-color:#f7f8fa;padding-left:5px;border-radius:3px}div.vis-configuration.vis-config-item.vis-config-s3{left:20px;background-color:#e4e9f0;padding-left:5px;border-radius:3px}div.vis-configuration.vis-config-item.vis-config-s4{left:30px;background-color:#cfd8e6;padding-left:5px;border-radius:3px}div.vis-configuration.vis-config-header{font-size:18px;font-weight:700}div.vis-configuration.vis-config-label{width:120px;height:25px;line-height:25px}div.vis-configuration.vis-config-label.vis-config-s3{width:110px}div.vis-configuration.vis-config-label.vis-config-s4{width:100px}div.vis-configuration.vis-config-colorBlock{top:1px;width:30px;height:19px;border:1px solid #444;border-radius:2px;padding:0;margin:0;cursor:pointer}input.vis-configuration.vis-config-checkbox{left:-5px}input.vis-configuration.vis-config-rangeinput{position:relative;top:-5px;width:60px;padding:1px;margin:0;pointer-events:none}.vis-panel,.vis-timeline{padding:0;box-sizing:border-box}input.vis-configuration.vis-config-range{-webkit-appearance:none;border:0 solid #fff;background-color:rgba(0,0,0,0);width:300px;height:20px}input.vis-configuration.vis-config-range::-webkit-slider-runnable-track{width:300px;height:5px;background:#dedede;background:-moz-linear-gradient(top,#dedede 0,#c8c8c8 99%);background:-webkit-gradient(linear,left top,left bottom,color-stop(0,#dedede),color-stop(99%,#c8c8c8));background:-webkit-linear-gradient(top,#dedede 0,#c8c8c8 99%);background:-o-linear-gradient(top,#dedede 0,#c8c8c8 99%);background:-ms-linear-gradient(top,#dedede 0,#c8c8c8 99%);background:linear-gradient(to bottom,#dedede 0,#c8c8c8 99%);filter:progid:DXImageTransform.Microsoft.gradient( startColorstr='#dedede', endColorstr='#c8c8c8', GradientType=0 );border:1px solid #999;box-shadow:#aaa 0 0 3px 0;border-radius:3px}input.vis-configuration.vis-config-range::-webkit-slider-thumb{-webkit-appearance:none;border:1px solid #14334b;height:17px;width:17px;border-radius:50%;background:#3876c2;background:-moz-linear-gradient(top,#3876c2 0,#385380 100%);background:-webkit-gradient(linear,left top,left bottom,color-stop(0,#3876c2),color-stop(100%,#385380));background:-webkit-linear-gradient(top,#3876c2 0,#385380 100%);background:-o-linear-gradient(top,#3876c2 0,#385380 100%);background:-ms-linear-gradient(top,#3876c2 0,#385380 100%);background:linear-gradient(to bottom,#3876c2 0,#385380 100%);filter:progid:DXImageTransform.Microsoft.gradient( startColorstr='#3876c2', endColorstr='#385380', GradientType=0 );box-shadow:#111927 0 0 1px 0;margin-top:-7px}input.vis-configuration.vis-config-range:focus{outline:0}input.vis-configuration.vis-config-range:focus::-webkit-slider-runnable-track{background:#9d9d9d;background:-moz-linear-gradient(top,#9d9d9d 0,#c8c8c8 99%);background:-webkit-gradient(linear,left top,left bottom,color-stop(0,#9d9d9d),color-stop(99%,#c8c8c8));background:-webkit-linear-gradient(top,#9d9d9d 0,#c8c8c8 99%);background:-o-linear-gradient(top,#9d9d9d 0,#c8c8c8 99%);background:-ms-linear-gradient(top,#9d9d9d 0,#c8c8c8 99%);background:linear-gradient(to bottom,#9d9d9d 0,#c8c8c8 99%);filter:progid:DXImageTransform.Microsoft.gradient( startColorstr='#9d9d9d', endColorstr='#c8c8c8', GradientType=0 )}input.vis-configuration.vis-config-range::-moz-range-track{width:300px;height:10px;background:#dedede;background:-moz-linear-gradient(top,#dedede 0,#c8c8c8 99%);background:-webkit-gradient(linear,left top,left bottom,color-stop(0,#dedede),color-stop(99%,#c8c8c8));background:-webkit-linear-gradient(top,#dedede 0,#c8c8c8 99%);background:-o-linear-gradient(top,#dedede 0,#c8c8c8 99%);background:-ms-linear-gradient(top,#dedede 0,#c8c8c8 99%);background:linear-gradient(to bottom,#dedede 0,#c8c8c8 99%);filter:progid:DXImageTransform.Microsoft.gradient( startColorstr='#dedede', endColorstr='#c8c8c8', GradientType=0 );border:1px solid #999;box-shadow:#aaa 0 0 3px 0;border-radius:3px}input.vis-configuration.vis-config-range::-moz-range-thumb{border:none;height:16px;width:16px;border-radius:50%;background:#385380}input.vis-configuration.vis-config-range:-moz-focusring{outline:#fff solid 1px;outline-offset:-1px}input.vis-configuration.vis-config-range::-ms-track{width:300px;height:5px;background:0 0;border-color:transparent;border-width:6px 0;color:transparent}input.vis-configuration.vis-config-range::-ms-fill-lower{background:#777;border-radius:10px}input.vis-configuration.vis-config-range::-ms-fill-upper{background:#ddd;border-radius:10px}input.vis-configuration.vis-config-range::-ms-thumb{border:none;height:16px;width:16px;border-radius:50%;background:#385380}input.vis-configuration.vis-config-range:focus::-ms-fill-lower{background:#888}input.vis-configuration.vis-config-range:focus::-ms-fill-upper{background:#ccc}.vis-configuration-popup{position:absolute;background:rgba(57,76,89,.85);border:2px solid #f2faff;line-height:30px;height:30px;width:150px;text-align:center;color:#fff;font-size:14px;border-radius:4px;-webkit-transition:opacity .3s ease-in-out;-moz-transition:opacity .3s ease-in-out;transition:opacity .3s ease-in-out}.vis-configuration-popup:after,.vis-configuration-popup:before{left:100%;top:50%;border:solid transparent;content:" ";height:0;width:0;position:absolute;pointer-events:none}.vis-configuration-popup:after{border-color:rgba(136,183,213,0);border-left-color:rgba(57,76,89,.85);border-width:8px;margin-top:-8px}.vis-configuration-popup:before{border-color:rgba(194,225,245,0);border-left-color:#f2faff;border-width:12px;margin-top:-12px}.vis-timeline{position:relative;border:1px solid #bfbfbf;margin:0}.vis-panel{position:absolute;margin:0}.vis-panel.vis-bottom,.vis-panel.vis-center,.vis-panel.vis-left,.vis-panel.vis-right,.vis-panel.vis-top{border:1px #bfbfbf}.vis-panel.vis-center,.vis-panel.vis-left,.vis-panel.vis-right{border-top-style:solid;border-bottom-style:solid;overflow:hidden}.vis-panel.vis-bottom,.vis-panel.vis-center,.vis-panel.vis-top{border-left-style:solid;border-right-style:solid}.vis-panel>.vis-content{position:relative}.vis-panel .vis-shadow{position:absolute;width:100%;height:1px;box-shadow:0 0 10px rgba(0,0,0,.8)}.vis-itemset,.vis-labelset,.vis-labelset .vis-label{position:relative;box-sizing:border-box}.vis-panel .vis-shadow.vis-top{top:-1px;left:0}.vis-panel .vis-shadow.vis-bottom{bottom:-1px;left:0}.vis-labelset .vis-label{left:0;top:0;width:100%;color:#4d4d4d;border-bottom:1px solid #bfbfbf}.vis-labelset .vis-label.draggable{cursor:pointer}.vis-labelset .vis-label:last-child{border-bottom:none}.vis-labelset .vis-label .vis-inner{display:inline-block;padding:5px}.vis-labelset .vis-label .vis-inner.vis-hidden{padding:0}.vis-itemset{padding:0;margin:0}.vis-itemset .vis-background,.vis-itemset .vis-foreground{position:absolute;width:100%;height:100%;overflow:visible}.vis-axis{position:absolute;width:100%;height:0;left:0;z-index:1}.vis-foreground .vis-group{position:relative;box-sizing:border-box;border-bottom:1px solid #bfbfbf}.vis-foreground .vis-group:last-child{border-bottom:none}.vis-overlay{position:absolute;top:0;left:0;width:100%;height:100%;z-index:10}.vis-item{position:absolute;color:#1A1A1A;border-color:#97B0F8;border-width:1px;background-color:#D5DDF6;display:inline-block}.vis-item.vis-point.vis-selected,.vis-item.vis-selected{background-color:#FFF785}.vis-item.vis-selected{border-color:#FFC200;z-index:2}.vis-editable.vis-selected{cursor:move}.vis-item.vis-box{text-align:center;border-style:solid;border-radius:2px}.vis-item.vis-point{background:0 0}.vis-item.vis-dot{position:absolute;padding:0;border-width:4px;border-style:solid;border-radius:4px}.vis-item.vis-range{border-style:solid;border-radius:2px;box-sizing:border-box}.vis-item.vis-background{border:none;background-color:rgba(213,221,246,.4);box-sizing:border-box;padding:0;margin:0}.vis-item .vis-item-overflow{position:relative;width:100%;height:100%;padding:0;margin:0;overflow:hidden}.vis-item .vis-delete,.vis-item .vis-delete-rtl{background:url(img/timeline/delete.png) center no-repeat;height:24px;top:-4px;cursor:pointer}.vis-item.vis-range .vis-item-content{position:relative;display:inline-block}.vis-item.vis-background .vis-item-content{position:absolute;display:inline-block}.vis-item.vis-line{padding:0;position:absolute;width:0;border-left-width:1px;border-left-style:solid}.vis-item .vis-item-content{white-space:nowrap;box-sizing:border-box;padding:5px}.vis-item .vis-delete{position:absolute;width:24px;right:-24px}.vis-item .vis-delete-rtl{position:absolute;width:24px;left:-24px}.vis-item.vis-range .vis-drag-left{position:absolute;width:24px;max-width:20%;min-width:2px;height:100%;top:0;left:-4px;cursor:w-resize}.vis-item.vis-range .vis-drag-right{position:absolute;width:24px;max-width:20%;min-width:2px;height:100%;top:0;right:-4px;cursor:e-resize}.vis-range.vis-item.vis-readonly .vis-drag-left,.vis-range.vis-item.vis-readonly .vis-drag-right{cursor:auto}.vis-time-axis{position:relative;overflow:hidden}.vis-time-axis.vis-foreground{top:0;left:0;width:100%}.vis-time-axis.vis-background{position:absolute;top:0;left:0;width:100%;height:100%}.vis-time-axis .vis-text{position:absolute;color:#4d4d4d;padding:3px;overflow:hidden;box-sizing:border-box;white-space:nowrap}.vis-time-axis .vis-text.vis-measure{position:absolute;padding-left:0;padding-right:0;margin-left:0;margin-right:0;visibility:hidden}.vis-time-axis .vis-grid.vis-vertical{position:absolute;border-left:1px solid}.vis-time-axis .vis-grid.vis-vertical-rtl{position:absolute;border-right:1px solid}.vis-time-axis .vis-grid.vis-minor{border-color:#e5e5e5}.vis-time-axis .vis-grid.vis-major{border-color:#bfbfbf}.vis-current-time{background-color:#FF7F6E;width:2px;z-index:1}.vis-custom-time{background-color:#6E94FF;width:2px;cursor:move;z-index:1}div.vis-network div.vis-close,div.vis-network div.vis-edit-mode div.vis-button,div.vis-network div.vis-manipulation div.vis-button{cursor:pointer;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;-webkit-touch-callout:none;-khtml-user-select:none}.vis-panel.vis-background.vis-horizontal .vis-grid.vis-horizontal{position:absolute;width:100%;height:0;border-bottom:1px solid}.vis-panel.vis-background.vis-horizontal .vis-grid.vis-minor{border-color:#e5e5e5}.vis-panel.vis-background.vis-horizontal .vis-grid.vis-major{border-color:#bfbfbf}.vis-data-axis .vis-y-axis.vis-major{width:100%;position:absolute;color:#4d4d4d;white-space:nowrap}.vis-data-axis .vis-y-axis.vis-major.vis-measure{padding:0;margin:0;border:0;visibility:hidden;width:auto}.vis-data-axis .vis-y-axis.vis-minor{position:absolute;width:100%;color:#bebebe;white-space:nowrap}.vis-data-axis .vis-y-axis.vis-minor.vis-measure{padding:0;margin:0;border:0;visibility:hidden;width:auto}.vis-data-axis .vis-y-axis.vis-title{position:absolute;color:#4d4d4d;white-space:nowrap;bottom:20px;text-align:center}.vis-data-axis .vis-y-axis.vis-title.vis-measure{padding:0;margin:0;visibility:hidden;width:auto}.vis-data-axis .vis-y-axis.vis-title.vis-left{bottom:0;-webkit-transform-origin:left top;-moz-transform-origin:left top;-ms-transform-origin:left top;-o-transform-origin:left top;transform-origin:left bottom;-webkit-transform:rotate(-90deg);-moz-transform:rotate(-90deg);-ms-transform:rotate(-90deg);-o-transform:rotate(-90deg);transform:rotate(-90deg)}.vis-data-axis .vis-y-axis.vis-title.vis-right{bottom:0;-webkit-transform-origin:right bottom;-moz-transform-origin:right bottom;-ms-transform-origin:right bottom;-o-transform-origin:right bottom;transform-origin:right bottom;-webkit-transform:rotate(90deg);-moz-transform:rotate(90deg);-ms-transform:rotate(90deg);-o-transform:rotate(90deg);transform:rotate(90deg)}.vis-legend{background-color:rgba(247,252,255,.65);padding:5px;border:1px solid #b3b3b3;box-shadow:2px 2px 10px rgba(154,154,154,.55)}.vis-legend-text{white-space:nowrap;display:inline-block}.vis-graph-group0{fill:#4f81bd;fill-opacity:0;stroke-width:2px;stroke:#4f81bd}.vis-graph-group1{fill:#f79646;fill-opacity:0;stroke-width:2px;stroke:#f79646}.vis-graph-group2{fill:#8c51cf;fill-opacity:0;stroke-width:2px;stroke:#8c51cf}.vis-graph-group3{fill:#75c841;fill-opacity:0;stroke-width:2px;stroke:#75c841}.vis-graph-group4{fill:#ff0100;fill-opacity:0;stroke-width:2px;stroke:#ff0100}.vis-graph-group5{fill:#37d8e6;fill-opacity:0;stroke-width:2px;stroke:#37d8e6}.vis-graph-group6{fill:#042662;fill-opacity:0;stroke-width:2px;stroke:#042662}.vis-graph-group7{fill:#00ff26;fill-opacity:0;stroke-width:2px;stroke:#00ff26}.vis-graph-group8{fill:#f0f;fill-opacity:0;stroke-width:2px;stroke:#f0f}.vis-graph-group9{fill:#8f3938;fill-opacity:0;stroke-width:2px;stroke:#8f3938}.vis-timeline .vis-fill{fill-opacity:.1;stroke:none}.vis-timeline .vis-bar{fill-opacity:.5;stroke-width:1px}.vis-timeline .vis-point{stroke-width:2px;fill-opacity:1}.vis-timeline .vis-legend-background{stroke-width:1px;fill-opacity:.9;fill:#fff;stroke:#c2c2c2}.vis-timeline .vis-outline{stroke-width:1px;fill-opacity:1;fill:#fff;stroke:#e5e5e5}.vis-timeline .vis-icon-fill{fill-opacity:.3;stroke:none}div.vis-network div.vis-manipulation{border-width:0;border-bottom:1px;border-style:solid;border-color:#d6d9d8;background:#fff;background:-moz-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:-webkit-gradient(linear,left top,left bottom,color-stop(0,#fff),color-stop(48%,#fcfcfc),color-stop(50%,#fafafa),color-stop(100%,#fcfcfc));background:-webkit-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:-o-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:-ms-linear-gradient(top,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);background:linear-gradient(to bottom,#fff 0,#fcfcfc 48%,#fafafa 50%,#fcfcfc 100%);filter:progid:DXImageTransform.Microsoft.gradient( startColorstr='#ffffff', endColorstr='#fcfcfc', GradientType=0 );padding-top:4px;position:absolute;left:0;top:0;width:100%;height:28px}div.vis-network div.vis-edit-mode{position:absolute;left:0;top:5px;height:30px}div.vis-network div.vis-close{position:absolute;right:0;top:0;width:30px;height:30px;background-position:20px 3px;background-repeat:no-repeat;background-image:url(img/network/cross.png);user-select:none}div.vis-network div.vis-close:hover{opacity:.6}div.vis-network div.vis-edit-mode div.vis-button,div.vis-network div.vis-manipulation div.vis-button{float:left;font-family:verdana;font-size:12px;-moz-border-radius:15px;border-radius:15px;display:inline-block;background-position:0 0;background-repeat:no-repeat;height:24px;margin-left:10px;padding:0 8px;user-select:none}div.vis-network div.vis-manipulation div.vis-button:hover{box-shadow:1px 1px 8px rgba(0,0,0,.2)}div.vis-network div.vis-manipulation div.vis-button:active{box-shadow:1px 1px 8px rgba(0,0,0,.5)}div.vis-network div.vis-manipulation div.vis-button.vis-back{background-image:url(img/network/backIcon.png)}div.vis-network div.vis-manipulation div.vis-button.vis-none:hover{box-shadow:1px 1px 8px transparent;cursor:default}div.vis-network div.vis-manipulation div.vis-button.vis-none:active{box-shadow:1px 1px 8px transparent}div.vis-network div.vis-manipulation div.vis-button.vis-none{padding:0}div.vis-network div.vis-manipulation div.notification{margin:2px;font-weight:700}div.vis-network div.vis-manipulation div.vis-button.vis-add{background-image:url(img/network/addNodeIcon.png)}div.vis-network div.vis-edit-mode div.vis-button.vis-edit,div.vis-network div.vis-manipulation div.vis-button.vis-edit{background-image:url(img/network/editIcon.png)}div.vis-network div.vis-edit-mode div.vis-button.vis-edit.vis-edit-mode{background-color:#fcfcfc;border:1px solid #ccc}div.vis-network div.vis-manipulation div.vis-button.vis-connect{background-image:url(img/network/connectIcon.png)}div.vis-network div.vis-manipulation div.vis-button.vis-delete{background-image:url(img/network/deleteIcon.png)}div.vis-network div.vis-edit-mode div.vis-label,div.vis-network div.vis-manipulation div.vis-label{margin:0 0 0 23px;line-height:25px}div.vis-network div.vis-manipulation div.vis-separator-line{float:left;display:inline-block;width:1px;height:21px;background-color:#bdbdbd;margin:0 7px 0 15px}div.vis-network-tooltip{position:absolute;visibility:hidden;padding:5px;white-space:nowrap;font-family:verdana;font-size:14px;color:#000;background-color:#f5f4ed;-moz-border-radius:3px;-webkit-border-radius:3px;border-radius:3px;border:1px solid #808074;box-shadow:3px 3px 10px rgba(0,0,0,.2);pointer-events:none}div.vis-network div.vis-navigation div.vis-button{width:34px;height:34px;-moz-border-radius:17px;border-radius:17px;position:absolute;display:inline-block;background-position:2px 2px;background-repeat:no-repeat;cursor:pointer;-webkit-touch-callout:none;-webkit-user-select:none;-khtml-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none}div.vis-network div.vis-navigation div.vis-button:hover{box-shadow:0 0 3px 3px rgba(56,207,21,.3)}div.vis-network div.vis-navigation div.vis-button:active{box-shadow:0 0 1px 3px rgba(56,207,21,.95)}div.vis-network div.vis-navigation div.vis-button.vis-up{background-image:url(img/network/upArrow.png);bottom:50px;left:55px}div.vis-network div.vis-navigation div.vis-button.vis-down{background-image:url(img/network/downArrow.png);bottom:10px;left:55px}div.vis-network div.vis-navigation div.vis-button.vis-left{background-image:url(img/network/leftArrow.png);bottom:10px;left:15px}div.vis-network div.vis-navigation div.vis-button.vis-right{background-image:url(img/network/rightArrow.png);bottom:10px;left:95px}div.vis-network div.vis-navigation div.vis-button.vis-zoomIn{background-image:url(img/network/plus.png);bottom:10px;right:15px}div.vis-network div.vis-navigation div.vis-button.vis-zoomOut{background-image:url(img/network/minus.png);bottom:10px;right:55px}div.vis-network div.vis-navigation div.vis-button.vis-zoomExtends{background-image:url(img/network/zoomExtends.png);bottom:50px;right:15px}div.vis-color-picker{position:absolute;top:0;left:30px;margin-top:-140px;margin-left:30px;width:310px;height:444px;z-index:1;padding:10px;border-radius:15px;background-color:#fff;display:none;box-shadow:rgba(0,0,0,.5) 0 0 10px 0}div.vis-color-picker div.vis-arrow{position:absolute;top:147px;left:5px}div.vis-color-picker div.vis-arrow::after,div.vis-color-picker div.vis-arrow::before{right:100%;top:50%;border:solid transparent;content:" ";height:0;width:0;position:absolute;pointer-events:none}div.vis-color-picker div.vis-arrow:after{border-color:rgba(255,255,255,0);border-right-color:#fff;border-width:30px;margin-top:-30px}div.vis-color-picker div.vis-color{position:absolute;width:289px;height:289px;cursor:pointer}div.vis-color-picker div.vis-brightness{position:absolute;top:313px}div.vis-color-picker div.vis-opacity{position:absolute;top:350px}div.vis-color-picker div.vis-selector{position:absolute;top:137px;left:137px;width:15px;height:15px;border-radius:15px;border:1px solid #fff;background:#4c4c4c;background:-moz-linear-gradient(top,#4c4c4c 0,#595959 12%,#666 25%,#474747 39%,#2c2c2c 50%,#000 51%,#111 60%,#2b2b2b 76%,#1c1c1c 91%,#131313 100%);background:-webkit-gradient(linear,left top,left bottom,color-stop(0,#4c4c4c),color-stop(12%,#595959),color-stop(25%,#666),color-stop(39%,#474747),color-stop(50%,#2c2c2c),color-stop(51%,#000),color-stop(60%,#111),color-stop(76%,#2b2b2b),color-stop(91%,#1c1c1c),color-stop(100%,#131313));background:-webkit-linear-gradient(top,#4c4c4c 0,#595959 12%,#666 25%,#474747 39%,#2c2c2c 50%,#000 51%,#111 60%,#2b2b2b 76%,#1c1c1c 91%,#131313 100%);background:-o-linear-gradient(top,#4c4c4c 0,#595959 12%,#666 25%,#474747 39%,#2c2c2c 50%,#000 51%,#111 60%,#2b2b2b 76%,#1c1c1c 91%,#131313 100%);background:-ms-linear-gradient(top,#4c4c4c 0,#595959 12%,#666 25%,#474747 39%,#2c2c2c 50%,#000 51%,#111 60%,#2b2b2b 76%,#1c1c1c 91%,#131313 100%);background:linear-gradient(to bottom,#4c4c4c 0,#595959 12%,#666 25%,#474747 39%,#2c2c2c 50%,#000 51%,#111 60%,#2b2b2b 76%,#1c1c1c 91%,#131313 100%);filter:progid:DXImageTransform.Microsoft.gradient( startColorstr='#4c4c4c', endColorstr='#131313', GradientType=0 )}div.vis-color-picker div.vis-initial-color,div.vis-color-picker div.vis-new-color{width:140px;height:20px;top:380px;font-size:10px;color:rgba(0,0,0,.4);line-height:20px;position:absolute;vertical-align:middle}div.vis-color-picker div.vis-new-color{border:1px solid rgba(0,0,0,.1);border-radius:5px;left:159px;text-align:right;padding-right:2px}div.vis-color-picker div.vis-initial-color{border:1px solid rgba(0,0,0,.1);border-radius:5px;left:10px;text-align:left;padding-left:2px}div.vis-color-picker div.vis-label{position:absolute;width:300px;left:10px}div.vis-color-picker div.vis-label.vis-brightness{top:300px}div.vis-color-picker div.vis-label.vis-opacity{top:338px}div.vis-color-picker div.vis-button{position:absolute;width:68px;height:25px;border-radius:10px;vertical-align:middle;text-align:center;line-height:25px;top:410px;border:2px solid #d9d9d9;background-color:#f7f7f7;cursor:pointer}div.vis-color-picker div.vis-button.vis-cancel{left:5px}div.vis-color-picker div.vis-button.vis-load{left:82px}div.vis-color-picker div.vis-button.vis-apply{left:159px}div.vis-color-picker div.vis-button.vis-save{left:236px}div.vis-color-picker input.vis-range{width:290px;height:20px} \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/vis.min.js b/core/src/main/resources/org/apache/spark/ui/static/vis.min.js index 2b3b1d60463f7..92b8ed75d85fc 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/vis.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/vis.min.js @@ -4,11 +4,11 @@ * * A dynamic, browser-based visualization library. * - * @version 3.9.0 - * @date 2015-01-16 + * @version 4.16.1 + * @date 2016-04-18 * * @license - * Copyright (C) 2011-2014 Almende B.V, http://almende.com + * Copyright (C) 2011-2016 Almende B.V, http://almende.com * * Vis.js is dual licensed under both * @@ -22,17 +22,24 @@ * * Vis.js may be distributed under either license. */ -"use strict";!function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define(e):"object"==typeof exports?exports.vis=e():t.vis=e()}(this,function(){return function(t){function e(s){if(i[s])return i[s].exports;var o=i[s]={exports:{},id:s,loaded:!1};return t[s].call(o.exports,o,o.exports,e),o.loaded=!0,o.exports}var i={};return e.m=t,e.c=i,e.p="",e(0)}([function(t,e,i){e.util=i(1),e.DOMutil=i(2),e.DataSet=i(3),e.DataView=i(4),e.Queue=i(5),e.Graph3d=i(6),e.graph3d={Camera:i(7),Filter:i(8),Point2d:i(9),Point3d:i(10),Slider:i(11),StepNumber:i(12)},e.Timeline=i(13),e.Graph2d=i(14),e.timeline={DateUtil:i(15),DataStep:i(16),Range:i(17),stack:i(18),TimeStep:i(19),components:{items:{Item:i(31),BackgroundItem:i(32),BoxItem:i(33),PointItem:i(34),RangeItem:i(35)},Component:i(20),CurrentTime:i(21),CustomTime:i(22),DataAxis:i(23),GraphGroup:i(24),Group:i(25),BackgroundGroup:i(26),ItemSet:i(27),Legend:i(28),LineGraph:i(29),TimeAxis:i(30)}},e.Network=i(36),e.network={Edge:i(37),Groups:i(38),Images:i(39),Node:i(40),Popup:i(41),dotparser:i(42),gephiParser:i(43)},e.Graph=function(){throw new Error("Graph is renamed to Network. Please create a graph as new vis.Network(...)")},e.moment=i(44),e.hammer=i(45),e.Hammer=i(45)},function(t,e,i){var s=i(44);e.isNumber=function(t){return t instanceof Number||"number"==typeof t},e.isString=function(t){return t instanceof String||"string"==typeof t},e.isDate=function(t){if(t instanceof Date)return!0;if(e.isString(t)){var i=o.exec(t);if(i)return!0;if(!isNaN(Date.parse(t)))return!0}return!1},e.isDataTable=function(t){return"undefined"!=typeof google&&google.visualization&&google.visualization.DataTable&&t instanceof google.visualization.DataTable},e.randomUUID=function(){var t=function(){return Math.floor(65536*Math.random()).toString(16)};return t()+t()+"-"+t()+"-"+t()+"-"+t()+"-"+t()+t()+t()},e.extend=function(t){for(var e=1,i=arguments.length;i>e;e++){var s=arguments[e];for(var o in s)s.hasOwnProperty(o)&&(t[o]=s[o])}return t},e.selectiveExtend=function(t,e){if(!Array.isArray(t))throw new Error("Array with property names expected as first argument");for(var i=2;ii;i++)if(t[i]!=e[i])return!1;return!0},e.convert=function(t,i){var n;if(void 0===t)return void 0;if(null===t)return null;if(!i)return t;if("string"!=typeof i&&!(i instanceof String))throw new Error("Type must be a string");switch(i){case"boolean":case"Boolean":return Boolean(t);case"number":case"Number":return Number(t.valueOf());case"string":case"String":return String(t);case"Date":if(e.isNumber(t))return new Date(t);if(t instanceof Date)return new Date(t.valueOf());if(s.isMoment(t))return new Date(t.valueOf());if(e.isString(t))return n=o.exec(t),n?new Date(Number(n[1])):s(t).toDate();throw new Error("Cannot convert object of type "+e.getType(t)+" to type Date");case"Moment":if(e.isNumber(t))return s(t);if(t instanceof Date)return s(t.valueOf());if(s.isMoment(t))return s(t);if(e.isString(t))return n=o.exec(t),s(n?Number(n[1]):t);throw new Error("Cannot convert object of type "+e.getType(t)+" to type Date");case"ISODate":if(e.isNumber(t))return new Date(t);if(t instanceof Date)return t.toISOString();if(s.isMoment(t))return t.toDate().toISOString();if(e.isString(t))return n=o.exec(t),n?new Date(Number(n[1])).toISOString():new Date(t).toISOString();throw new Error("Cannot convert object of type "+e.getType(t)+" to type ISODate");case"ASPDate":if(e.isNumber(t))return"/Date("+t+")/";if(t instanceof Date)return"/Date("+t.valueOf()+")/";if(e.isString(t)){n=o.exec(t);var r;return r=n?new Date(Number(n[1])).valueOf():new Date(t).valueOf(),"/Date("+r+")/"}throw new Error("Cannot convert object of type "+e.getType(t)+" to type ASPDate");default:throw new Error('Unknown type "'+i+'"')}};var o=/^\/?Date\((\-?\d+)/i;e.getType=function(t){var e=typeof t;return"object"==e?null==t?"null":t instanceof Boolean?"Boolean":t instanceof Number?"Number":t instanceof String?"String":Array.isArray(t)?"Array":t instanceof Date?"Date":"Object":"number"==e?"Number":"boolean"==e?"Boolean":"string"==e?"String":e},e.getAbsoluteLeft=function(t){return t.getBoundingClientRect().left},e.getAbsoluteTop=function(t){return t.getBoundingClientRect().top},e.addClassName=function(t,e){var i=t.className.split(" ");-1==i.indexOf(e)&&(i.push(e),t.className=i.join(" "))},e.removeClassName=function(t,e){var i=t.className.split(" "),s=i.indexOf(e);-1!=s&&(i.splice(s,1),t.className=i.join(" "))},e.forEach=function(t,e){var i,s;if(Array.isArray(t))for(i=0,s=t.length;s>i;i++)e(t[i],i,t);else for(i in t)t.hasOwnProperty(i)&&e(t[i],i,t)},e.toArray=function(t){var e=[];for(var i in t)t.hasOwnProperty(i)&&e.push(t[i]);return e},e.updateProperty=function(t,e,i){return t[e]!==i?(t[e]=i,!0):!1},e.addEventListener=function(t,e,i,s){t.addEventListener?(void 0===s&&(s=!1),"mousewheel"===e&&navigator.userAgent.indexOf("Firefox")>=0&&(e="DOMMouseScroll"),t.addEventListener(e,i,s)):t.attachEvent("on"+e,i)},e.removeEventListener=function(t,e,i,s){t.removeEventListener?(void 0===s&&(s=!1),"mousewheel"===e&&navigator.userAgent.indexOf("Firefox")>=0&&(e="DOMMouseScroll"),t.removeEventListener(e,i,s)):t.detachEvent("on"+e,i)},e.preventDefault=function(t){t||(t=window.event),t.preventDefault?t.preventDefault():t.returnValue=!1},e.getTarget=function(t){t||(t=window.event);var e;return t.target?e=t.target:t.srcElement&&(e=t.srcElement),void 0!=e.nodeType&&3==e.nodeType&&(e=e.parentNode),e},e.option={},e.option.asBoolean=function(t,e){return"function"==typeof t&&(t=t()),null!=t?0!=t:e||null},e.option.asNumber=function(t,e){return"function"==typeof t&&(t=t()),null!=t?Number(t)||e||null:e||null},e.option.asString=function(t,e){return"function"==typeof t&&(t=t()),null!=t?String(t):e||null},e.option.asSize=function(t,i){return"function"==typeof t&&(t=t()),e.isString(t)?t:e.isNumber(t)?t+"px":i||null},e.option.asElement=function(t,e){return"function"==typeof t&&(t=t()),t||e||null},e.hexToRGB=function(t){var e=/^#?([a-f\d])([a-f\d])([a-f\d])$/i;t=t.replace(e,function(t,e,i,s){return e+e+i+i+s+s});var i=/^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(t);return i?{r:parseInt(i[1],16),g:parseInt(i[2],16),b:parseInt(i[3],16)}:null},e.RGBToHex=function(t,e,i){return"#"+((1<<24)+(t<<16)+(e<<8)+i).toString(16).slice(1)},e.parseColor=function(t){var i;if(e.isString(t)){if(e.isValidRGB(t)){var s=t.substr(4).substr(0,t.length-5).split(",");t=e.RGBToHex(s[0],s[1],s[2])}if(e.isValidHex(t)){var o=e.hexToHSV(t),n={h:o.h,s:.45*o.s,v:Math.min(1,1.05*o.v)},r={h:o.h,s:Math.min(1,1.25*o.v),v:.6*o.v},a=e.HSVToHex(r.h,r.h,r.v),h=e.HSVToHex(n.h,n.s,n.v);i={background:t,border:a,highlight:{background:h,border:a},hover:{background:h,border:a}}}else i={background:t,border:t,highlight:{background:t,border:t},hover:{background:t,border:t}}}else i={},i.background=t.background||"white",i.border=t.border||i.background,e.isString(t.highlight)?i.highlight={border:t.highlight,background:t.highlight}:(i.highlight={},i.highlight.background=t.highlight&&t.highlight.background||i.background,i.highlight.border=t.highlight&&t.highlight.border||i.border),e.isString(t.hover)?i.hover={border:t.hover,background:t.hover}:(i.hover={},i.hover.background=t.hover&&t.hover.background||i.background,i.hover.border=t.hover&&t.hover.border||i.border);return i},e.RGBToHSV=function(t,e,i){t/=255,e/=255,i/=255;var s=Math.min(t,Math.min(e,i)),o=Math.max(t,Math.max(e,i));if(s==o)return{h:0,s:0,v:s};var n=t==s?e-i:i==s?t-e:i-t,r=t==s?3:i==s?1:5,a=60*(r-n/(o-s))/360,h=(o-s)/o,d=o;return{h:a,s:h,v:d}};var n={split:function(t){var e={};return t.split(";").forEach(function(t){if(""!=t.trim()){var i=t.split(":"),s=i[0].trim(),o=i[1].trim();e[s]=o}}),e},join:function(t){return Object.keys(t).map(function(e){return e+": "+t[e]}).join("; ")}};e.addCssText=function(t,i){var s=n.split(t.style.cssText),o=n.split(i),r=e.extend(s,o);t.style.cssText=n.join(r)},e.removeCssText=function(t,e){var i=n.split(t.style.cssText),s=n.split(e);for(var o in s)s.hasOwnProperty(o)&&delete i[o];t.style.cssText=n.join(i)},e.HSVToRGB=function(t,e,i){var s,o,n,r=Math.floor(6*t),a=6*t-r,h=i*(1-e),d=i*(1-a*e),l=i*(1-(1-a)*e);switch(r%6){case 0:s=i,o=l,n=h;break;case 1:s=d,o=i,n=h;break;case 2:s=h,o=i,n=l;break;case 3:s=h,o=d,n=i;break;case 4:s=l,o=h,n=i;break;case 5:s=i,o=h,n=d}return{r:Math.floor(255*s),g:Math.floor(255*o),b:Math.floor(255*n)}},e.HSVToHex=function(t,i,s){var o=e.HSVToRGB(t,i,s);return e.RGBToHex(o.r,o.g,o.b)},e.hexToHSV=function(t){var i=e.hexToRGB(t);return e.RGBToHSV(i.r,i.g,i.b)},e.isValidHex=function(t){var e=/(^#[0-9A-F]{6}$)|(^#[0-9A-F]{3}$)/i.test(t);return e},e.isValidRGB=function(t){t=t.replace(" ","");var e=/rgb\((\d{1,3}),(\d{1,3}),(\d{1,3})\)/i.test(t);return e},e.selectiveBridgeObject=function(t,i){if("object"==typeof i){for(var s=Object.create(i),o=0;o=r&&o>n;){var h=Math.floor((r+a)/2),d=t[h],l=void 0===s?d[i]:d[i][s],c=e(l);if(0==c)return h;-1==c?r=h+1:a=h-1,n++}return-1},e.binarySearchValue=function(t,e,i,s){for(var o,n,r,a,h=1e4,d=0,l=0,c=t.length-1;c>=l&&h>d;){if(a=Math.floor(.5*(c+l)),o=t[Math.max(0,a-1)][i],n=t[a][i],r=t[Math.min(t.length-1,a+1)][i],n==e)return a;if(e>o&&n>e)return"before"==s?Math.max(0,a-1):a;if(e>n&&r>e)return"before"==s?a:Math.min(t.length-1,a+1);e>n?l=a+1:c=a-1,d++}return-1},e.easeInOutQuad=function(t,e,i,s){var o=i-e;return t/=s/2,1>t?o/2*t*t+e:(t--,-o/2*(t*(t-2)-1)+e)},e.easingFunctions={linear:function(t){return t},easeInQuad:function(t){return t*t},easeOutQuad:function(t){return t*(2-t)},easeInOutQuad:function(t){return.5>t?2*t*t:-1+(4-2*t)*t},easeInCubic:function(t){return t*t*t},easeOutCubic:function(t){return--t*t*t+1},easeInOutCubic:function(t){return.5>t?4*t*t*t:(t-1)*(2*t-2)*(2*t-2)+1},easeInQuart:function(t){return t*t*t*t},easeOutQuart:function(t){return 1- --t*t*t*t},easeInOutQuart:function(t){return.5>t?8*t*t*t*t:1-8*--t*t*t*t},easeInQuint:function(t){return t*t*t*t*t},easeOutQuint:function(t){return 1+--t*t*t*t*t},easeInOutQuint:function(t){return.5>t?16*t*t*t*t*t:1+16*--t*t*t*t*t}}},function(t,e){e.prepareElements=function(t){for(var e in t)t.hasOwnProperty(e)&&(t[e].redundant=t[e].used,t[e].used=[])},e.cleanupElements=function(t){for(var e in t)if(t.hasOwnProperty(e)&&t[e].redundant){for(var i=0;i0?(s=e[t].redundant[0],e[t].redundant.shift()):(s=document.createElementNS("http://www.w3.org/2000/svg",t),i.appendChild(s)):(s=document.createElementNS("http://www.w3.org/2000/svg",t),e[t]={used:[],redundant:[]},i.appendChild(s)),e[t].used.push(s),s},e.getDOMElement=function(t,e,i,s){var o;return e.hasOwnProperty(t)?e[t].redundant.length>0?(o=e[t].redundant[0],e[t].redundant.shift()):(o=document.createElement(t),void 0!==s?i.insertBefore(o,s):i.appendChild(o)):(o=document.createElement(t),e[t]={used:[],redundant:[]},void 0!==s?i.insertBefore(o,s):i.appendChild(o)),e[t].used.push(o),o},e.drawPoint=function(t,i,s,o,n){var r;return"circle"==s.options.drawPoints.style?(r=e.getSVGElement("circle",o,n),r.setAttributeNS(null,"cx",t),r.setAttributeNS(null,"cy",i),r.setAttributeNS(null,"r",.5*s.options.drawPoints.size)):(r=e.getSVGElement("rect",o,n),r.setAttributeNS(null,"x",t-.5*s.options.drawPoints.size),r.setAttributeNS(null,"y",i-.5*s.options.drawPoints.size),r.setAttributeNS(null,"width",s.options.drawPoints.size),r.setAttributeNS(null,"height",s.options.drawPoints.size)),void 0!==s.options.drawPoints.styles&&r.setAttributeNS(null,"style",s.group.options.drawPoints.styles),r.setAttributeNS(null,"class",s.className+" point"),r},e.drawBar=function(t,i,s,o,n,r,a){if(0!=o){0>o&&(o*=-1,i-=o);var h=e.getSVGElement("rect",r,a);h.setAttributeNS(null,"x",t-.5*s),h.setAttributeNS(null,"y",i),h.setAttributeNS(null,"width",s),h.setAttributeNS(null,"height",o),h.setAttributeNS(null,"class",n)}}},function(t,e,i){function s(t,e){if(!t||Array.isArray(t)||o.isDataTable(t)||(e=t,t=null),this._options=e||{},this._data={},this._fieldId=this._options.fieldId||"id",this._type={},this._options.type)for(var i in this._options.type)if(this._options.type.hasOwnProperty(i)){var s=this._options.type[i];this._type[i]="Date"==s||"ISODate"==s||"ASPDate"==s?"Date":s}if(this._options.convert)throw new Error('Option "convert" is deprecated. Use "type" instead.');this._subscribers={},t&&this.add(t),this.setOptions(e)}var o=i(1),n=i(5);s.prototype.setOptions=function(t){t&&void 0!==t.queue&&(t.queue===!1?this._queue&&(this._queue.destroy(),delete this._queue):(this._queue||(this._queue=n.extend(this,{replace:["add","update","remove"]})),"object"==typeof t.queue&&this._queue.setOptions(t.queue)))},s.prototype.on=function(t,e){var i=this._subscribers[t];i||(i=[],this._subscribers[t]=i),i.push({callback:e})},s.prototype.subscribe=s.prototype.on,s.prototype.off=function(t,e){var i=this._subscribers[t];i&&(this._subscribers[t]=i.filter(function(t){return t.callback!=e}))},s.prototype.unsubscribe=s.prototype.off,s.prototype._trigger=function(t,e,i){if("*"==t)throw new Error("Cannot trigger event *");var s=[];t in this._subscribers&&(s=s.concat(this._subscribers[t])),"*"in this._subscribers&&(s=s.concat(this._subscribers["*"]));for(var o=0;or;r++)i=n._addItem(t[r]),s.push(i);else if(o.isDataTable(t))for(var h=this._getColumnNames(t),d=0,l=t.getNumberOfRows();l>d;d++){for(var c={},p=0,u=h.length;u>p;p++){var m=h[p];c[m]=t.getValue(d,p)}i=n._addItem(c),s.push(i)}else{if(!(t instanceof Object))throw new Error("Unknown dataType");i=n._addItem(t),s.push(i)}return s.length&&this._trigger("add",{items:s},e),s},s.prototype.update=function(t,e){var i=[],s=[],n=[],r=this,a=r._fieldId,h=function(t){var e=t[a];r._data[e]?(e=r._updateItem(t),s.push(e),n.push(t)):(e=r._addItem(t),i.push(e))};if(Array.isArray(t))for(var d=0,l=t.length;l>d;d++)h(t[d]);else if(o.isDataTable(t))for(var c=this._getColumnNames(t),p=0,u=t.getNumberOfRows();u>p;p++){for(var m={},f=0,g=c.length;g>f;f++){var v=c[f];m[v]=t.getValue(p,f)}h(m)}else{if(!(t instanceof Object))throw new Error("Unknown dataType");h(t)}return i.length&&this._trigger("add",{items:i},e),s.length&&this._trigger("update",{items:s,data:n},e),i.concat(s)},s.prototype.get=function(){var t,e,i,s,n=this,r=o.getType(arguments[0]);"String"==r||"Number"==r?(t=arguments[0],i=arguments[1],s=arguments[2]):"Array"==r?(e=arguments[0],i=arguments[1],s=arguments[2]):(i=arguments[0],s=arguments[1]);var a;if(i&&i.returnType){var h=["DataTable","Array","Object"];if(a=-1==h.indexOf(i.returnType)?"Array":i.returnType,s&&a!=o.getType(s))throw new Error('Type of parameter "data" ('+o.getType(s)+") does not correspond with specified options.type ("+i.type+")");if("DataTable"==a&&!o.isDataTable(s))throw new Error('Parameter "data" must be a DataTable when options.type is "DataTable"')}else a=s&&"DataTable"==o.getType(s)?"DataTable":"Array";var d,l,c,p,u=i&&i.type||this._options.type,m=i&&i.filter,f=[];if(void 0!=t)d=n._getItem(t,u),m&&!m(d)&&(d=null);else if(void 0!=e)for(c=0,p=e.length;p>c;c++)d=n._getItem(e[c],u),(!m||m(d))&&f.push(d);else for(l in this._data)this._data.hasOwnProperty(l)&&(d=n._getItem(l,u),(!m||m(d))&&f.push(d));if(i&&i.order&&void 0==t&&this._sort(f,i.order),i&&i.fields){var g=i.fields;if(void 0!=t)d=this._filterFields(d,g);else for(c=0,p=f.length;p>c;c++)f[c]=this._filterFields(f[c],g)}if("DataTable"==a){var v=this._getColumnNames(s);if(void 0!=t)n._appendRow(s,v,d);else for(c=0;cc;c++)s.push(f[c]);return s}return f},s.prototype.getIds=function(t){var e,i,s,o,n,r=this._data,a=t&&t.filter,h=t&&t.order,d=t&&t.type||this._options.type,l=[];if(a)if(h){n=[];for(s in r)r.hasOwnProperty(s)&&(o=this._getItem(s,d),a(o)&&n.push(o));for(this._sort(n,h),e=0,i=n.length;i>e;e++)l[e]=n[e][this._fieldId]}else for(s in r)r.hasOwnProperty(s)&&(o=this._getItem(s,d),a(o)&&l.push(o[this._fieldId]));else if(h){n=[];for(s in r)r.hasOwnProperty(s)&&n.push(r[s]);for(this._sort(n,h),e=0,i=n.length;i>e;e++)l[e]=n[e][this._fieldId]}else for(s in r)r.hasOwnProperty(s)&&(o=r[s],l.push(o[this._fieldId]));return l},s.prototype.getDataSet=function(){return this},s.prototype.forEach=function(t,e){var i,s,o=e&&e.filter,n=e&&e.type||this._options.type,r=this._data;if(e&&e.order)for(var a=this.get(e),h=0,d=a.length;d>h;h++)i=a[h],s=i[this._fieldId],t(i,s);else for(s in r)r.hasOwnProperty(s)&&(i=this._getItem(s,n),(!o||o(i))&&t(i,s))},s.prototype.map=function(t,e){var i,s=e&&e.filter,o=e&&e.type||this._options.type,n=[],r=this._data;for(var a in r)r.hasOwnProperty(a)&&(i=this._getItem(a,o),(!s||s(i))&&n.push(t(i,a)));return e&&e.order&&this._sort(n,e.order),n},s.prototype._filterFields=function(t,e){var i={};for(var s in t)t.hasOwnProperty(s)&&-1!=e.indexOf(s)&&(i[s]=t[s]);return i},s.prototype._sort=function(t,e){if(o.isString(e)){var i=e;t.sort(function(t,e){var s=t[i],o=e[i];return s>o?1:o>s?-1:0})}else{if("function"!=typeof e)throw new TypeError("Order must be a function or a string");t.sort(e)}},s.prototype.remove=function(t,e){var i,s,o,n=[];if(Array.isArray(t))for(i=0,s=t.length;s>i;i++)o=this._remove(t[i]),null!=o&&n.push(o);else o=this._remove(t),null!=o&&n.push(o);return n.length&&this._trigger("remove",{items:n},e),n},s.prototype._remove=function(t){if(o.isNumber(t)||o.isString(t)){if(this._data[t])return delete this._data[t],t}else if(t instanceof Object){var e=t[this._fieldId];if(e&&this._data[e])return delete this._data[e],e}return null},s.prototype.clear=function(t){var e=Object.keys(this._data);return this._data={},this._trigger("remove",{items:e},t),e},s.prototype.max=function(t){var e=this._data,i=null,s=null;for(var o in e)if(e.hasOwnProperty(o)){var n=e[o],r=n[t];null!=r&&(!i||r>s)&&(i=n,s=r)}return i},s.prototype.min=function(t){var e=this._data,i=null,s=null;for(var o in e)if(e.hasOwnProperty(o)){var n=e[o],r=n[t];null!=r&&(!i||s>r)&&(i=n,s=r)}return i},s.prototype.distinct=function(t){var e,i=this._data,s=[],n=this._options.type&&this._options.type[t]||null,r=0;for(var a in i)if(i.hasOwnProperty(a)){var h=i[a],d=h[t],l=!1;for(e=0;r>e;e++)if(s[e]==d){l=!0;break}l||void 0===d||(s[r]=d,r++)}if(n)for(e=0;ei;i++)e[i]=t.getColumnId(i)||t.getColumnLabel(i);return e},s.prototype._appendRow=function(t,e,i){for(var s=t.addRow(),o=0,n=e.length;n>o;o++){var r=e[o];t.setValue(s,o,i[r])}},t.exports=s},function(t,e,i){function s(t,e){this._data=null,this._ids={},this._options=e||{},this._fieldId="id",this._subscribers={};var i=this;this.listener=function(){i._onEvent.apply(i,arguments)},this.setData(t)}var o=i(1),n=i(3);s.prototype.setData=function(t){var e,i,s;if(this._data){this._data.unsubscribe&&this._data.unsubscribe("*",this.listener),e=[];for(var o in this._ids)this._ids.hasOwnProperty(o)&&e.push(o);this._ids={},this._trigger("remove",{items:e})}if(this._data=t,this._data){for(this._fieldId=this._options.fieldId||this._data&&this._data.options&&this._data.options.fieldId||"id",e=this._data.getIds({filter:this._options&&this._options.filter}),i=0,s=e.length;s>i;i++)o=e[i],this._ids[o]=!0;this._trigger("add",{items:e}),this._data.on&&this._data.on("*",this.listener)}},s.prototype.get=function(){var t,e,i,s=this,n=o.getType(arguments[0]);"String"==n||"Number"==n||"Array"==n?(t=arguments[0],e=arguments[1],i=arguments[2]):(e=arguments[0],i=arguments[1]);var r=o.extend({},this._options,e);this._options.filter&&e&&e.filter&&(r.filter=function(t){return s._options.filter(t)&&e.filter(t)});var a=[];return void 0!=t&&a.push(t),a.push(r),a.push(i),this._data&&this._data.get.apply(this._data,a)},s.prototype.getIds=function(t){var e;if(this._data){var i,s=this._options.filter;i=t&&t.filter?s?function(e){return s(e)&&t.filter(e)}:t.filter:s,e=this._data.getIds({filter:i,order:t&&t.order})}else e=[];return e},s.prototype.getDataSet=function(){for(var t=this;t instanceof s;)t=t._data;return t||null},s.prototype._onEvent=function(t,e,i){var s,o,n,r,a=e&&e.items,h=this._data,d=[],l=[],c=[];if(a&&h){switch(t){case"add":for(s=0,o=a.length;o>s;s++)n=a[s],r=this.get(n),r&&(this._ids[n]=!0,d.push(n));break;case"update":for(s=0,o=a.length;o>s;s++)n=a[s],r=this.get(n),r?this._ids[n]?l.push(n):(this._ids[n]=!0,d.push(n)):this._ids[n]&&(delete this._ids[n],c.push(n));break;case"remove":for(s=0,o=a.length;o>s;s++)n=a[s],this._ids[n]&&(delete this._ids[n],c.push(n))}d.length&&this._trigger("add",{items:d},i),l.length&&this._trigger("update",{items:l},i),c.length&&this._trigger("remove",{items:c},i)}},s.prototype.on=n.prototype.on,s.prototype.off=n.prototype.off,s.prototype._trigger=n.prototype._trigger,s.prototype.subscribe=s.prototype.on,s.prototype.unsubscribe=s.prototype.off,t.exports=s},function(t){function e(t){this.delay=null,this.max=1/0,this._queue=[],this._timeout=null,this._extended=null,this.setOptions(t)}e.prototype.setOptions=function(t){t&&"undefined"!=typeof t.delay&&(this.delay=t.delay),t&&"undefined"!=typeof t.max&&(this.max=t.max),this._flushIfNeeded()},e.extend=function(t,i){var s=new e(i);if(void 0!==t.flush)throw new Error("Target object already has a property flush");t.flush=function(){s.flush()};var o=[{name:"flush",original:void 0}];if(i&&i.replace)for(var n=0;nthis.max&&this.flush(),clearTimeout(this._timeout),this.queue.length>0&&"number"==typeof this.delay){var t=this;this._timeout=setTimeout(function(){t.flush()},this.delay)}},e.prototype.flush=function(){for(;this._queue.length>0;){var t=this._queue.shift();t.fn.apply(t.context||t.fn,t.args||[])}},t.exports=e},function(t,e,i){function s(t,e,i){if(!(this instanceof s))throw new SyntaxError("Constructor must be called with the new operator");this.containerElement=t,this.width="400px",this.height="400px",this.margin=10,this.defaultXCenter="55%",this.defaultYCenter="50%",this.xLabel="x",this.yLabel="y",this.zLabel="z";var o=function(t){return t};this.xValueLabel=o,this.yValueLabel=o,this.zValueLabel=o,this.filterLabel="time",this.legendLabel="value",this.style=s.STYLE.DOT,this.showPerspective=!0,this.showGrid=!0,this.keepAspectRatio=!0,this.showShadow=!1,this.showGrayBottom=!1,this.showTooltip=!1,this.verticalRatio=.5,this.animationInterval=1e3,this.animationPreload=!1,this.camera=new p,this.eye=new l(0,0,-1),this.dataTable=null,this.dataPoints=null,this.colX=void 0,this.colY=void 0,this.colZ=void 0,this.colValue=void 0,this.colFilter=void 0,this.xMin=0,this.xStep=void 0,this.xMax=1,this.yMin=0,this.yStep=void 0,this.yMax=1,this.zMin=0,this.zStep=void 0,this.zMax=1,this.valueMin=0,this.valueMax=1,this.xBarWidth=1,this.yBarWidth=1,this.colorAxis="#4D4D4D",this.colorGrid="#D3D3D3",this.colorDot="#7DC1FF",this.colorDotBorder="#3267D2",this.create(),this.setOptions(i),e&&this.setData(e)}function o(t){return"clientX"in t?t.clientX:t.targetTouches[0]&&t.targetTouches[0].clientX||0}function n(t){return"clientY"in t?t.clientY:t.targetTouches[0]&&t.targetTouches[0].clientY||0}var r=i(56),a=i(3),h=i(4),d=i(1),l=i(10),c=i(9),p=i(7),u=i(8),m=i(11),f=i(12);r(s.prototype),s.prototype._setScale=function(){this.scale=new l(1/(this.xMax-this.xMin),1/(this.yMax-this.yMin),1/(this.zMax-this.zMin)),this.keepAspectRatio&&(this.scale.x3&&(this.colFilter=3);else{if(this.style!==s.STYLE.DOTCOLOR&&this.style!==s.STYLE.DOTSIZE&&this.style!==s.STYLE.BARCOLOR&&this.style!==s.STYLE.BARSIZE)throw'Unknown style "'+this.style+'"';this.colX=0,this.colY=1,this.colZ=2,this.colValue=3,t.getNumberOfColumns()>4&&(this.colFilter=4)}},s.prototype.getNumberOfRows=function(t){return t.length},s.prototype.getNumberOfColumns=function(t){var e=0;for(var i in t[0])t[0].hasOwnProperty(i)&&e++;return e},s.prototype.getDistinctValues=function(t,e){for(var i=[],s=0;st[s][e]&&(i.min=t[s][e]),i.maxt;t++){var m=(t-p)/(u-p),g=240*m,v=this._hsv2rgb(g,1,1);c.strokeStyle=v,c.beginPath(),c.moveTo(h,r+t),c.lineTo(a,r+t),c.stroke()}c.strokeStyle=this.colorAxis,c.strokeRect(h,r,i,n)}if(this.style===s.STYLE.DOTSIZE&&(c.strokeStyle=this.colorAxis,c.fillStyle=this.colorDot,c.beginPath(),c.moveTo(h,r),c.lineTo(a,r),c.lineTo(a-i+e,d),c.lineTo(h,d),c.closePath(),c.fill(),c.stroke()),this.style===s.STYLE.DOTCOLOR||this.style===s.STYLE.DOTSIZE){var y=5,b=new f(this.valueMin,this.valueMax,(this.valueMax-this.valueMin)/5,!0);for(b.start(),b.getCurrent()0?this.yMin:this.yMax,o=this._convert3Dto2D(new l(x,r,this.zMin)),Math.cos(2*_)>0?(g.textAlign="center",g.textBaseline="top",o.y+=b):Math.sin(2*_)<0?(g.textAlign="right",g.textBaseline="middle"):(g.textAlign="left",g.textBaseline="middle"),g.fillStyle=this.colorAxis,g.fillText(" "+this.xValueLabel(i.getCurrent())+" ",o.x,o.y),i.next()}for(g.lineWidth=1,s=void 0===this.defaultYStep,i=new f(this.yMin,this.yMax,this.yStep,s),i.start(),i.getCurrent()0?this.xMin:this.xMax,o=this._convert3Dto2D(new l(n,i.getCurrent(),this.zMin)),Math.cos(2*_)<0?(g.textAlign="center",g.textBaseline="top",o.y+=b):Math.sin(2*_)>0?(g.textAlign="right",g.textBaseline="middle"):(g.textAlign="left",g.textBaseline="middle"),g.fillStyle=this.colorAxis,g.fillText(" "+this.yValueLabel(i.getCurrent())+" ",o.x,o.y),i.next();for(g.lineWidth=1,s=void 0===this.defaultZStep,i=new f(this.zMin,this.zMax,this.zStep,s),i.start(),i.getCurrent()0?this.xMin:this.xMax,r=Math.sin(_)<0?this.yMin:this.yMax;!i.end();)t=this._convert3Dto2D(new l(n,r,i.getCurrent())),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(t.x,t.y),g.lineTo(t.x-b,t.y),g.stroke(),g.textAlign="right",g.textBaseline="middle",g.fillStyle=this.colorAxis,g.fillText(this.zValueLabel(i.getCurrent())+" ",t.x-5,t.y),i.next();g.lineWidth=1,t=this._convert3Dto2D(new l(n,r,this.zMin)),e=this._convert3Dto2D(new l(n,r,this.zMax)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(t.x,t.y),g.lineTo(e.x,e.y),g.stroke(),g.lineWidth=1,p=this._convert3Dto2D(new l(this.xMin,this.yMin,this.zMin)),u=this._convert3Dto2D(new l(this.xMax,this.yMin,this.zMin)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(p.x,p.y),g.lineTo(u.x,u.y),g.stroke(),p=this._convert3Dto2D(new l(this.xMin,this.yMax,this.zMin)),u=this._convert3Dto2D(new l(this.xMax,this.yMax,this.zMin)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(p.x,p.y),g.lineTo(u.x,u.y),g.stroke(),g.lineWidth=1,t=this._convert3Dto2D(new l(this.xMin,this.yMin,this.zMin)),e=this._convert3Dto2D(new l(this.xMin,this.yMax,this.zMin)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(t.x,t.y),g.lineTo(e.x,e.y),g.stroke(),t=this._convert3Dto2D(new l(this.xMax,this.yMin,this.zMin)),e=this._convert3Dto2D(new l(this.xMax,this.yMax,this.zMin)),g.strokeStyle=this.colorAxis,g.beginPath(),g.moveTo(t.x,t.y),g.lineTo(e.x,e.y),g.stroke();var w=this.xLabel;w.length>0&&(c=.1/this.scale.y,n=(this.xMin+this.xMax)/2,r=Math.cos(_)>0?this.yMin-c:this.yMax+c,o=this._convert3Dto2D(new l(n,r,this.zMin)),Math.cos(2*_)>0?(g.textAlign="center",g.textBaseline="top"):Math.sin(2*_)<0?(g.textAlign="right",g.textBaseline="middle"):(g.textAlign="left",g.textBaseline="middle"),g.fillStyle=this.colorAxis,g.fillText(w,o.x,o.y));var S=this.yLabel;S.length>0&&(d=.1/this.scale.x,n=Math.sin(_)>0?this.xMin-d:this.xMax+d,r=(this.yMin+this.yMax)/2,o=this._convert3Dto2D(new l(n,r,this.zMin)),Math.cos(2*_)<0?(g.textAlign="center",g.textBaseline="top"):Math.sin(2*_)>0?(g.textAlign="right",g.textBaseline="middle"):(g.textAlign="left",g.textBaseline="middle"),g.fillStyle=this.colorAxis,g.fillText(S,o.x,o.y));var M=this.zLabel;M.length>0&&(h=30,n=Math.cos(_)>0?this.xMin:this.xMax,r=Math.sin(_)<0?this.yMin:this.yMax,a=(this.zMin+this.zMax)/2,o=this._convert3Dto2D(new l(n,r,a)),g.textAlign="right",g.textBaseline="middle",g.fillStyle=this.colorAxis,g.fillText(M,o.x-h,o.y))},s.prototype._hsv2rgb=function(t,e,i){var s,o,n,r,a,h;switch(r=i*e,a=Math.floor(t/60),h=r*(1-Math.abs(t/60%2-1)),a){case 0:s=r,o=h,n=0;break;case 1:s=h,o=r,n=0;break;case 2:s=0,o=r,n=h;break;case 3:s=0,o=h,n=r;break;case 4:s=h,o=0,n=r;break;case 5:s=r,o=0,n=h;break;default:s=0,o=0,n=0}return"RGB("+parseInt(255*s)+","+parseInt(255*o)+","+parseInt(255*n)+")"},s.prototype._redrawDataGrid=function(){var t,e,i,o,n,r,a,h,d,c,p,u,m,f=this.frame.canvas,g=f.getContext("2d");if(!(void 0===this.dataPoints||this.dataPoints.length<=0)){for(n=0;n0}else r=!0;r?(m=(t.point.z+e.point.z+i.point.z+o.point.z)/4,c=240*(1-(m-this.zMin)*this.scale.z/this.verticalRatio),p=1,this.showShadow?(u=Math.min(1+S.x/M/2,1),a=this._hsv2rgb(c,p,u),h=a):(u=1,a=this._hsv2rgb(c,p,u),h=this.colorAxis)):(a="gray",h=this.colorAxis),d=.5,g.lineWidth=d,g.fillStyle=a,g.strokeStyle=h,g.beginPath(),g.moveTo(t.screen.x,t.screen.y),g.lineTo(e.screen.x,e.screen.y),g.lineTo(o.screen.x,o.screen.y),g.lineTo(i.screen.x,i.screen.y),g.closePath(),g.fill(),g.stroke()}}else for(n=0;np&&(p=0);var u,m,f;this.style===s.STYLE.DOTCOLOR?(u=240*(1-(d.point.value-this.valueMin)*this.scale.value),m=this._hsv2rgb(u,1,1),f=this._hsv2rgb(u,1,.8)):this.style===s.STYLE.DOTSIZE?(m=this.colorDot,f=this.colorDotBorder):(u=240*(1-(d.point.z-this.zMin)*this.scale.z/this.verticalRatio),m=this._hsv2rgb(u,1,1),f=this._hsv2rgb(u,1,.8)),i.lineWidth=1,i.strokeStyle=f,i.fillStyle=m,i.beginPath(),i.arc(d.screen.x,d.screen.y,p,0,2*Math.PI,!0),i.fill(),i.stroke()}}},s.prototype._redrawDataBar=function(){var t,e,i,o,n=this.frame.canvas,r=n.getContext("2d");if(!(void 0===this.dataPoints||this.dataPoints.length<=0)){for(t=0;t0&&(t=this.dataPoints[0],s.lineWidth=1,s.strokeStyle="blue",s.beginPath(),s.moveTo(t.screen.x,t.screen.y)),e=1;e0&&s.stroke()}},s.prototype._onMouseDown=function(t){if(t=t||window.event,this.leftButtonDown&&this._onMouseUp(t),this.leftButtonDown=t.which?1===t.which:1===t.button,this.leftButtonDown||this.touchDown){this.startMouseX=o(t),this.startMouseY=n(t),this.startStart=new Date(this.start),this.startEnd=new Date(this.end),this.startArmRotation=this.camera.getArmRotation(),this.frame.style.cursor="move";var e=this;this.onmousemove=function(t){e._onMouseMove(t)},this.onmouseup=function(t){e._onMouseUp(t)},d.addEventListener(document,"mousemove",e.onmousemove),d.addEventListener(document,"mouseup",e.onmouseup),d.preventDefault(t)}},s.prototype._onMouseMove=function(t){t=t||window.event;var e=parseFloat(o(t))-this.startMouseX,i=parseFloat(n(t))-this.startMouseY,s=this.startArmRotation.horizontal+e/200,r=this.startArmRotation.vertical+i/200,a=4,h=Math.sin(a/360*2*Math.PI);Math.abs(Math.sin(s))0?1:0>t?-1:0}var s=e[0],o=e[1],n=e[2],r=i((o.x-s.x)*(t.y-s.y)-(o.y-s.y)*(t.x-s.x)),a=i((n.x-o.x)*(t.y-o.y)-(n.y-o.y)*(t.x-o.x)),h=i((s.x-n.x)*(t.y-n.y)-(s.y-n.y)*(t.x-n.x));return!(0!=r&&0!=a&&r!=a||0!=a&&0!=h&&a!=h||0!=r&&0!=h&&r!=h)},s.prototype._dataPointFromXY=function(t,e){var i,o=100,n=null,r=null,a=null,h=new c(t,e);if(this.style===s.STYLE.BAR||this.style===s.STYLE.BARCOLOR||this.style===s.STYLE.BARSIZE)for(i=this.dataPoints.length-1;i>=0;i--){n=this.dataPoints[i];var d=n.surfaces;if(d)for(var l=d.length-1;l>=0;l--){var p=d[l],u=p.corners,m=[u[0].screen,u[1].screen,u[2].screen],f=[u[2].screen,u[3].screen,u[0].screen];if(this._insideTriangle(h,m)||this._insideTriangle(h,f))return n}}else for(i=0;ib)&&o>b&&(a=b,r=n)}}return r},s.prototype._showTooltip=function(t){var e,i,s;this.tooltip?(e=this.tooltip.dom.content,i=this.tooltip.dom.line,s=this.tooltip.dom.dot):(e=document.createElement("div"),e.style.position="absolute",e.style.padding="10px",e.style.border="1px solid #4d4d4d",e.style.color="#1a1a1a",e.style.background="rgba(255,255,255,0.7)",e.style.borderRadius="2px",e.style.boxShadow="5px 5px 10px rgba(128,128,128,0.5)",i=document.createElement("div"),i.style.position="absolute",i.style.height="40px",i.style.width="0",i.style.borderLeft="1px solid #4d4d4d",s=document.createElement("div"),s.style.position="absolute",s.style.height="0",s.style.width="0",s.style.border="5px solid #4d4d4d",s.style.borderRadius="5px",this.tooltip={dataPoint:null,dom:{content:e,line:i,dot:s}}),this._hideTooltip(),this.tooltip.dataPoint=t,e.innerHTML="function"==typeof this.showTooltip?this.showTooltip(t.point):"
x:"+t.point.x+"
y:"+t.point.y+"
z:"+t.point.z+"
",e.style.left="0",e.style.top="0",this.frame.appendChild(e),this.frame.appendChild(i),this.frame.appendChild(s);var o=e.offsetWidth,n=e.offsetHeight,r=i.offsetHeight,a=s.offsetWidth,h=s.offsetHeight,d=t.screen.x-o/2;d=Math.min(Math.max(d,10),this.frame.clientWidth-10-o),i.style.left=t.screen.x+"px",i.style.top=t.screen.y-r+"px",e.style.left=d+"px",e.style.top=t.screen.y-r-n+"px",s.style.left=t.screen.x-a/2+"px",s.style.top=t.screen.y-h/2+"px"},s.prototype._hideTooltip=function(){if(this.tooltip){this.tooltip.dataPoint=null;for(var t in this.tooltip.dom)if(this.tooltip.dom.hasOwnProperty(t)){var e=this.tooltip.dom[t];e&&e.parentNode&&e.parentNode.removeChild(e)}}},t.exports=s},function(t,e,i){function s(){this.armLocation=new o,this.armRotation={},this.armRotation.horizontal=0,this.armRotation.vertical=0,this.armLength=1.7,this.cameraLocation=new o,this.cameraRotation=new o(.5*Math.PI,0,0),this.calculateCameraOrientation()}var o=i(10);s.prototype.setArmLocation=function(t,e,i){this.armLocation.x=t,this.armLocation.y=e,this.armLocation.z=i,this.calculateCameraOrientation()},s.prototype.setArmRotation=function(t,e){void 0!==t&&(this.armRotation.horizontal=t),void 0!==e&&(this.armRotation.vertical=e,this.armRotation.vertical<0&&(this.armRotation.vertical=0),this.armRotation.vertical>.5*Math.PI&&(this.armRotation.vertical=.5*Math.PI)),(void 0!==t||void 0!==e)&&this.calculateCameraOrientation()},s.prototype.getArmRotation=function(){var t={};return t.horizontal=this.armRotation.horizontal,t.vertical=this.armRotation.vertical,t},s.prototype.setArmLength=function(t){void 0!==t&&(this.armLength=t,this.armLength<.71&&(this.armLength=.71),this.armLength>5&&(this.armLength=5),this.calculateCameraOrientation())},s.prototype.getArmLength=function(){return this.armLength},s.prototype.getCameraLocation=function(){return this.cameraLocation},s.prototype.getCameraRotation=function(){return this.cameraRotation},s.prototype.calculateCameraOrientation=function(){this.cameraLocation.x=this.armLocation.x-this.armLength*Math.sin(this.armRotation.horizontal)*Math.cos(this.armRotation.vertical),this.cameraLocation.y=this.armLocation.y-this.armLength*Math.cos(this.armRotation.horizontal)*Math.cos(this.armRotation.vertical),this.cameraLocation.z=this.armLocation.z+this.armLength*Math.sin(this.armRotation.vertical),this.cameraRotation.x=Math.PI/2-this.armRotation.vertical,this.cameraRotation.y=0,this.cameraRotation.z=-this.armRotation.horizontal},t.exports=s},function(t,e,i){function s(t,e,i){this.data=t,this.column=e,this.graph=i,this.index=void 0,this.value=void 0,this.values=i.getDistinctValues(t.get(),this.column),this.values.sort(function(t,e){return t>e?1:e>t?-1:0}),this.values.length>0&&this.selectValue(0),this.dataPoints=[],this.loaded=!1,this.onLoadCallback=void 0,i.animationPreload?(this.loaded=!1,this.loadInBackground()):this.loaded=!0}var o=i(4);s.prototype.isLoaded=function(){return this.loaded},s.prototype.getLoadedProgress=function(){for(var t=this.values.length,e=0;this.dataPoints[e];)e++;return Math.round(e/t*100)},s.prototype.getLabel=function(){return this.graph.filterLabel},s.prototype.getColumn=function(){return this.column},s.prototype.getSelectedValue=function(){return void 0===this.index?void 0:this.values[this.index]},s.prototype.getValues=function(){return this.values},s.prototype.getValue=function(t){if(t>=this.values.length)throw"Error: index out of range";return this.values[t]},s.prototype._getDataPoints=function(t){if(void 0===t&&(t=this.index),void 0===t)return[];var e;if(this.dataPoints[t])e=this.dataPoints[t];else{var i={};i.column=this.column,i.value=this.values[t];var s=new o(this.data,{filter:function(t){return t[i.column]==i.value}}).get();e=this.graph._getDataPoints(s),this.dataPoints[t]=e}return e},s.prototype.setOnLoadCallback=function(t){this.onLoadCallback=t},s.prototype.selectValue=function(t){if(t>=this.values.length)throw"Error: index out of range";this.index=t,this.value=this.values[t]},s.prototype.loadInBackground=function(t){void 0===t&&(t=0);var e=this.graph.frame;if(t0&&(t--,this.setIndex(t))},s.prototype.next=function(){var t=this.getIndex();t0?this.setIndex(0):this.index=void 0},s.prototype.setIndex=function(t){if(!(ts&&(s=0),s>this.values.length-1&&(s=this.values.length-1),s},s.prototype.indexToLeft=function(t){var e=parseFloat(this.frame.bar.style.width)-this.frame.slide.clientWidth-10,i=t/(this.values.length-1)*e,s=i+3;return s},s.prototype._onMouseMove=function(t){var e=t.clientX-this.startClientX,i=this.startSlideX+e,s=this.leftToIndex(i);this.setIndex(s),o.preventDefault()},s.prototype._onMouseUp=function(){this.frame.style.cursor="auto",o.removeEventListener(document,"mousemove",this.onmousemove),o.removeEventListener(document,"mouseup",this.onmouseup),o.preventDefault()},t.exports=s},function(t){function e(t,e,i,s){this._start=0,this._end=0,this._step=1,this.prettyStep=!0,this.precision=5,this._current=0,this.setRange(t,e,i,s)}e.prototype.setRange=function(t,e,i,s){this._start=t?t:0,this._end=e?e:0,this.setStep(i,s)},e.prototype.setStep=function(t,i){void 0===t||0>=t||(void 0!==i&&(this.prettyStep=i),this._step=this.prettyStep===!0?e.calculatePrettyStep(t):t)},e.calculatePrettyStep=function(t){var e=function(t){return Math.log(t)/Math.LN10},i=Math.pow(10,Math.round(e(t))),s=2*Math.pow(10,Math.round(e(t/2))),o=5*Math.pow(10,Math.round(e(t/5))),n=i;return Math.abs(s-t)<=Math.abs(n-t)&&(n=s),Math.abs(o-t)<=Math.abs(n-t)&&(n=o),0>=n&&(n=1),n},e.prototype.getCurrent=function(){return parseFloat(this._current.toPrecision(this.precision))},e.prototype.getStep=function(){return this._step},e.prototype.start=function(){this._current=this._start-this._start%this._step},e.prototype.next=function(){this._current+=this._step},e.prototype.end=function(){return this._current>this._end},t.exports=e},function(t,e,i){function s(t,e,i,r){if(!(this instanceof s))throw new SyntaxError("Constructor must be called with the new operator");if(!(Array.isArray(i)||i instanceof n)&&i instanceof Object){var h=r;r=i,i=h}var u=this;this.defaultOptions={start:null,end:null,autoResize:!0,orientation:"bottom",width:null,height:null,maxHeight:null,minHeight:null},this.options=o.deepExtend({},this.defaultOptions),this._create(t),this.components=[],this.body={dom:this.dom,domProps:this.props,emitter:{on:this.on.bind(this),off:this.off.bind(this),emit:this.emit.bind(this)},hiddenDates:[],util:{snap:null,toScreen:u._toScreen.bind(u),toGlobalScreen:u._toGlobalScreen.bind(u),toTime:u._toTime.bind(u),toGlobalTime:u._toGlobalTime.bind(u)}},this.range=new a(this.body),this.components.push(this.range),this.body.range=this.range,this.timeAxis=new d(this.body),this.components.push(this.timeAxis),this.body.util.snap=this.timeAxis.snap.bind(this.timeAxis),this.currentTime=new l(this.body),this.components.push(this.currentTime),this.customTime=new c(this.body),this.components.push(this.customTime),this.itemSet=new p(this.body),this.components.push(this.itemSet),this.itemsData=null,this.groupsData=null,r&&this.setOptions(r),i&&this.setGroups(i),e?this.setItems(e):this.redraw()}var o=(i(56),i(45),i(1)),n=i(3),r=i(4),a=i(17),h=i(46),d=i(30),l=i(21),c=i(22),p=i(27);s.prototype=new h,s.prototype.setItems=function(t){var e,i=null==this.itemsData;if(e=t?t instanceof n||t instanceof r?t:new n(t,{type:{start:"Date",end:"Date"}}):null,this.itemsData=e,this.itemSet&&this.itemSet.setItems(e),i)if(void 0!=this.options.start||void 0!=this.options.end){if(void 0==this.options.start||void 0==this.options.end)var s=this._getDataRange();var o=void 0!=this.options.start?this.options.start:s.start,a=void 0!=this.options.end?this.options.end:s.end;this.setWindow(o,a,{animate:!1})}else this.fit({animate:!1})},s.prototype.setGroups=function(t){var e;e=t?t instanceof n||t instanceof r?t:new n(t):null,this.groupsData=e,this.itemSet.setGroups(e)},s.prototype.setSelection=function(t,e){this.itemSet&&this.itemSet.setSelection(t),e&&e.focus&&this.focus(t,e)},s.prototype.getSelection=function(){return this.itemSet&&this.itemSet.getSelection()||[]},s.prototype.focus=function(t,e){if(this.itemsData&&void 0!=t){var i=Array.isArray(t)?t:[t],s=this.itemsData.getDataSet().get(i,{type:{start:"Date",end:"Date"}}),o=null,n=null;if(s.forEach(function(t){var e=t.start.valueOf(),i="end"in t?t.end.valueOf():t.start.valueOf();(null===o||o>e)&&(o=e),(null===n||i>n)&&(n=i)}),null!==o&&null!==n){var r=(o+n)/2,a=Math.max(this.range.end-this.range.start,1.1*(n-o)),h=e&&void 0!==e.animate?e.animate:!0;this.range.setRange(r-a/2,r+a/2,h)}}},s.prototype.getItemRange=function(){var t=this.itemsData.getDataSet(),e=null,i=null;if(t){var s=t.min("start");e=s?o.convert(s.start,"Date").valueOf():null;var n=t.max("start");n&&(i=o.convert(n.start,"Date").valueOf());var r=t.max("end");r&&(i=null==i?o.convert(r.end,"Date").valueOf():Math.max(i,o.convert(r.end,"Date").valueOf()))}return{min:null!=e?new Date(e):null,max:null!=i?new Date(i):null}},t.exports=s},function(t,e,i){function s(t,e,i,s){if(!(Array.isArray(i)||i instanceof n)&&i instanceof Object){var r=s;s=i,i=r}var h=this;this.defaultOptions={start:null,end:null,autoResize:!0,orientation:"bottom",width:null,height:null,maxHeight:null,minHeight:null},this.options=o.deepExtend({},this.defaultOptions),this._create(t),this.components=[],this.body={dom:this.dom,domProps:this.props,emitter:{on:this.on.bind(this),off:this.off.bind(this),emit:this.emit.bind(this)},hiddenDates:[],util:{snap:null,toScreen:h._toScreen.bind(h),toGlobalScreen:h._toGlobalScreen.bind(h),toTime:h._toTime.bind(h),toGlobalTime:h._toGlobalTime.bind(h)}},this.range=new a(this.body),this.components.push(this.range),this.body.range=this.range,this.timeAxis=new d(this.body),this.components.push(this.timeAxis),this.body.util.snap=this.timeAxis.snap.bind(this.timeAxis),this.currentTime=new l(this.body),this.components.push(this.currentTime),this.customTime=new c(this.body),this.components.push(this.customTime),this.linegraph=new p(this.body),this.components.push(this.linegraph),this.itemsData=null,this.groupsData=null,s&&this.setOptions(s),i&&this.setGroups(i),e?this.setItems(e):this.redraw()}var o=(i(56),i(45),i(1)),n=i(3),r=i(4),a=i(17),h=i(46),d=i(30),l=i(21),c=i(22),p=i(29);s.prototype=new h,s.prototype.setItems=function(t){var e,i=null==this.itemsData;if(e=t?t instanceof n||t instanceof r?t:new n(t,{type:{start:"Date",end:"Date"}}):null,this.itemsData=e,this.linegraph&&this.linegraph.setItems(e),i)if(void 0!=this.options.start||void 0!=this.options.end){var s=void 0!=this.options.start?this.options.start:null,o=void 0!=this.options.end?this.options.end:null;this.setWindow(s,o,{animate:!1})}else this.fit({animate:!1})},s.prototype.setGroups=function(t){var e;e=t?t instanceof n||t instanceof r?t:new n(t):null,this.groupsData=e,this.linegraph.setGroups(e)},s.prototype.getLegend=function(t,e,i){return void 0===e&&(e=15),void 0===i&&(i=15),void 0!==this.linegraph.groups[t]?this.linegraph.groups[t].getLegend(e,i):"cannot find group:"+t},s.prototype.isGroupVisible=function(t){return void 0!==this.linegraph.groups[t]?this.linegraph.groups[t].visible&&(void 0===this.linegraph.options.groups.visibility[t]||1==this.linegraph.options.groups.visibility[t]):!1},s.prototype.getItemRange=function(){var t=null,e=null;for(var i in this.linegraph.groups)if(this.linegraph.groups.hasOwnProperty(i)&&1==this.linegraph.groups[i].visible)for(var s=0;sr?r:t,e=null==e?r:r>e?r:e}return{min:null!=t?new Date(t):null,max:null!=e?new Date(e):null}},t.exports=s},function(t,e,i){var s=i(44);e.convertHiddenOptions=function(t,e){if(t.hiddenDates=[],e&&1==Array.isArray(e)){for(var i=0;i=4*a){var p=0,u=n.clone();switch(i[h].repeat){case"daily":d.day()!=l.day()&&(p=1),d.dayOfYear(o.dayOfYear()),d.year(o.year()),d.subtract(7,"days"),l.dayOfYear(o.dayOfYear()),l.year(o.year()),l.subtract(7-p,"days"),u.add(1,"weeks");break;case"weekly":var m=l.diff(d,"days"),f=d.day();d.date(o.date()),d.month(o.month()),d.year(o.year()),l=d.clone(),d.day(f),l.day(f),l.add(m,"days"),d.subtract(1,"weeks"),l.subtract(1,"weeks"),u.add(1,"weeks");break;case"monthly":d.month()!=l.month()&&(p=1),d.month(o.month()),d.year(o.year()),d.subtract(1,"months"),l.month(o.month()),l.year(o.year()),l.subtract(1,"months"),l.add(p,"months"),u.add(1,"months");break;case"yearly":d.year()!=l.year()&&(p=1),d.year(o.year()),d.subtract(1,"years"),l.year(o.year()),l.subtract(1,"years"),l.add(p,"years"),u.add(1,"years");break;default:return void console.log("Wrong repeat format, allowed are: daily, weekly, monthly, yearly. Given:",i[h].repeat)}for(;u>d;)switch(t.hiddenDates.push({start:d.valueOf(),end:l.valueOf()}),i[h].repeat){case"daily":d.add(1,"days"),l.add(1,"days");break;case"weekly":d.add(1,"weeks"),l.add(1,"weeks");break;case"monthly":d.add(1,"months"),l.add(1,"months");break;case"yearly":d.add(1,"y"),l.add(1,"y");break;default:return void console.log("Wrong repeat format, allowed are: daily, weekly, monthly, yearly. Given:",i[h].repeat)}t.hiddenDates.push({start:d.valueOf(),end:l.valueOf()})}}e.removeDuplicates(t);var g=e.isHidden(t.range.start,t.hiddenDates),v=e.isHidden(t.range.end,t.hiddenDates),y=t.range.start,b=t.range.end;1==g.hidden&&(y=1==t.range.startToFront?g.startDate-1:g.endDate+1),1==v.hidden&&(b=1==t.range.endToFront?v.startDate-1:v.endDate+1),(1==g.hidden||1==v.hidden)&&t.range._applyRange(y,b)}},e.removeDuplicates=function(t){for(var e=t.hiddenDates,i=[],s=0;s=e[s].start&&e[o].end<=e[s].end?e[o].remove=!0:e[o].start>=e[s].start&&e[o].start<=e[s].end?(e[s].end=e[o].end,e[o].remove=!0):e[o].end>=e[s].start&&e[o].end<=e[s].end&&(e[s].start=e[o].start,e[o].remove=!0));for(var s=0;s=r&&a>o){i=!0;break}}if(1==i&&o=e&&i>r&&(s+=r-n)}return s},e.correctTimeForHidden=function(t,i,o){return o=s(o).toDate().valueOf(),o-=e.getHiddenDurationBefore(t,i,o)},e.getHiddenDurationBefore=function(t,e,i){var o=0;i=s(i).toDate().valueOf();for(var n=0;n=e.start&&a=a&&(o+=a-r)}return o},e.getAccumulatedHiddenDuration=function(t,e,i){for(var s=0,o=0,n=e.start,r=0;r=e.start&&h=i)break;s+=h-a}}return s},e.snapAwayFromHidden=function(t,i,s,o){var n=e.isHidden(i,t);return 1==n.hidden?0>s?1==o?n.startDate-(n.endDate-i)-1:n.startDate-1:1==o?n.endDate+(i-n.startDate)+1:n.endDate+1:i},e.isHidden=function(t,e){for(var i=0;i=s&&o>t)return{hidden:!0,startDate:s,endDate:o}}return{hidden:!1,startDate:s,endDate:o}}},function(t){function e(t,e,i,s,o,n){this.current=0,this.autoScale=!0,this.stepIndex=0,this.step=1,this.scale=1,this.marginStart,this.marginEnd,this.deadSpace=0,this.majorSteps=[1,2,5,10],this.minorSteps=[.25,.5,1,2],this.alignZeros=n,this.setRange(t,e,i,s,o)}e.prototype.setRange=function(t,e,i,s,o){this._start=void 0===o.min?t:o.min,this._end=void 0===o.max?e:o.max,this._start==this._end&&(this._start-=.75,this._end+=1),1==this.autoScale&&this.setMinimumStep(i,s),this.setFirst(o)},e.prototype.setMinimumStep=function(t,e){var i=this._end-this._start,s=1.2*i,o=t*(s/e),n=Math.round(Math.log(s)/Math.LN10),r=-1,a=Math.pow(10,n),h=0;0>n&&(h=n);for(var d=!1,l=h;Math.abs(l)<=Math.abs(n);l++){a=Math.pow(10,l);for(var c=0;c=o){d=!0,r=c;break}}if(1==d)break}this.stepIndex=r,this.scale=a,this.step=a*this.minorSteps[r]},e.prototype.setFirst=function(t){void 0===t&&(t={});var e=void 0===t.min?this._start-2*this.scale*this.minorSteps[this.stepIndex]:t.min,i=void 0===t.max?this._end+this.scale*this.minorSteps[this.stepIndex]:t.max;this.marginEnd=void 0===t.max?this.roundToMinor(i):t.max,this.marginStart=void 0===t.min?this.roundToMinor(e):t.min,1==this.alignZeros&&(this.marginEnd-this.marginStart)%this.step!=0&&(this.marginEnd+=this.marginEnd%this.step),this.deadSpace=this.roundToMinor(i)-i+this.roundToMinor(e)-e,this.marginRange=this.marginEnd-this.marginStart,this.current=this.marginEnd},e.prototype.roundToMinor=function(t){var e=t-t%(this.scale*this.minorSteps[this.stepIndex]);return t%(this.scale*this.minorSteps[this.stepIndex])>.5*this.scale*this.minorSteps[this.stepIndex]?e+this.scale*this.minorSteps[this.stepIndex]:e},e.prototype.hasNext=function(){return this.current>=this.marginStart},e.prototype.next=function(){var t=this.current;this.current-=this.step,this.current==t&&(this.current=this._end)},e.prototype.previous=function(){this.current+=this.step,this.marginEnd+=this.step,this.marginRange=this.marginEnd-this.marginStart},e.prototype.getCurrent=function(t){var e=Math.abs(this.current)0;s--){if("0"!=i[s]){if("."==i[s]||","==i[s]){i=i.slice(0,s);break}break}i=i.slice(0,s)}}else{var o="",n=i.indexOf("e");if(-1!=n&&(o=i.slice(n),i=i.slice(0,n)),n=Math.max(i.indexOf(","),i.indexOf(".")),-1===n?(0!==t&&(i+="."),n=i.length+t):0!==t&&(n+=t+1),n>i.length)for(var r=n-i.length;r>0;r--)i+="0";else i=i.slice(0,n);i+=o}return i},e.prototype.snap=function(){},e.prototype.isMajor=function(){return this.current%(this.scale*this.majorSteps[this.stepIndex])==0},t.exports=e},function(t,e,i){function s(t,e){var i=a().hours(0).minutes(0).seconds(0).milliseconds(0);this.start=i.clone().add(-3,"days").valueOf(),this.end=i.clone().add(4,"days").valueOf(),this.body=t,this.deltaDifference=0,this.scaleOffset=0,this.startToFront=!1,this.endToFront=!0,this.defaultOptions={start:null,end:null,direction:"horizontal",moveable:!0,zoomable:!0,min:null,max:null,zoomMin:10,zoomMax:31536e10},this.options=r.extend({},this.defaultOptions),this.props={touch:{}},this.animateTimer=null,this.body.emitter.on("panstart",this._onDragStart.bind(this)),this.body.emitter.on("panmove",this._onDrag.bind(this)),this.body.emitter.on("panend",this._onDragEnd.bind(this)),this.body.emitter.on("press",this._onHold.bind(this)),this.body.emitter.on("mousewheel",this._onMouseWheel.bind(this)),this.body.emitter.on("touch",this._onTouch.bind(this)),this.body.emitter.on("pinch",this._onPinch.bind(this)),this.setOptions(e)}function o(t){if("horizontal"!=t&&"vertical"!=t)throw new TypeError('Unknown direction "'+t+'". Choose "horizontal" or "vertical".')}function n(t,e){return{x:t.x-r.getAbsoluteLeft(e),y:t.y-r.getAbsoluteTop(e)}}var r=i(1),a=(i(47),i(44)),h=i(20),d=i(15);s.prototype=new h,s.prototype.setOptions=function(t){if(t){var e=["direction","min","max","zoomMin","zoomMax","moveable","zoomable","activate","hiddenDates"];r.selectiveExtend(e,this.options,t),("start"in t||"end"in t)&&this.setRange(t.start,t.end)}},s.prototype.setRange=function(t,e,i,s){s!==!0&&(s=!1);var o=void 0!=t?r.convert(t,"Date").valueOf():null,n=void 0!=e?r.convert(e,"Date").valueOf():null;if(this._cancelAnimation(),i){var a=this,h=this.start,l=this.end,c="number"==typeof i?i:500,p=(new Date).valueOf(),u=!1,m=function(){if(!a.props.touch.dragging){var t=(new Date).valueOf(),e=t-p,i=e>c,g=i||null===o?o:r.easeInOutQuad(e,h,o,c),v=i||null===n?n:r.easeInOutQuad(e,l,n,c);f=a._applyRange(g,v),d.updateHiddenDates(a.body,a.options.hiddenDates),u=u||f,f&&a.body.emitter.emit("rangechange",{start:new Date(a.start),end:new Date(a.end),byUser:s}),i?u&&a.body.emitter.emit("rangechanged",{start:new Date(a.start),end:new Date(a.end),byUser:s}):a.animateTimer=setTimeout(m,20)}};return m()}var f=this._applyRange(o,n);if(d.updateHiddenDates(this.body,this.options.hiddenDates),f){var g={start:new Date(this.start),end:new Date(this.end),byUser:s};this.body.emitter.emit("rangechange",g),this.body.emitter.emit("rangechanged",g)}},s.prototype._cancelAnimation=function(){this.animateTimer&&(clearTimeout(this.animateTimer),this.animateTimer=null)},s.prototype._applyRange=function(t,e){var i,s=null!=t?r.convert(t,"Date").valueOf():this.start,o=null!=e?r.convert(e,"Date").valueOf():this.end,n=null!=this.options.max?r.convert(this.options.max,"Date").valueOf():null,a=null!=this.options.min?r.convert(this.options.min,"Date").valueOf():null;if(isNaN(s)||null===s)throw new Error('Invalid start "'+t+'"');if(isNaN(o)||null===o)throw new Error('Invalid end "'+e+'"');if(s>o&&(o=s),null!==a&&a>s&&(i=a-s,s+=i,o+=i,null!=n&&o>n&&(o=n)),null!==n&&o>n&&(i=o-n,s-=i,o-=i,null!=a&&a>s&&(s=a)),null!==this.options.zoomMin){var h=parseFloat(this.options.zoomMin);0>h&&(h=0),h>o-s&&(this.end-this.start===h?(s=this.start,o=this.end):(i=h-(o-s),s-=i/2,o+=i/2))}if(null!==this.options.zoomMax){var d=parseFloat(this.options.zoomMax);0>d&&(d=0),o-s>d&&(this.end-this.start===d?(s=this.start,o=this.end):(i=o-s-d,s+=i/2,o-=i/2))}var l=this.start!=s||this.end!=o;return s>=this.start&&s<=this.end||o>=this.start&&o<=this.end||this.start>=s&&this.start<=o||this.end>=s&&this.end<=o||this.body.emitter.emit("checkRangedItems"),this.start=s,this.end=o,l},s.prototype.getRange=function(){return{start:this.start,end:this.end}},s.prototype.conversion=function(t,e){return s.conversion(this.start,this.end,t,e)},s.conversion=function(t,e,i,s){return void 0===s&&(s=0),0!=i&&e-t!=0?{offset:t,scale:i/(e-t-s)}:{offset:0,scale:1}},s.prototype._onDragStart=function(t){this.deltaDifference=0,this.previousDelta=0,this.options.moveable&&this.props.touch.allowDragging&&(this.props.touch.start=this.start,this.props.touch.end=this.end,this.props.touch.dragging=!0,this.body.dom.root&&(this.body.dom.root.style.cursor="move"),t.preventDefault())},s.prototype._onDrag=function(t){if(this.options.moveable&&this.props.touch.allowDragging){var e=this.options.direction;o(e);var i="horizontal"==e?t.deltaX:t.deltaY;i-=this.deltaDifference;var s=this.props.touch.end-this.props.touch.start,n=d.getHiddenDurationBetween(this.body.hiddenDates,this.start,this.end);s-=n;var r="horizontal"==e?this.body.domProps.center.width:this.body.domProps.center.height,a=-i/r*s,h=this.props.touch.start+a,l=this.props.touch.end+a,c=d.snapAwayFromHidden(this.body.hiddenDates,h,this.previousDelta-i,!0),p=d.snapAwayFromHidden(this.body.hiddenDates,l,this.previousDelta-i,!0);if(c!=h||p!=l)return this.deltaDifference+=i,this.props.touch.start=c,this.props.touch.end=p,void this._onDrag(t);this.previousDelta=i,this._applyRange(h,l),this.body.emitter.emit("rangechange",{start:new Date(this.start),end:new Date(this.end),byUser:!0}),t.preventDefault()}},s.prototype._onDragEnd=function(){this.options.moveable&&this.props.touch.allowDragging&&(this.props.touch.dragging=!1,this.body.dom.root&&(this.body.dom.root.style.cursor="auto"),this.body.emitter.emit("rangechanged",{start:new Date(this.start),end:new Date(this.end),byUser:!0}))},s.prototype._onMouseWheel=function(t){if(this.options.zoomable&&this.options.moveable){var e=0;if(t.wheelDelta?e=t.wheelDelta/120:t.detail&&(e=-t.detail/3),e){var i;i=0>e?1-e/5:1/(1+e/5);var s=n({x:t.pageX,y:t.pageY},this.body.dom.center),o=this._pointerToDate(s);this.zoom(i,o,e)}t.preventDefault()}},s.prototype._onTouch=function(){this.props.touch.start=this.start,this.props.touch.end=this.end,this.props.touch.allowDragging=!0,this.props.touch.center=null,this.scaleOffset=0,this.deltaDifference=0},s.prototype._onHold=function(){this.props.touch.allowDragging=!1},s.prototype._onPinch=function(t){if(this.options.zoomable&&this.options.moveable){this.props.touch.allowDragging=!1,this.props.touch.center||(this.props.touch.center=n(t.center,this.body.dom.center));var e=1/(t.scale+this.scaleOffset),i=this._pointerToDate(this.props.touch.center),s=d.getHiddenDurationBetween(this.body.hiddenDates,this.start,this.end),o=d.getHiddenDurationBefore(this.body.hiddenDates,this,i),r=s-o,a=i-o+(this.props.touch.start-(i-o))*e,h=i+r+(this.props.touch.end-(i+r))*e;this.startToFront=0>=1-e,this.endToFront=0>=e-1;var l=d.snapAwayFromHidden(this.body.hiddenDates,a,1-e,!0),c=d.snapAwayFromHidden(this.body.hiddenDates,h,e-1,!0);(l!=a||c!=h)&&(this.props.touch.start=l,this.props.touch.end=c,this.scaleOffset=1-t.scale,a=l,h=c),this.setRange(a,h,!1,!0),this.startToFront=!1,this.endToFront=!0,t.preventDefault()}},s.prototype._pointerToDate=function(t){var e,i=this.options.direction;if(o(i),"horizontal"==i)return this.body.util.toTime(t.x).valueOf();var s=this.body.domProps.center.height;return e=this.conversion(s),t.y/e.scale+e.offset},s.prototype.zoom=function(t,e,i){null==e&&(e=(this.start+this.end)/2);var s=d.getHiddenDurationBetween(this.body.hiddenDates,this.start,this.end),o=d.getHiddenDurationBefore(this.body.hiddenDates,this,e),n=s-o,r=e-o+(this.start-(e-o))*t,a=e+n+(this.end-(e+n))*t;this.startToFront=i>0?!1:!0,this.endToFront=-i>0?!1:!0;var h=d.snapAwayFromHidden(this.body.hiddenDates,r,i,!0),l=d.snapAwayFromHidden(this.body.hiddenDates,a,-i,!0);(h!=r||l!=a)&&(r=h,a=l),this.setRange(r,a,!1,!0),this.startToFront=!1,this.endToFront=!0},s.prototype.move=function(t){var e=this.end-this.start,i=this.start+e*t,s=this.end+e*t;this.start=i,this.end=s},s.prototype.moveTo=function(t){var e=(this.start+this.end)/2,i=e-t,s=this.start-i,o=this.end-i;this.setRange(s,o)},t.exports=s},function(t,e){var i=.001;e.orderByStart=function(t){t.sort(function(t,e){return t.data.start-e.data.start})},e.orderByEnd=function(t){t.sort(function(t,e){var i="end"in t.data?t.data.end:t.data.start,s="end"in e.data?e.data.end:e.data.start;return i-s})},e.stack=function(t,i,s){var o,n;if(s)for(o=0,n=t.length;n>o;o++)t[o].top=null;for(o=0,n=t.length;n>o;o++){var r=t[o];if(r.stack&&null===r.top){r.top=i.axis;do{for(var a=null,h=0,d=t.length;d>h;h++){var l=t[h];if(null!==l.top&&l!==r&&l.stack&&e.collision(r,l,i.item)){a=l;break}}null!=a&&(r.top=a.top+a.height+i.item.vertical)}while(a)}}},e.nostack=function(t,e,i){var s,o,n;for(s=0,o=t.length;o>s;s++)if(void 0!==t[s].data.subgroup){n=e.axis;for(var r in i)i.hasOwnProperty(r)&&1==i[r].visible&&i[r].indexe.left&&t.top-s.vertical+ie.top}},function(t,e,i){function s(t,e,i,o){this.current=new Date,this._start=new Date,this._end=new Date,this.autoScale=!0,this.scale="day",this.step=1,this.setRange(t,e,i),this.switchedDay=!1,this.switchedMonth=!1,this.switchedYear=!1,this.hiddenDates=o,void 0===o&&(this.hiddenDates=[]),this.format=s.FORMAT}var o=i(44),n=i(15),r=i(1);s.FORMAT={minorLabels:{millisecond:"SSS",second:"s",minute:"HH:mm",hour:"HH:mm",weekday:"ddd D",day:"D",month:"MMM",year:"YYYY"},majorLabels:{millisecond:"HH:mm:ss",second:"D MMMM HH:mm",minute:"ddd D MMMM",hour:"ddd D MMMM",weekday:"MMMM YYYY",day:"MMMM YYYY",month:"YYYY",year:""}},s.prototype.setFormat=function(t){var e=r.deepExtend({},s.FORMAT);this.format=r.deepExtend(e,t)},s.prototype.setRange=function(t,e,i){if(!(t instanceof Date&&e instanceof Date))throw"No legal start or end date in method setRange";this._start=void 0!=t?new Date(t.valueOf()):new Date,this._end=void 0!=e?new Date(e.valueOf()):new Date,this.autoScale&&this.setMinimumStep(i)},s.prototype.first=function(){this.current=new Date(this._start.valueOf()),this.roundToMinor()},s.prototype.roundToMinor=function(){switch(this.scale){case"year":this.current.setFullYear(this.step*Math.floor(this.current.getFullYear()/this.step)),this.current.setMonth(0);case"month":this.current.setDate(1);case"day":case"weekday":this.current.setHours(0);case"hour":this.current.setMinutes(0);case"minute":this.current.setSeconds(0);case"second":this.current.setMilliseconds(0)}if(1!=this.step)switch(this.scale){case"millisecond":this.current.setMilliseconds(this.current.getMilliseconds()-this.current.getMilliseconds()%this.step);break;case"second":this.current.setSeconds(this.current.getSeconds()-this.current.getSeconds()%this.step);break;case"minute":this.current.setMinutes(this.current.getMinutes()-this.current.getMinutes()%this.step);break;case"hour":this.current.setHours(this.current.getHours()-this.current.getHours()%this.step);break;case"weekday":case"day":this.current.setDate(this.current.getDate()-1-(this.current.getDate()-1)%this.step+1);break;case"month":this.current.setMonth(this.current.getMonth()-this.current.getMonth()%this.step);break;case"year":this.current.setFullYear(this.current.getFullYear()-this.current.getFullYear()%this.step)}},s.prototype.hasNext=function(){return this.current.valueOf()<=this._end.valueOf()},s.prototype.next=function(){var t=this.current.valueOf();if(this.current.getMonth()<6)switch(this.scale){case"millisecond":this.current=new Date(this.current.valueOf()+this.step);break;case"second":this.current=new Date(this.current.valueOf()+1e3*this.step);break;case"minute":this.current=new Date(this.current.valueOf()+1e3*this.step*60);break;case"hour":this.current=new Date(this.current.valueOf()+1e3*this.step*60*60);var e=this.current.getHours();this.current.setHours(e-e%this.step);break;case"weekday":case"day":this.current.setDate(this.current.getDate()+this.step);break;case"month":this.current.setMonth(this.current.getMonth()+this.step);break;case"year":this.current.setFullYear(this.current.getFullYear()+this.step)}else switch(this.scale){case"millisecond":this.current=new Date(this.current.valueOf()+this.step);break;case"second":this.current.setSeconds(this.current.getSeconds()+this.step);break;case"minute":this.current.setMinutes(this.current.getMinutes()+this.step); -break;case"hour":this.current.setHours(this.current.getHours()+this.step);break;case"weekday":case"day":this.current.setDate(this.current.getDate()+this.step);break;case"month":this.current.setMonth(this.current.getMonth()+this.step);break;case"year":this.current.setFullYear(this.current.getFullYear()+this.step)}if(1!=this.step)switch(this.scale){case"millisecond":this.current.getMilliseconds()0&&(this.step=e),this.autoScale=!1},s.prototype.setAutoScale=function(t){this.autoScale=t},s.prototype.setMinimumStep=function(t){if(void 0!=t){var e=31104e6,i=2592e6,s=864e5,o=36e5,n=6e4,r=1e3,a=1;1e3*e>t&&(this.scale="year",this.step=1e3),500*e>t&&(this.scale="year",this.step=500),100*e>t&&(this.scale="year",this.step=100),50*e>t&&(this.scale="year",this.step=50),10*e>t&&(this.scale="year",this.step=10),5*e>t&&(this.scale="year",this.step=5),e>t&&(this.scale="year",this.step=1),3*i>t&&(this.scale="month",this.step=3),i>t&&(this.scale="month",this.step=1),5*s>t&&(this.scale="day",this.step=5),2*s>t&&(this.scale="day",this.step=2),s>t&&(this.scale="day",this.step=1),s/2>t&&(this.scale="weekday",this.step=1),4*o>t&&(this.scale="hour",this.step=4),o>t&&(this.scale="hour",this.step=1),15*n>t&&(this.scale="minute",this.step=15),10*n>t&&(this.scale="minute",this.step=10),5*n>t&&(this.scale="minute",this.step=5),n>t&&(this.scale="minute",this.step=1),15*r>t&&(this.scale="second",this.step=15),10*r>t&&(this.scale="second",this.step=10),5*r>t&&(this.scale="second",this.step=5),r>t&&(this.scale="second",this.step=1),200*a>t&&(this.scale="millisecond",this.step=200),100*a>t&&(this.scale="millisecond",this.step=100),50*a>t&&(this.scale="millisecond",this.step=50),10*a>t&&(this.scale="millisecond",this.step=10),5*a>t&&(this.scale="millisecond",this.step=5),a>t&&(this.scale="millisecond",this.step=1)}},s.prototype.snap=function(t){var e=new Date(t.valueOf());if("year"==this.scale){var i=e.getFullYear()+Math.round(e.getMonth()/12);e.setFullYear(Math.round(i/this.step)*this.step),e.setMonth(0),e.setDate(0),e.setHours(0),e.setMinutes(0),e.setSeconds(0),e.setMilliseconds(0)}else if("month"==this.scale)e.getDate()>15?(e.setDate(1),e.setMonth(e.getMonth()+1)):e.setDate(1),e.setHours(0),e.setMinutes(0),e.setSeconds(0),e.setMilliseconds(0);else if("day"==this.scale){switch(this.step){case 5:case 2:e.setHours(24*Math.round(e.getHours()/24));break;default:e.setHours(12*Math.round(e.getHours()/12))}e.setMinutes(0),e.setSeconds(0),e.setMilliseconds(0)}else if("weekday"==this.scale){switch(this.step){case 5:case 2:e.setHours(12*Math.round(e.getHours()/12));break;default:e.setHours(6*Math.round(e.getHours()/6))}e.setMinutes(0),e.setSeconds(0),e.setMilliseconds(0)}else if("hour"==this.scale){switch(this.step){case 4:e.setMinutes(60*Math.round(e.getMinutes()/60));break;default:e.setMinutes(30*Math.round(e.getMinutes()/30))}e.setSeconds(0),e.setMilliseconds(0)}else if("minute"==this.scale){switch(this.step){case 15:case 10:e.setMinutes(5*Math.round(e.getMinutes()/5)),e.setSeconds(0);break;case 5:e.setSeconds(60*Math.round(e.getSeconds()/60));break;default:e.setSeconds(30*Math.round(e.getSeconds()/30))}e.setMilliseconds(0)}else if("second"==this.scale)switch(this.step){case 15:case 10:e.setSeconds(5*Math.round(e.getSeconds()/5)),e.setMilliseconds(0);break;case 5:e.setMilliseconds(1e3*Math.round(e.getMilliseconds()/1e3));break;default:e.setMilliseconds(500*Math.round(e.getMilliseconds()/500))}else if("millisecond"==this.scale){var s=this.step>5?this.step/2:1;e.setMilliseconds(Math.round(e.getMilliseconds()/s)*s)}return e},s.prototype.isMajor=function(){if(1==this.switchedYear)switch(this.switchedYear=!1,this.scale){case"year":case"month":case"weekday":case"day":case"hour":case"minute":case"second":case"millisecond":return!0;default:return!1}else if(1==this.switchedMonth)switch(this.switchedMonth=!1,this.scale){case"weekday":case"day":case"hour":case"minute":case"second":case"millisecond":return!0;default:return!1}else if(1==this.switchedDay)switch(this.switchedDay=!1,this.scale){case"millisecond":case"second":case"minute":case"hour":return!0;default:return!1}switch(this.scale){case"millisecond":return 0==this.current.getMilliseconds();case"second":return 0==this.current.getSeconds();case"minute":return 0==this.current.getHours()&&0==this.current.getMinutes();case"hour":return 0==this.current.getHours();case"weekday":case"day":return 1==this.current.getDate();case"month":return 0==this.current.getMonth();case"year":return!1;default:return!1}},s.prototype.getLabelMinor=function(t){void 0==t&&(t=this.current);var e=this.format.minorLabels[this.scale];return e&&e.length>0?o(t).format(e):""},s.prototype.getLabelMajor=function(t){void 0==t&&(t=this.current);var e=this.format.majorLabels[this.scale];return e&&e.length>0?o(t).format(e):""},s.prototype.getClassName=function(){function t(t){return t/h%2==0?" even":" odd"}function e(t){return t.isSame(new Date,"day")?" today":t.isSame(o().add(1,"day"),"day")?" tomorrow":t.isSame(o().add(-1,"day"),"day")?" yesterday":""}function i(t){return t.isSame(new Date,"week")?" current-week":""}function s(t){return t.isSame(new Date,"month")?" current-month":""}function n(t){return t.isSame(new Date,"year")?" current-year":""}var r=o(this.current),a=r.locale?r.locale("en"):r.lang("en"),h=this.step;switch(this.scale){case"millisecond":return t(a.milliseconds()).trim();case"second":return t(a.seconds()).trim();case"minute":return t(a.minutes()).trim();case"hour":var d=a.hours();return 4==this.step&&(d=d+"-"+(d+4)),d+"h"+e(a)+t(a.hours());case"weekday":return a.format("dddd").toLowerCase()+e(a)+i(a)+t(a.date());case"day":var l=a.date(),c=a.format("MMMM").toLowerCase();return"day"+l+" "+c+s(a)+t(l-1);case"month":return a.format("MMMM").toLowerCase()+s(a)+t(a.month());case"year":var p=a.year();return"year"+p+n(a)+t(p);default:return""}},t.exports=s},function(t){function e(){this.options=null,this.props=null}e.prototype.setOptions=function(t){t&&util.extend(this.options,t)},e.prototype.redraw=function(){return!1},e.prototype.destroy=function(){},e.prototype._isResized=function(){var t=this.props._previousWidth!==this.props.width||this.props._previousHeight!==this.props.height;return this.props._previousWidth=this.props.width,this.props._previousHeight=this.props.height,t},t.exports=e},function(t,e,i){function s(t,e){this.body=t,this.defaultOptions={showCurrentTime:!0,locales:a,locale:"en"},this.options=o.extend({},this.defaultOptions),this.offset=0,this._create(),this.setOptions(e)}var o=i(1),n=i(20),r=i(44),a=i(48);s.prototype=new n,s.prototype._create=function(){var t=document.createElement("div");t.className="currenttime",t.style.position="absolute",t.style.top="0px",t.style.height="100%",this.bar=t},s.prototype.destroy=function(){this.options.showCurrentTime=!1,this.redraw(),this.body=null},s.prototype.setOptions=function(t){t&&o.selectiveExtend(["showCurrentTime","locale","locales"],this.options,t)},s.prototype.redraw=function(){if(this.options.showCurrentTime){var t=this.body.dom.backgroundVertical;this.bar.parentNode!=t&&(this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar),t.appendChild(this.bar),this.start());var e=new Date((new Date).valueOf()+this.offset),i=this.body.util.toScreen(e),s=this.options.locales[this.options.locale],o=s.current+" "+s.time+": "+r(e).format("dddd, MMMM Do YYYY, H:mm:ss");o=o.charAt(0).toUpperCase()+o.substring(1),this.bar.style.left=i+"px",this.bar.title=o}else this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar),this.stop();return!1},s.prototype.start=function(){function t(){e.stop();var i=e.body.range.conversion(e.body.domProps.center.width).scale,s=1/i/10;30>s&&(s=30),s>1e3&&(s=1e3),e.redraw(),e.currentTimeTimer=setTimeout(t,s)}var e=this;t()},s.prototype.stop=function(){void 0!==this.currentTimeTimer&&(clearTimeout(this.currentTimeTimer),delete this.currentTimeTimer)},s.prototype.setCurrentTime=function(t){var e=o.convert(t,"Date").valueOf(),i=(new Date).valueOf();this.offset=e-i,this.redraw()},s.prototype.getCurrentTime=function(){return new Date((new Date).valueOf()+this.offset)},t.exports=s},function(t,e,i){function s(t,e){this.body=t,this.defaultOptions={showCustomTime:!1,locales:h,locale:"en"},this.options=n.extend({},this.defaultOptions),this.customTime=new Date,this.eventParams={},this._create(),this.setOptions(e)}var o=i(45),n=i(1),r=i(20),a=i(44),h=i(48);s.prototype=new r,s.prototype.setOptions=function(t){t&&n.selectiveExtend(["showCustomTime","locale","locales"],this.options,t)},s.prototype._create=function(){var t=document.createElement("div");t.className="customtime",t.style.position="absolute",t.style.top="0px",t.style.height="100%",this.bar=t;var e=document.createElement("div");e.style.position="relative",e.style.top="0px",e.style.left="-10px",e.style.height="100%",e.style.width="20px",t.appendChild(e),this.hammer=new o(e),this.hammer.on("panstart",this._onDragStart.bind(this)),this.hammer.on("panmove",this._onDrag.bind(this)),this.hammer.on("panend",this._onDragEnd.bind(this)),this.hammer.on("pan",function(t){t.preventDefault()})},s.prototype.destroy=function(){this.options.showCustomTime=!1,this.redraw(),this.hammer.enable(!1),this.hammer=null,this.body=null},s.prototype.redraw=function(){if(this.options.showCustomTime){var t=this.body.dom.backgroundVertical;this.bar.parentNode!=t&&(this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar),t.appendChild(this.bar));var e=this.body.util.toScreen(this.customTime),i=this.options.locales[this.options.locale],s=i.time+": "+a(this.customTime).format("dddd, MMMM Do YYYY, H:mm:ss");s=s.charAt(0).toUpperCase()+s.substring(1),this.bar.style.left=e+"px",this.bar.title=s}else this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar);return!1},s.prototype.setCustomTime=function(t){this.customTime=n.convert(t,"Date"),this.redraw()},s.prototype.getCustomTime=function(){return new Date(this.customTime.valueOf())},s.prototype._onDragStart=function(t){this.eventParams.dragging=!0,this.eventParams.customTime=this.customTime,t.stopPropagation(),t.preventDefault()},s.prototype._onDrag=function(t){if(this.eventParams.dragging){var e=this.body.util.toScreen(this.eventParams.customTime)+t.deltaX,i=this.body.util.toTime(e);this.setCustomTime(i),this.body.emitter.emit("timechange",{time:new Date(this.customTime.valueOf())}),t.stopPropagation(),t.preventDefault()}},s.prototype._onDragEnd=function(t){this.eventParams.dragging&&(this.body.emitter.emit("timechanged",{time:new Date(this.customTime.valueOf())}),t.stopPropagation(),t.preventDefault())},t.exports=s},function(t,e,i){function s(t,e,i,s){this.id=o.randomUUID(),this.body=t,this.defaultOptions={orientation:"left",showMinorLabels:!0,showMajorLabels:!0,icons:!0,majorLinesOffset:7,minorLinesOffset:4,labelOffsetX:10,labelOffsetY:2,iconWidth:20,width:"40px",visible:!0,alignZeros:!0,customRange:{left:{min:void 0,max:void 0},right:{min:void 0,max:void 0}},title:{left:{text:void 0},right:{text:void 0}},format:{left:{decimals:void 0},right:{decimals:void 0}}},this.linegraphOptions=s,this.linegraphSVG=i,this.props={},this.DOMelements={lines:{},labels:{},title:{}},this.dom={},this.range={start:0,end:0},this.options=o.extend({},this.defaultOptions),this.conversionFactor=1,this.setOptions(e),this.width=Number((""+this.options.width).replace("px","")),this.minWidth=this.width,this.height=this.linegraphSVG.offsetHeight,this.hidden=!1,this.stepPixels=25,this.stepPixelsForced=25,this.zeroCrossing=-1,this.lineOffset=0,this.master=!0,this.svgElements={},this.iconsRemoved=!1,this.groups={},this.amountOfGroups=0,this._create();var n=this;this.body.emitter.on("verticalDrag",function(){n.dom.lineContainer.style.top=n.body.domProps.scrollTop+"px"})}var o=i(1),n=i(2),r=i(20),a=i(16);s.prototype=new r,s.prototype.addGroup=function(t,e){this.groups.hasOwnProperty(t)||(this.groups[t]=e),this.amountOfGroups+=1},s.prototype.updateGroup=function(t,e){this.groups[t]=e},s.prototype.removeGroup=function(t){this.groups.hasOwnProperty(t)&&(delete this.groups[t],this.amountOfGroups-=1)},s.prototype.setOptions=function(t){if(t){var e=!1;this.options.orientation!=t.orientation&&void 0!==t.orientation&&(e=!0);var i=["orientation","showMinorLabels","showMajorLabels","icons","majorLinesOffset","minorLinesOffset","labelOffsetX","labelOffsetY","iconWidth","width","visible","customRange","title","format","alignZeros"];o.selectiveExtend(i,this.options,t),this.minWidth=Number((""+this.options.width).replace("px","")),1==e&&this.dom.frame&&(this.hide(),this.show())}},s.prototype._create=function(){this.dom.frame=document.createElement("div"),this.dom.frame.style.width=this.options.width,this.dom.frame.style.height=this.height,this.dom.lineContainer=document.createElement("div"),this.dom.lineContainer.style.width="100%",this.dom.lineContainer.style.height=this.height,this.dom.lineContainer.style.position="relative",this.svg=document.createElementNS("http://www.w3.org/2000/svg","svg"),this.svg.style.position="absolute",this.svg.style.top="0px",this.svg.style.height="100%",this.svg.style.width="100%",this.svg.style.display="block",this.dom.frame.appendChild(this.svg)},s.prototype._redrawGroupIcons=function(){n.prepareElements(this.svgElements);var t,e=this.options.iconWidth,i=15,s=4,o=s+.5*i;t="left"==this.options.orientation?s:this.width-e-s;for(var r in this.groups)this.groups.hasOwnProperty(r)&&(1!=this.groups[r].visible||void 0!==this.linegraphOptions.visibility[r]&&1!=this.linegraphOptions.visibility[r]||(this.groups[r].drawIcon(t,o,this.svgElements,this.svg,e,i),o+=i+s));n.cleanupElements(this.svgElements),this.iconsRemoved=!1},s.prototype._cleanupIcons=function(){0==this.iconsRemoved&&(n.prepareElements(this.svgElements),n.cleanupElements(this.svgElements),this.iconsRemoved=!0)},s.prototype.show=function(){this.hidden=!1,this.dom.frame.parentNode||("left"==this.options.orientation?this.body.dom.left.appendChild(this.dom.frame):this.body.dom.right.appendChild(this.dom.frame)),this.dom.lineContainer.parentNode||this.body.dom.backgroundHorizontal.appendChild(this.dom.lineContainer)},s.prototype.hide=function(){this.hidden=!0,this.dom.frame.parentNode&&this.dom.frame.parentNode.removeChild(this.dom.frame),this.dom.lineContainer.parentNode&&this.dom.lineContainer.parentNode.removeChild(this.dom.lineContainer)},s.prototype.setRange=function(t,e){0==this.master&&1==this.options.alignZeros&&-1!=this.zeroCrossing&&t>0&&(t=0),this.range.start=t,this.range.end=e},s.prototype.redraw=function(){var t=!1,e=0;this.dom.lineContainer.style.top=this.body.domProps.scrollTop+"px";for(var i in this.groups)this.groups.hasOwnProperty(i)&&(1!=this.groups[i].visible||void 0!==this.linegraphOptions.visibility[i]&&1!=this.linegraphOptions.visibility[i]||e++);if(0==this.amountOfGroups||0==e)this.hide();else{this.show(),this.height=Number(this.linegraphSVG.style.height.replace("px","")),this.dom.lineContainer.style.height=this.height+"px",this.width=1==this.options.visible?Number((""+this.options.width).replace("px","")):0;var s=this.props,o=this.dom.frame;o.className="dataaxis",this._calculateCharSize();var n=this.options.orientation,r=this.options.showMinorLabels,a=this.options.showMajorLabels;s.minorLabelHeight=r?s.minorCharHeight:0,s.majorLabelHeight=a?s.majorCharHeight:0,s.minorLineWidth=this.body.dom.backgroundHorizontal.offsetWidth-this.lineOffset-this.width+2*this.options.minorLinesOffset,s.minorLineHeight=1,s.majorLineWidth=this.body.dom.backgroundHorizontal.offsetWidth-this.lineOffset-this.width+2*this.options.majorLinesOffset,s.majorLineHeight=1,"left"==n?(o.style.top="0",o.style.left="0",o.style.bottom="",o.style.width=this.width+"px",o.style.height=this.height+"px",this.props.width=this.body.domProps.left.width,this.props.height=this.body.domProps.left.height):(o.style.top="",o.style.bottom="0",o.style.left="0",o.style.width=this.width+"px",o.style.height=this.height+"px",this.props.width=this.body.domProps.right.width,this.props.height=this.body.domProps.right.height),t=this._redrawLabels(),t=this._isResized()||t,1==this.options.icons?this._redrawGroupIcons():this._cleanupIcons(),this._redrawTitle(n)}return t},s.prototype._redrawLabels=function(){var t=!1;n.prepareElements(this.DOMelements.lines),n.prepareElements(this.DOMelements.labels);var e=this.options.orientation,i=this.master?this.props.majorCharHeight||10:this.stepPixelsForced,s=new a(this.range.start,this.range.end,i,this.dom.frame.offsetHeight,this.options.customRange[this.options.orientation],0==this.master&&this.options.alignZeros);this.step=s;var o=(this.dom.frame.offsetHeight-s.deadSpace*(this.dom.frame.offsetHeight/s.marginRange))/((s.marginRange-s.deadSpace)/s.step);this.stepPixels=o;var r=this.height/o,h=0;if(0==this.master){o=this.stepPixelsForced,h=Math.round(this.dom.frame.offsetHeight/o-r);for(var d=0;.5*h>d;d++)s.previous();if(r=this.height/o,-1!=this.zeroCrossing&&1==this.options.alignZeros){var l=s.marginEnd/s.step-this.zeroCrossing;if(l>0)for(var d=0;l>d;d++)s.next();else if(0>l)for(var d=0;-l>d;d++)s.previous()}}else r+=.25;this.valueAtZero=s.marginEnd;var c,p=0,u=1;void 0!==this.options.format[e]&&(c=this.options.format[e].decimals),this.maxLabelSize=0;for(var m=0;u=0&&this._redrawLabel(m-2,s.getCurrent(c),e,"yAxis major",this.props.majorCharHeight),this._redrawLine(m,e,"grid horizontal major",this.options.majorLinesOffset,this.props.majorLineWidth)):this._redrawLine(m,e,"grid horizontal minor",this.options.minorLinesOffset,this.props.minorLineWidth),1==this.master&&0==s.current&&(this.zeroCrossing=u),u++}this.conversionFactor=0==this.master?m/(this.valueAtZero-s.current):this.dom.frame.offsetHeight/s.marginRange;var g=0;void 0!==this.options.title[e]&&void 0!==this.options.title[e].text&&(g=this.props.titleCharHeight);var v=1==this.options.icons?Math.max(this.options.iconWidth,g)+this.options.labelOffsetX+15:g+this.options.labelOffsetX+15;return this.maxLabelSize>this.width-v&&1==this.options.visible?(this.width=this.maxLabelSize+v,this.options.width=this.width+"px",n.cleanupElements(this.DOMelements.lines),n.cleanupElements(this.DOMelements.labels),this.redraw(),t=!0):this.maxLabelSizethis.minWidth?(this.width=Math.max(this.minWidth,this.maxLabelSize+v),this.options.width=this.width+"px",n.cleanupElements(this.DOMelements.lines),n.cleanupElements(this.DOMelements.labels),this.redraw(),t=!0):(n.cleanupElements(this.DOMelements.lines),n.cleanupElements(this.DOMelements.labels),t=!1),t},s.prototype.convertValue=function(t){var e=this.valueAtZero-t,i=e*this.conversionFactor;return i},s.prototype._redrawLabel=function(t,e,i,s,o){var r=n.getDOMElement("div",this.DOMelements.labels,this.dom.frame);r.className=s,r.innerHTML=e,"left"==i?(r.style.left="-"+this.options.labelOffsetX+"px",r.style.textAlign="right"):(r.style.right="-"+this.options.labelOffsetX+"px",r.style.textAlign="left"),r.style.top=t-.5*o+this.options.labelOffsetY+"px",e+="";var a=Math.max(this.props.majorCharWidth,this.props.minorCharWidth);this.maxLabelSized;d++){var c=this.visibleItems[d];c.repositionY(e)}return s},s.prototype._calculateHeight=function(t){var e,i=this.visibleItems;this.resetSubgroups();var s=this;if(i.length){var n=i[0].top,r=i[0].top+i[0].height;if(o.forEach(i,function(t){n=Math.min(n,t.top),r=Math.max(r,t.top+t.height),void 0!==t.data.subgroup&&(s.subgroups[t.data.subgroup].height=Math.max(s.subgroups[t.data.subgroup].height,t.height),s.subgroups[t.data.subgroup].visible=!0)}),n>t.axis){var a=n-t.axis;r-=a,o.forEach(i,function(t){t.top-=a})}e=r+t.item.vertical/2}else e=t.axis+t.item.vertical;return e=Math.max(e,this.props.label.height)},s.prototype.show=function(){this.dom.label.parentNode||this.itemSet.dom.labelSet.appendChild(this.dom.label),this.dom.foreground.parentNode||this.itemSet.dom.foreground.appendChild(this.dom.foreground),this.dom.background.parentNode||this.itemSet.dom.background.appendChild(this.dom.background),this.dom.axis.parentNode||this.itemSet.dom.axis.appendChild(this.dom.axis)},s.prototype.hide=function(){var t=this.dom.label;t.parentNode&&t.parentNode.removeChild(t);var e=this.dom.foreground;e.parentNode&&e.parentNode.removeChild(e);var i=this.dom.background;i.parentNode&&i.parentNode.removeChild(i);var s=this.dom.axis;s.parentNode&&s.parentNode.removeChild(s)},s.prototype.add=function(t){if(this.items[t.id]=t,t.setParent(this),void 0!==t.data.subgroup&&(void 0===this.subgroups[t.data.subgroup]&&(this.subgroups[t.data.subgroup]={height:0,visible:!1,index:this.subgroupIndex,items:[]},this.subgroupIndex++),this.subgroups[t.data.subgroup].items.push(t)),this.orderSubgroups(),-1==this.visibleItems.indexOf(t)){var e=this.itemSet.body.range;this._checkIfVisible(t,this.visibleItems,e)}},s.prototype.orderSubgroups=function(){if(void 0!==this.subgroupOrderer){var t=[];if("string"==typeof this.subgroupOrderer){for(var e in this.subgroups)t.push({subgroup:e,sortField:this.subgroups[e].items[0].data[this.subgroupOrderer]});t.sort(function(t,e){return t.sortField-e.sortField})}else if("function"==typeof this.subgroupOrderer){for(var e in this.subgroups)t.push(this.subgroups[e].items[0].data);t.sort(this.subgroupOrderer)}if(t.length>0)for(var i=0;it?-1:l>=t?0:1};if(e.length>0)for(n=0;nl}),1==this.checkRangedItems)for(this.checkRangedItems=!1,n=0;nl})}for(n=0;n=0&&(n=e[r],!o(n));r--)void 0===s[n.id]&&(s[n.id]=!0,i.push(n));for(r=t+1;rs;s++){var n=this.visibleItems[s];n.repositionY(e)}return i},s.prototype.show=function(){this.dom.background.parentNode||this.itemSet.dom.background.appendChild(this.dom.background)},t.exports=s},function(t,e,i){function s(t,e){this.body=t,this.defaultOptions={type:null,orientation:"bottom",align:"auto",stack:!0,groupOrder:null,selectable:!0,editable:{updateTime:!1,updateGroup:!1,add:!1,remove:!1},onAdd:function(t,e){e(t)},onUpdate:function(t,e){e(t)},onMove:function(t,e){e(t)},onRemove:function(t,e){e(t)},onMoving:function(t,e){e(t)},margin:{item:{horizontal:10,vertical:10},axis:20},padding:5},this.options=n.extend({},this.defaultOptions),this.itemOptions={type:{start:"Date",end:"Date"}},this.conversion={toScreen:t.util.toScreen,toTime:t.util.toTime},this.dom={},this.props={},this.hammer=null;var i=this;this.itemsData=null,this.groupsData=null,this.itemListeners={add:function(t,e){i._onAdd(e.items)},update:function(t,e){i._onUpdate(e.items)},remove:function(t,e){i._onRemove(e.items)}},this.groupListeners={add:function(t,e){i._onAddGroups(e.items)},update:function(t,e){i._onUpdateGroups(e.items)},remove:function(t,e){i._onRemoveGroups(e.items)}},this.items={},this.groups={},this.groupIds=[],this.selection=[],this.stackDirty=!0,this.touchParams={},this._create(),this.setOptions(e)}var o=i(45),n=i(1),r=i(3),a=i(4),h=i(20),d=i(25),l=i(26),c=i(33),p=i(34),u=i(35),m=i(32),f="__ungrouped__",g="__background__";s.prototype=new h,s.types={background:m,box:c,range:u,point:p},s.prototype._create=function(){var t=document.createElement("div");t.className="itemset",t["timeline-itemset"]=this,this.dom.frame=t;var e=document.createElement("div");e.className="background",t.appendChild(e),this.dom.background=e;var i=document.createElement("div");i.className="foreground",t.appendChild(i),this.dom.foreground=i;var s=document.createElement("div");s.className="axis",this.dom.axis=s;var n=document.createElement("div");n.className="labelset",this.dom.labelSet=n,this._updateUngrouped();var r=new l(g,null,this);r.show(),this.groups[g]=r,this.hammer=new o(this.body.dom.centerContainer),this.hammer.on("hammer.input",function(t){t.isFirst&&this._onTouch(t)}.bind(this)),this.hammer.on("panstart",this._onDragStart.bind(this)),this.hammer.on("panmove",this._onDrag.bind(this)),this.hammer.on("panend",this._onDragEnd.bind(this)),this.hammer.on("tap",this._onSelectItem.bind(this)),this.hammer.on("press",this._onMultiSelectItem.bind(this)),this.hammer.on("doubletap",this._onAddItem.bind(this)),this.show()},s.prototype.setOptions=function(t){if(t){var e=["type","align","orientation","padding","stack","selectable","groupOrder","dataAttributes","template","hide"];n.selectiveExtend(e,this.options,t),"margin"in t&&("number"==typeof t.margin?(this.options.margin.axis=t.margin,this.options.margin.item.horizontal=t.margin,this.options.margin.item.vertical=t.margin):"object"==typeof t.margin&&(n.selectiveExtend(["axis"],this.options.margin,t.margin),"item"in t.margin&&("number"==typeof t.margin.item?(this.options.margin.item.horizontal=t.margin.item,this.options.margin.item.vertical=t.margin.item):"object"==typeof t.margin.item&&n.selectiveExtend(["horizontal","vertical"],this.options.margin.item,t.margin.item)))),"editable"in t&&("boolean"==typeof t.editable?(this.options.editable.updateTime=t.editable,this.options.editable.updateGroup=t.editable,this.options.editable.add=t.editable,this.options.editable.remove=t.editable):"object"==typeof t.editable&&n.selectiveExtend(["updateTime","updateGroup","add","remove"],this.options.editable,t.editable));var i=function(e){var i=t[e];if(i){if(!(i instanceof Function))throw new Error("option "+e+" must be a function "+e+"(item, callback)");this.options[e]=i}}.bind(this);["onAdd","onUpdate","onRemove","onMove","onMoving"].forEach(i),this.markDirty()}},s.prototype.markDirty=function(){this.groupIds=[],this.stackDirty=!0},s.prototype.destroy=function(){this.hide(),this.setItems(null),this.setGroups(null),this.hammer=null,this.body=null,this.conversion=null},s.prototype.hide=function(){this.dom.frame.parentNode&&this.dom.frame.parentNode.removeChild(this.dom.frame),this.dom.axis.parentNode&&this.dom.axis.parentNode.removeChild(this.dom.axis),this.dom.labelSet.parentNode&&this.dom.labelSet.parentNode.removeChild(this.dom.labelSet)},s.prototype.show=function(){this.dom.frame.parentNode||this.body.dom.center.appendChild(this.dom.frame),this.dom.axis.parentNode||this.body.dom.backgroundVertical.appendChild(this.dom.axis),this.dom.labelSet.parentNode||this.body.dom.left.appendChild(this.dom.labelSet)},s.prototype.setSelection=function(t){var e,i,s,o;for(void 0==t&&(t=[]),Array.isArray(t)||(t=[t]),e=0,i=this.selection.length;i>e;e++)s=this.selection[e],o=this.items[s],o&&o.unselect();for(this.selection=[],e=0,i=t.length;i>e;e++)s=t[e],o=this.items[s],o&&(this.selection.push(s),o.select())},s.prototype.getSelection=function(){return this.selection.concat([])},s.prototype.getVisibleItems=function(){var t=this.body.range.getRange(),e=this.body.util.toScreen(t.start),i=this.body.util.toScreen(t.end),s=[];for(var o in this.groups)if(this.groups.hasOwnProperty(o))for(var n=this.groups[o],r=n.visibleItems,a=0;ae&&s.push(h.id)}return s},s.prototype._deselect=function(t){for(var e=this.selection,i=0,s=e.length;s>i;i++)if(e[i]==t){e.splice(i,1);break}},s.prototype.redraw=function(){var t=this.options.margin,e=this.body.range,i=n.option.asSize,s=this.options,o=s.orientation,r=!1,a=this.dom.frame,h=s.editable.updateTime||s.editable.updateGroup;this.props.top=this.body.domProps.top.height+this.body.domProps.border.top,this.props.left=this.body.domProps.left.width+this.body.domProps.border.left,a.className="itemset"+(h?" editable":""),r=this._orderGroups()||r;var d=e.end-e.start,l=d!=this.lastVisibleInterval||this.props.width!=this.props.lastWidth;l&&(this.stackDirty=!0),this.lastVisibleInterval=d,this.props.lastWidth=this.props.width;var c=this.stackDirty,p=this._firstGroup(),u={item:t.item,axis:t.axis},m={item:t.item,axis:t.item.vertical/2},f=0,v=t.axis+t.item.vertical;return this.groups[g].redraw(e,m,c),n.forEach(this.groups,function(t){var i=t==p?u:m,s=t.redraw(e,i,c);r=s||r,f+=t.height}),f=Math.max(f,v),this.stackDirty=!1,a.style.height=i(f),this.props.width=a.offsetWidth,this.props.height=f,this.dom.axis.style.top=i("top"==o?this.body.domProps.top.height+this.body.domProps.border.top:this.body.domProps.top.height+this.body.domProps.centerContainer.height),this.dom.axis.style.left="0",r=this._isResized()||r},s.prototype._firstGroup=function(){var t="top"==this.options.orientation?0:this.groupIds.length-1,e=this.groupIds[t],i=this.groups[e]||this.groups[f];return i||null},s.prototype._updateUngrouped=function(){{var t,e,i=this.groups[f];this.groups[g]}if(this.groupsData){if(i){i.hide(),delete this.groups[f];for(e in this.items)if(this.items.hasOwnProperty(e)){t=this.items[e],t.parent&&t.parent.remove(t);var s=this._getGroupId(t.data),o=this.groups[s];o&&o.add(t)||t.hide()}}}else if(!i){var n=null,r=null;i=new d(n,r,this),this.groups[f]=i;for(e in this.items)this.items.hasOwnProperty(e)&&(t=this.items[e],i.add(t));i.show()}},s.prototype.getLabelSet=function(){return this.dom.labelSet},s.prototype.setItems=function(t){var e,i=this,s=this.itemsData;if(t){if(!(t instanceof r||t instanceof a))throw new TypeError("Data must be an instance of DataSet or DataView");this.itemsData=t}else this.itemsData=null;if(s&&(n.forEach(this.itemListeners,function(t,e){s.off(e,t)}),e=s.getIds(),this._onRemove(e)),this.itemsData){var o=this.id;n.forEach(this.itemListeners,function(t,e){i.itemsData.on(e,t,o)}),e=this.itemsData.getIds(),this._onAdd(e),this._updateUngrouped()}},s.prototype.getItems=function(){return this.itemsData},s.prototype.setGroups=function(t){var e,i=this;if(this.groupsData&&(n.forEach(this.groupListeners,function(t,e){i.groupsData.unsubscribe(e,t)}),e=this.groupsData.getIds(),this.groupsData=null,this._onRemoveGroups(e)),t){if(!(t instanceof r||t instanceof a))throw new TypeError("Data must be an instance of DataSet or DataView");this.groupsData=t}else this.groupsData=null;if(this.groupsData){var s=this.id;n.forEach(this.groupListeners,function(t,e){i.groupsData.on(e,t,s)}),e=this.groupsData.getIds(),this._onAddGroups(e)}this._updateUngrouped(),this._order(),this.body.emitter.emit("change",{queue:!0})},s.prototype.getGroups=function(){return this.groupsData},s.prototype.removeItem=function(t){var e=this.itemsData.get(t),i=this.itemsData.getDataSet();e&&this.options.onRemove(e,function(e){e&&i.remove(t)})},s.prototype._getType=function(t){return t.type||this.options.type||(t.end?"range":"box")},s.prototype._getGroupId=function(t){var e=this._getType(t);return"background"==e&&void 0==t.group?g:this.groupsData?t.group:f},s.prototype._onUpdate=function(t){var e=this;t.forEach(function(t){var i=e.itemsData.get(t,e.itemOptions),o=e.items[t],n=e._getType(i),r=s.types[n];if(o&&(r&&o instanceof r?e._updateItem(o,i):(e._removeItem(o),o=null)),!o){if(!r)throw new TypeError("rangeoverflow"==n?'Item type "rangeoverflow" is deprecated. Use css styling instead: .vis.timeline .item.range .content {overflow: visible;}':'Unknown item type "'+n+'"');o=new r(i,e.conversion,e.options),o.id=t,e._addItem(o)}}),this._order(),this.stackDirty=!0,this.body.emitter.emit("change",{queue:!0})},s.prototype._onAdd=s.prototype._onUpdate,s.prototype._onRemove=function(t){var e=0,i=this;t.forEach(function(t){var s=i.items[t];s&&(e++,i._removeItem(s))}),e&&(this._order(),this.stackDirty=!0,this.body.emitter.emit("change",{queue:!0}))},s.prototype._order=function(){n.forEach(this.groups,function(t){t.order()})},s.prototype._onUpdateGroups=function(t){this._onAddGroups(t)},s.prototype._onAddGroups=function(t){var e=this;t.forEach(function(t){var i=e.groupsData.get(t),s=e.groups[t];if(s)s.setData(i);else{if(t==f||t==g)throw new Error("Illegal group id. "+t+" is a reserved id.");var o=Object.create(e.options);n.extend(o,{height:null}),s=new d(t,i,e),e.groups[t]=s;for(var r in e.items)if(e.items.hasOwnProperty(r)){var a=e.items[r];a.data.group==t&&s.add(a)}s.order(),s.show()}}),this.body.emitter.emit("change",{queue:!0})},s.prototype._onRemoveGroups=function(t){var e=this.groups;t.forEach(function(t){var i=e[t];i&&(i.hide(),delete e[t])}),this.markDirty(),this.body.emitter.emit("change",{queue:!0})},s.prototype._orderGroups=function(){if(this.groupsData){var t=this.groupsData.getIds({order:this.options.groupOrder}),e=!n.equalArray(t,this.groupIds);if(e){var i=this.groups;t.forEach(function(t){i[t].hide()}),t.forEach(function(t){i[t].show()}),this.groupIds=t}return e}return!1},s.prototype._addItem=function(t){this.items[t.id]=t;var e=this._getGroupId(t.data),i=this.groups[e];i&&i.add(t)},s.prototype._updateItem=function(t,e){var i=t.data.group;if(t.setData(e),i!=t.data.group){var s=this.groups[i];s&&s.remove(t);var o=this._getGroupId(t.data),n=this.groups[o];n&&n.add(t)}},s.prototype._removeItem=function(t){t.hide(),delete this.items[t.id];var e=this.selection.indexOf(t.id);-1!=e&&this.selection.splice(e,1),t.parent&&t.parent.remove(t)},s.prototype._constructByEndArray=function(t){for(var e=[],i=0;i0||o.length>0)&&this.body.emitter.emit("select",{items:a})}},s.prototype._onAddItem=function(t){if(this.options.selectable&&this.options.editable.add){var e=this,i=this.body.util.snap||null,o=s.itemFromTarget(t);if(o){var r=e.itemsData.get(o.id);this.options.onUpdate(r,function(t){t&&e.itemsData.getDataSet().update(t)})}else{var a=n.getAbsoluteLeft(this.dom.frame),h=t.center.x-a,d=this.body.util.toTime(h),l={start:i?i(d):d,content:"new item"};if("range"===this.options.type){var c=this.body.util.toTime(h+this.props.width/5);l.end=i?i(c):c}l[this.itemsData._fieldId]=n.randomUUID();var p=s.groupFromTarget(t);p&&(l.group=p.groupId),this.options.onAdd(l,function(t){t&&e.itemsData.getDataSet().add(t)})}}},s.prototype._onMultiSelectItem=function(t){if(this.options.selectable){var e,i=s.itemFromTarget(t);if(i){e=this.getSelection();var o=t.srcEvent&&t.srcEvent.shiftKey||!1;if(o){e.push(i.id);var n=s._getItemRange(this.itemsData.get(e,this.itemOptions));e=[];for(var r in this.items)if(this.items.hasOwnProperty(r)){var a=this.items[r],h=a.data.start,d=void 0!==a.data.end?a.data.end:h;h>=n.min&&d<=n.max&&e.push(a.id)}}else{var l=e.indexOf(i.id);-1==l?e.push(i.id):e.splice(l,1)}this.setSelection(e),this.body.emitter.emit("select",{items:this.getSelection()})}}},s._getItemRange=function(t){var e=null,i=null;return t.forEach(function(t){(null==i||t.starte)&&(e=t.end):(null==e||t.start>e)&&(e=t.start)}),{min:i,max:e}},s.itemFromTarget=function(t){for(var e=t.target;e;){if(e.hasOwnProperty("timeline-item"))return e["timeline-item"];e=e.parentNode}return null},s.groupFromTarget=function(t){for(var e=t.target;e;){if(e.hasOwnProperty("timeline-group"))return e["timeline-group"];e=e.parentNode}return null},s.itemSetFromTarget=function(t){for(var e=t.target;e;){if(e.hasOwnProperty("timeline-itemset"))return e["timeline-itemset"];e=e.parentNode}return null},t.exports=s},function(t,e,i){function s(t,e,i,s){this.body=t,this.defaultOptions={enabled:!0,icons:!0,iconSize:20,iconSpacing:6,left:{visible:!0,position:"top-left"},right:{visible:!0,position:"top-left"}},this.side=i,this.options=o.extend({},this.defaultOptions),this.linegraphOptions=s,this.svgElements={},this.dom={},this.groups={},this.amountOfGroups=0,this._create(),this.setOptions(e)}var o=i(1),n=i(2),r=i(20);s.prototype=new r,s.prototype.clear=function(){this.groups={},this.amountOfGroups=0},s.prototype.addGroup=function(t,e){this.groups.hasOwnProperty(t)||(this.groups[t]=e),this.amountOfGroups+=1},s.prototype.updateGroup=function(t,e){this.groups[t]=e},s.prototype.removeGroup=function(t){this.groups.hasOwnProperty(t)&&(delete this.groups[t],this.amountOfGroups-=1)},s.prototype._create=function(){this.dom.frame=document.createElement("div"),this.dom.frame.className="legend",this.dom.frame.style.position="absolute",this.dom.frame.style.top="10px",this.dom.frame.style.display="block",this.dom.textArea=document.createElement("div"),this.dom.textArea.className="legendText",this.dom.textArea.style.position="relative",this.dom.textArea.style.top="0px",this.svg=document.createElementNS("http://www.w3.org/2000/svg","svg"),this.svg.style.position="absolute",this.svg.style.top="0px",this.svg.style.width=this.options.iconSize+5+"px",this.svg.style.height="100%",this.dom.frame.appendChild(this.svg),this.dom.frame.appendChild(this.dom.textArea)},s.prototype.hide=function(){this.dom.frame.parentNode&&this.dom.frame.parentNode.removeChild(this.dom.frame)},s.prototype.show=function(){this.dom.frame.parentNode||this.body.dom.center.appendChild(this.dom.frame)},s.prototype.setOptions=function(t){var e=["enabled","orientation","icons","left","right"];o.selectiveDeepExtend(e,this.options,t)},s.prototype.redraw=function(){var t=0;for(var e in this.groups)this.groups.hasOwnProperty(e)&&(1!=this.groups[e].visible||void 0!==this.linegraphOptions.visibility[e]&&1!=this.linegraphOptions.visibility[e]||t++);if(0==this.options[this.side].visible||0==this.amountOfGroups||0==this.options.enabled||0==t)this.hide();else{if(this.show(),"top-left"==this.options[this.side].position||"bottom-left"==this.options[this.side].position?(this.dom.frame.style.left="4px",this.dom.frame.style.textAlign="left",this.dom.textArea.style.textAlign="left",this.dom.textArea.style.left=this.options.iconSize+15+"px",this.dom.textArea.style.right="",this.svg.style.left="0px",this.svg.style.right=""):(this.dom.frame.style.right="4px",this.dom.frame.style.textAlign="right",this.dom.textArea.style.textAlign="right",this.dom.textArea.style.right=this.options.iconSize+15+"px",this.dom.textArea.style.left="",this.svg.style.right="0px",this.svg.style.left=""),"top-left"==this.options[this.side].position||"top-right"==this.options[this.side].position)this.dom.frame.style.top=4-Number(this.body.dom.center.style.top.replace("px",""))+"px",this.dom.frame.style.bottom="";else{var i=this.body.domProps.center.height-this.body.domProps.centerContainer.height;this.dom.frame.style.bottom=4+i+Number(this.body.dom.center.style.top.replace("px",""))+"px",this.dom.frame.style.top=""}0==this.options.icons?(this.dom.frame.style.width=this.dom.textArea.offsetWidth+10+"px",this.dom.textArea.style.right="",this.dom.textArea.style.left="",this.svg.style.width="0px"):(this.dom.frame.style.width=this.options.iconSize+15+this.dom.textArea.offsetWidth+10+"px",this.drawLegendIcons());var s="";for(var e in this.groups)this.groups.hasOwnProperty(e)&&(1!=this.groups[e].visible||void 0!==this.linegraphOptions.visibility[e]&&1!=this.linegraphOptions.visibility[e]||(s+=this.groups[e].content+"
"));this.dom.textArea.innerHTML=s,this.dom.textArea.style.lineHeight=.75*this.options.iconSize+this.options.iconSpacing+"px"}},s.prototype.drawLegendIcons=function(){if(this.dom.frame.parentNode){n.prepareElements(this.svgElements);var t=window.getComputedStyle(this.dom.frame).paddingTop,e=Number(t.replace("px","")),i=e,s=this.options.iconSize,o=.75*this.options.iconSize,r=e+.5*o+3;this.svg.style.width=s+5+e+"px";for(var a in this.groups)this.groups.hasOwnProperty(a)&&(1!=this.groups[a].visible||void 0!==this.linegraphOptions.visibility[a]&&1!=this.linegraphOptions.visibility[a]||(this.groups[a].drawIcon(i,r,this.svgElements,this.svg,s,o),r+=o+this.options.iconSpacing));n.cleanupElements(this.svgElements)}},t.exports=s},function(t,e,i){function s(t,e){this.id=o.randomUUID(),this.body=t,this.defaultOptions={yAxisOrientation:"left",defaultGroup:"default",sort:!0,sampling:!0,graphHeight:"400px",shaded:{enabled:!1,orientation:"bottom"},style:"line",barChart:{width:50,handleOverlap:"overlap",align:"center"},catmullRom:{enabled:!0,parametrization:"centripetal",alpha:.5},drawPoints:{enabled:!0,size:6,style:"square"},dataAxis:{showMinorLabels:!0,showMajorLabels:!0,icons:!1,width:"40px",visible:!0,alignZeros:!0,customRange:{left:{min:void 0,max:void 0},right:{min:void 0,max:void 0}}},legend:{enabled:!1,icons:!0,left:{visible:!0,position:"top-left"},right:{visible:!0,position:"top-right"}},groups:{visibility:{}}},this.options=o.extend({},this.defaultOptions),this.dom={},this.props={},this.hammer=null,this.groups={},this.abortedGraphUpdate=!1,this.updateSVGheight=!1,this.updateSVGheightOnResize=!1;var i=this;this.itemsData=null,this.groupsData=null,this.itemListeners={add:function(t,e){i._onAdd(e.items)},update:function(t,e){i._onUpdate(e.items)},remove:function(t,e){i._onRemove(e.items)}},this.groupListeners={add:function(t,e){i._onAddGroups(e.items)},update:function(t,e){i._onUpdateGroups(e.items)},remove:function(t,e){i._onRemoveGroups(e.items)}},this.items={},this.selection=[],this.lastStart=this.body.range.start,this.touchParams={},this.svgElements={},this.setOptions(e),this.groupsUsingDefaultStyles=[0],this.COUNTER=0,this.body.emitter.on("rangechanged",function(){i.lastStart=i.body.range.start,i.svg.style.left=o.option.asSize(-i.props.width),i.redraw.call(i,!0)}),this._create(),this.framework={svg:this.svg,svgElements:this.svgElements,options:this.options,groups:this.groups},this.body.emitter.emit("change")}var o=i(1),n=i(2),r=i(3),a=i(4),h=i(20),d=i(23),l=i(24),c=i(28),p=i(52),u="__ungrouped__";s.prototype=new h,s.prototype._create=function(){var t=document.createElement("div");t.className="LineGraph",this.dom.frame=t,this.svg=document.createElementNS("http://www.w3.org/2000/svg","svg"),this.svg.style.position="relative",this.svg.style.height=(""+this.options.graphHeight).replace("px","")+"px",this.svg.style.display="block",t.appendChild(this.svg),this.options.dataAxis.orientation="left",this.yAxisLeft=new d(this.body,this.options.dataAxis,this.svg,this.options.groups),this.options.dataAxis.orientation="right",this.yAxisRight=new d(this.body,this.options.dataAxis,this.svg,this.options.groups),delete this.options.dataAxis.orientation,this.legendLeft=new c(this.body,this.options.legend,"left",this.options.groups),this.legendRight=new c(this.body,this.options.legend,"right",this.options.groups),this.show()},s.prototype.setOptions=function(t){if(t){var e=["sampling","defaultGroup","height","graphHeight","yAxisOrientation","style","barChart","dataAxis","sort","groups"];void 0===t.graphHeight&&void 0!==t.height&&void 0!==this.body.domProps.centerContainer.height?(this.updateSVGheight=!0,this.updateSVGheightOnResize=!0):void 0!==this.body.domProps.centerContainer.height&&void 0!==t.graphHeight&&parseInt((t.graphHeight+"").replace("px",""))0){var d=this.body.util.toGlobalTime(-this.body.domProps.root.width),l=this.body.util.toGlobalTime(2*this.body.domProps.root.width),c={};for(this._getRelevantData(a,c,d,l),this._applySampling(a,c),e=0;eu&&console.log("WARNING: there may be an infinite loop in the _updateGraph emitter cycle."),this.COUNTER=0,this.abortedGraphUpdate=!1,e=0;e0)for(r=0;rs){d.push(h);break}d.push(h)}}else for(a=0;ai&&h.x0)for(var s=0;s0){var n=1,r=o.length,a=this.body.util.toGlobalScreen(o[o.length-1].x)-this.body.util.toGlobalScreen(o[0].x),h=r/a;n=Math.min(Math.ceil(.2*r),Math.max(1,Math.round(h)));for(var d=[],l=0;r>l;l+=n)d.push(o[l]);e[t[s]]=d}}},s.prototype._getYRanges=function(t,e,i){var s,o,n,r,a=[],h=[];if(t.length>0){for(n=0;n0&&(o=this.groups[t[n]],"stack"==r.barChart.handleOverlap&&"bar"==r.style?"left"==r.yAxisOrientation?a=a.concat(o.getYRange(s)):h=h.concat(o.getYRange(s)):i[t[n]]=o.getYRange(s,t[n]));p.getStackedBarYRange(a,i,t,"__barchartLeft","left"),p.getStackedBarYRange(h,i,t,"__barchartRight","right")}},s.prototype._updateYAxis=function(t,e){var i,s,o=!1,n=!1,r=!1,a=1e9,h=1e9,d=-1e9,l=-1e9;if(t.length>0){for(var c=0;ci?i:a,d=s>d?s:d):(r=!0,h=h>i?i:h,l=s>l?s:l));1==n&&this.yAxisLeft.setRange(a,d),1==r&&this.yAxisRight.setRange(h,l)}return o=this._toggleAxisVisiblity(n,this.yAxisLeft)||o,o=this._toggleAxisVisiblity(r,this.yAxisRight)||o,1==r&&1==n?(this.yAxisLeft.drawIcons=!0,this.yAxisRight.drawIcons=!0):(this.yAxisLeft.drawIcons=!1,this.yAxisRight.drawIcons=!1),this.yAxisRight.master=!n,0==this.yAxisRight.master?(this.yAxisLeft.lineOffset=1==r?this.yAxisRight.width:0,o=this.yAxisLeft.redraw()||o,this.yAxisRight.stepPixelsForced=this.yAxisLeft.stepPixels,this.yAxisRight.zeroCrossing=this.yAxisLeft.zeroCrossing,o=this.yAxisRight.redraw()||o):o=this.yAxisRight.redraw()||o,-1!=t.indexOf("__barchartLeft")&&t.splice(t.indexOf("__barchartLeft"),1),-1!=t.indexOf("__barchartRight")&&t.splice(t.indexOf("__barchartRight"),1),o},s.prototype._toggleAxisVisiblity=function(t,e){var i=!1;return 0==t?e.dom.frame.parentNode&&0==e.hidden&&(e.hide(),i=!0):e.dom.frame.parentNode||1!=e.hidden||(e.show(),i=!0),i},s.prototype._convertXcoordinates=function(t){for(var e,i,s=[],o=this.body.util.toScreen,n=0;ny;)y++,l=h.getCurrent(),c=h.isMajor(),u=h.getClassName(),f=m,m=this.body.util.toScreen(l),g=m-f,p&&(p.style.width=g+"px"),this.options.showMinorLabels&&this._repaintMinorText(m,h.getLabelMinor(),t,u),c&&this.options.showMajorLabels?(m>0&&(void 0==v&&(v=m),this._repaintMajorText(m,h.getLabelMajor(),t,u)),p=this._repaintMajorLine(m,t,u)):p=this._repaintMinorLine(m,t,u),h.next();if(this.options.showMajorLabels){var b=this.body.util.toTime(0),_=h.getLabelMajor(b),x=_.length*(this.props.majorCharWidth||10)+10;(void 0==v||v>x)&&this._repaintMajorText(0,_,t,u)}o.forEach(this.dom.redundant,function(t){for(;t.length;){var e=t.pop();e&&e.parentNode&&e.parentNode.removeChild(e)}})},s.prototype._repaintMinorText=function(t,e,i,s){var o=this.dom.redundant.minorTexts.shift();if(!o){var n=document.createTextNode("");o=document.createElement("div"),o.appendChild(n),this.dom.foreground.appendChild(o)}this.dom.minorTexts.push(o),o.childNodes[0].nodeValue=e,o.style.top="top"==i?this.props.majorLabelHeight+"px":"0",o.style.left=t+"px",o.className="text minor "+s},s.prototype._repaintMajorText=function(t,e,i,s){var o=this.dom.redundant.majorTexts.shift();if(!o){var n=document.createTextNode(e);o=document.createElement("div"),o.appendChild(n),this.dom.foreground.appendChild(o)}this.dom.majorTexts.push(o),o.childNodes[0].nodeValue=e,o.className="text major "+s,o.style.top="top"==i?"0":this.props.minorLabelHeight+"px",o.style.left=t+"px"},s.prototype._repaintMinorLine=function(t,e,i){var s=this.dom.redundant.lines.shift();s||(s=document.createElement("div"),this.dom.background.appendChild(s)),this.dom.lines.push(s);var o=this.props;return s.style.top="top"==e?o.majorLabelHeight+"px":this.body.domProps.top.height+"px",s.style.height=o.minorLineHeight+"px",s.style.left=t-o.minorLineWidth/2+"px",s.className="grid vertical minor "+i,s},s.prototype._repaintMajorLine=function(t,e,i){var s=this.dom.redundant.lines.shift();s||(s=document.createElement("div"),this.dom.background.appendChild(s)),this.dom.lines.push(s);var o=this.props;return s.style.top="top"==e?"0":this.body.domProps.top.height+"px",s.style.left=t-o.majorLineWidth/2+"px",s.style.height=o.majorLineHeight+"px",s.className="grid vertical major "+i,s},s.prototype._calculateCharSize=function(){this.dom.measureCharMinor||(this.dom.measureCharMinor=document.createElement("DIV"),this.dom.measureCharMinor.className="text minor measure",this.dom.measureCharMinor.style.position="absolute",this.dom.measureCharMinor.appendChild(document.createTextNode("0")),this.dom.foreground.appendChild(this.dom.measureCharMinor)),this.props.minorCharHeight=this.dom.measureCharMinor.clientHeight,this.props.minorCharWidth=this.dom.measureCharMinor.clientWidth,this.dom.measureCharMajor||(this.dom.measureCharMajor=document.createElement("DIV"),this.dom.measureCharMajor.className="text major measure",this.dom.measureCharMajor.style.position="absolute",this.dom.measureCharMajor.appendChild(document.createTextNode("0")),this.dom.foreground.appendChild(this.dom.measureCharMajor)),this.props.majorCharHeight=this.dom.measureCharMajor.clientHeight,this.props.majorCharWidth=this.dom.measureCharMajor.clientWidth},s.prototype.snap=function(t){return this.step.snap(t)},t.exports=s},function(t,e,i){function s(t,e,i){this.id=null,this.parent=null,this.data=t,this.dom=null,this.conversion=e||{},this.options=i||{},this.selected=!1,this.displayed=!1,this.dirty=!0,this.top=null,this.left=null,this.width=null,this.height=null}var o=i(45),n=i(1);s.prototype.stack=!0,s.prototype.select=function(){this.selected=!0,this.dirty=!0,this.displayed&&this.redraw()},s.prototype.unselect=function(){this.selected=!1,this.dirty=!0,this.displayed&&this.redraw()},s.prototype.setData=function(t){this.data=t,this.dirty=!0,this.displayed&&this.redraw()},s.prototype.setParent=function(t){this.displayed?(this.hide(),this.parent=t,this.parent&&this.show()):this.parent=t},s.prototype.isVisible=function(){return!1},s.prototype.show=function(){return!1},s.prototype.hide=function(){return!1},s.prototype.redraw=function(){},s.prototype.repositionX=function(){},s.prototype.repositionY=function(){},s.prototype._repaintDeleteButton=function(t){if(this.selected&&this.options.editable.remove&&!this.dom.deleteButton){var e=this,i=document.createElement("div");i.className="delete",i.title="Delete this item",new o(i).on("tap",function(t){e.parent.removeFromDataSet(e),t.stopPropagation(),t.preventDefault()}),t.appendChild(i),this.dom.deleteButton=i}else!this.selected&&this.dom.deleteButton&&(this.dom.deleteButton.parentNode&&this.dom.deleteButton.parentNode.removeChild(this.dom.deleteButton),this.dom.deleteButton=null)},s.prototype._updateContents=function(t){var e;if(this.options.template){var i=this.parent.itemSet.itemsData.get(this.id);e=this.options.template(i)}else e=this.data.content;if(e!==this.content){if(e instanceof Element)t.innerHTML="",t.appendChild(e);else if(void 0!=e)t.innerHTML=e;else if("background"!=this.data.type||void 0!==this.data.content)throw new Error('Property "content" missing in item '+this.id);this.content=e}},s.prototype._updateTitle=function(t){null!=this.data.title?t.title=this.data.title||"":t.removeAttribute("title")},s.prototype._updateDataAttributes=function(t){if(this.options.dataAttributes&&this.options.dataAttributes.length>0){var e=[];if(Array.isArray(this.options.dataAttributes))e=this.options.dataAttributes;else{if("all"!=this.options.dataAttributes)return;e=Object.keys(this.data)}for(var i=0;it.start},s.prototype.redraw=function(){var t=this.dom;if(t||(this.dom={},t=this.dom,t.box=document.createElement("div"),t.content=document.createElement("div"),t.content.className="content",t.box.appendChild(t.content),this.dirty=!0),!this.parent)throw new Error("Cannot redraw item: no parent attached");if(!t.box.parentNode){var e=this.parent.dom.background;if(!e)throw new Error("Cannot redraw item: parent has no background container element");e.appendChild(t.box)}if(this.displayed=!0,this.dirty){this._updateContents(this.dom.content),this._updateTitle(this.dom.content),this._updateDataAttributes(this.dom.content),this._updateStyle(this.dom.box);var i=(this.data.className?" "+this.data.className:"")+(this.selected?" selected":"");t.box.className=this.baseClassName+i,this.overflow="hidden"!==window.getComputedStyle(t.content).overflow,this.props.content.width=this.dom.content.offsetWidth,this.height=0,this.dirty=!1}},s.prototype.show=r.prototype.show,s.prototype.hide=r.prototype.hide,s.prototype.repositionX=r.prototype.repositionX,s.prototype.repositionY=function(t){var e="top"===this.options.orientation;this.dom.content.style.top=e?"":"0",this.dom.content.style.bottom=e?"0":"";var i;if(void 0!==this.data.subgroup){var s=this.data.subgroup,o=this.parent.subgroups,r=o[s].index;if(1==e){i=this.parent.subgroups[s].height+t.item.vertical,i+=0==r?t.axis-.5*t.item.vertical:0;var a=this.parent.top;for(var h in o)o.hasOwnProperty(h)&&1==o[h].visible&&o[h].indexr&&(a+=o[h].height+t.item.vertical);i=this.parent.subgroups[s].height+t.item.vertical,this.dom.box.style.top=a+"px",this.dom.box.style.bottom=""}}else this.parent instanceof n?(i=Math.max(this.parent.height,this.parent.itemSet.body.domProps.center.height,this.parent.itemSet.body.domProps.centerContainer.height),this.dom.box.style.top=e?"0":"",this.dom.box.style.bottom=e?"":"0"):(i=this.parent.height,this.dom.box.style.top=this.parent.top+"px",this.dom.box.style.bottom="");this.dom.box.style.height=i+"px"},t.exports=s},function(t,e,i){function s(t,e,i){if(this.props={dot:{width:0,height:0},line:{width:0,height:0}},t&&void 0==t.start)throw new Error('Property "start" missing in item '+t);o.call(this,t,e,i)}{var o=i(31);i(1)}s.prototype=new o(null,null,null),s.prototype.isVisible=function(t){var e=(t.end-t.start)/4;return this.data.start>t.start-e&&this.data.startt.start-e&&this.data.startt.start},s.prototype.redraw=function(){var t=this.dom;if(t||(this.dom={},t=this.dom,t.box=document.createElement("div"),t.content=document.createElement("div"),t.content.className="content",t.box.appendChild(t.content),t.box["timeline-item"]=this,this.dirty=!0),!this.parent)throw new Error("Cannot redraw item: no parent attached");if(!t.box.parentNode){var e=this.parent.dom.foreground;if(!e)throw new Error("Cannot redraw item: parent has no foreground container element");e.appendChild(t.box)}if(this.displayed=!0,this.dirty){this._updateContents(this.dom.content),this._updateTitle(this.dom.box),this._updateDataAttributes(this.dom.box),this._updateStyle(this.dom.box);var i=(this.data.className?" "+this.data.className:"")+(this.selected?" selected":"");t.box.className=this.baseClassName+i,this.overflow="hidden"!==window.getComputedStyle(t.content).overflow,this.dom.content.style.maxWidth="none",this.props.content.width=this.dom.content.offsetWidth,this.height=this.dom.box.offsetHeight,this.dom.content.style.maxWidth="",this.dirty=!1}this._repaintDeleteButton(t.box),this._repaintDragLeft(),this._repaintDragRight()},s.prototype.show=function(){this.displayed||this.redraw()},s.prototype.hide=function(){if(this.displayed){var t=this.dom.box;t.parentNode&&t.parentNode.removeChild(t),this.top=null,this.left=null,this.displayed=!1}},s.prototype.repositionX=function(){var t,e,i=this.parent.width,s=this.conversion.toScreen(this.data.start),o=this.conversion.toScreen(this.data.end);-i>s&&(s=-i),o>2*i&&(o=2*i);var n=Math.max(o-s,1);switch(this.overflow?(this.left=s,this.width=n+this.props.content.width,e=this.props.content.width):(this.left=s,this.width=n,e=Math.min(o-s-2*this.options.padding,this.props.content.width)),this.dom.box.style.left=this.left+"px",this.dom.box.style.width=n+"px",this.options.align){case"left":this.dom.content.style.left="0";break;case"right":this.dom.content.style.left=Math.max(n-e-2*this.options.padding,0)+"px";break;case"center":this.dom.content.style.left=Math.max((n-e-2*this.options.padding)/2,0)+"px";break;default:t=this.overflow?o>0?Math.max(-s,0):-e:0>s?Math.min(-s,o-s-e-2*this.options.padding):0,this.dom.content.style.left=t+"px"}},s.prototype.repositionY=function(){var t=this.options.orientation,e=this.dom.box;e.style.top="top"==t?this.top+"px":this.parent.height-this.top-this.height+"px"},s.prototype._repaintDragLeft=function(){if(this.selected&&this.options.editable.updateTime&&!this.dom.dragLeft){var t=document.createElement("div");t.className="drag-left",t.dragLeftItem=this,this.dom.box.appendChild(t),this.dom.dragLeft=t}else!this.selected&&this.dom.dragLeft&&(this.dom.dragLeft.parentNode&&this.dom.dragLeft.parentNode.removeChild(this.dom.dragLeft),this.dom.dragLeft=null)},s.prototype._repaintDragRight=function(){if(this.selected&&this.options.editable.updateTime&&!this.dom.dragRight){var t=document.createElement("div");t.className="drag-right",t.dragRightItem=this,this.dom.box.appendChild(t),this.dom.dragRight=t}else!this.selected&&this.dom.dragRight&&(this.dom.dragRight.parentNode&&this.dom.dragRight.parentNode.removeChild(this.dom.dragRight),this.dom.dragRight=null)},t.exports=s},function(t,e,i){function s(t,e,i){if(!(this instanceof s))throw new SyntaxError("Constructor must be called with the new operator");this._determineBrowserMethod(),this._initializeMixinLoaders(),this.containerElement=t,this.renderRefreshRate=60,this.renderTimestep=1e3/this.renderRefreshRate,this.renderTime=0,this.physicsTime=0,this.runDoubleSpeed=!1,this.physicsDiscreteStepsize=.5,this.initializing=!0,this.triggerFunctions={add:null,edit:null,editEdge:null,connect:null,del:null},this.defaultOptions={nodes:{mass:1,radiusMin:10,radiusMax:30,radius:10,shape:"ellipse",image:void 0,widthMin:16,widthMax:64,fontColor:"black",fontSize:14,fontFace:"verdana",fontFill:void 0,fontStrokeWidth:0,fontStrokeColor:"white",level:-1,color:{border:"#2B7CE9",background:"#97C2FC",highlight:{border:"#2B7CE9",background:"#D2E5FF"},hover:{border:"#2B7CE9",background:"#D2E5FF"}},group:void 0,borderWidth:1,borderWidthSelected:void 0},edges:{widthMin:1,widthMax:15,width:1,widthSelectionMultiplier:2,hoverWidth:1.5,style:"line",color:{color:"#848484",highlight:"#848484",hover:"#848484"},fontColor:"#343434",fontSize:14,fontFace:"arial",fontFill:"white",fontStrokeWidth:0,fontStrokeColor:"white",labelAlignment:"horizontal",arrowScaleFactor:1,dash:{length:10,gap:5,altLength:void 0},inheritColor:"from"},configurePhysics:!1,physics:{barnesHut:{enabled:!0,thetaInverted:2,gravitationalConstant:-2e3,centralGravity:.3,springLength:95,springConstant:.04,damping:.09},repulsion:{centralGravity:0,springLength:200,springConstant:.05,nodeDistance:100,damping:.09},hierarchicalRepulsion:{enabled:!1,centralGravity:0,springLength:100,springConstant:.01,nodeDistance:150,damping:.09},damping:null,centralGravity:null,springLength:null,springConstant:null},clustering:{enabled:!1,initialMaxNodes:100,clusterThreshold:500,reduceToNodes:300,chainThreshold:.4,clusterEdgeThreshold:20,sectorThreshold:100,screenSizeThreshold:.2,fontSizeMultiplier:4,maxFontSize:1e3,forceAmplification:.1,distanceAmplification:.1,edgeGrowth:20,nodeScaling:{width:1,height:1,radius:1},maxNodeSizeIncrements:600,activeAreaBoxSize:80,clusterLevelDifference:2},navigation:{enabled:!1},keyboard:{enabled:!1,speed:{x:10,y:10,zoom:.02}},dataManipulation:{enabled:!1,initiallyVisible:!1},hierarchicalLayout:{enabled:!1,levelSeparation:150,nodeSpacing:100,direction:"UD",layout:"hubsize"},freezeForStabilization:!1,smoothCurves:{enabled:!0,dynamic:!0,type:"continuous",roundness:.5},maxVelocity:30,minVelocity:.1,stabilize:!0,stabilizationIterations:1e3,zoomExtentOnStabilize:!0,locale:"en",locales:_,tooltip:{delay:300,fontColor:"black",fontSize:14,fontFace:"verdana",color:{border:"#666",background:"#FFFFC6"}},dragNetwork:!0,dragNodes:!0,zoomable:!0,hover:!1,hideEdgesOnDrag:!1,hideNodesOnDrag:!1,width:"100%",height:"100%",selectable:!0},this.constants=a.extend({},this.defaultOptions),this.pixelRatio=1,this.hoverObj={nodes:{},edges:{}},this.controlNodesActive=!1,this.navigationHammers={existing:[],_new:[]},this.animationSpeed=1/this.renderRefreshRate,this.animationEasingFunction="easeInOutQuint",this.easingTime=0,this.sourceScale=0,this.targetScale=0,this.sourceTranslation=0,this.targetTranslation=0,this.lockedOnNodeId=null,this.lockedOnNodeOffset=null,this.touchTime=0;var o=this;this.groups=new u,this.images=new m,this.images.setOnloadCallback(function(){o._redraw()}),this.xIncrement=0,this.yIncrement=0,this.zoomIncrement=0,this._loadPhysicsSystem(),this._create(),this._loadSectorSystem(),this._loadClusterSystem(),this._loadSelectionSystem(),this._loadHierarchySystem(),this._setTranslation(this.frame.clientWidth/2,this.frame.clientHeight/2),this._setScale(1),this.setOptions(i),this.freezeSimulation=!1,this.cachedFunctions={},this.startedStabilization=!1,this.stabilized=!1,this.stabilizationIterations=null,this.draggingNodes=!1,this.calculationNodes={},this.calculationNodeIndices=[],this.nodeIndices=[],this.nodes={},this.edges={},this.canvasTopLeft={x:0,y:0},this.canvasBottomRight={x:0,y:0},this.pointerPosition={x:0,y:0},this.areaCenter={},this.scale=1,this.previousScale=this.scale,this.nodesData=null,this.edgesData=null,this.nodesListeners={add:function(t,e){o._addNodes(e.items),o.start()},update:function(t,e){o._updateNodes(e.items,e.data),o.start()},remove:function(t,e){o._removeNodes(e.items),o.start()}},this.edgesListeners={add:function(t,e){o._addEdges(e.items),o.start()},update:function(t,e){o._updateEdges(e.items),o.start()},remove:function(t,e){o._removeEdges(e.items),o.start()}},this.moving=!0,this.timer=void 0,this.setData(e,this.constants.clustering.enabled||this.constants.hierarchicalLayout.enabled),this.initializing=!1,1==this.constants.hierarchicalLayout.enabled?this._setupHierarchicalLayout():0==this.constants.stabilize&&this.zoomExtent(void 0,!0,this.constants.clustering.enabled),this.constants.clustering.enabled&&this.startWithClustering()}var o=i(56),n=i(45),r=i(58),a=i(1),h=i(47),d=i(3),l=i(4),c=i(42),p=i(43),u=i(38),m=i(39),f=i(40),g=i(37),v=i(41),y=i(54),b=i(55),_=i(49);i(50),o(s.prototype),s.prototype._determineBrowserMethod=function(){var t=navigator.userAgent.toLowerCase();this.requiresTimeout=!1,-1!=t.indexOf("msie 9.0")?this.requiresTimeout=!0:-1!=t.indexOf("safari")&&t.indexOf("chrome")<=-1&&(this.requiresTimeout=!0)},s.prototype._getScriptPath=function(){for(var t=document.getElementsByTagName("script"),e=0;et.boundingBox.left&&(s=t.boundingBox.left),ot.boundingBox.bottom&&(e=t.boundingBox.bottom),i=this.constants.clustering.initialMaxNodes?49.07548/(n+142.05338)+91444e-8:12.662/(n+7.4147)+.0964822:1==this.constants.clustering.enabled&&n>=this.constants.clustering.initialMaxNodes?77.5271985/(n+187.266146)+476710517e-13:30.5062972/(n+19.93597763)+.08413486;var r=Math.min(this.frame.canvas.clientWidth/600,this.frame.canvas.clientHeight/600);s*=r}else{var a=1.1*Math.abs(o.maxX-o.minX),h=1.1*Math.abs(o.maxY-o.minY),d=this.frame.canvas.clientWidth/a,l=this.frame.canvas.clientHeight/h;s=l>=d?d:l}s>1&&(s=1);var c=this._findCenter(o);if(0==i){var p={position:c,scale:s,animation:t};this.moveTo(p),this.moving=!0,this.start()}else c.x*=s,c.y*=s,c.x-=.5*this.frame.canvas.clientWidth,c.y-=.5*this.frame.canvas.clientHeight,this._setScale(s),this._setTranslation(-c.x,-c.y)},s.prototype._updateNodeIndexList=function(){this._clearNodeIndexList();for(var t in this.nodes)this.nodes.hasOwnProperty(t)&&this.nodeIndices.push(t)},s.prototype.setData=function(t,e){if(void 0===e&&(e=!1),this.initializing=!0,t&&t.dot&&(t.nodes||t.edges))throw new SyntaxError('Data must contain either parameter "dot" or parameter pair "nodes" and "edges", but not both.');if(1==this.constants.dataManipulation.enabled&&this._createManipulatorBar(),this.setOptions(t&&t.options),t&&t.dot){if(t&&t.dot){var i=c.DOTToGraph(t.dot);return void this.setData(i)}}else if(t&&t.gephi){if(t&&t.gephi){var s=p.parseGephi(t.gephi);return void this.setData(s)}}else this._setNodes(t&&t.nodes),this._setEdges(t&&t.edges);this._putDataInSector(),0==e&&(1==this.constants.hierarchicalLayout.enabled?(this._resetLevels(),this._setupHierarchicalLayout()):this.constants.stabilize&&this._stabilize(),this.start()),this.initializing=!1},s.prototype.setOptions=function(t){if(t){var e,i=["nodes","edges","smoothCurves","hierarchicalLayout","clustering","navigation","keyboard","dataManipulation","onAdd","onEdit","onEditEdge","onConnect","onDelete","clickToUse"];if(a.selectiveNotDeepExtend(i,this.constants,t),a.selectiveNotDeepExtend(["color"],this.constants.nodes,t.nodes),a.selectiveNotDeepExtend(["color","length"],this.constants.edges,t.edges),t.physics&&(a.mergeOptions(this.constants.physics,t.physics,"barnesHut"),a.mergeOptions(this.constants.physics,t.physics,"repulsion"),t.physics.hierarchicalRepulsion)){this.constants.hierarchicalLayout.enabled=!0,this.constants.physics.hierarchicalRepulsion.enabled=!0,this.constants.physics.barnesHut.enabled=!1;for(e in t.physics.hierarchicalRepulsion)t.physics.hierarchicalRepulsion.hasOwnProperty(e)&&(this.constants.physics.hierarchicalRepulsion[e]=t.physics.hierarchicalRepulsion[e]) -}if(t.onAdd&&(this.triggerFunctions.add=t.onAdd),t.onEdit&&(this.triggerFunctions.edit=t.onEdit),t.onEditEdge&&(this.triggerFunctions.editEdge=t.onEditEdge),t.onConnect&&(this.triggerFunctions.connect=t.onConnect),t.onDelete&&(this.triggerFunctions.del=t.onDelete),a.mergeOptions(this.constants,t,"smoothCurves"),a.mergeOptions(this.constants,t,"hierarchicalLayout"),a.mergeOptions(this.constants,t,"clustering"),a.mergeOptions(this.constants,t,"navigation"),a.mergeOptions(this.constants,t,"keyboard"),a.mergeOptions(this.constants,t,"dataManipulation"),t.dataManipulation&&(this.editMode=this.constants.dataManipulation.initiallyVisible),t.edges&&(void 0!==t.edges.color&&(a.isString(t.edges.color)?(this.constants.edges.color={},this.constants.edges.color.color=t.edges.color,this.constants.edges.color.highlight=t.edges.color,this.constants.edges.color.hover=t.edges.color):(void 0!==t.edges.color.color&&(this.constants.edges.color.color=t.edges.color.color),void 0!==t.edges.color.highlight&&(this.constants.edges.color.highlight=t.edges.color.highlight),void 0!==t.edges.color.hover&&(this.constants.edges.color.hover=t.edges.color.hover)),this.constants.edges.inheritColor=!1),t.edges.fontColor||void 0!==t.edges.color&&(a.isString(t.edges.color)?this.constants.edges.fontColor=t.edges.color:void 0!==t.edges.color.color&&(this.constants.edges.fontColor=t.edges.color.color))),t.nodes&&t.nodes.color){var s=a.parseColor(t.nodes.color);this.constants.nodes.color.background=s.background,this.constants.nodes.color.border=s.border,this.constants.nodes.color.highlight.background=s.highlight.background,this.constants.nodes.color.highlight.border=s.highlight.border,this.constants.nodes.color.hover.background=s.hover.background,this.constants.nodes.color.hover.border=s.hover.border}if(t.groups)for(var o in t.groups)if(t.groups.hasOwnProperty(o)){var n=t.groups[o];this.groups.add(o,n)}if(t.tooltip){for(e in t.tooltip)t.tooltip.hasOwnProperty(e)&&(this.constants.tooltip[e]=t.tooltip[e]);t.tooltip.color&&(this.constants.tooltip.color=a.parseColor(t.tooltip.color))}if("clickToUse"in t&&(t.clickToUse?this.activator||(this.activator=new b(this.frame),this.activator.on("change",this._createKeyBinds.bind(this))):this.activator&&(this.activator.destroy(),delete this.activator)),t.labels)throw new Error('Option "labels" is deprecated. Use options "locale" and "locales" instead.');this._loadPhysicsSystem(),this._loadNavigationControls(),this._loadManipulationSystem(),this._configureSmoothCurves(),this._createKeyBinds(),this.setSize(this.constants.width,this.constants.height),this.moving=!0,this.start()}},s.prototype._create=function(){for(;this.containerElement.hasChildNodes();)this.containerElement.removeChild(this.containerElement.firstChild);if(this.frame=document.createElement("div"),this.frame.className="vis network-frame",this.frame.style.position="relative",this.frame.style.overflow="hidden",this.frame.canvas=document.createElement("canvas"),this.frame.canvas.style.position="relative",this.frame.appendChild(this.frame.canvas),this.frame.canvas.getContext){var t=this.frame.canvas.getContext("2d");this.pixelRatio=(window.devicePixelRatio||1)/(t.webkitBackingStorePixelRatio||t.mozBackingStorePixelRatio||t.msBackingStorePixelRatio||t.oBackingStorePixelRatio||t.backingStorePixelRatio||1),this.frame.canvas.getContext("2d").setTransform(this.pixelRatio,0,0,this.pixelRatio,0,0)}else{var e=document.createElement("DIV");e.style.color="red",e.style.fontWeight="bold",e.style.padding="10px",e.innerHTML="Error: your browser does not support HTML canvas",this.frame.canvas.appendChild(e)}var i=this;this.drag={},this.pinch={},this.hammer=new n(this.frame.canvas),this.hammer.get("pinch").set({enable:!0}),this.hammer.on("tap",i._onTap.bind(i)),this.hammer.on("doubletap",i._onDoubleTap.bind(i)),this.hammer.on("press",i._onHold.bind(i)),this.hammer.on("pinch",i._onPinch.bind(i)),h.onTouch(this.hammer,i._onTouch.bind(i)),this.hammer.on("panstart",i._onDragStart.bind(i)),this.hammer.on("panmove",i._onDrag.bind(i)),this.hammer.on("panend",i._onDragEnd.bind(i)),this.frame.canvas.addEventListener("mousemove",i._onMouseMoveTitle.bind(i)),this.frame.canvas.addEventListener("mousewheel",i._onMouseWheel.bind(i)),this.frame.canvas.addEventListener("DOMMouseScroll",i._onMouseWheel.bind(i)),this.containerElement.appendChild(this.frame)},s.prototype._createKeyBinds=function(){var t=this;void 0!==this.keycharm&&this.keycharm.destroy(),this.keycharm=r(),this.keycharm.reset(),this.constants.keyboard.enabled&&this.isActive()&&(this.keycharm.bind("up",this._moveUp.bind(t),"keydown"),this.keycharm.bind("up",this._yStopMoving.bind(t),"keyup"),this.keycharm.bind("down",this._moveDown.bind(t),"keydown"),this.keycharm.bind("down",this._yStopMoving.bind(t),"keyup"),this.keycharm.bind("left",this._moveLeft.bind(t),"keydown"),this.keycharm.bind("left",this._xStopMoving.bind(t),"keyup"),this.keycharm.bind("right",this._moveRight.bind(t),"keydown"),this.keycharm.bind("right",this._xStopMoving.bind(t),"keyup"),this.keycharm.bind("=",this._zoomIn.bind(t),"keydown"),this.keycharm.bind("=",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("num+",this._zoomIn.bind(t),"keydown"),this.keycharm.bind("num+",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("num-",this._zoomOut.bind(t),"keydown"),this.keycharm.bind("num-",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("-",this._zoomOut.bind(t),"keydown"),this.keycharm.bind("-",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("[",this._zoomIn.bind(t),"keydown"),this.keycharm.bind("[",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("]",this._zoomOut.bind(t),"keydown"),this.keycharm.bind("]",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("pageup",this._zoomIn.bind(t),"keydown"),this.keycharm.bind("pageup",this._stopZoom.bind(t),"keyup"),this.keycharm.bind("pagedown",this._zoomOut.bind(t),"keydown"),this.keycharm.bind("pagedown",this._stopZoom.bind(t),"keyup")),1==this.constants.dataManipulation.enabled&&(this.keycharm.bind("esc",this._createManipulatorBar.bind(t)),this.keycharm.bind("delete",this._deleteSelected.bind(t)))},s.prototype.destroy=function(){this.start=function(){},this.redraw=function(){},this.timer=!1,this._cleanupPhysicsConfiguration(),this.keycharm.reset(),this.hammer.destroy(),this.off(),this._recursiveDOMDelete(this.containerElement)},s.prototype._recursiveDOMDelete=function(t){for(;1==t.hasChildNodes();)this._recursiveDOMDelete(t.firstChild),t.removeChild(t.firstChild)},s.prototype._getPointer=function(t){return{x:t.x-a.getAbsoluteLeft(this.frame.canvas),y:t.y-a.getAbsoluteTop(this.frame.canvas)}},s.prototype._onTouch=function(t){(new Date).valueOf()-this.touchTime>100&&(this.drag.pointer=this._getPointer(t.center),this.drag.pinched=!1,this.pinch.scale=this._getScale(),this.touchTime=(new Date).valueOf(),this._handleTouch(this.drag.pointer))},s.prototype._onDragStart=function(t){this._handleDragStart(t)},s.prototype._handleDragStart=function(t){void 0===this.drag.pointer&&this._onTouch(t);var e=this._getNodeAt(this.drag.pointer);if(this.drag.dragging=!0,this.drag.selection=[],this.drag.translation=this._getTranslation(),this.drag.nodeId=null,this.draggingNodes=!1,null!=e&&1==this.constants.dragNodes){this.draggingNodes=!0,this.drag.nodeId=e.id,e.isSelected()||this._selectObject(e,!1),this.emit("dragStart",{nodeIds:this.getSelection().nodes});for(var i in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(i)){var s=this.selectionObj.nodes[i],o={id:s.id,node:s,x:s.x,y:s.y,xFixed:s.xFixed,yFixed:s.yFixed};s.xFixed=!0,s.yFixed=!0,this.drag.selection.push(o)}}t.preventDefault()},s.prototype._onDrag=function(t){this._handleOnDrag(t)},s.prototype._handleOnDrag=function(t){if(!this.drag.pinched){this.releaseNode();var e=this._getPointer(t.center),i=this,s=this.drag,o=s.selection;if(o&&o.length&&1==this.constants.dragNodes){var n=e.x-s.pointer.x,r=e.y-s.pointer.y;o.forEach(function(t){var e=t.node;t.xFixed||(e.x=i._XconvertDOMtoCanvas(i._XconvertCanvasToDOM(t.x)+n)),t.yFixed||(e.y=i._YconvertDOMtoCanvas(i._YconvertCanvasToDOM(t.y)+r))}),this.moving||(this.moving=!0,this.start())}else if(1==this.constants.dragNetwork){if(void 0===this.drag.pointer)return void this._handleDragStart(t);var a=e.x-this.drag.pointer.x,h=e.y-this.drag.pointer.y;this._setTranslation(this.drag.translation.x+a,this.drag.translation.y+h),this._redraw()}t.preventDefault()}},s.prototype._onDragEnd=function(t){this._handleDragEnd(t)},s.prototype._handleDragEnd=function(t){this.drag.dragging=!1;var e=this.drag.selection;e&&e.length?(e.forEach(function(t){t.node.xFixed=t.xFixed,t.node.yFixed=t.yFixed}),this.moving=!0,this.start()):this._redraw(),0==this.draggingNodes?this.emit("dragEnd",{nodeIds:[]}):this.emit("dragEnd",{nodeIds:this.getSelection().nodes}),t.preventDefault()},s.prototype._onTap=function(t){var e=this._getPointer(t.center);this.pointerPosition=e,this._handleTap(e)},s.prototype._onDoubleTap=function(t){var e=this._getPointer(t.center);this._handleDoubleTap(e)},s.prototype._onHold=function(t){var e=this._getPointer(t.center);this.pointerPosition=e,this._handleOnHold(e)},s.prototype._onRelease=function(t){var e=this._getPointer(t.center);this._handleOnRelease(e)},s.prototype._onPinch=function(t){var e=this._getPointer(t.center);this.drag.pinched=!0,"scale"in this.pinch||(this.pinch.scale=1);var i=this.pinch.scale*t.scale;this._zoom(i,e)},s.prototype._zoom=function(t,e){if(1==this.constants.zoomable){var i=this._getScale();1e-5>t&&(t=1e-5),t>10&&(t=10);var s=null;void 0!==this.drag&&1==this.drag.dragging&&(s=this.DOMtoCanvas(this.drag.pointer));var o=this._getTranslation(),n=t/i,r=(1-n)*e.x+o.x*n,a=(1-n)*e.y+o.y*n;if(this.areaCenter={x:this._XconvertDOMtoCanvas(e.x),y:this._YconvertDOMtoCanvas(e.y)},this._setScale(t),this._setTranslation(r,a),this.updateClustersDefault(),null!=s){var h=this.canvasToDOM(s);this.drag.pointer.x=h.x,this.drag.pointer.y=h.y}return this._redraw(),t>i?this.emit("zoom",{direction:"+"}):this.emit("zoom",{direction:"-"}),t}},s.prototype._onMouseWheel=function(t){var e=0;if(t.wheelDelta?e=t.wheelDelta/120:t.detail&&(e=-t.detail/3),e){var i=this._getScale(),s=e/10;0>e&&(s/=1-s),i*=1+s;var o=this._getPointer({x:t.pageX,y:t.pageY});this._zoom(i,o)}t.preventDefault()},s.prototype._onMouseMoveTitle=function(t){var e=this._getPointer({x:t.pageX,y:t.pageY});this.popupObj&&this._checkHidePopup(e);var i=this,s=function(){i._checkShowPopup(e)};if(this.popupTimer&&clearInterval(this.popupTimer),this.drag.dragging||(this.popupTimer=setTimeout(s,this.constants.tooltip.delay)),1==this.constants.hover){for(var o in this.hoverObj.edges)this.hoverObj.edges.hasOwnProperty(o)&&(this.hoverObj.edges[o].hover=!1,delete this.hoverObj.edges[o]);var n=this._getNodeAt(e);null==n&&(n=this._getEdgeAt(e)),null!=n&&this._hoverObject(n);for(var r in this.hoverObj.nodes)this.hoverObj.nodes.hasOwnProperty(r)&&(n instanceof f&&n.id!=r||n instanceof g||null==n)&&(this._blurObject(this.hoverObj.nodes[r]),delete this.hoverObj.nodes[r]);this.redraw()}},s.prototype._checkShowPopup=function(t){var e,i={left:this._XconvertDOMtoCanvas(t.x),top:this._YconvertDOMtoCanvas(t.y),right:this._XconvertDOMtoCanvas(t.x),bottom:this._YconvertDOMtoCanvas(t.y)},s=this.popupObj,o=!1;if(void 0==this.popupObj){var n=this.nodes,r=[];for(e in n)if(n.hasOwnProperty(e)){var a=n[e];a.isOverlappingWith(i)&&void 0!==a.getTitle()&&r.push(e)}r.length>0&&(this.popupObj=this.nodes[r[r.length-1]],o=!0)}if(void 0===this.popupObj&&0==o){var h=this.edges,d=[];for(e in h)if(h.hasOwnProperty(e)){var l=h[e];l.connected&&void 0!==l.getTitle()&&l.isOverlappingWith(i)&&d.push(e)}d.length>0&&(this.popupObj=this.edges[d[d.length-1]])}if(this.popupObj){if(this.popupObj!=s){var c=this;c.popup||(c.popup=new v(c.frame,c.constants.tooltip)),c.popup.setPosition(t.x-3,t.y-3),c.popup.setText(c.popupObj.getTitle()),c.popup.show()}}else this.popup&&this.popup.hide()},s.prototype._checkHidePopup=function(t){this.popupObj&&this._getNodeAt(t)||(this.popupObj=void 0,this.popup&&this.popup.hide())},s.prototype.setSize=function(t,e){var i=!1,s=this.frame.canvas.width,o=this.frame.canvas.height;t!=this.constants.width||e!=this.constants.height||this.frame.style.width!=t||this.frame.style.height!=e?(this.frame.style.width=t,this.frame.style.height=e,this.frame.canvas.style.width="100%",this.frame.canvas.style.height="100%",this.frame.canvas.width=this.frame.canvas.clientWidth*this.pixelRatio,this.frame.canvas.height=this.frame.canvas.clientHeight*this.pixelRatio,this.constants.width=t,this.constants.height=e,i=!0):(this.frame.canvas.width!=this.frame.canvas.clientWidth*this.pixelRatio&&(this.frame.canvas.width=this.frame.canvas.clientWidth*this.pixelRatio,i=!0),this.frame.canvas.height!=this.frame.canvas.clientHeight*this.pixelRatio&&(this.frame.canvas.height=this.frame.canvas.clientHeight*this.pixelRatio,i=!0)),1==i&&this.emit("resize",{width:this.frame.canvas.width*this.pixelRatio,height:this.frame.canvas.height*this.pixelRatio,oldWidth:s*this.pixelRatio,oldHeight:o*this.pixelRatio})},s.prototype._setNodes=function(t){var e=this.nodesData;if(t instanceof d||t instanceof l)this.nodesData=t;else if(Array.isArray(t))this.nodesData=new d,this.nodesData.add(t);else{if(t)throw new TypeError("Array or DataSet expected");this.nodesData=new d}if(e&&a.forEach(this.nodesListeners,function(t,i){e.off(i,t)}),this.nodes={},this.nodesData){var i=this;a.forEach(this.nodesListeners,function(t,e){i.nodesData.on(e,t)});var s=this.nodesData.getIds();this._addNodes(s)}this._updateSelection()},s.prototype._addNodes=function(t){for(var e,i=0,s=t.length;s>i;i++){e=t[i];var o=this.nodesData.get(e),n=new f(o,this.images,this.groups,this.constants);if(this.nodes[e]=n,!(0!=n.xFixed&&0!=n.yFixed||null!==n.x&&null!==n.y)){var r=1*t.length+10,a=2*Math.PI*Math.random();0==n.xFixed&&(n.x=r*Math.cos(a)),0==n.yFixed&&(n.y=r*Math.sin(a))}this.moving=!0}this._updateNodeIndexList(),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this._updateCalculationNodes(),this._reconnectEdges(),this._updateValueRange(this.nodes),this.updateLabels()},s.prototype._updateNodes=function(t,e){for(var i=this.nodes,s=0,o=t.length;o>s;s++){var n=t[s],r=i[n],a=e[s];r?r.setProperties(a,this.constants):(r=new f(properties,this.images,this.groups,this.constants),i[n]=r)}this.moving=!0,1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this._updateNodeIndexList(),this._updateValueRange(i)},s.prototype._removeNodes=function(t){for(var e=this.nodes,i=0,s=t.length;s>i;i++){var o=t[i];delete e[o]}this._updateNodeIndexList(),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this._updateCalculationNodes(),this._reconnectEdges(),this._updateSelection(),this._updateValueRange(e)},s.prototype._setEdges=function(t){var e=this.edgesData;if(t instanceof d||t instanceof l)this.edgesData=t;else if(Array.isArray(t))this.edgesData=new d,this.edgesData.add(t);else{if(t)throw new TypeError("Array or DataSet expected");this.edgesData=new d}if(e&&a.forEach(this.edgesListeners,function(t,i){e.off(i,t)}),this.edges={},this.edgesData){var i=this;a.forEach(this.edgesListeners,function(t,e){i.edgesData.on(e,t)});var s=this.edgesData.getIds();this._addEdges(s)}this._reconnectEdges()},s.prototype._addEdges=function(t){for(var e=this.edges,i=this.edgesData,s=0,o=t.length;o>s;s++){var n=t[s],r=e[n];r&&r.disconnect();var a=i.get(n,{showInternalIds:!0});e[n]=new g(a,this,this.constants)}this.moving=!0,this._updateValueRange(e),this._createBezierNodes(),this._updateCalculationNodes(),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout())},s.prototype._updateEdges=function(t){for(var e=this.edges,i=this.edgesData,s=0,o=t.length;o>s;s++){var n=t[s],r=i.get(n),a=e[n];a?(a.disconnect(),a.setProperties(r,this.constants),a.connect()):(a=new g(r,this,this.constants),this.edges[n]=a)}this._createBezierNodes(),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this.moving=!0,this._updateValueRange(e)},s.prototype._removeEdges=function(t){for(var e=this.edges,i=0,s=t.length;s>i;i++){var o=t[i],n=e[o];n&&(null!=n.via&&delete this.sectors.support.nodes[n.via.id],n.disconnect(),delete e[o])}this.moving=!0,this._updateValueRange(e),1==this.constants.hierarchicalLayout.enabled&&0==this.initializing&&(this._resetLevels(),this._setupHierarchicalLayout()),this._updateCalculationNodes()},s.prototype._reconnectEdges=function(){var t,e=this.nodes,i=this.edges;for(t in e)e.hasOwnProperty(t)&&(e[t].edges=[],e[t].dynamicEdges=[]);for(t in i)if(i.hasOwnProperty(t)){var s=i[t];s.from=null,s.to=null,s.connect()}},s.prototype._updateValueRange=function(t){var e,i=void 0,s=void 0;for(e in t)if(t.hasOwnProperty(e)){var o=t[e].getValue();void 0!==o&&(i=void 0===i?o:Math.min(o,i),s=void 0===s?o:Math.max(o,s))}if(void 0!==i&&void 0!==s)for(e in t)t.hasOwnProperty(e)&&t[e].setValueRange(i,s)},s.prototype.redraw=function(){this.setSize(this.constants.width,this.constants.height),this._redraw()},s.prototype._redraw=function(t){var e=this.frame.canvas.getContext("2d");e.setTransform(this.pixelRatio,0,0,this.pixelRatio,0,0);var i=this.frame.canvas.width*this.pixelRatio,s=this.frame.canvas.height*this.pixelRatio;e.clearRect(0,0,i,s),e.save(),e.translate(this.translation.x,this.translation.y),e.scale(this.scale,this.scale),this.canvasTopLeft={x:this._XconvertDOMtoCanvas(0),y:this._YconvertDOMtoCanvas(0)},this.canvasBottomRight={x:this._XconvertDOMtoCanvas(this.frame.canvas.clientWidth*this.pixelRatio),y:this._YconvertDOMtoCanvas(this.frame.canvas.clientHeight*this.pixelRatio)},1!=t&&(this._doInAllSectors("_drawAllSectorNodes",e),(0==this.drag.dragging||void 0===this.drag.dragging||0==this.constants.hideEdgesOnDrag)&&this._doInAllSectors("_drawEdges",e)),(0==this.drag.dragging||void 0===this.drag.dragging||0==this.constants.hideNodesOnDrag)&&this._doInAllSectors("_drawNodes",e,!1),1!=t&&1==this.controlNodesActive&&this._doInAllSectors("_drawControlNodes",e),e.restore(),1==t&&e.clearRect(0,0,i,s)},s.prototype._setTranslation=function(t,e){void 0===this.translation&&(this.translation={x:0,y:0}),void 0!==t&&(this.translation.x=t),void 0!==e&&(this.translation.y=e),this.emit("viewChanged")},s.prototype._getTranslation=function(){return{x:this.translation.x,y:this.translation.y}},s.prototype._setScale=function(t){this.scale=t},s.prototype._getScale=function(){return this.scale},s.prototype._XconvertDOMtoCanvas=function(t){return(t-this.translation.x)/this.scale},s.prototype._XconvertCanvasToDOM=function(t){return t*this.scale+this.translation.x},s.prototype._YconvertDOMtoCanvas=function(t){return(t-this.translation.y)/this.scale},s.prototype._YconvertCanvasToDOM=function(t){return t*this.scale+this.translation.y},s.prototype.canvasToDOM=function(t){return{x:this._XconvertCanvasToDOM(t.x),y:this._YconvertCanvasToDOM(t.y)}},s.prototype.DOMtoCanvas=function(t){return{x:this._XconvertDOMtoCanvas(t.x),y:this._YconvertDOMtoCanvas(t.y)}},s.prototype._drawNodes=function(t,e){void 0===e&&(e=!1);var i=this.nodes,s=[];for(var o in i)i.hasOwnProperty(o)&&(i[o].setScaleAndPos(this.scale,this.canvasTopLeft,this.canvasBottomRight),i[o].isSelected()?s.push(o):(i[o].inArea()||e)&&i[o].draw(t));for(var n=0,r=s.length;r>n;n++)(i[s[n]].inArea()||e)&&i[s[n]].draw(t)},s.prototype._drawEdges=function(t){var e=this.edges;for(var i in e)if(e.hasOwnProperty(i)){var s=e[i];s.setScale(this.scale),s.connected&&e[i].draw(t)}},s.prototype._drawControlNodes=function(t){var e=this.edges;for(var i in e)e.hasOwnProperty(i)&&e[i]._drawControlNodes(t)},s.prototype._stabilize=function(){1==this.constants.freezeForStabilization&&this._freezeDefinedNodes();for(var t=0;this.moving&&t0)for(t in i)i.hasOwnProperty(t)&&(i[t].discreteStepLimited(e,this.constants.maxVelocity),s=!0);else for(t in i)i.hasOwnProperty(t)&&(i[t].discreteStep(e),s=!0);if(1==s){var o=this.constants.minVelocity/Math.max(this.scale,.05);return o>.5*this.constants.maxVelocity?!0:this._isMoving(o)}return!1},s.prototype._revertPhysicsState=function(){var t=this.nodes;for(var e in t)t.hasOwnProperty(e)&&t[e].revertPosition()},s.prototype._revertPhysicsTick=function(){this._doInAllActiveSectors("_revertPhysicsState"),1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic&&this._doInSupportSector("_revertPhysicsState")},s.prototype._physicsTick=function(){if(!this.freezeSimulation&&1==this.moving){var t=!1,e=!1;this._doInAllActiveSectors("_initializeForceCalculation");var i=this._doInAllActiveSectors("_discreteStepNodes");1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic&&(e=this._doInSupportSector("_discreteStepNodes"));for(var s=0;s2*e||1==this.runDoubleSpeed)&&1==this.moving&&(this._physicsTick(),0!=this.renderTime&&(this.runDoubleSpeed=!0));var i=Date.now();this._redraw(),this.renderTime=Date.now()-i,this.start()},"undefined"!=typeof window&&(window.requestAnimationFrame=window.requestAnimationFrame||window.mozRequestAnimationFrame||window.webkitRequestAnimationFrame||window.msRequestAnimationFrame),s.prototype.start=function(){if(1==this.moving||0!=this.xIncrement||0!=this.yIncrement||0!=this.zoomIncrement)this.timer||(this.timer=1==this.requiresTimeout?window.setTimeout(this._animationStep.bind(this),this.renderTimestep):window.requestAnimationFrame(this._animationStep.bind(this)));else if(this._redraw(),this.stabilizationIterations>1){var t=this,e={iterations:t.stabilizationIterations};this.stabilizationIterations=0,this.startedStabilization=!1,setTimeout(function(){t.emit("stabilized",e)},0)}else this.stabilizationIterations=0},s.prototype._handleNavigation=function(){if(0!=this.xIncrement||0!=this.yIncrement){var t=this._getTranslation();this._setTranslation(t.x+this.xIncrement,t.y+this.yIncrement)}if(0!=this.zoomIncrement){var e={x:this.frame.canvas.clientWidth/2,y:this.frame.canvas.clientHeight/2};this._zoom(this.scale*(1+this.zoomIncrement),e)}},s.prototype.toggleFreeze=function(){0==this.freezeSimulation?this.freezeSimulation=!0:(this.freezeSimulation=!1,this.start())},s.prototype._configureSmoothCurves=function(t){if(void 0===t&&(t=!0),1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic){this._createBezierNodes();for(var e in this.sectors.support.nodes)this.sectors.support.nodes.hasOwnProperty(e)&&void 0===this.edges[this.sectors.support.nodes[e].parentEdgeId]&&delete this.sectors.support.nodes[e]}else{this.sectors.support.nodes={};for(var i in this.edges)this.edges.hasOwnProperty(i)&&(this.edges[i].via=null)}this._updateCalculationNodes(),t||(this.moving=!0,this.start())},s.prototype._createBezierNodes=function(){if(1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic)for(var t in this.edges)if(this.edges.hasOwnProperty(t)){var e=this.edges[t];if(null==e.via){var i="edgeId:".concat(e.id);this.sectors.support.nodes[i]=new f({id:i,mass:1,shape:"circle",image:"",internalMultiplier:1},{},{},this.constants),e.via=this.sectors.support.nodes[i],e.via.parentEdgeId=e.id,e.positionBezierNode()}}},s.prototype._initializeMixinLoaders=function(){for(var t in y)y.hasOwnProperty(t)&&(s.prototype[t]=y[t])},s.prototype.storePosition=function(){console.log("storePosition is deprecated: use .storePositions() from now on."),this.storePositions()},s.prototype.storePositions=function(){var t=[];for(var e in this.nodes)if(this.nodes.hasOwnProperty(e)){var i=this.nodes[e],s=!this.nodes.xFixed,o=!this.nodes.yFixed;(this.nodesData._data[e].x!=Math.round(i.x)||this.nodesData._data[e].y!=Math.round(i.y))&&t.push({id:e,x:Math.round(i.x),y:Math.round(i.y),allowedToMoveX:s,allowedToMoveY:o})}this.nodesData.update(t)},s.prototype.getPositions=function(t){var e={};if(void 0!==t){if(1==Array.isArray(t)){for(var i=0;i=1&&(this.easingTime=0,this._redraw=null!=this.lockedOnNodeId?this._lockedRedraw:this._classicRedraw,this.emit("animationFinished"))},s.prototype._classicRedraw=function(){},s.prototype.isActive=function(){return!this.activator||this.activator.active},s.prototype.setScale=function(){return this._setScale()},s.prototype.getScale=function(){return this._getScale()},s.prototype.getCenterCoordinates=function(){return this.DOMtoCanvas({x:.5*this.frame.canvas.clientWidth,y:.5*this.frame.canvas.clientHeight})},s.prototype.getBoundingBox=function(t){return void 0!==this.nodes[t]?this.nodes[t].boundingBox:void 0},t.exports=s},function(t,e,i){function s(t,e,i){if(!e)throw"No network provided";var s=["edges","physics"],n=o.selectiveBridgeObject(s,i);this.options=n.edges,this.physics=n.physics,this.options.smoothCurves=i.smoothCurves,this.network=e,this.id=void 0,this.fromId=void 0,this.toId=void 0,this.title=void 0,this.widthSelected=this.options.width*this.options.widthSelectionMultiplier,this.value=void 0,this.selected=!1,this.hover=!1,this.labelDimensions={top:0,left:0,width:0,height:0,yLine:0},this.dirtyLabel=!0,this.from=null,this.to=null,this.via=null,this.fromBackup=null,this.toBackup=null,this.originalFromId=[],this.originalToId=[],this.connected=!1,this.widthFixed=!1,this.lengthFixed=!1,this.setProperties(t),this.controlNodesEnabled=!1,this.controlNodes={from:null,to:null,positions:{}},this.connectedNode=null}var o=i(1),n=i(40);s.prototype.setProperties=function(t){if(t){var e=["style","fontSize","fontFace","fontColor","fontFill","fontStrokeWidth","fontStrokeColor","width","widthSelectionMultiplier","hoverWidth","arrowScaleFactor","dash","inheritColor","labelAlignment"];switch(o.selectiveDeepExtend(e,this.options,t),void 0!==t.from&&(this.fromId=t.from),void 0!==t.to&&(this.toId=t.to),void 0!==t.id&&(this.id=t.id),void 0!==t.label&&(this.label=t.label,this.dirtyLabel=!0),void 0!==t.title&&(this.title=t.title),void 0!==t.value&&(this.value=t.value),void 0!==t.length&&(this.physics.springLength=t.length),void 0!==t.color&&(this.options.inheritColor=!1,o.isString(t.color)?(this.options.color.color=t.color,this.options.color.highlight=t.color):(void 0!==t.color.color&&(this.options.color.color=t.color.color),void 0!==t.color.highlight&&(this.options.color.highlight=t.color.highlight),void 0!==t.color.hover&&(this.options.color.hover=t.color.hover))),this.connect(),this.widthFixed=this.widthFixed||void 0!==t.width,this.lengthFixed=this.lengthFixed||void 0!==t.length,this.widthSelected=this.options.width*this.options.widthSelectionMultiplier,this.options.style){case"line":this.draw=this._drawLine;break;case"arrow":this.draw=this._drawArrow;break;case"arrow-center":this.draw=this._drawArrowCenter;break;case"dash-line":this.draw=this._drawDashLine;break;default:this.draw=this._drawLine}}},s.prototype.connect=function(){this.disconnect(),this.from=this.network.nodes[this.fromId]||null,this.to=this.network.nodes[this.toId]||null,this.connected=this.from&&this.to,this.connected?(this.from.attachEdge(this),this.to.attachEdge(this)):(this.from&&this.from.detachEdge(this),this.to&&this.to.detachEdge(this))},s.prototype.disconnect=function(){this.from&&(this.from.detachEdge(this),this.from=null),this.to&&(this.to.detachEdge(this),this.to=null),this.connected=!1},s.prototype.getTitle=function(){return"function"==typeof this.title?this.title():this.title -},s.prototype.getValue=function(){return this.value},s.prototype.setValueRange=function(t,e){if(!this.widthFixed&&void 0!==this.value){var i=(this.options.widthMax-this.options.widthMin)/(e-t);this.options.width=(this.value-t)*i+this.options.widthMin,this.widthSelected=this.options.width*this.options.widthSelectionMultiplier}},s.prototype.draw=function(){throw"Method draw not initialized in edge"},s.prototype.isOverlappingWith=function(t){if(this.connected){var e=10,i=this.from.x,s=this.from.y,o=this.to.x,n=this.to.y,r=t.left,a=t.top,h=this._getDistanceToEdge(i,s,o,n,r,a);return e>h}return!1},s.prototype._getColor=function(){var t=this.options.color;return"to"==this.options.inheritColor?t={highlight:this.to.options.color.highlight.border,hover:this.to.options.color.hover.border,color:this.to.options.color.border}:("from"==this.options.inheritColor||1==this.options.inheritColor)&&(t={highlight:this.from.options.color.highlight.border,hover:this.from.options.color.hover.border,color:this.from.options.color.border}),1==this.selected?t.highlight:1==this.hover?t.hover:t.color},s.prototype._drawLine=function(t){if(t.strokeStyle=this._getColor(),t.lineWidth=this._getLineWidth(),this.from!=this.to){var e,i=this._line(t);if(this.label){if(1==this.options.smoothCurves.enabled&&null!=i){var s=.5*(.5*(this.from.x+i.x)+.5*(this.to.x+i.x)),o=.5*(.5*(this.from.y+i.y)+.5*(this.to.y+i.y));e={x:s,y:o}}else e=this._pointOnLine(.5);this._label(t,this.label,e.x,e.y)}}else{var n,r,a=this.physics.springLength/4,h=this.from;h.width||h.resize(t),h.width>h.height?(n=h.x+h.width/2,r=h.y-a):(n=h.x+a,r=h.y-h.height/2),this._circle(t,n,r,a),e=this._pointOnCircle(n,r,a,.5),this._label(t,this.label,e.x,e.y)}},s.prototype._getLineWidth=function(){return 1==this.selected?Math.max(Math.min(this.widthSelected,this.options.widthMax),.3*this.networkScaleInv):1==this.hover?Math.max(Math.min(this.options.hoverWidth,this.options.widthMax),.3*this.networkScaleInv):Math.max(this.options.width,.3*this.networkScaleInv)},s.prototype._getViaCoordinates=function(){if(1==this.options.smoothCurves.dynamic&&1==this.options.smoothCurves.enabled)return this.via;if(0==this.options.smoothCurves.enabled)return{x:0,y:0};var t=null,e=null,i=this.options.smoothCurves.roundness,s=this.options.smoothCurves.type,o=Math.abs(this.from.x-this.to.x),n=Math.abs(this.from.y-this.to.y);return"discrete"==s||"diagonalCross"==s?Math.abs(this.from.x-this.to.x)this.to.y?this.from.xthis.to.x&&(t=this.from.x-i*n,e=this.from.y-i*n):this.from.ythis.to.x&&(t=this.from.x-i*n,e=this.from.y+i*n)),"discrete"==s&&(t=i*n>o?this.from.x:t)):Math.abs(this.from.x-this.to.x)>Math.abs(this.from.y-this.to.y)&&(this.from.y>this.to.y?this.from.xthis.to.x&&(t=this.from.x-i*o,e=this.from.y-i*o):this.from.ythis.to.x&&(t=this.from.x-i*o,e=this.from.y+i*o)),"discrete"==s&&(e=i*o>n?this.from.y:e)):"straightCross"==s?Math.abs(this.from.x-this.to.x)Math.abs(this.from.y-this.to.y)&&(t=this.from.xthis.to.y?this.from.xthis.to.x&&(t=this.from.x-i*n,e=this.from.y-i*n,t=this.to.x>t?this.to.x:t):this.from.ythis.to.x&&(t=this.from.x-i*n,e=this.from.y+i*n,t=this.to.x>t?this.to.x:t)):Math.abs(this.from.x-this.to.x)>Math.abs(this.from.y-this.to.y)&&(this.from.y>this.to.y?this.from.xe?this.to.y:e):this.from.x>this.to.x&&(t=this.from.x-i*o,e=this.from.y-i*o,e=this.to.y>e?this.to.y:e):this.from.ythis.to.x&&(t=this.from.x-i*o,e=this.from.y+i*o,e=this.to.yd;d++){var l=t.measureText(n[d]).width;h=l>h?l:h}var c=this.options.fontSize*r,p=i-h/2,u=s-c/2;this.labelDimensions={top:u,left:p,width:h,height:c,yLine:o}}var o=this.labelDimensions.yLine;t.save(),"horizontal"!=this.options.labelAlignment&&(t.translate(i,o),this._rotateForLabelAlignment(t),i=0,o=0),this._drawLabelRect(t),this._drawLabelText(t,i,o,n,r,a),t.restore()}},s.prototype._rotateForLabelAlignment=function(t){var e=this.from.y-this.to.y,i=this.from.x-this.to.x,s=Math.atan2(e,i);(-1>s&&0>i||s>0&&0>i)&&(s+=Math.PI),t.rotate(s)},s.prototype._drawLabelRect=function(t){if(void 0!==this.options.fontFill&&null!==this.options.fontFill&&"none"!==this.options.fontFill){t.fillStyle=this.options.fontFill;var e=2;"line-center"==this.options.labelAlignment?t.fillRect(.5*-this.labelDimensions.width,.5*-this.labelDimensions.height,this.labelDimensions.width,this.labelDimensions.height):"line-above"==this.options.labelAlignment?t.fillRect(.5*-this.labelDimensions.width,-(this.labelDimensions.height+e),this.labelDimensions.width,this.labelDimensions.height):"line-below"==this.options.labelAlignment?t.fillRect(.5*-this.labelDimensions.width,e,this.labelDimensions.width,this.labelDimensions.height):t.fillRect(this.labelDimensions.left,this.labelDimensions.top,this.labelDimensions.width,this.labelDimensions.height)}},s.prototype._drawLabelText=function(t,e,i,s,o,n){if(t.fillStyle=this.options.fontColor||"black",t.textAlign="center","horizontal"!=this.options.labelAlignment){var r=2;"line-above"==this.options.labelAlignment?(t.textBaseline="alphabetic",i-=2*r):"line-below"==this.options.labelAlignment?(t.textBaseline="hanging",i+=2*r):t.textBaseline="middle"}else t.textBaseline="middle";this.options.fontStrokeWidth>0&&(t.lineWidth=this.options.fontStrokeWidth,t.strokeStyle=this.options.fontStrokeColor,t.lineJoin="round");for(var a=0;o>a;a++)this.options.fontStrokeWidth>0&&t.strokeText(s[a],e,i),t.fillText(s[a],e,i),i+=n},s.prototype._drawDashLine=function(t){t.strokeStyle=this._getColor(),t.lineWidth=this._getLineWidth();var e=null;if(void 0!==t.setLineDash){t.save();var i=[0];i=void 0!==this.options.dash.length&&void 0!==this.options.dash.gap?[this.options.dash.length,this.options.dash.gap]:[5,5],t.setLineDash(i),t.lineDashOffset=0,e=this._line(t),t.setLineDash([0]),t.lineDashOffset=0,t.restore()}else t.beginPath(),t.lineCap="round",void 0!==this.options.dash.altLength?t.dashedLine(this.from.x,this.from.y,this.to.x,this.to.y,[this.options.dash.length,this.options.dash.gap,this.options.dash.altLength,this.options.dash.gap]):void 0!==this.options.dash.length&&void 0!==this.options.dash.gap?t.dashedLine(this.from.x,this.from.y,this.to.x,this.to.y,[this.options.dash.length,this.options.dash.gap]):(t.moveTo(this.from.x,this.from.y),t.lineTo(this.to.x,this.to.y)),t.stroke();if(this.label){var s;if(1==this.options.smoothCurves.enabled&&null!=e){var o=.5*(.5*(this.from.x+e.x)+.5*(this.to.x+e.x)),n=.5*(.5*(this.from.y+e.y)+.5*(this.to.y+e.y));s={x:o,y:n}}else s=this._pointOnLine(.5);this._label(t,this.label,s.x,s.y)}},s.prototype._pointOnLine=function(t){return{x:(1-t)*this.from.x+t*this.to.x,y:(1-t)*this.from.y+t*this.to.y}},s.prototype._pointOnCircle=function(t,e,i,s){var o=2*(s-3/8)*Math.PI;return{x:t+i*Math.cos(o),y:e-i*Math.sin(o)}},s.prototype._drawArrowCenter=function(t){var e;if(t.strokeStyle=this._getColor(),t.fillStyle=t.strokeStyle,t.lineWidth=this._getLineWidth(),this.from!=this.to){var i=this._line(t),s=Math.atan2(this.to.y-this.from.y,this.to.x-this.from.x),o=(10+5*this.options.width)*this.options.arrowScaleFactor;if(1==this.options.smoothCurves.enabled&&null!=i){var n=.5*(.5*(this.from.x+i.x)+.5*(this.to.x+i.x)),r=.5*(.5*(this.from.y+i.y)+.5*(this.to.y+i.y));e={x:n,y:r}}else e=this._pointOnLine(.5);t.arrow(e.x,e.y,s,o),t.fill(),t.stroke(),this.label&&this._label(t,this.label,e.x,e.y)}else{var a,h,d=.25*Math.max(100,this.physics.springLength),l=this.from;l.width||l.resize(t),l.width>l.height?(a=l.x+.5*l.width,h=l.y-d):(a=l.x+d,h=l.y-.5*l.height),this._circle(t,a,h,d);var s=.2*Math.PI,o=(10+5*this.options.width)*this.options.arrowScaleFactor;e=this._pointOnCircle(a,h,d,.5),t.arrow(e.x,e.y,s,o),t.fill(),t.stroke(),this.label&&(e=this._pointOnCircle(a,h,d,.5),this._label(t,this.label,e.x,e.y))}},s.prototype._pointOnBezier=function(t){var e=this._getViaCoordinates(),i=Math.pow(1-t,2)*this.from.x+2*t*(1-t)*e.x+Math.pow(t,2)*this.to.x,s=Math.pow(1-t,2)*this.from.y+2*t*(1-t)*e.y+Math.pow(t,2)*this.to.y;return{x:i,y:s}},s.prototype._findBorderPosition=function(t,e){var i,s,o,n,r,a=10,h=0,d=0,l=1,c=.2,p=this.to;for(1==t&&(p=this.from);l>=d&&a>h;){var u=.5*(d+l);if(i=this._pointOnBezier(u),s=Math.atan2(p.y-i.y,p.x-i.x),o=p.distanceToBorder(e,s),n=Math.sqrt(Math.pow(i.x-p.x,2)+Math.pow(i.y-p.y,2)),r=o-n,Math.abs(r)r?0==t?d=u:l=u:0==t?l=u:d=u,h++}return i.t=u,i},s.prototype._drawArrow=function(t){t.strokeStyle=this._getColor(),t.fillStyle=t.strokeStyle,t.lineWidth=this._getLineWidth();var e,i,s;if(this.from!=this.to){if(this._line(t),1==this.options.smoothCurves.enabled){var o=this._getViaCoordinates();s=this._findBorderPosition(!1,t);var n=this._pointOnBezier(Math.max(0,s.t-.1));e=Math.atan2(s.y-n.y,s.x-n.x)}else{e=Math.atan2(this.to.y-this.from.y,this.to.x-this.from.x);var r=this.to.x-this.from.x,a=this.to.y-this.from.y,h=Math.sqrt(r*r+a*a),d=this.to.distanceToBorder(t,e),l=(h-d)/h;s={},s.x=(1-l)*this.from.x+l*this.to.x,s.y=(1-l)*this.from.y+l*this.to.y}if(i=(10+5*this.options.width)*this.options.arrowScaleFactor,t.arrow(s.x,s.y,e,i),t.fill(),t.stroke(),this.label){var c;c=1==this.options.smoothCurves.enabled&&null!=o?this._pointOnBezier(.5):this._pointOnLine(.5),this._label(t,this.label,c.x,c.y)}}else{var p,u,m,f=this.from,g=.25*Math.max(100,this.physics.springLength);f.width||f.resize(t),f.width>f.height?(p=f.x+.5*f.width,u=f.y-g,m={x:p,y:f.y,angle:.9*Math.PI}):(p=f.x+g,u=f.y-.5*f.height,m={x:f.x,y:u,angle:.6*Math.PI}),t.beginPath(),t.arc(p,u,g,0,2*Math.PI,!1),t.stroke();var i=(10+5*this.options.width)*this.options.arrowScaleFactor;t.arrow(m.x,m.y,m.angle,i),t.fill(),t.stroke(),this.label&&(c=this._pointOnCircle(p,u,g,.5),this._label(t,this.label,c.x,c.y))}},s.prototype._getDistanceToEdge=function(t,e,i,s,o,n){var r=0;if(this.from!=this.to)if(1==this.options.smoothCurves.enabled){var a,h;if(1==this.options.smoothCurves.enabled&&1==this.options.smoothCurves.dynamic)a=this.via.x,h=this.via.y;else{var d=this._getViaCoordinates();a=d.x,h=d.y}var l,c,p,u,m,f,g,v=1e9;for(c=0;10>c;c++)p=.1*c,u=Math.pow(1-p,2)*t+2*p*(1-p)*a+Math.pow(p,2)*i,m=Math.pow(1-p,2)*e+2*p*(1-p)*h+Math.pow(p,2)*s,c>0&&(l=this._getDistanceToLine(f,g,u,m,o,n),v=v>l?l:v),f=u,g=m;r=v}else r=this._getDistanceToLine(t,e,i,s,o,n);else{var u,m,y,b,_=.25*this.physics.springLength,x=this.from;x.width>x.height?(u=x.x+.5*x.width,m=x.y-_):(u=x.x+_,m=x.y-.5*x.height),y=u-o,b=m-n,r=Math.abs(Math.sqrt(y*y+b*b)-_)}return this.labelDimensions.lefto&&this.labelDimensions.topn?0:r},s.prototype._getDistanceToLine=function(t,e,i,s,o,n){var r=i-t,a=s-e,h=r*r+a*a,d=((o-t)*r+(n-e)*a)/h;d>1?d=1:0>d&&(d=0);var l=t+d*r,c=e+d*a,p=l-o,u=c-n;return Math.sqrt(p*p+u*u)},s.prototype.setScale=function(t){this.networkScaleInv=1/t},s.prototype.select=function(){this.selected=!0},s.prototype.unselect=function(){this.selected=!1},s.prototype.positionBezierNode=function(){null!==this.via&&null!==this.from&&null!==this.to?(this.via.x=.5*(this.from.x+this.to.x),this.via.y=.5*(this.from.y+this.to.y)):(this.via.x=0,this.via.y=0)},s.prototype._drawControlNodes=function(t){if(1==this.controlNodesEnabled){if(null===this.controlNodes.from&&null===this.controlNodes.to){var e="edgeIdFrom:".concat(this.id),i="edgeIdTo:".concat(this.id),s={nodes:{group:"",radius:7,borderWidth:2,borderWidthSelected:2},physics:{damping:0},clustering:{maxNodeSizeIncrements:0,nodeScaling:{width:0,height:0,radius:0}}};this.controlNodes.from=new n({id:e,shape:"dot",color:{background:"#ff0000",border:"#3c3c3c",highlight:{background:"#07f968"}}},{},{},s),this.controlNodes.to=new n({id:i,shape:"dot",color:{background:"#ff0000",border:"#3c3c3c",highlight:{background:"#07f968"}}},{},{},s)}this.controlNodes.positions={},0==this.controlNodes.from.selected&&(this.controlNodes.positions.from=this.getControlNodeFromPosition(t),this.controlNodes.from.x=this.controlNodes.positions.from.x,this.controlNodes.from.y=this.controlNodes.positions.from.y),0==this.controlNodes.to.selected&&(this.controlNodes.positions.to=this.getControlNodeToPosition(t),this.controlNodes.to.x=this.controlNodes.positions.to.x,this.controlNodes.to.y=this.controlNodes.positions.to.y),this.controlNodes.from.draw(t),this.controlNodes.to.draw(t)}else this.controlNodes={from:null,to:null,positions:{}}},s.prototype._enableControlNodes=function(){this.fromBackup=this.from,this.toBackup=this.to,this.controlNodesEnabled=!0},s.prototype._disableControlNodes=function(){this.fromId=this.from.id,this.toId=this.to.id,this.fromId!=this.fromBackup.id?this.fromBackup.detachEdge(this):this.toId!=this.toBackup.id&&this.toBackup.detachEdge(this),this.fromBackup=null,this.toBackup=null,this.controlNodesEnabled=!1},s.prototype._getSelectedControlNode=function(t,e){var i=this.controlNodes.positions,s=Math.sqrt(Math.pow(t-i.from.x,2)+Math.pow(e-i.from.y,2)),o=Math.sqrt(Math.pow(t-i.to.x,2)+Math.pow(e-i.to.y,2));return 15>s?(this.connectedNode=this.from,this.from=this.controlNodes.from,this.controlNodes.from):15>o?(this.connectedNode=this.to,this.to=this.controlNodes.to,this.controlNodes.to):null},s.prototype._restoreControlNodes=function(){1==this.controlNodes.from.selected?(this.from=this.connectedNode,this.connectedNode=null,this.controlNodes.from.unselect()):1==this.controlNodes.to.selected&&(this.to=this.connectedNode,this.connectedNode=null,this.controlNodes.to.unselect())},s.prototype.getControlNodeFromPosition=function(t){var e;if(1==this.options.smoothCurves.enabled)e=this._findBorderPosition(!0,t);else{var i=Math.atan2(this.to.y-this.from.y,this.to.x-this.from.x),s=this.to.x-this.from.x,o=this.to.y-this.from.y,n=Math.sqrt(s*s+o*o),r=this.from.distanceToBorder(t,i+Math.PI),a=(n-r)/n;e={},e.x=a*this.from.x+(1-a)*this.to.x,e.y=a*this.from.y+(1-a)*this.to.y}return e},s.prototype.getControlNodeToPosition=function(t){var e;if(1==this.options.smoothCurves.enabled)e=this._findBorderPosition(!1,t);else{var i=Math.atan2(this.to.y-this.from.y,this.to.x-this.from.x),s=this.to.x-this.from.x,o=this.to.y-this.from.y,n=Math.sqrt(s*s+o*o),r=this.to.distanceToBorder(t,i),a=(n-r)/n;e={},e.x=(1-a)*this.from.x+a*this.to.x,e.y=(1-a)*this.from.y+a*this.to.y}return e},t.exports=s},function(t,e,i){function s(){this.clear(),this.defaultIndex=0}i(1);s.DEFAULT=[{border:"#2B7CE9",background:"#97C2FC",highlight:{border:"#2B7CE9",background:"#D2E5FF"},hover:{border:"#2B7CE9",background:"#D2E5FF"}},{border:"#FFA500",background:"#FFFF00",highlight:{border:"#FFA500",background:"#FFFFA3"},hover:{border:"#FFA500",background:"#FFFFA3"}},{border:"#FA0A10",background:"#FB7E81",highlight:{border:"#FA0A10",background:"#FFAFB1"},hover:{border:"#FA0A10",background:"#FFAFB1"}},{border:"#41A906",background:"#7BE141",highlight:{border:"#41A906",background:"#A1EC76"},hover:{border:"#41A906",background:"#A1EC76"}},{border:"#E129F0",background:"#EB7DF4",highlight:{border:"#E129F0",background:"#F0B3F5"},hover:{border:"#E129F0",background:"#F0B3F5"}},{border:"#7C29F0",background:"#AD85E4",highlight:{border:"#7C29F0",background:"#D3BDF0"},hover:{border:"#7C29F0",background:"#D3BDF0"}},{border:"#C37F00",background:"#FFA807",highlight:{border:"#C37F00",background:"#FFCA66"},hover:{border:"#C37F00",background:"#FFCA66"}},{border:"#4220FB",background:"#6E6EFD",highlight:{border:"#4220FB",background:"#9B9BFD"},hover:{border:"#4220FB",background:"#9B9BFD"}},{border:"#FD5A77",background:"#FFC0CB",highlight:{border:"#FD5A77",background:"#FFD1D9"},hover:{border:"#FD5A77",background:"#FFD1D9"}},{border:"#4AD63A",background:"#C2FABC",highlight:{border:"#4AD63A",background:"#E6FFE3"},hover:{border:"#4AD63A",background:"#E6FFE3"}}],s.prototype.clear=function(){this.groups={},this.groups.length=function(){var t=0;for(var e in this)this.hasOwnProperty(e)&&t++;return t}},s.prototype.get=function(t){var e=this.groups[t];if(void 0==e){var i=this.defaultIndex%s.DEFAULT.length;this.defaultIndex++,e={},e.color=s.DEFAULT[i],this.groups[t]=e}return e},s.prototype.add=function(t,e){return this.groups[t]=e,e},t.exports=s},function(t){function e(){this.images={},this.imageBroken={},this.callback=void 0}e.prototype.setOnloadCallback=function(t){this.callback=t},e.prototype.load=function(t,e){var i=this.images[t];if(void 0===i){var s=this;i=new Image,i.onload=function(){0==this.width&&(document.body.appendChild(this),this.width=this.offsetWidth,this.height=this.offsetHeight,document.body.removeChild(this)),s.callback&&(s.images[t]=i,s.callback(this))},i.onerror=function(){void 0===e?(console.error("Could not load image:",t),delete this.src,s.callback&&s.callback(this)):s.imageBroken[t]===!0?(console.error("Could not load brokenImage:",e),delete this.src,s.callback&&s.callback(this)):(this.src=e,s.imageBroken[t]=!0)},i.src=t}return i},t.exports=e},function(t,e,i){function s(t,e,i,s){var n=o.selectiveBridgeObject(["nodes"],s);this.options=n.nodes,this.selected=!1,this.hover=!1,this.edges=[],this.dynamicEdges=[],this.reroutedEdges={},this.fontDrawThreshold=3,this.id=void 0,this.allowedToMoveX=!1,this.allowedToMoveY=!1,this.xFixed=!1,this.yFixed=!1,this.horizontalAlignLeft=!0,this.verticalAlignTop=!0,this.baseRadiusValue=s.nodes.radius,this.radiusFixed=!1,this.level=-1,this.preassignedLevel=!1,this.hierarchyEnumerated=!1,this.labelDimensions={top:0,left:0,width:0,height:0,yLine:0},this.boundingBox={top:0,left:0,right:0,bottom:0},this.imagelist=e,this.grouplist=i,this.fx=0,this.fy=0,this.vx=0,this.vy=0,this.x=null,this.y=null,this.previousState={vx:0,vy:0,x:0,y:0},this.damping=s.physics.damping,this.fixedData={x:null,y:null},this.setProperties(t,n),this.resetCluster(),this.dynamicEdgesLength=0,this.clusterSession=0,this.clusterSizeWidthFactor=s.clustering.nodeScaling.width,this.clusterSizeHeightFactor=s.clustering.nodeScaling.height,this.clusterSizeRadiusFactor=s.clustering.nodeScaling.radius,this.maxNodeSizeIncrements=s.clustering.maxNodeSizeIncrements,this.growthIndicator=0,this.networkScaleInv=1,this.networkScale=1,this.canvasTopLeft={x:-300,y:-300},this.canvasBottomRight={x:300,y:300},this.parentEdgeId=null}var o=i(1);s.prototype.revertPosition=function(){this.x=this.previousState.x,this.y=this.previousState.y,this.vx=this.previousState.vx,this.vy=this.previousState.vy},s.prototype.resetCluster=function(){this.formationScale=void 0,this.clusterSize=1,this.containedNodes={},this.containedEdges={},this.clusterSessions=[]},s.prototype.attachEdge=function(t){-1==this.edges.indexOf(t)&&this.edges.push(t),-1==this.dynamicEdges.indexOf(t)&&this.dynamicEdges.push(t),this.dynamicEdgesLength=this.dynamicEdges.length},s.prototype.detachEdge=function(t){var e=this.edges.indexOf(t);-1!=e&&this.edges.splice(e,1),e=this.dynamicEdges.indexOf(t),-1!=e&&this.dynamicEdges.splice(e,1),this.dynamicEdgesLength=this.dynamicEdges.length},s.prototype.setProperties=function(t,e){if(t){var i=["borderWidth","borderWidthSelected","shape","image","brokenImage","radius","fontColor","fontSize","fontFace","fontFill","fontStrokeWidth","fontStrokeColor","group","mass"];if(o.selectiveDeepExtend(i,this.options,t),void 0!==t.id&&(this.id=t.id),void 0!==t.label&&(this.label=t.label,this.originalLabel=t.label),void 0!==t.title&&(this.title=t.title),void 0!==t.x&&(this.x=t.x),void 0!==t.y&&(this.y=t.y),void 0!==t.value&&(this.value=t.value),void 0!==t.level&&(this.level=t.level,this.preassignedLevel=!0),void 0!==t.horizontalAlignLeft&&(this.horizontalAlignLeft=t.horizontalAlignLeft),void 0!==t.verticalAlignTop&&(this.verticalAlignTop=t.verticalAlignTop),void 0!==t.triggerFunction&&(this.triggerFunction=t.triggerFunction),void 0===this.id)throw"Node must have an id";if("number"==typeof this.options.group||"string"==typeof this.options.group&&""!=this.options.group){var s=this.grouplist.get(this.options.group);o.deepExtend(this.options,s),this.options.color=o.parseColor(this.options.color)}if(void 0!==t.radius&&(this.baseRadiusValue=this.options.radius),void 0!==t.color&&(this.options.color=o.parseColor(t.color)),void 0!==this.options.image&&""!=this.options.image){if(!this.imagelist)throw"No imagelist provided";this.imageObj=this.imagelist.load(this.options.image,this.options.brokenImage)}switch(void 0!==t.allowedToMoveX?(this.xFixed=!t.allowedToMoveX,this.allowedToMoveX=t.allowedToMoveX):void 0!==t.x&&0==this.allowedToMoveX&&(this.xFixed=!0),void 0!==t.allowedToMoveY?(this.yFixed=!t.allowedToMoveY,this.allowedToMoveY=t.allowedToMoveY):void 0!==t.y&&0==this.allowedToMoveY&&(this.yFixed=!0),this.radiusFixed=this.radiusFixed||void 0!==t.radius,("image"===this.options.shape||"circularImage"===this.options.shape)&&(this.options.radiusMin=e.nodes.widthMin,this.options.radiusMax=e.nodes.widthMax),this.options.shape){case"database":this.draw=this._drawDatabase,this.resize=this._resizeDatabase;break;case"box":this.draw=this._drawBox,this.resize=this._resizeBox;break;case"circle":this.draw=this._drawCircle,this.resize=this._resizeCircle;break;case"ellipse":this.draw=this._drawEllipse,this.resize=this._resizeEllipse;break;case"image":this.draw=this._drawImage,this.resize=this._resizeImage;break;case"circularImage":this.draw=this._drawCircularImage,this.resize=this._resizeCircularImage;break;case"text":this.draw=this._drawText,this.resize=this._resizeText;break;case"dot":this.draw=this._drawDot,this.resize=this._resizeShape;break;case"square":this.draw=this._drawSquare,this.resize=this._resizeShape;break;case"triangle":this.draw=this._drawTriangle,this.resize=this._resizeShape;break;case"triangleDown":this.draw=this._drawTriangleDown,this.resize=this._resizeShape;break;case"star":this.draw=this._drawStar,this.resize=this._resizeShape;break;default:this.draw=this._drawEllipse,this.resize=this._resizeEllipse}this._reset()}},s.prototype.select=function(){this.selected=!0,this._reset()},s.prototype.unselect=function(){this.selected=!1,this._reset()},s.prototype.clearSizeCache=function(){this._reset()},s.prototype._reset=function(){this.width=void 0,this.height=void 0},s.prototype.getTitle=function(){return"function"==typeof this.title?this.title():this.title},s.prototype.distanceToBorder=function(t,e){var i=1;switch(this.width||this.resize(t),this.options.shape){case"circle":case"dot":return this.options.radius+i;case"ellipse":var s=this.width/2,o=this.height/2,n=Math.sin(e)*s,r=Math.cos(e)*o;return s*o/Math.sqrt(n*n+r*r);case"box":case"image":case"text":default:return this.width?Math.min(Math.abs(this.width/2/Math.cos(e)),Math.abs(this.height/2/Math.sin(e)))+i:0}},s.prototype._setForce=function(t,e){this.fx=t,this.fy=e},s.prototype._addForce=function(t,e){this.fx+=t,this.fy+=e},s.prototype.storeState=function(){this.previousState.x=this.x,this.previousState.y=this.y,this.previousState.vx=this.vx,this.previousState.vy=this.vy},s.prototype.discreteStep=function(t){if(this.storeState(),this.xFixed)this.fx=0,this.vx=0;else{var e=this.damping*this.vx,i=(this.fx-e)/this.options.mass;this.vx+=i*t,this.x+=this.vx*t}if(this.yFixed)this.fy=0,this.vy=0;else{var s=this.damping*this.vy,o=(this.fy-s)/this.options.mass;this.vy+=o*t,this.y+=this.vy*t}},s.prototype.discreteStepLimited=function(t,e){if(this.storeState(),this.xFixed)this.fx=0,this.vx=0;else{var i=this.damping*this.vx,s=(this.fx-i)/this.options.mass;this.vx+=s*t,this.vx=Math.abs(this.vx)>e?this.vx>0?e:-e:this.vx,this.x+=this.vx*t}if(this.yFixed)this.fy=0,this.vy=0;else{var o=this.damping*this.vy,n=(this.fy-o)/this.options.mass;this.vy+=n*t,this.vy=Math.abs(this.vy)>e?this.vy>0?e:-e:this.vy,this.y+=this.vy*t}},s.prototype.isFixed=function(){return this.xFixed&&this.yFixed},s.prototype.isMoving=function(t){var e=Math.sqrt(Math.pow(this.vx,2)+Math.pow(this.vy,2));return e>t},s.prototype.isSelected=function(){return this.selected},s.prototype.getValue=function(){return this.value},s.prototype.getDistance=function(t,e){var i=this.x-t,s=this.y-e;return Math.sqrt(i*i+s*s)},s.prototype.setValueRange=function(t,e){if(!this.radiusFixed&&void 0!==this.value)if(e==t)this.options.radius=(this.options.radiusMin+this.options.radiusMax)/2;else{var i=(this.options.radiusMax-this.options.radiusMin)/(e-t);this.options.radius=(this.value-t)*i+this.options.radiusMin}this.baseRadiusValue=this.options.radius},s.prototype.draw=function(){throw"Draw method not initialized for node"},s.prototype.resize=function(){throw"Resize method not initialized for node"},s.prototype.isOverlappingWith=function(t){return this.leftt.left&&this.topt.top},s.prototype._resizeImage=function(){if(!this.width||!this.height){var t,e;if(this.value){this.options.radius=this.baseRadiusValue;var i=this.imageObj.height/this.imageObj.width;void 0!==i?(t=this.options.radius||this.imageObj.width,e=this.options.radius*i||this.imageObj.height):(t=0,e=0)}else t=this.imageObj.width,e=this.imageObj.height;this.width=t,this.height=e,this.growthIndicator=0,this.width>0&&this.height>0&&(this.width+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.options.radius+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.width-t)}},s.prototype._drawImageAtPosition=function(t){if(0!=this.imageObj.width){if(this.clusterSize>1){var e=this.clusterSize>1?10:0;e*=this.networkScaleInv,e=Math.min(.2*this.width,e),t.globalAlpha=.5,t.drawImage(this.imageObj,this.left-e,this.top-e,this.width+2*e,this.height+2*e)}t.globalAlpha=1,t.drawImage(this.imageObj,this.left,this.top,this.width,this.height)}},s.prototype._drawImageLabel=function(t){var e,i=0;if(this.height){i=this.height/2;var s=this.getTextSize(t);s.lineCount>=1&&(i+=s.height/2,i+=3)}e=this.y+i,this._label(t,this.label,this.x,e,void 0)},s.prototype._drawImage=function(t){this._resizeImage(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2,this._drawImageAtPosition(t),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height,this._drawImageLabel(t),this.boundingBox.left=Math.min(this.boundingBox.left,this.labelDimensions.left),this.boundingBox.right=Math.max(this.boundingBox.right,this.labelDimensions.left+this.labelDimensions.width),this.boundingBox.bottom=Math.max(this.boundingBox.bottom,this.boundingBox.bottom+this.labelDimensions.height)},s.prototype._resizeCircularImage=function(t){if(this.imageObj.src&&this.imageObj.width&&this.imageObj.height)this._swapToImageResizeWhenImageLoaded&&(this.width=0,this.height=0,delete this._swapToImageResizeWhenImageLoaded),this._resizeImage(t);else if(!this.width){var e=2*this.options.radius;this.width=e,this.height=e,this.options.radius+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.options.radius-.5*e,this._swapToImageResizeWhenImageLoaded=!0}},s.prototype._drawCircularImage=function(t){this._resizeCircularImage(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2;var e=this.left+this.width/2,i=this.top+this.height/2,s=Math.abs(this.height/2);this._drawRawCircle(t,e,i,s),t.save(),t.circle(this.x,this.y,s),t.stroke(),t.clip(),this._drawImageAtPosition(t),t.restore(),this.boundingBox.top=this.y-this.options.radius,this.boundingBox.left=this.x-this.options.radius,this.boundingBox.right=this.x+this.options.radius,this.boundingBox.bottom=this.y+this.options.radius,this._drawImageLabel(t),this.boundingBox.left=Math.min(this.boundingBox.left,this.labelDimensions.left),this.boundingBox.right=Math.max(this.boundingBox.right,this.labelDimensions.left+this.labelDimensions.width),this.boundingBox.bottom=Math.max(this.boundingBox.bottom,this.boundingBox.bottom+this.labelDimensions.height)},s.prototype._resizeBox=function(t){if(!this.width){var e=5,i=this.getTextSize(t);this.width=i.width+2*e,this.height=i.height+2*e,this.width+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.growthIndicator=this.width-(i.width+2*e)}},s.prototype._drawBox=function(t){this._resizeBox(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2;var e=2.5,i=this.options.borderWidth,s=this.options.borderWidthSelected||2*this.options.borderWidth;t.strokeStyle=this.selected?this.options.color.highlight.border:this.hover?this.options.color.hover.border:this.options.color.border,this.clusterSize>1&&(t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.roundRect(this.left-2*t.lineWidth,this.top-2*t.lineWidth,this.width+4*t.lineWidth,this.height+4*t.lineWidth,this.options.radius),t.stroke()),t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t.roundRect(this.left,this.top,this.width,this.height,this.options.radius),t.fill(),t.stroke(),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height,this._label(t,this.label,this.x,this.y)},s.prototype._resizeDatabase=function(t){if(!this.width){var e=5,i=this.getTextSize(t),s=i.width+2*e;this.width=s,this.height=s,this.width+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.options.radius+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.width-s}},s.prototype._drawDatabase=function(t){this._resizeDatabase(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2;var e=2.5,i=this.options.borderWidth,s=this.options.borderWidthSelected||2*this.options.borderWidth;t.strokeStyle=this.selected?this.options.color.highlight.border:this.hover?this.options.color.hover.border:this.options.color.border,this.clusterSize>1&&(t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.database(this.x-this.width/2-2*t.lineWidth,this.y-.5*this.height-2*t.lineWidth,this.width+4*t.lineWidth,this.height+4*t.lineWidth),t.stroke()),t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t.database(this.x-this.width/2,this.y-.5*this.height,this.width,this.height),t.fill(),t.stroke(),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height,this._label(t,this.label,this.x,this.y) -},s.prototype._resizeCircle=function(t){if(!this.width){var e=5,i=this.getTextSize(t),s=Math.max(i.width,i.height)+2*e;this.options.radius=s/2,this.width=s,this.height=s,this.options.radius+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.options.radius-.5*s}},s.prototype._drawRawCircle=function(t,e,i,s){var o=2.5,n=this.options.borderWidth,r=this.options.borderWidthSelected||2*this.options.borderWidth;t.strokeStyle=this.selected?this.options.color.highlight.border:this.hover?this.options.color.hover.border:this.options.color.border,this.clusterSize>1&&(t.lineWidth=(this.selected?r:n)+(this.clusterSize>1?o:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.circle(e,i,s+2*t.lineWidth),t.stroke()),t.lineWidth=(this.selected?r:n)+(this.clusterSize>1?o:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t.circle(this.x,this.y,s),t.fill(),t.stroke()},s.prototype._drawCircle=function(t){this._resizeCircle(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2,this._drawRawCircle(t,this.x,this.y,this.options.radius),this.boundingBox.top=this.y-this.options.radius,this.boundingBox.left=this.x-this.options.radius,this.boundingBox.right=this.x+this.options.radius,this.boundingBox.bottom=this.y+this.options.radius,this._label(t,this.label,this.x,this.y)},s.prototype._resizeEllipse=function(t){if(!this.width){var e=this.getTextSize(t);this.width=1.5*e.width,this.height=2*e.height,this.width1&&(t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.ellipse(this.left-2*t.lineWidth,this.top-2*t.lineWidth,this.width+4*t.lineWidth,this.height+4*t.lineWidth),t.stroke()),t.lineWidth=(this.selected?s:i)+(this.clusterSize>1?e:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t.ellipse(this.left,this.top,this.width,this.height),t.fill(),t.stroke(),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height,this._label(t,this.label,this.x,this.y)},s.prototype._drawDot=function(t){this._drawShape(t,"circle")},s.prototype._drawTriangle=function(t){this._drawShape(t,"triangle")},s.prototype._drawTriangleDown=function(t){this._drawShape(t,"triangleDown")},s.prototype._drawSquare=function(t){this._drawShape(t,"square")},s.prototype._drawStar=function(t){this._drawShape(t,"star")},s.prototype._resizeShape=function(){if(!this.width){this.options.radius=this.baseRadiusValue;var t=2*this.options.radius;this.width=t,this.height=t,this.width+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.options.radius+=.5*Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.width-t}},s.prototype._drawShape=function(t,e){this._resizeShape(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2;var i=2.5,s=this.options.borderWidth,o=this.options.borderWidthSelected||2*this.options.borderWidth,n=2;switch(e){case"dot":n=2;break;case"square":n=2;break;case"triangle":n=3;break;case"triangleDown":n=3;break;case"star":n=4}t.strokeStyle=this.selected?this.options.color.highlight.border:this.hover?this.options.color.hover.border:this.options.color.border,this.clusterSize>1&&(t.lineWidth=(this.selected?o:s)+(this.clusterSize>1?i:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t[e](this.x,this.y,this.options.radius+n*t.lineWidth),t.stroke()),t.lineWidth=(this.selected?o:s)+(this.clusterSize>1?i:0),t.lineWidth*=this.networkScaleInv,t.lineWidth=Math.min(this.width,t.lineWidth),t.fillStyle=this.selected?this.options.color.highlight.background:this.hover?this.options.color.hover.background:this.options.color.background,t[e](this.x,this.y,this.options.radius),t.fill(),t.stroke(),this.boundingBox.top=this.y-this.options.radius,this.boundingBox.left=this.x-this.options.radius,this.boundingBox.right=this.x+this.options.radius,this.boundingBox.bottom=this.y+this.options.radius,this.label&&(this._label(t,this.label,this.x,this.y+this.height/2,void 0,"hanging",!0),this.boundingBox.left=Math.min(this.boundingBox.left,this.labelDimensions.left),this.boundingBox.right=Math.max(this.boundingBox.right,this.labelDimensions.left+this.labelDimensions.width),this.boundingBox.bottom=Math.max(this.boundingBox.bottom,this.boundingBox.bottom+this.labelDimensions.height))},s.prototype._resizeText=function(t){if(!this.width){var e=5,i=this.getTextSize(t);this.width=i.width+2*e,this.height=i.height+2*e,this.width+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeWidthFactor,this.height+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeHeightFactor,this.options.radius+=Math.min(this.clusterSize-1,this.maxNodeSizeIncrements)*this.clusterSizeRadiusFactor,this.growthIndicator=this.width-(i.width+2*e)}},s.prototype._drawText=function(t){this._resizeText(t),this.left=this.x-this.width/2,this.top=this.y-this.height/2,this._label(t,this.label,this.x,this.y),this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height},s.prototype._label=function(t,e,i,s,o,n,r){if(e&&Number(this.options.fontSize)*this.networkScale>this.fontDrawThreshold){t.font=(this.selected?"bold ":"")+this.options.fontSize+"px "+this.options.fontFace;var a=e.split("\n"),h=a.length,d=Number(this.options.fontSize),l=s+(1-h)/2*d;1==r&&(l=s+(1-h)/(2*d));for(var c=t.measureText(a[0]).width,p=1;h>p;p++){var u=t.measureText(a[p]).width;c=u>c?u:c}var m=this.options.fontSize*h,f=i-c/2,g=s-m/2;"hanging"==n&&(g+=.5*d,g+=4,l+=4),this.labelDimensions={top:g,left:f,width:c,height:m,yLine:l},void 0!==this.options.fontFill&&null!==this.options.fontFill&&"none"!==this.options.fontFill&&(t.fillStyle=this.options.fontFill,t.fillRect(f,g,c,m)),t.fillStyle=this.options.fontColor||"black",t.textAlign=o||"center",t.textBaseline=n||"middle",this.options.fontStrokeWidth>0&&(t.lineWidth=this.options.fontStrokeWidth,t.strokeStyle=this.options.fontStrokeColor,t.lineJoin="round");for(var p=0;h>p;p++)this.options.fontStrokeWidth&&t.strokeText(a[p],i,l),t.fillText(a[p],i,l),l+=d}},s.prototype.getTextSize=function(t){if(void 0!==this.label){t.font=(this.selected?"bold ":"")+this.options.fontSize+"px "+this.options.fontFace;for(var e=this.label.split("\n"),i=(Number(this.options.fontSize)+4)*e.length,s=0,o=0,n=e.length;n>o;o++)s=Math.max(s,t.measureText(e[o]).width);return{width:s,height:i,lineCount:e.length}}return{width:0,height:0,lineCount:0}},s.prototype.inArea=function(){return void 0!==this.width?this.x+this.width*this.networkScaleInv>=this.canvasTopLeft.x&&this.x-this.width*this.networkScaleInv=this.canvasTopLeft.y&&this.y-this.height*this.networkScaleInv=this.canvasTopLeft.x&&this.x=this.canvasTopLeft.y&&this.ys&&(n=s-e-this.padding),no&&(r=o-i-this.padding),ri;i++)if(e.id===r.nodes[i].id){o=r.nodes[i];break}for(o||(o={id:e.id},t.node&&(o.attr=a(o.attr,t.node))),i=n.length-1;i>=0;i--){var h=n[i];h.nodes||(h.nodes=[]),-1==h.nodes.indexOf(o)&&h.nodes.push(o)}e.attr&&(o.attr=a(o.attr,e.attr))}function l(t,e){if(t.edges||(t.edges=[]),t.edges.push(e),t.edge){var i=a({},t.edge);e.attr=a(i,e.attr)}}function c(t,e,i,s,o){var n={from:e,to:i,type:s};return t.edge&&(n.attr=a({},t.edge)),n.attr=a(n.attr||{},o),n}function p(){for(N=D.NULL,k="";" "==E||" "==E||"\n"==E||"\r"==E;)o();do{var t=!1;if("#"==E){for(var e=O-1;" "==T.charAt(e)||" "==T.charAt(e);)e--;if("\n"==T.charAt(e)||""==T.charAt(e)){for(;""!=E&&"\n"!=E;)o();t=!0}}if("/"==E&&"/"==n()){for(;""!=E&&"\n"!=E;)o();t=!0}if("/"==E&&"*"==n()){for(;""!=E;){if("*"==E&&"/"==n()){o(),o();break}o()}t=!0}for(;" "==E||" "==E||"\n"==E||"\r"==E;)o()}while(t);if(""==E)return void(N=D.DELIMITER);var i=E+n();if(C[i])return N=D.DELIMITER,k=i,o(),void o();if(C[E])return N=D.DELIMITER,k=E,void o();if(r(E)||"-"==E){for(k+=E,o();r(E);)k+=E,o();return"false"==k?k=!1:"true"==k?k=!0:isNaN(Number(k))||(k=Number(k)),void(N=D.IDENTIFIER)}if('"'==E){for(o();""!=E&&('"'!=E||'"'==E&&'"'==n());)k+=E,'"'==E&&o(),o();if('"'!=E)throw x('End of string " expected');return o(),void(N=D.IDENTIFIER)}for(N=D.UNKNOWN;""!=E;)k+=E,o();throw new SyntaxError('Syntax error in part "'+w(k,30)+'"')}function u(){var t={};if(s(),p(),"strict"==k&&(t.strict=!0,p()),("graph"==k||"digraph"==k)&&(t.type=k,p()),N==D.IDENTIFIER&&(t.id=k,p()),"{"!=k)throw x("Angle bracket { expected");if(p(),m(t),"}"!=k)throw x("Angle bracket } expected");if(p(),""!==k)throw x("End of file expected");return p(),delete t.node,delete t.edge,delete t.graph,t}function m(t){for(;""!==k&&"}"!=k;)f(t),";"==k&&p()}function f(t){var e=g(t);if(e)return void b(t,e);var i=v(t);if(!i){if(N!=D.IDENTIFIER)throw x("Identifier expected");var s=k;if(p(),"="==k){if(p(),N!=D.IDENTIFIER)throw x("Identifier expected");t[s]=k,p()}else y(t,s)}}function g(t){var e=null;if("subgraph"==k&&(e={},e.type="subgraph",p(),N==D.IDENTIFIER&&(e.id=k,p())),"{"==k){if(p(),e||(e={}),e.parent=t,e.node=t.node,e.edge=t.edge,e.graph=t.graph,m(e),"}"!=k)throw x("Angle bracket } expected");p(),delete e.node,delete e.edge,delete e.graph,delete e.parent,t.subgraphs||(t.subgraphs=[]),t.subgraphs.push(e)}return e}function v(t){return"node"==k?(p(),t.node=_(),"node"):"edge"==k?(p(),t.edge=_(),"edge"):"graph"==k?(p(),t.graph=_(),"graph"):null}function y(t,e){var i={id:e},s=_();s&&(i.attr=s),d(t,i),b(t,e)}function b(t,e){for(;"->"==k||"--"==k;){var i,s=k;p();var o=g(t);if(o)i=o;else{if(N!=D.IDENTIFIER)throw x("Identifier or subgraph expected");i=k,d(t,{id:i}),p()}var n=_(),r=c(t,e,i,s,n);l(t,r),e=i}}function _(){for(var t=null;"["==k;){for(p(),t={};""!==k&&"]"!=k;){if(N!=D.IDENTIFIER)throw x("Attribute name expected");var e=k;if(p(),"="!=k)throw x("Equal sign = expected");if(p(),N!=D.IDENTIFIER)throw x("Attribute value expected");var i=k;h(t,e,i),p(),","==k&&p()}if("]"!=k)throw x("Bracket ] expected");p()}return t}function x(t){return new SyntaxError(t+', got "'+w(k,30)+'" (char '+O+")")}function w(t,e){return t.length<=e?t:t.substr(0,27)+"..."}function S(t,e,i){Array.isArray(t)?t.forEach(function(t){Array.isArray(e)?e.forEach(function(e){i(t,e)}):i(t,e)}):Array.isArray(e)?e.forEach(function(e){i(t,e)}):i(t,e)}function M(t){var e=i(t),s={nodes:[],edges:[],options:{}};if(e.nodes&&e.nodes.forEach(function(t){var e={id:t.id,label:String(t.label||t.id)};a(e,t.attr),e.image&&(e.shape="image"),s.nodes.push(e)}),e.edges){var o=function(t){var e={from:t.from,to:t.to};return a(e,t.attr),e.style="->"==t.type?"arrow":"line",e};e.edges.forEach(function(t){var e,i;e=t.from instanceof Object?t.from.nodes:{id:t.from},i=t.to instanceof Object?t.to.nodes:{id:t.to},t.from instanceof Object&&t.from.edges&&t.from.edges.forEach(function(t){var e=o(t);s.edges.push(e)}),S(e,i,function(e,i){var n=c(s,e.id,i.id,t.type,t.attr),r=o(n);s.edges.push(r)}),t.to instanceof Object&&t.to.edges&&t.to.edges.forEach(function(t){var e=o(t);s.edges.push(e)})})}return e.attr&&(s.options=e.attr),s}var D={NULL:0,DELIMITER:1,IDENTIFIER:2,UNKNOWN:3},C={"{":!0,"}":!0,"[":!0,"]":!0,";":!0,"=":!0,",":!0,"->":!0,"--":!0},T="",O=0,E="",k="",N=D.NULL,I=/[a-zA-Z_0-9.:#]/;e.parseDOT=i,e.DOTToGraph=M},function(t,e){function i(t,e){var i=[],s=[];this.options={edges:{inheritColor:!0},nodes:{allowedToMove:!1,parseColor:!1}},void 0!==e&&(this.options.nodes.allowedToMove=e.allowedToMove|!1,this.options.nodes.parseColor=e.parseColor|!1,this.options.edges.inheritColor=e.inheritColor|!0);for(var o=t.edges,n=t.nodes,r=0;r=s&&(s=864e5),e=new Date(e.valueOf()-.05*s),i=new Date(i.valueOf()+.05*s)}return{start:e,end:i}},s.prototype.setWindow=function(t,e,i){var s=i&&void 0!==i.animate?i.animate:!0;if(1==arguments.length){var o=arguments[0];this.range.setRange(o.start,o.end,s)}else this.range.setRange(t,e,s)},s.prototype.moveTo=function(t,e){var i=this.range.end-this.range.start,s=r.convert(t,"Date").valueOf(),o=s-i/2,n=s+i/2,a=e&&void 0!==e.animate?e.animate:!0;this.range.setRange(o,n,a)},s.prototype.getWindow=function(){var t=this.range.getRange();return{start:new Date(t.start),end:new Date(t.end)}},s.prototype.redraw=function(){var t=!1,e=this.options,i=this.props,s=this.dom;if(s){h.updateHiddenDates(this.body,this.options.hiddenDates),"top"==e.orientation?(r.addClassName(s.root,"top"),r.removeClassName(s.root,"bottom")):(r.removeClassName(s.root,"top"),r.addClassName(s.root,"bottom")),s.root.style.maxHeight=r.option.asSize(e.maxHeight,""),s.root.style.minHeight=r.option.asSize(e.minHeight,""),s.root.style.width=r.option.asSize(e.width,""),i.border.left=(s.centerContainer.offsetWidth-s.centerContainer.clientWidth)/2,i.border.right=i.border.left,i.border.top=(s.centerContainer.offsetHeight-s.centerContainer.clientHeight)/2,i.border.bottom=i.border.top;var o=s.root.offsetHeight-s.root.clientHeight,n=s.root.offsetWidth-s.root.clientWidth;0===s.centerContainer.clientHeight&&(i.border.left=i.border.top,i.border.right=i.border.left),0===s.root.clientHeight&&(n=o),i.center.height=s.center.offsetHeight,i.left.height=s.left.offsetHeight,i.right.height=s.right.offsetHeight,i.top.height=s.top.clientHeight||-i.border.top,i.bottom.height=s.bottom.clientHeight||-i.border.bottom;var a=Math.max(i.left.height,i.center.height,i.right.height),d=i.top.height+a+i.bottom.height+o+i.border.top+i.border.bottom;s.root.style.height=r.option.asSize(e.height,d+"px"),i.root.height=s.root.offsetHeight,i.background.height=i.root.height-o;var l=i.root.height-i.top.height-i.bottom.height-o;i.centerContainer.height=l,i.leftContainer.height=l,i.rightContainer.height=i.leftContainer.height,i.root.width=s.root.offsetWidth,i.background.width=i.root.width-n,i.left.width=s.leftContainer.clientWidth||-i.border.left,i.leftContainer.width=i.left.width,i.right.width=s.rightContainer.clientWidth||-i.border.right,i.rightContainer.width=i.right.width;var c=i.root.width-i.left.width-i.right.width-n;i.center.width=c,i.centerContainer.width=c,i.top.width=c,i.bottom.width=c,s.background.style.height=i.background.height+"px",s.backgroundVertical.style.height=i.background.height+"px",s.backgroundHorizontal.style.height=i.centerContainer.height+"px",s.centerContainer.style.height=i.centerContainer.height+"px",s.leftContainer.style.height=i.leftContainer.height+"px",s.rightContainer.style.height=i.rightContainer.height+"px",s.background.style.width=i.background.width+"px",s.backgroundVertical.style.width=i.centerContainer.width+"px",s.backgroundHorizontal.style.width=i.background.width+"px",s.centerContainer.style.width=i.center.width+"px",s.top.style.width=i.top.width+"px",s.bottom.style.width=i.bottom.width+"px",s.background.style.left="0",s.background.style.top="0",s.backgroundVertical.style.left=i.left.width+i.border.left+"px",s.backgroundVertical.style.top="0",s.backgroundHorizontal.style.left="0",s.backgroundHorizontal.style.top=i.top.height+"px",s.centerContainer.style.left=i.left.width+"px",s.centerContainer.style.top=i.top.height+"px",s.leftContainer.style.left="0",s.leftContainer.style.top=i.top.height+"px",s.rightContainer.style.left=i.left.width+i.center.width+"px",s.rightContainer.style.top=i.top.height+"px",s.top.style.left=i.left.width+"px",s.top.style.top="0",s.bottom.style.left=i.left.width+"px",s.bottom.style.top=i.top.height+i.centerContainer.height+"px",this._updateScrollTop();var p=this.props.scrollTop;"bottom"==e.orientation&&(p+=Math.max(this.props.centerContainer.height-this.props.center.height-this.props.border.top-this.props.border.bottom,0)),s.center.style.left="0",s.center.style.top=p+"px",s.left.style.left="0",s.left.style.top=p+"px",s.right.style.left="0",s.right.style.top=p+"px";var u=0==this.props.scrollTop?"hidden":"",m=this.props.scrollTop==this.props.scrollTopMin?"hidden":"";if(s.shadowTop.style.visibility=u,s.shadowBottom.style.visibility=m,s.shadowTopLeft.style.visibility=u,s.shadowBottomLeft.style.visibility=m,s.shadowTopRight.style.visibility=u,s.shadowBottomRight.style.visibility=m,this.components.forEach(function(e){t=e.redraw()||t}),t){var f=3;this.redrawCount0&&(this.props.scrollTop=0),this.props.scrollTops;s++){var o=s%2===0?1.3*i:.5*i;this.lineTo(t+o*Math.sin(2*s*Math.PI/10),e-o*Math.cos(2*s*Math.PI/10))}this.closePath()},CanvasRenderingContext2D.prototype.roundRect=function(t,e,i,s,o){var n=Math.PI/180;0>i-2*o&&(o=i/2),0>s-2*o&&(o=s/2),this.beginPath(),this.moveTo(t+o,e),this.lineTo(t+i-o,e),this.arc(t+i-o,e+o,o,270*n,360*n,!1),this.lineTo(t+i,e+s-o),this.arc(t+i-o,e+s-o,o,0,90*n,!1),this.lineTo(t+o,e+s),this.arc(t+o,e+s-o,o,90*n,180*n,!1),this.lineTo(t,e+o),this.arc(t+o,e+o,o,180*n,270*n,!1)},CanvasRenderingContext2D.prototype.ellipse=function(t,e,i,s){var o=.5522848,n=i/2*o,r=s/2*o,a=t+i,h=e+s,d=t+i/2,l=e+s/2; -this.beginPath(),this.moveTo(t,l),this.bezierCurveTo(t,l-r,d-n,e,d,e),this.bezierCurveTo(d+n,e,a,l-r,a,l),this.bezierCurveTo(a,l+r,d+n,h,d,h),this.bezierCurveTo(d-n,h,t,l+r,t,l)},CanvasRenderingContext2D.prototype.database=function(t,e,i,s){var o=1/3,n=i,r=s*o,a=.5522848,h=n/2*a,d=r/2*a,l=t+n,c=e+r,p=t+n/2,u=e+r/2,m=e+(s-r/2),f=e+s;this.beginPath(),this.moveTo(l,u),this.bezierCurveTo(l,u+d,p+h,c,p,c),this.bezierCurveTo(p-h,c,t,u+d,t,u),this.bezierCurveTo(t,u-d,p-h,e,p,e),this.bezierCurveTo(p+h,e,l,u-d,l,u),this.lineTo(l,m),this.bezierCurveTo(l,m+d,p+h,f,p,f),this.bezierCurveTo(p-h,f,t,m+d,t,m),this.lineTo(t,u)},CanvasRenderingContext2D.prototype.arrow=function(t,e,i,s){var o=t-s*Math.cos(i),n=e-s*Math.sin(i),r=t-.9*s*Math.cos(i),a=e-.9*s*Math.sin(i),h=o+s/3*Math.cos(i+.5*Math.PI),d=n+s/3*Math.sin(i+.5*Math.PI),l=o+s/3*Math.cos(i-.5*Math.PI),c=n+s/3*Math.sin(i-.5*Math.PI);this.beginPath(),this.moveTo(t,e),this.lineTo(h,d),this.lineTo(r,a),this.lineTo(l,c),this.closePath()},CanvasRenderingContext2D.prototype.dashedLine=function(t,e,i,s,o){o||(o=[10,5]),0==p&&(p=.001);var n=o.length;this.moveTo(t,e);for(var r=i-t,a=s-e,h=a/r,d=Math.sqrt(r*r+a*a),l=0,c=!0;d>=.1;){var p=o[l++%n];p>d&&(p=d);var u=Math.sqrt(p*p/(1+h*h));0>r&&(u=-u),t+=u,e+=h*u,this[c?"lineTo":"moveTo"](t,e),d-=p,c=!c}})},function(t,e,i){function s(t,e){this.groupId=t,this.options=e}var o=i(2),n=i(53);s.prototype.getYRange=function(t){for(var e=t[0].y,i=t[0].y,s=0;st[s].y?t[s].y:e,i=i0){var r,a,h=Number(i.svg.style.height.replace("px",""));if(r=o.getSVGElement("path",i.svgElements,i.svg),r.setAttributeNS(null,"class",e.className),void 0!==e.style&&r.setAttributeNS(null,"style",e.style),a=1==e.options.catmullRom.enabled?s._catmullRom(t,e):s._linear(t),1==e.options.shaded.enabled){var d,l=o.getSVGElement("path",i.svgElements,i.svg);d="top"==e.options.shaded.orientation?"M"+t[0].x+",0 "+a+"L"+t[t.length-1].x+",0":"M"+t[0].x+","+h+" "+a+"L"+t[t.length-1].x+","+h,l.setAttributeNS(null,"class",e.className+" fill"),void 0!==e.options.shaded.style&&l.setAttributeNS(null,"style",e.options.shaded.style),l.setAttributeNS(null,"d",d)}r.setAttributeNS(null,"d","M"+a),1==e.options.drawPoints.enabled&&n.draw(t,e,i)}},s._catmullRomUniform=function(t){for(var e,i,s,o,n,r,a=Math.round(t[0].x)+","+Math.round(t[0].y)+" ",h=1/6,d=t.length,l=0;d-1>l;l++)e=0==l?t[0]:t[l-1],i=t[l],s=t[l+1],o=d>l+2?t[l+2]:s,n={x:(-e.x+6*i.x+s.x)*h,y:(-e.y+6*i.y+s.y)*h},r={x:(i.x+6*s.x-o.x)*h,y:(i.y+6*s.y-o.y)*h},a+="C"+n.x+","+n.y+" "+r.x+","+r.y+" "+s.x+","+s.y+" ";return a},s._catmullRom=function(t,e){var i=e.options.catmullRom.alpha;if(0==i||void 0===i)return this._catmullRomUniform(t);for(var s,o,n,r,a,h,d,l,c,p,u,m,f,g,v,y,b,_,x,w=Math.round(t[0].x)+","+Math.round(t[0].y)+" ",S=t.length,M=0;S-1>M;M++)s=0==M?t[0]:t[M-1],o=t[M],n=t[M+1],r=S>M+2?t[M+2]:n,d=Math.sqrt(Math.pow(s.x-o.x,2)+Math.pow(s.y-o.y,2)),l=Math.sqrt(Math.pow(o.x-n.x,2)+Math.pow(o.y-n.y,2)),c=Math.sqrt(Math.pow(n.x-r.x,2)+Math.pow(n.y-r.y,2)),g=Math.pow(c,i),y=Math.pow(c,2*i),v=Math.pow(l,i),b=Math.pow(l,2*i),x=Math.pow(d,i),_=Math.pow(d,2*i),p=2*_+3*x*v+b,u=2*y+3*g*v+b,m=3*x*(x+v),m>0&&(m=1/m),f=3*g*(g+v),f>0&&(f=1/f),a={x:(-b*s.x+p*o.x+_*n.x)*m,y:(-b*s.y+p*o.y+_*n.y)*m},h={x:(y*o.x+u*n.x-b*r.x)*f,y:(y*o.y+u*n.y-b*r.y)*f},0==a.x&&0==a.y&&(a=o),0==h.x&&0==h.y&&(h=n),w+="C"+a.x+","+a.y+" "+h.x+","+h.y+" "+n.x+","+n.y+" ";return w},s._linear=function(t){for(var e="",i=0;it[s].y?t[s].y:e,i=i0&&(n=Math.min(n,Math.abs(c[d-1].x-r))),a=s._getSafeDrawData(n,h,m);else{var g=d+(p[r].amount-p[r].resolved),v=d-(p[r].resolved+1);g0&&(n=Math.min(n,Math.abs(c[v].x-r))),a=s._getSafeDrawData(n,h,m),p[r].resolved+=1,"stack"==h.options.barChart.handleOverlap?(f=p[r].accumulated,p[r].accumulated+=h.zeroPosition-c[d].y):"sideBySide"==h.options.barChart.handleOverlap&&(a.width=a.width/p[r].amount,a.offset+=p[r].resolved*a.width-.5*a.width*(p[r].amount+1),"left"==h.options.barChart.align?a.offset-=.5*a.width:"right"==h.options.barChart.align&&(a.offset+=.5*a.width))}o.drawBar(c[d].x+a.offset,c[d].y-f,a.width,h.zeroPosition-c[d].y,h.className+" bar",i.svgElements,i.svg),1==h.options.drawPoints.enabled&&o.drawPoint(c[d].x+a.offset,c[d].y,h,i.svgElements,i.svg)}},s._getDataIntersections=function(t,e){for(var i,s=0;s0&&(i=Math.min(i,Math.abs(e[s-1].x-e[s].x))),0==i&&(void 0===t[e[s].x]&&(t[e[s].x]={amount:0,resolved:0,accumulated:0}),t[e[s].x].amount+=1)},s._getSafeDrawData=function(t,e,i){var s,o;return t0?(s=i>t?i:t,o=0,"left"==e.options.barChart.align?o-=.5*t:"right"==e.options.barChart.align&&(o+=.5*t)):(s=e.options.barChart.width,o=0,"left"==e.options.barChart.align?o-=.5*e.options.barChart.width:"right"==e.options.barChart.align&&(o+=.5*e.options.barChart.width)),{width:s,offset:o}},s.getStackedBarYRange=function(t,e,i,o,n){if(t.length>0){t.sort(function(t,e){return t.x==e.x?t.groupId-e.groupId:t.x-e.x});var r={};s._getDataIntersections(r,t),e[o]=s._getStackedBarYRange(r,t),e[o].yAxisOrientation=n,i.push(o)}},s._getStackedBarYRange=function(t,e){for(var i,s=e[0].y,o=e[0].y,n=0;ne[n].y?e[n].y:s,o=ot[r].accumulated?t[r].accumulated:s,o=ot[s].y?t[s].y:e,i=is;++s)i[s].apply(this,e)}return this},e.prototype.listeners=function(t){return this._callbacks=this._callbacks||{},this._callbacks[t]||[]},e.prototype.hasListeners=function(t){return!!this.listeners(t).length}},function(t,e,i){var s;(function(t,o){(function(n){function r(t,e,i){switch(arguments.length){case 2:return null!=t?t:e;case 3:return null!=t?t:null!=e?e:i;default:throw new Error("Implement me")}}function a(t,e){return Ie.call(t,e)}function h(){return{empty:!1,unusedTokens:[],unusedInput:[],overflow:-2,charsLeftOver:0,nullInput:!1,invalidMonth:null,invalidFormat:!1,userInvalidated:!1,iso:!1}}function d(t){Ce.suppressDeprecationWarnings===!1&&"undefined"!=typeof console&&console.warn&&console.warn("Deprecation warning: "+t)}function l(t,e){var i=!0;return b(function(){return i&&(d(t),i=!1),e.apply(this,arguments)},e)}function c(t,e){Si[t]||(d(e),Si[t]=!0)}function p(t,e){return function(i){return w(t.call(this,i),e)}}function u(t,e){return function(i){return this.localeData().ordinal(t.call(this,i),e)}}function m(t,e){var i,s,o=12*(e.year()-t.year())+(e.month()-t.month()),n=t.clone().add(o,"months");return 0>e-n?(i=t.clone().add(o-1,"months"),s=(e-n)/(n-i)):(i=t.clone().add(o+1,"months"),s=(e-n)/(i-n)),-(o+s)}function f(t,e,i){var s;return null==i?e:null!=t.meridiemHour?t.meridiemHour(e,i):null!=t.isPM?(s=t.isPM(i),s&&12>e&&(e+=12),s||12!==e||(e=0),e):e}function g(){}function v(t,e){e!==!1&&F(t),_(this,t),this._d=new Date(+t._d),Di===!1&&(Di=!0,Ce.updateOffset(this),Di=!1)}function y(t){var e=N(t),i=e.year||0,s=e.quarter||0,o=e.month||0,n=e.week||0,r=e.day||0,a=e.hour||0,h=e.minute||0,d=e.second||0,l=e.millisecond||0;this._milliseconds=+l+1e3*d+6e4*h+36e5*a,this._days=+r+7*n,this._months=+o+3*s+12*i,this._data={},this._locale=Ce.localeData(),this._bubble()}function b(t,e){for(var i in e)a(e,i)&&(t[i]=e[i]);return a(e,"toString")&&(t.toString=e.toString),a(e,"valueOf")&&(t.valueOf=e.valueOf),t}function _(t,e){var i,s,o;if("undefined"!=typeof e._isAMomentObject&&(t._isAMomentObject=e._isAMomentObject),"undefined"!=typeof e._i&&(t._i=e._i),"undefined"!=typeof e._f&&(t._f=e._f),"undefined"!=typeof e._l&&(t._l=e._l),"undefined"!=typeof e._strict&&(t._strict=e._strict),"undefined"!=typeof e._tzm&&(t._tzm=e._tzm),"undefined"!=typeof e._isUTC&&(t._isUTC=e._isUTC),"undefined"!=typeof e._offset&&(t._offset=e._offset),"undefined"!=typeof e._pf&&(t._pf=e._pf),"undefined"!=typeof e._locale&&(t._locale=e._locale),Ye.length>0)for(i in Ye)s=Ye[i],o=e[s],"undefined"!=typeof o&&(t[s]=o);return t}function x(t){return 0>t?Math.ceil(t):Math.floor(t)}function w(t,e,i){for(var s=""+Math.abs(t),o=t>=0;s.lengths;s++)(i&&t[s]!==e[s]||!i&&L(t[s])!==L(e[s]))&&r++;return r+n}function k(t){if(t){var e=t.toLowerCase().replace(/(.)s$/,"$1");t=gi[t]||vi[e]||e}return t}function N(t){var e,i,s={};for(i in t)a(t,i)&&(e=k(i),e&&(s[e]=t[i]));return s}function I(t){var e,i;if(0===t.indexOf("week"))e=7,i="day";else{if(0!==t.indexOf("month"))return;e=12,i="month"}Ce[t]=function(s,o){var r,a,h=Ce._locale[t],d=[];if("number"==typeof s&&(o=s,s=n),a=function(t){var e=Ce().utc().set(i,t);return h.call(Ce._locale,e,s||"")},null!=o)return a(o);for(r=0;e>r;r++)d.push(a(r));return d}}function L(t){var e=+t,i=0;return 0!==e&&isFinite(e)&&(i=e>=0?Math.floor(e):Math.ceil(e)),i}function z(t,e){return new Date(Date.UTC(t,e+1,0)).getUTCDate()}function P(t,e,i){return me(Ce([t,11,31+e-i]),e,i).week}function A(t){return R(t)?366:365}function R(t){return t%4===0&&t%100!==0||t%400===0}function F(t){var e;t._a&&-2===t._pf.overflow&&(e=t._a[ze]<0||t._a[ze]>11?ze:t._a[Pe]<1||t._a[Pe]>z(t._a[Le],t._a[ze])?Pe:t._a[Ae]<0||t._a[Ae]>24||24===t._a[Ae]&&(0!==t._a[Re]||0!==t._a[Fe]||0!==t._a[He])?Ae:t._a[Re]<0||t._a[Re]>59?Re:t._a[Fe]<0||t._a[Fe]>59?Fe:t._a[He]<0||t._a[He]>999?He:-1,t._pf._overflowDayOfYear&&(Le>e||e>Pe)&&(e=Pe),t._pf.overflow=e)}function H(t){return null==t._isValid&&(t._isValid=!isNaN(t._d.getTime())&&t._pf.overflow<0&&!t._pf.empty&&!t._pf.invalidMonth&&!t._pf.nullInput&&!t._pf.invalidFormat&&!t._pf.userInvalidated,t._strict&&(t._isValid=t._isValid&&0===t._pf.charsLeftOver&&0===t._pf.unusedTokens.length&&t._pf.bigHour===n)),t._isValid}function B(t){return t?t.toLowerCase().replace("_","-"):t}function Y(t){for(var e,i,s,o,n=0;n0;){if(s=W(o.slice(0,e).join("-")))return s;if(i&&i.length>=e&&E(o,i,!0)>=e-1)break;e--}n++}return null}function W(t){var e=null;if(!Be[t]&&We)try{e=Ce.locale(),!function(){var t=new Error('Cannot find module "./locale"');throw t.code="MODULE_NOT_FOUND",t}(),Ce.locale(e)}catch(i){}return Be[t]}function G(t,e){var i,s;return e._isUTC?(i=e.clone(),s=(Ce.isMoment(t)||O(t)?+t:+Ce(t))-+i,i._d.setTime(+i._d+s),Ce.updateOffset(i,!1),i):Ce(t).local()}function j(t){return t.match(/\[[\s\S]/)?t.replace(/^\[|\]$/g,""):t.replace(/\\/g,"")}function U(t){var e,i,s=t.match(Ve);for(e=0,i=s.length;i>e;e++)s[e]=wi[s[e]]?wi[s[e]]:j(s[e]);return function(o){var n="";for(e=0;i>e;e++)n+=s[e]instanceof Function?s[e].call(o,t):s[e];return n}}function V(t,e){return t.isValid()?(e=X(e,t.localeData()),yi[e]||(yi[e]=U(e)),yi[e](t)):t.localeData().invalidDate()}function X(t,e){function i(t){return e.longDateFormat(t)||t}var s=5;for(Xe.lastIndex=0;s>=0&&Xe.test(t);)t=t.replace(Xe,i),Xe.lastIndex=0,s-=1;return t}function q(t,e){var i,s=e._strict;switch(t){case"Q":return oi;case"DDDD":return ri;case"YYYY":case"GGGG":case"gggg":return s?ai:Qe;case"Y":case"G":case"g":return di;case"YYYYYY":case"YYYYY":case"GGGGG":case"ggggg":return s?hi:Ke;case"S":if(s)return oi;case"SS":if(s)return ni;case"SSS":if(s)return ri;case"DDD":return Ze;case"MMM":case"MMMM":case"dd":case"ddd":case"dddd":return Je;case"a":case"A":return e._locale._meridiemParse;case"x":return ii;case"X":return si;case"Z":case"ZZ":return ti;case"T":return ei;case"SSSS":return $e;case"MM":case"DD":case"YY":case"GG":case"gg":case"HH":case"hh":case"mm":case"ss":case"ww":case"WW":return s?ni:qe;case"M":case"D":case"d":case"H":case"h":case"m":case"s":case"w":case"W":case"e":case"E":return qe;case"Do":return s?e._locale._ordinalParse:e._locale._ordinalParseLenient;default:return i=new RegExp(se(ie(t.replace("\\","")),"i"))}}function Z(t){t=t||"";var e=t.match(ti)||[],i=e[e.length-1]||[],s=(i+"").match(mi)||["-",0,0],o=+(60*s[1])+L(s[2]);return"+"===s[0]?o:-o}function Q(t,e,i){var s,o=i._a;switch(t){case"Q":null!=e&&(o[ze]=3*(L(e)-1));break;case"M":case"MM":null!=e&&(o[ze]=L(e)-1);break;case"MMM":case"MMMM":s=i._locale.monthsParse(e,t,i._strict),null!=s?o[ze]=s:i._pf.invalidMonth=e;break;case"D":case"DD":null!=e&&(o[Pe]=L(e));break;case"Do":null!=e&&(o[Pe]=L(parseInt(e.match(/\d{1,2}/)[0],10)));break;case"DDD":case"DDDD":null!=e&&(i._dayOfYear=L(e));break;case"YY":o[Le]=Ce.parseTwoDigitYear(e);break;case"YYYY":case"YYYYY":case"YYYYYY":o[Le]=L(e);break;case"a":case"A":i._meridiem=e;break;case"h":case"hh":i._pf.bigHour=!0;case"H":case"HH":o[Ae]=L(e);break;case"m":case"mm":o[Re]=L(e);break;case"s":case"ss":o[Fe]=L(e);break;case"S":case"SS":case"SSS":case"SSSS":o[He]=L(1e3*("0."+e));break;case"x":i._d=new Date(L(e));break;case"X":i._d=new Date(1e3*parseFloat(e));break;case"Z":case"ZZ":i._useUTC=!0,i._tzm=Z(e);break;case"dd":case"ddd":case"dddd":s=i._locale.weekdaysParse(e),null!=s?(i._w=i._w||{},i._w.d=s):i._pf.invalidWeekday=e;break;case"w":case"ww":case"W":case"WW":case"d":case"e":case"E":t=t.substr(0,1);case"gggg":case"GGGG":case"GGGGG":t=t.substr(0,2),e&&(i._w=i._w||{},i._w[t]=L(e));break;case"gg":case"GG":i._w=i._w||{},i._w[t]=Ce.parseTwoDigitYear(e)}}function K(t){var e,i,s,o,n,a,h;e=t._w,null!=e.GG||null!=e.W||null!=e.E?(n=1,a=4,i=r(e.GG,t._a[Le],me(Ce(),1,4).year),s=r(e.W,1),o=r(e.E,1)):(n=t._locale._week.dow,a=t._locale._week.doy,i=r(e.gg,t._a[Le],me(Ce(),n,a).year),s=r(e.w,1),null!=e.d?(o=e.d,n>o&&++s):o=null!=e.e?e.e+n:n),h=fe(i,s,o,a,n),t._a[Le]=h.year,t._dayOfYear=h.dayOfYear}function $(t){var e,i,s,o,n=[];if(!t._d){for(s=te(t),t._w&&null==t._a[Pe]&&null==t._a[ze]&&K(t),t._dayOfYear&&(o=r(t._a[Le],s[Le]),t._dayOfYear>A(o)&&(t._pf._overflowDayOfYear=!0),i=le(o,0,t._dayOfYear),t._a[ze]=i.getUTCMonth(),t._a[Pe]=i.getUTCDate()),e=0;3>e&&null==t._a[e];++e)t._a[e]=n[e]=s[e];for(;7>e;e++)t._a[e]=n[e]=null==t._a[e]?2===e?1:0:t._a[e];24===t._a[Ae]&&0===t._a[Re]&&0===t._a[Fe]&&0===t._a[He]&&(t._nextDay=!0,t._a[Ae]=0),t._d=(t._useUTC?le:de).apply(null,n),null!=t._tzm&&t._d.setUTCMinutes(t._d.getUTCMinutes()-t._tzm),t._nextDay&&(t._a[Ae]=24)}}function J(t){var e;t._d||(e=N(t._i),t._a=[e.year,e.month,e.day||e.date,e.hour,e.minute,e.second,e.millisecond],$(t))}function te(t){var e=new Date;return t._useUTC?[e.getUTCFullYear(),e.getUTCMonth(),e.getUTCDate()]:[e.getFullYear(),e.getMonth(),e.getDate()]}function ee(t){if(t._f===Ce.ISO_8601)return void ne(t);t._a=[],t._pf.empty=!0;var e,i,s,o,r,a=""+t._i,h=a.length,d=0;for(s=X(t._f,t._locale).match(Ve)||[],e=0;e0&&t._pf.unusedInput.push(r),a=a.slice(a.indexOf(i)+i.length),d+=i.length),wi[o]?(i?t._pf.empty=!1:t._pf.unusedTokens.push(o),Q(o,i,t)):t._strict&&!i&&t._pf.unusedTokens.push(o);t._pf.charsLeftOver=h-d,a.length>0&&t._pf.unusedInput.push(a),t._pf.bigHour===!0&&t._a[Ae]<=12&&(t._pf.bigHour=n),t._a[Ae]=f(t._locale,t._a[Ae],t._meridiem),$(t),F(t)}function ie(t){return t.replace(/\\(\[)|\\(\])|\[([^\]\[]*)\]|\\(.)/g,function(t,e,i,s,o){return e||i||s||o})}function se(t){return t.replace(/[-\/\\^$*+?.()|[\]{}]/g,"\\$&")}function oe(t){var e,i,s,o,n;if(0===t._f.length)return t._pf.invalidFormat=!0,void(t._d=new Date(0/0));for(o=0;on)&&(s=n,i=e));b(t,i||e)}function ne(t){var e,i,s=t._i,o=li.exec(s);if(o){for(t._pf.iso=!0,e=0,i=pi.length;i>e;e++)if(pi[e][1].exec(s)){t._f=pi[e][0]+(o[6]||" ");break}for(e=0,i=ui.length;i>e;e++)if(ui[e][1].exec(s)){t._f+=ui[e][0];break}s.match(ti)&&(t._f+="Z"),ee(t)}else t._isValid=!1}function re(t){ne(t),t._isValid===!1&&(delete t._isValid,Ce.createFromInputFallback(t))}function ae(t,e){var i,s=[];for(i=0;it&&a.setFullYear(t),a}function le(t){var e=new Date(Date.UTC.apply(null,arguments));return 1970>t&&e.setUTCFullYear(t),e}function ce(t,e){if("string"==typeof t)if(isNaN(t)){if(t=e.weekdaysParse(t),"number"!=typeof t)return null}else t=parseInt(t,10);return t}function pe(t,e,i,s,o){return o.relativeTime(e||1,!!i,t,s)}function ue(t,e,i){var s=Ce.duration(t).abs(),o=Ne(s.as("s")),n=Ne(s.as("m")),r=Ne(s.as("h")),a=Ne(s.as("d")),h=Ne(s.as("M")),d=Ne(s.as("y")),l=o0,l[4]=i,pe.apply({},l)}function me(t,e,i){var s,o=i-e,n=i-t.day();return n>o&&(n-=7),o-7>n&&(n+=7),s=Ce(t).add(n,"d"),{week:Math.ceil(s.dayOfYear()/7),year:s.year()}}function fe(t,e,i,s,o){var n,r,a=le(t,0,1).getUTCDay();return a=0===a?7:a,i=null!=i?i:o,n=o-a+(a>s?7:0)-(o>a?7:0),r=7*(e-1)+(i-o)+n+1,{year:r>0?t:t-1,dayOfYear:r>0?r:A(t-1)+r}}function ge(t){var e,i=t._i,s=t._f;return t._locale=t._locale||Ce.localeData(t._l),null===i||s===n&&""===i?Ce.invalid({nullInput:!0}):("string"==typeof i&&(t._i=i=t._locale.preparse(i)),Ce.isMoment(i)?new v(i,!0):(s?T(s)?oe(t):ee(t):he(t),e=new v(t),e._nextDay&&(e.add(1,"d"),e._nextDay=n),e))}function ve(t,e){var i,s;if(1===e.length&&T(e[0])&&(e=e[0]),!e.length)return Ce();for(i=e[0],s=1;s=0?"+":"-";return e+w(Math.abs(t),6)},gg:function(){return w(this.weekYear()%100,2)},gggg:function(){return w(this.weekYear(),4)},ggggg:function(){return w(this.weekYear(),5)},GG:function(){return w(this.isoWeekYear()%100,2)},GGGG:function(){return w(this.isoWeekYear(),4)},GGGGG:function(){return w(this.isoWeekYear(),5)},e:function(){return this.weekday()},E:function(){return this.isoWeekday()},a:function(){return this.localeData().meridiem(this.hours(),this.minutes(),!0)},A:function(){return this.localeData().meridiem(this.hours(),this.minutes(),!1)},H:function(){return this.hours()},h:function(){return this.hours()%12||12},m:function(){return this.minutes()},s:function(){return this.seconds()},S:function(){return L(this.milliseconds()/100)},SS:function(){return w(L(this.milliseconds()/10),2)},SSS:function(){return w(this.milliseconds(),3)},SSSS:function(){return w(this.milliseconds(),3)},Z:function(){var t=this.utcOffset(),e="+";return 0>t&&(t=-t,e="-"),e+w(L(t/60),2)+":"+w(L(t)%60,2)},ZZ:function(){var t=this.utcOffset(),e="+";return 0>t&&(t=-t,e="-"),e+w(L(t/60),2)+w(L(t)%60,2)},z:function(){return this.zoneAbbr()},zz:function(){return this.zoneName()},x:function(){return this.valueOf()},X:function(){return this.unix()},Q:function(){return this.quarter()}},Si={},Mi=["months","monthsShort","weekdays","weekdaysShort","weekdaysMin"],Di=!1;_i.length;)Oe=_i.pop(),wi[Oe+"o"]=u(wi[Oe],Oe);for(;xi.length;)Oe=xi.pop(),wi[Oe+Oe]=p(wi[Oe],2);wi.DDDD=p(wi.DDD,3),b(g.prototype,{set:function(t){var e,i;for(i in t)e=t[i],"function"==typeof e?this[i]=e:this["_"+i]=e;this._ordinalParseLenient=new RegExp(this._ordinalParse.source+"|"+/\d{1,2}/.source)},_months:"January_February_March_April_May_June_July_August_September_October_November_December".split("_"),months:function(t){return this._months[t.month()]},_monthsShort:"Jan_Feb_Mar_Apr_May_Jun_Jul_Aug_Sep_Oct_Nov_Dec".split("_"),monthsShort:function(t){return this._monthsShort[t.month()]},monthsParse:function(t,e,i){var s,o,n;for(this._monthsParse||(this._monthsParse=[],this._longMonthsParse=[],this._shortMonthsParse=[]),s=0;12>s;s++){if(o=Ce.utc([2e3,s]),i&&!this._longMonthsParse[s]&&(this._longMonthsParse[s]=new RegExp("^"+this.months(o,"").replace(".","")+"$","i"),this._shortMonthsParse[s]=new RegExp("^"+this.monthsShort(o,"").replace(".","")+"$","i")),i||this._monthsParse[s]||(n="^"+this.months(o,"")+"|^"+this.monthsShort(o,""),this._monthsParse[s]=new RegExp(n.replace(".",""),"i")),i&&"MMMM"===e&&this._longMonthsParse[s].test(t))return s;if(i&&"MMM"===e&&this._shortMonthsParse[s].test(t))return s;if(!i&&this._monthsParse[s].test(t))return s}},_weekdays:"Sunday_Monday_Tuesday_Wednesday_Thursday_Friday_Saturday".split("_"),weekdays:function(t){return this._weekdays[t.day()]},_weekdaysShort:"Sun_Mon_Tue_Wed_Thu_Fri_Sat".split("_"),weekdaysShort:function(t){return this._weekdaysShort[t.day()]},_weekdaysMin:"Su_Mo_Tu_We_Th_Fr_Sa".split("_"),weekdaysMin:function(t){return this._weekdaysMin[t.day()]},weekdaysParse:function(t){var e,i,s;for(this._weekdaysParse||(this._weekdaysParse=[]),e=0;7>e;e++)if(this._weekdaysParse[e]||(i=Ce([2e3,1]).day(e),s="^"+this.weekdays(i,"")+"|^"+this.weekdaysShort(i,"")+"|^"+this.weekdaysMin(i,""),this._weekdaysParse[e]=new RegExp(s.replace(".",""),"i")),this._weekdaysParse[e].test(t))return e},_longDateFormat:{LTS:"h:mm:ss A",LT:"h:mm A",L:"MM/DD/YYYY",LL:"MMMM D, YYYY",LLL:"MMMM D, YYYY LT",LLLL:"dddd, MMMM D, YYYY LT"},longDateFormat:function(t){var e=this._longDateFormat[t]; -return!e&&this._longDateFormat[t.toUpperCase()]&&(e=this._longDateFormat[t.toUpperCase()].replace(/MMMM|MM|DD|dddd/g,function(t){return t.slice(1)}),this._longDateFormat[t]=e),e},isPM:function(t){return"p"===(t+"").toLowerCase().charAt(0)},_meridiemParse:/[ap]\.?m?\.?/i,meridiem:function(t,e,i){return t>11?i?"pm":"PM":i?"am":"AM"},_calendar:{sameDay:"[Today at] LT",nextDay:"[Tomorrow at] LT",nextWeek:"dddd [at] LT",lastDay:"[Yesterday at] LT",lastWeek:"[Last] dddd [at] LT",sameElse:"L"},calendar:function(t,e,i){var s=this._calendar[t];return"function"==typeof s?s.apply(e,[i]):s},_relativeTime:{future:"in %s",past:"%s ago",s:"a few seconds",m:"a minute",mm:"%d minutes",h:"an hour",hh:"%d hours",d:"a day",dd:"%d days",M:"a month",MM:"%d months",y:"a year",yy:"%d years"},relativeTime:function(t,e,i,s){var o=this._relativeTime[i];return"function"==typeof o?o(t,e,i,s):o.replace(/%d/i,t)},pastFuture:function(t,e){var i=this._relativeTime[t>0?"future":"past"];return"function"==typeof i?i(e):i.replace(/%s/i,e)},ordinal:function(t){return this._ordinal.replace("%d",t)},_ordinal:"%d",_ordinalParse:/\d{1,2}/,preparse:function(t){return t},postformat:function(t){return t},week:function(t){return me(t,this._week.dow,this._week.doy).week},_week:{dow:0,doy:6},firstDayOfWeek:function(){return this._week.dow},firstDayOfYear:function(){return this._week.doy},_invalidDate:"Invalid date",invalidDate:function(){return this._invalidDate}}),Ce=function(t,e,i,s){var o;return"boolean"==typeof i&&(s=i,i=n),o={},o._isAMomentObject=!0,o._i=t,o._f=e,o._l=i,o._strict=s,o._isUTC=!1,o._pf=h(),ge(o)},Ce.suppressDeprecationWarnings=!1,Ce.createFromInputFallback=l("moment construction falls back to js Date. This is discouraged and will be removed in upcoming major release. Please refer to https://github.com/moment/moment/issues/1407 for more info.",function(t){t._d=new Date(t._i+(t._useUTC?" UTC":""))}),Ce.min=function(){var t=[].slice.call(arguments,0);return ve("isBefore",t)},Ce.max=function(){var t=[].slice.call(arguments,0);return ve("isAfter",t)},Ce.utc=function(t,e,i,s){var o;return"boolean"==typeof i&&(s=i,i=n),o={},o._isAMomentObject=!0,o._useUTC=!0,o._isUTC=!0,o._l=i,o._i=t,o._f=e,o._strict=s,o._pf=h(),ge(o).utc()},Ce.unix=function(t){return Ce(1e3*t)},Ce.duration=function(t,e){var i,s,o,n,r=t,h=null;return Ce.isDuration(t)?r={ms:t._milliseconds,d:t._days,M:t._months}:"number"==typeof t?(r={},e?r[e]=t:r.milliseconds=t):(h=je.exec(t))?(i="-"===h[1]?-1:1,r={y:0,d:L(h[Pe])*i,h:L(h[Ae])*i,m:L(h[Re])*i,s:L(h[Fe])*i,ms:L(h[He])*i}):(h=Ue.exec(t))?(i="-"===h[1]?-1:1,o=function(t){var e=t&&parseFloat(t.replace(",","."));return(isNaN(e)?0:e)*i},r={y:o(h[2]),M:o(h[3]),d:o(h[4]),h:o(h[5]),m:o(h[6]),s:o(h[7]),w:o(h[8])}):null==r?r={}:"object"==typeof r&&("from"in r||"to"in r)&&(n=M(Ce(r.from),Ce(r.to)),r={},r.ms=n.milliseconds,r.M=n.months),s=new y(r),Ce.isDuration(t)&&a(t,"_locale")&&(s._locale=t._locale),s},Ce.version=Ee,Ce.defaultFormat=ci,Ce.ISO_8601=function(){},Ce.momentProperties=Ye,Ce.updateOffset=function(){},Ce.relativeTimeThreshold=function(t,e){return bi[t]===n?!1:e===n?bi[t]:(bi[t]=e,!0)},Ce.lang=l("moment.lang is deprecated. Use moment.locale instead.",function(t,e){return Ce.locale(t,e)}),Ce.locale=function(t,e){var i;return t&&(i="undefined"!=typeof e?Ce.defineLocale(t,e):Ce.localeData(t),i&&(Ce.duration._locale=Ce._locale=i)),Ce._locale._abbr},Ce.defineLocale=function(t,e){return null!==e?(e.abbr=t,Be[t]||(Be[t]=new g),Be[t].set(e),Ce.locale(t),Be[t]):(delete Be[t],null)},Ce.langData=l("moment.langData is deprecated. Use moment.localeData instead.",function(t){return Ce.localeData(t)}),Ce.localeData=function(t){var e;if(t&&t._locale&&t._locale._abbr&&(t=t._locale._abbr),!t)return Ce._locale;if(!T(t)){if(e=W(t))return e;t=[t]}return Y(t)},Ce.isMoment=function(t){return t instanceof v||null!=t&&a(t,"_isAMomentObject")},Ce.isDuration=function(t){return t instanceof y};for(Oe=Mi.length-1;Oe>=0;--Oe)I(Mi[Oe]);Ce.normalizeUnits=function(t){return k(t)},Ce.invalid=function(t){var e=Ce.utc(0/0);return null!=t?b(e._pf,t):e._pf.userInvalidated=!0,e},Ce.parseZone=function(){return Ce.apply(null,arguments).parseZone()},Ce.parseTwoDigitYear=function(t){return L(t)+(L(t)>68?1900:2e3)},Ce.isDate=O,b(Ce.fn=v.prototype,{clone:function(){return Ce(this)},valueOf:function(){return+this._d-6e4*(this._offset||0)},unix:function(){return Math.floor(+this/1e3)},toString:function(){return this.clone().locale("en").format("ddd MMM DD YYYY HH:mm:ss [GMT]ZZ")},toDate:function(){return this._offset?new Date(+this):this._d},toISOString:function(){var t=Ce(this).utc();return 00:!1},parsingFlags:function(){return b({},this._pf)},invalidAt:function(){return this._pf.overflow},utc:function(t){return this.utcOffset(0,t)},local:function(t){return this._isUTC&&(this.utcOffset(0,t),this._isUTC=!1,t&&this.subtract(this._dateUtcOffset(),"m")),this},format:function(t){var e=V(this,t||Ce.defaultFormat);return this.localeData().postformat(e)},add:D(1,"add"),subtract:D(-1,"subtract"),diff:function(t,e,i){var s,o,n=G(t,this),r=6e4*(n.utcOffset()-this.utcOffset());return e=k(e),"year"===e||"month"===e||"quarter"===e?(o=m(this,n),"quarter"===e?o/=3:"year"===e&&(o/=12)):(s=this-n,o="second"===e?s/1e3:"minute"===e?s/6e4:"hour"===e?s/36e5:"day"===e?(s-r)/864e5:"week"===e?(s-r)/6048e5:s),i?o:x(o)},from:function(t,e){return Ce.duration({to:this,from:t}).locale(this.locale()).humanize(!e)},fromNow:function(t){return this.from(Ce(),t)},calendar:function(t){var e=t||Ce(),i=G(e,this).startOf("day"),s=this.diff(i,"days",!0),o=-6>s?"sameElse":-1>s?"lastWeek":0>s?"lastDay":1>s?"sameDay":2>s?"nextDay":7>s?"nextWeek":"sameElse";return this.format(this.localeData().calendar(o,this,Ce(e)))},isLeapYear:function(){return R(this.year())},isDST:function(){return this.utcOffset()>this.clone().month(0).utcOffset()||this.utcOffset()>this.clone().month(5).utcOffset()},day:function(t){var e=this._isUTC?this._d.getUTCDay():this._d.getDay();return null!=t?(t=ce(t,this.localeData()),this.add(t-e,"d")):e},month:xe("Month",!0),startOf:function(t){switch(t=k(t)){case"year":this.month(0);case"quarter":case"month":this.date(1);case"week":case"isoWeek":case"day":this.hours(0);case"hour":this.minutes(0);case"minute":this.seconds(0);case"second":this.milliseconds(0)}return"week"===t?this.weekday(0):"isoWeek"===t&&this.isoWeekday(1),"quarter"===t&&this.month(3*Math.floor(this.month()/3)),this},endOf:function(t){return t=k(t),t===n||"millisecond"===t?this:this.startOf(t).add(1,"isoWeek"===t?"week":t).subtract(1,"ms")},isAfter:function(t,e){var i;return e=k("undefined"!=typeof e?e:"millisecond"),"millisecond"===e?(t=Ce.isMoment(t)?t:Ce(t),+this>+t):(i=Ce.isMoment(t)?+t:+Ce(t),i<+this.clone().startOf(e))},isBefore:function(t,e){var i;return e=k("undefined"!=typeof e?e:"millisecond"),"millisecond"===e?(t=Ce.isMoment(t)?t:Ce(t),+t>+this):(i=Ce.isMoment(t)?+t:+Ce(t),+this.clone().endOf(e)t?this:t}),max:l("moment().max is deprecated, use moment.max instead. https://github.com/moment/moment/issues/1548",function(t){return t=Ce.apply(null,arguments),t>this?this:t}),zone:l("moment().zone is deprecated, use moment().utcOffset instead. https://github.com/moment/moment/issues/1779",function(t,e){return null!=t?("string"!=typeof t&&(t=-t),this.utcOffset(t,e),this):-this.utcOffset()}),utcOffset:function(t,e){var i,s=this._offset||0;return null!=t?("string"==typeof t&&(t=Z(t)),Math.abs(t)<16&&(t=60*t),!this._isUTC&&e&&(i=this._dateUtcOffset()),this._offset=t,this._isUTC=!0,null!=i&&this.add(i,"m"),s!==t&&(!e||this._changeInProgress?C(this,Ce.duration(t-s,"m"),1,!1):this._changeInProgress||(this._changeInProgress=!0,Ce.updateOffset(this,!0),this._changeInProgress=null)),this):this._isUTC?s:this._dateUtcOffset()},isLocal:function(){return!this._isUTC},isUtcOffset:function(){return this._isUTC},isUtc:function(){return this._isUTC&&0===this._offset},zoneAbbr:function(){return this._isUTC?"UTC":""},zoneName:function(){return this._isUTC?"Coordinated Universal Time":""},parseZone:function(){return this._tzm?this.utcOffset(this._tzm):"string"==typeof this._i&&this.utcOffset(Z(this._i)),this},hasAlignedHourOffset:function(t){return t=t?Ce(t).utcOffset():0,(this.utcOffset()-t)%60===0},daysInMonth:function(){return z(this.year(),this.month())},dayOfYear:function(t){var e=Ne((Ce(this).startOf("day")-Ce(this).startOf("year"))/864e5)+1;return null==t?e:this.add(t-e,"d")},quarter:function(t){return null==t?Math.ceil((this.month()+1)/3):this.month(3*(t-1)+this.month()%3)},weekYear:function(t){var e=me(this,this.localeData()._week.dow,this.localeData()._week.doy).year;return null==t?e:this.add(t-e,"y")},isoWeekYear:function(t){var e=me(this,1,4).year;return null==t?e:this.add(t-e,"y")},week:function(t){var e=this.localeData().week(this);return null==t?e:this.add(7*(t-e),"d")},isoWeek:function(t){var e=me(this,1,4).week;return null==t?e:this.add(7*(t-e),"d")},weekday:function(t){var e=(this.day()+7-this.localeData()._week.dow)%7;return null==t?e:this.add(t-e,"d")},isoWeekday:function(t){return null==t?this.day()||7:this.day(this.day()%7?t:t-7)},isoWeeksInYear:function(){return P(this.year(),1,4)},weeksInYear:function(){var t=this.localeData()._week;return P(this.year(),t.dow,t.doy)},get:function(t){return t=k(t),this[t]()},set:function(t,e){var i;if("object"==typeof t)for(i in t)this.set(i,t[i]);else t=k(t),"function"==typeof this[t]&&this[t](e);return this},locale:function(t){var e;return t===n?this._locale._abbr:(e=Ce.localeData(t),null!=e&&(this._locale=e),this)},lang:l("moment().lang() is deprecated. Instead, use moment().localeData() to get the language configuration. Use moment().locale() to change languages.",function(t){return t===n?this.localeData():this.locale(t)}),localeData:function(){return this._locale},_dateUtcOffset:function(){return 15*-Math.round(this._d.getTimezoneOffset()/15)}}),Ce.fn.millisecond=Ce.fn.milliseconds=xe("Milliseconds",!1),Ce.fn.second=Ce.fn.seconds=xe("Seconds",!1),Ce.fn.minute=Ce.fn.minutes=xe("Minutes",!1),Ce.fn.hour=Ce.fn.hours=xe("Hours",!0),Ce.fn.date=xe("Date",!0),Ce.fn.dates=l("dates accessor is deprecated. Use date instead.",xe("Date",!0)),Ce.fn.year=xe("FullYear",!0),Ce.fn.years=l("years accessor is deprecated. Use year instead.",xe("FullYear",!0)),Ce.fn.days=Ce.fn.day,Ce.fn.months=Ce.fn.month,Ce.fn.weeks=Ce.fn.week,Ce.fn.isoWeeks=Ce.fn.isoWeek,Ce.fn.quarters=Ce.fn.quarter,Ce.fn.toJSON=Ce.fn.toISOString,Ce.fn.isUTC=Ce.fn.isUtc,b(Ce.duration.fn=y.prototype,{_bubble:function(){var t,e,i,s=this._milliseconds,o=this._days,n=this._months,r=this._data,a=0;r.milliseconds=s%1e3,t=x(s/1e3),r.seconds=t%60,e=x(t/60),r.minutes=e%60,i=x(e/60),r.hours=i%24,o+=x(i/24),a=x(we(o)),o-=x(Se(a)),n+=x(o/30),o%=30,a+=x(n/12),n%=12,r.days=o,r.months=n,r.years=a},abs:function(){return this._milliseconds=Math.abs(this._milliseconds),this._days=Math.abs(this._days),this._months=Math.abs(this._months),this._data.milliseconds=Math.abs(this._data.milliseconds),this._data.seconds=Math.abs(this._data.seconds),this._data.minutes=Math.abs(this._data.minutes),this._data.hours=Math.abs(this._data.hours),this._data.months=Math.abs(this._data.months),this._data.years=Math.abs(this._data.years),this},weeks:function(){return x(this.days()/7)},valueOf:function(){return this._milliseconds+864e5*this._days+this._months%12*2592e6+31536e6*L(this._months/12)},humanize:function(t){var e=ue(this,!t,this.localeData());return t&&(e=this.localeData().pastFuture(+this,e)),this.localeData().postformat(e)},add:function(t,e){var i=Ce.duration(t,e);return this._milliseconds+=i._milliseconds,this._days+=i._days,this._months+=i._months,this._bubble(),this},subtract:function(t,e){var i=Ce.duration(t,e);return this._milliseconds-=i._milliseconds,this._days-=i._days,this._months-=i._months,this._bubble(),this},get:function(t){return t=k(t),this[t.toLowerCase()+"s"]()},as:function(t){var e,i;if(t=k(t),"month"===t||"year"===t)return e=this._days+this._milliseconds/864e5,i=this._months+12*we(e),"month"===t?i:i/12;switch(e=this._days+Math.round(Se(this._months/12)),t){case"week":return e/7+this._milliseconds/6048e5;case"day":return e+this._milliseconds/864e5;case"hour":return 24*e+this._milliseconds/36e5;case"minute":return 24*e*60+this._milliseconds/6e4;case"second":return 24*e*60*60+this._milliseconds/1e3;case"millisecond":return Math.floor(24*e*60*60*1e3)+this._milliseconds;default:throw new Error("Unknown unit "+t)}},lang:Ce.fn.lang,locale:Ce.fn.locale,toIsoString:l("toIsoString() is deprecated. Please use toISOString() instead (notice the capitals)",function(){return this.toISOString()}),toISOString:function(){var t=Math.abs(this.years()),e=Math.abs(this.months()),i=Math.abs(this.days()),s=Math.abs(this.hours()),o=Math.abs(this.minutes()),n=Math.abs(this.seconds()+this.milliseconds()/1e3);return this.asSeconds()?(this.asSeconds()<0?"-":"")+"P"+(t?t+"Y":"")+(e?e+"M":"")+(i?i+"D":"")+(s||o||n?"T":"")+(s?s+"H":"")+(o?o+"M":"")+(n?n+"S":""):"P0D"},localeData:function(){return this._locale},toJSON:function(){return this.toISOString()}}),Ce.duration.fn.toString=Ce.duration.fn.toISOString;for(Oe in fi)a(fi,Oe)&&Me(Oe.toLowerCase());Ce.duration.fn.asMilliseconds=function(){return this.as("ms")},Ce.duration.fn.asSeconds=function(){return this.as("s")},Ce.duration.fn.asMinutes=function(){return this.as("m")},Ce.duration.fn.asHours=function(){return this.as("h")},Ce.duration.fn.asDays=function(){return this.as("d")},Ce.duration.fn.asWeeks=function(){return this.as("weeks")},Ce.duration.fn.asMonths=function(){return this.as("M")},Ce.duration.fn.asYears=function(){return this.as("y")},Ce.locale("en",{ordinalParse:/\d{1,2}(th|st|nd|rd)/,ordinal:function(t){var e=t%10,i=1===L(t%100/10)?"th":1===e?"st":2===e?"nd":3===e?"rd":"th";return t+i}}),We?o.exports=Ce:(s=function(t,e,i){return i.config&&i.config()&&i.config().noGlobal===!0&&(ke.moment=Te),Ce}.call(e,i,e,o),!(s!==n&&(o.exports=s)),De(!0))}).call(this)}).call(e,function(){return this}(),i(72)(t))},function(t,e){var i,s,o;!function(n,r){s=[],i=r,o="function"==typeof i?i.apply(e,s):i,!(void 0!==o&&(t.exports=o))}(this,function(){function t(t){var e,i=t&&t.preventDefault||!1,s=t&&t.container||window,o={},n={keydown:{},keyup:{}},r={};for(e=97;122>=e;e++)r[String.fromCharCode(e)]={code:65+(e-97),shift:!1};for(e=65;90>=e;e++)r[String.fromCharCode(e)]={code:e,shift:!0};for(e=0;9>=e;e++)r[""+e]={code:48+e,shift:!1};for(e=1;12>=e;e++)r["F"+e]={code:111+e,shift:!1};for(e=0;9>=e;e++)r["num"+e]={code:96+e,shift:!1};r["num*"]={code:106,shift:!1},r["num+"]={code:107,shift:!1},r["num-"]={code:109,shift:!1},r["num/"]={code:111,shift:!1},r["num."]={code:110,shift:!1},r.left={code:37,shift:!1},r.up={code:38,shift:!1},r.right={code:39,shift:!1},r.down={code:40,shift:!1},r.space={code:32,shift:!1},r.enter={code:13,shift:!1},r.shift={code:16,shift:void 0},r.esc={code:27,shift:!1},r.backspace={code:8,shift:!1},r.tab={code:9,shift:!1},r.ctrl={code:17,shift:!1},r.alt={code:18,shift:!1},r["delete"]={code:46,shift:!1},r.pageup={code:33,shift:!1},r.pagedown={code:34,shift:!1},r["="]={code:187,shift:!1},r["-"]={code:189,shift:!1},r["]"]={code:221,shift:!1},r["["]={code:219,shift:!1};var a=function(t){d(t,"keydown")},h=function(t){d(t,"keyup")},d=function(t,e){if(void 0!==n[e][t.keyCode]){for(var s=n[e][t.keyCode],o=0;o0?i._handlers[t]=s:(i._off(t,o),delete i._handlers[t]))}),i},i.destroy=function(){var t=i.element;delete t.hammer,i._handlers={},i._destroy()},i}})},function(t,e,i){var s;!function(o,n,r,a){function h(t,e,i){return setTimeout(m(t,i),e)}function d(t,e,i){return Array.isArray(t)?(l(t,i[e],i),!0):!1}function l(t,e,i){var s;if(t)if(t.forEach)t.forEach(e,i);else if(t.length!==a)for(s=0;s-1}function x(t){return t.trim().split(/\s+/g)}function w(t,e,i){if(t.indexOf&&!i)return t.indexOf(e);for(var s=0;si[e]}):s.sort()),s}function D(t,e){for(var i,s,o=e[0].toUpperCase()+e.slice(1),n=0;n1&&!i.firstMultiple?i.firstMultiple=z(e):1===o&&(i.firstMultiple=!1);var n=i.firstInput,r=i.firstMultiple,a=r?r.center:n.center,h=e.center=P(s);e.timeStamp=ve(),e.deltaTime=e.timeStamp-n.timeStamp,e.angle=H(a,h),e.distance=F(a,h),I(i,e),e.offsetDirection=R(e.deltaX,e.deltaY),e.scale=r?Y(r.pointers,s):1,e.rotation=r?B(r.pointers,s):0,L(i,e);var d=t.element;b(e.srcEvent.target,d)&&(d=e.srcEvent.target),e.target=d}function I(t,e){var i=e.center,s=t.offsetDelta||{},o=t.prevDelta||{},n=t.prevInput||{};(e.eventType===Oe||n.eventType===ke)&&(o=t.prevDelta={x:n.deltaX||0,y:n.deltaY||0},s=t.offsetDelta={x:i.x,y:i.y}),e.deltaX=o.x+(i.x-s.x),e.deltaY=o.y+(i.y-s.y)}function L(t,e){var i,s,o,n,r=t.lastInterval||e,h=e.timeStamp-r.timeStamp;if(e.eventType!=Ne&&(h>Te||r.velocity===a)){var d=r.deltaX-e.deltaX,l=r.deltaY-e.deltaY,c=A(h,d,l);s=c.x,o=c.y,i=ge(c.x)>ge(c.y)?c.x:c.y,n=R(d,l),t.lastInterval=e}else i=r.velocity,s=r.velocityX,o=r.velocityY,n=r.direction;e.velocity=i,e.velocityX=s,e.velocityY=o,e.direction=n}function z(t){for(var e=[],i=0;io;)i+=t[o].clientX,s+=t[o].clientY,o++;return{x:fe(i/e),y:fe(s/e)}}function A(t,e,i){return{x:e/t||0,y:i/t||0}}function R(t,e){return t===e?Ie:ge(t)>=ge(e)?t>0?Le:ze:e>0?Pe:Ae}function F(t,e,i){i||(i=Be);var s=e[i[0]]-t[i[0]],o=e[i[1]]-t[i[1]];return Math.sqrt(s*s+o*o)}function H(t,e,i){i||(i=Be);var s=e[i[0]]-t[i[0]],o=e[i[1]]-t[i[1]];return 180*Math.atan2(o,s)/Math.PI}function B(t,e){return H(e[1],e[0],Ye)-H(t[1],t[0],Ye)}function Y(t,e){return F(e[0],e[1],Ye)/F(t[0],t[1],Ye)}function W(){this.evEl=Ge,this.evWin=je,this.allow=!0,this.pressed=!1,O.apply(this,arguments)}function G(){this.evEl=Xe,this.evWin=qe,O.apply(this,arguments),this.store=this.manager.session.pointerEvents=[]}function j(){this.evTarget=Qe,this.evWin=Ke,this.started=!1,O.apply(this,arguments)}function U(t,e){var i=S(t.touches),s=S(t.changedTouches);return e&(ke|Ne)&&(i=M(i.concat(s),"identifier",!0)),[i,s]}function V(){this.evTarget=Je,this.targetIds={},O.apply(this,arguments)}function X(t,e){var i=S(t.touches),s=this.targetIds;if(e&(Oe|Ee)&&1===i.length)return s[i[0].identifier]=!0,[i,i];var o,n,r=S(t.changedTouches),a=[],h=this.target;if(n=i.filter(function(t){return b(t.target,h)}),e===Oe)for(o=0;oa&&(e.push(t),a=e.length-1):o&(ke|Ne)&&(i=!0),0>a||(e[a]=t,this.callback(this.manager,o,{pointers:e,changedPointers:[t],pointerType:n,srcEvent:t}),i&&e.splice(a,1))}});var Ze={touchstart:Oe,touchmove:Ee,touchend:ke,touchcancel:Ne},Qe="touchstart",Ke="touchstart touchmove touchend touchcancel";u(j,O,{handler:function(t){var e=Ze[t.type];if(e===Oe&&(this.started=!0),this.started){var i=U.call(this,t,e);e&(ke|Ne)&&i[0].length-i[1].length===0&&(this.started=!1),this.callback(this.manager,e,{pointers:i[0],changedPointers:i[1],pointerType:Se,srcEvent:t})}}});var $e={touchstart:Oe,touchmove:Ee,touchend:ke,touchcancel:Ne},Je="touchstart touchmove touchend touchcancel";u(V,O,{handler:function(t){var e=$e[t.type],i=X.call(this,t,e);i&&this.callback(this.manager,e,{pointers:i[0],changedPointers:i[1],pointerType:Se,srcEvent:t})}}),u(q,O,{handler:function(t,e,i){var s=i.pointerType==Se,o=i.pointerType==De;if(s)this.mouse.allow=!1;else if(o&&!this.mouse.allow)return;e&(ke|Ne)&&(this.mouse.allow=!0),this.callback(t,e,i)},destroy:function(){this.touch.destroy(),this.mouse.destroy()}});var ti=D(ue.style,"touchAction"),ei=ti!==a,ii="compute",si="auto",oi="manipulation",ni="none",ri="pan-x",ai="pan-y";Z.prototype={set:function(t){t==ii&&(t=this.compute()),ei&&(this.manager.element.style[ti]=t),this.actions=t.toLowerCase().trim()},update:function(){this.set(this.manager.options.touchAction)},compute:function(){var t=[];return l(this.manager.recognizers,function(e){f(e.options.enable,[e])&&(t=t.concat(e.getTouchAction()))}),Q(t.join(" "))},preventDefaults:function(t){if(!ei){var e=t.srcEvent,i=t.offsetDirection;if(this.manager.session.prevented)return void e.preventDefault();var s=this.actions,o=_(s,ni),n=_(s,ai),r=_(s,ri);return o||n&&i&Re||r&&i&Fe?this.preventSrc(e):void 0}},preventSrc:function(t){this.manager.session.prevented=!0,t.preventDefault()}};var hi=1,di=2,li=4,ci=8,pi=ci,ui=16,mi=32;K.prototype={defaults:{},set:function(t){return c(this.options,t),this.manager&&this.manager.touchAction.update(),this},recognizeWith:function(t){if(d(t,"recognizeWith",this))return this;var e=this.simultaneous;return t=te(t,this),e[t.id]||(e[t.id]=t,t.recognizeWith(this)),this},dropRecognizeWith:function(t){return d(t,"dropRecognizeWith",this)?this:(t=te(t,this),delete this.simultaneous[t.id],this)},requireFailure:function(t){if(d(t,"requireFailure",this))return this;var e=this.requireFail;return t=te(t,this),-1===w(e,t)&&(e.push(t),t.requireFailure(this)),this},dropRequireFailure:function(t){if(d(t,"dropRequireFailure",this))return this;t=te(t,this);var e=w(this.requireFail,t);return e>-1&&this.requireFail.splice(e,1),this},hasRequireFailures:function(){return this.requireFail.length>0},canRecognizeWith:function(t){return!!this.simultaneous[t.id]},emit:function(t){function e(e){i.manager.emit(i.options.event+(e?$(s):""),t)}var i=this,s=this.state;ci>s&&e(!0),e(),s>=ci&&e(!0)},tryEmit:function(t){return this.canEmit()?this.emit(t):void(this.state=mi)},canEmit:function(){for(var t=0;tn?Le:ze,i=n!=this.pX,s=Math.abs(t.deltaX)):(o=0===r?Ie:0>r?Pe:Ae,i=r!=this.pY,s=Math.abs(t.deltaY))),t.direction=o,i&&s>e.threshold&&o&e.direction},attrTest:function(t){return ee.prototype.attrTest.call(this,t)&&(this.state&di||!(this.state&di)&&this.directionTest(t))},emit:function(t){this.pX=t.deltaX,this.pY=t.deltaY;var e=J(t.direction);e&&this.manager.emit(this.options.event+e,t),this._super.emit.call(this,t)}}),u(se,ee,{defaults:{event:"pinch",threshold:0,pointers:2},getTouchAction:function(){return[ni]},attrTest:function(t){return this._super.attrTest.call(this,t)&&(Math.abs(t.scale-1)>this.options.threshold||this.state&di)},emit:function(t){if(this._super.emit.call(this,t),1!==t.scale){var e=t.scale<1?"in":"out";this.manager.emit(this.options.event+e,t)}}}),u(oe,K,{defaults:{event:"press",pointers:1,time:500,threshold:5},getTouchAction:function(){return[si]},process:function(t){var e=this.options,i=t.pointers.length===e.pointers,s=t.distancee.time;if(this._input=t,!s||!i||t.eventType&(ke|Ne)&&!o)this.reset();else if(t.eventType&Oe)this.reset(),this._timer=h(function(){this.state=pi,this.tryEmit() -},e.time,this);else if(t.eventType&ke)return pi;return mi},reset:function(){clearTimeout(this._timer)},emit:function(t){this.state===pi&&(t&&t.eventType&ke?this.manager.emit(this.options.event+"up",t):(this._input.timeStamp=ve(),this.manager.emit(this.options.event,this._input)))}}),u(ne,ee,{defaults:{event:"rotate",threshold:0,pointers:2},getTouchAction:function(){return[ni]},attrTest:function(t){return this._super.attrTest.call(this,t)&&(Math.abs(t.rotation)>this.options.threshold||this.state&di)}}),u(re,ee,{defaults:{event:"swipe",threshold:10,velocity:.65,direction:Re|Fe,pointers:1},getTouchAction:function(){return ie.prototype.getTouchAction.call(this)},attrTest:function(t){var e,i=this.options.direction;return i&(Re|Fe)?e=t.velocity:i&Re?e=t.velocityX:i&Fe&&(e=t.velocityY),this._super.attrTest.call(this,t)&&i&t.direction&&t.distance>this.options.threshold&&ge(e)>this.options.velocity&&t.eventType&ke},emit:function(t){var e=J(t.direction);e&&this.manager.emit(this.options.event+e,t),this.manager.emit(this.options.event,t)}}),u(ae,K,{defaults:{event:"tap",pointers:1,taps:1,interval:300,time:250,threshold:2,posThreshold:10},getTouchAction:function(){return[oi]},process:function(t){var e=this.options,i=t.pointers.length===e.pointers,s=t.distancet&&s>o;)o%3==0?(this.forceAggregateHubs(!0),this.normalizeClusterLevels()):this.increaseClusterLevel(),i=this.nodeIndices.length,o+=1;o>0&&1==e&&this.repositionNodes(),this._updateCalculationNodes()},e.openCluster=function(t){var e=this.moving;if(t.clusterSize>this.constants.clustering.sectorThreshold&&this._nodeInActiveArea(t)&&("default"!=this._sector()||1!=this.nodeIndices.length)){this._addSector(t);for(var i=0;this.nodeIndices.lengthi;)this.decreaseClusterLevel(),i+=1}else this._expandClusterNode(t,!1,!0),this._updateNodeIndexList(),this._updateDynamicEdges(),this._updateCalculationNodes(),this.updateLabels();this.moving!=e&&this.start()},e.updateClustersDefault=function(){1==this.constants.clustering.enabled&&this.updateClusters(0,!1,!1)},e.increaseClusterLevel=function(){this.updateClusters(-1,!1,!0)},e.decreaseClusterLevel=function(){this.updateClusters(1,!1,!0)},e.updateClusters=function(t,e,i,s){var o=this.moving,n=this.nodeIndices.length;this.previousScale>this.scale&&0==t&&this._collapseSector(),this.previousScale>this.scale||-1==t?this._formClusters(i):(this.previousScalethis.scale||-1==t)&&(this._aggregateHubs(i),this._updateNodeIndexList()),(this.previousScale>this.scale||-1==t)&&(this.handleChains(),this._updateNodeIndexList()),this.previousScale=this.scale,this._updateDynamicEdges(),this.updateLabels(),this.nodeIndices.lengththis.constants.clustering.chainThreshold&&this._reduceAmountOfChains(1-this.constants.clustering.chainThreshold/t)},e._aggregateHubs=function(t){this._getHubSize(),this._formClustersByHub(t,!1)},e.forceAggregateHubs=function(t){var e=this.moving,i=this.nodeIndices.length;this._aggregateHubs(!0),this._updateNodeIndexList(),this._updateDynamicEdges(),this.updateLabels(),this.nodeIndices.length!=i&&(this.clusterSession+=1),(0==t||void 0===t)&&this.moving!=e&&this.start()},e._openClustersBySize=function(){for(var t in this.nodes)if(this.nodes.hasOwnProperty(t)){var e=this.nodes[t];1==e.inView()&&(e.width*this.scale>this.constants.clustering.screenSizeThreshold*this.frame.canvas.clientWidth||e.height*this.scale>this.constants.clustering.screenSizeThreshold*this.frame.canvas.clientHeight)&&this.openCluster(e)}},e._openClusters=function(t,e){for(var i=0;i1&&(t.clusterSizei)){var r=n.from,a=n.to;n.to.options.mass>n.from.options.mass&&(r=n.to,a=n.from),1==a.dynamicEdgesLength?this._addToCluster(r,a,!1):1==r.dynamicEdgesLength&&this._addToCluster(a,r,!1)}}},e._forceClustersByZoom=function(){for(var t in this.nodes)if(this.nodes.hasOwnProperty(t)){var e=this.nodes[t];if(1==e.dynamicEdgesLength&&0!=e.dynamicEdges.length){var i=e.dynamicEdges[0],s=i.toId==e.id?this.nodes[i.fromId]:this.nodes[i.toId];e.id!=s.id&&(s.options.mass>e.options.mass?this._addToCluster(s,e,!0):this._addToCluster(e,s,!0))}}},e._clusterToSmallestNeighbour=function(t){for(var e=-1,i=null,s=0;so.clusterSessions.length&&(e=o.clusterSessions.length,i=o)}null!=o&&void 0!==this.nodes[o.id]&&this._addToCluster(o,t,!0)},e._formClustersByHub=function(t,e){for(var i in this.nodes)this.nodes.hasOwnProperty(i)&&this._formClusterFromHub(this.nodes[i],t,e)},e._formClusterFromHub=function(t,e,i,s){if(void 0===s&&(s=0),t.dynamicEdgesLength>=this.hubThreshold&&0==i||t.dynamicEdgesLength==this.hubThreshold&&1==i){for(var o,n,r,a=this.constants.clustering.clusterEdgeThreshold/this.scale,h=!1,d=[],l=t.dynamicEdges.length,c=0;l>c;c++)d.push(t.dynamicEdges[c].id);if(0==e)for(h=!1,c=0;l>c;c++){var p=this.edges[d[c]];if(void 0!==p&&p.connected&&p.toId!=p.fromId&&(o=p.to.x-p.from.x,n=p.to.y-p.from.y,r=Math.sqrt(o*o+n*n),a>r)){h=!0;break}}if(!e&&h||e)for(c=0;l>c;c++)if(p=this.edges[d[c]],void 0!==p){var u=this.nodes[p.fromId==t.id?p.toId:p.fromId];u.dynamicEdges.length<=this.hubThreshold+s&&u.id!=t.id&&this._addToCluster(t,u,e)}}},e._addToCluster=function(t,e,i){t.containedNodes[e.id]=e;for(var s=0;s1)for(var s=0;s1&&(e.label="[".concat(String(e.clusterSize),"]"))}for(t in this.nodes)this.nodes.hasOwnProperty(t)&&(e=this.nodes[t],1==e.clusterSize&&(e.label=void 0!==e.originalLabel?e.originalLabel:String(e.id)))},e.normalizeClusterLevels=function(){var t,e=0,i=1e9,s=0;for(t in this.nodes)this.nodes.hasOwnProperty(t)&&(s=this.nodes[t].clusterSessions.length,s>e&&(e=s),i>s&&(i=s));if(e-i>this.constants.clustering.clusterLevelDifference){var o=this.nodeIndices.length,n=e-this.constants.clustering.clusterLevelDifference;for(t in this.nodes)this.nodes.hasOwnProperty(t)&&this.nodes[t].clusterSessions.lengths&&(s=n.dynamicEdgesLength),t+=n.dynamicEdgesLength,e+=Math.pow(n.dynamicEdgesLength,2),i+=1}t/=i,e/=i;var r=e-Math.pow(t,2),a=Math.sqrt(r);this.hubThreshold=Math.floor(t+2*a),this.hubThreshold>s&&(this.hubThreshold=s)},e._reduceAmountOfChains=function(t){this.hubThreshold=2;var e=Math.floor(this.nodeIndices.length*t);for(var i in this.nodes)this.nodes.hasOwnProperty(i)&&2==this.nodes[i].dynamicEdgesLength&&this.nodes[i].dynamicEdges.length>=2&&e>0&&(this._formClusterFromHub(this.nodes[i],!0,!0,1),e-=1)},e._getChainFraction=function(){var t=0,e=0;for(var i in this.nodes)this.nodes.hasOwnProperty(i)&&(2==this.nodes[i].dynamicEdgesLength&&this.nodes[i].dynamicEdges.length>=2&&(t+=1),e+=1);return t/e}},function(t,e,i){var s=i(1),o=i(40);e._putDataInSector=function(){this.sectors.active[this._sector()].nodes=this.nodes,this.sectors.active[this._sector()].edges=this.edges,this.sectors.active[this._sector()].nodeIndices=this.nodeIndices},e._switchToSector=function(t,e){void 0===e||"active"==e?this._switchToActiveSector(t):this._switchToFrozenSector(t)},e._switchToActiveSector=function(t){this.nodeIndices=this.sectors.active[t].nodeIndices,this.nodes=this.sectors.active[t].nodes,this.edges=this.sectors.active[t].edges},e._switchToSupportSector=function(){this.nodeIndices=this.sectors.support.nodeIndices,this.nodes=this.sectors.support.nodes,this.edges=this.sectors.support.edges},e._switchToFrozenSector=function(t){this.nodeIndices=this.sectors.frozen[t].nodeIndices,this.nodes=this.sectors.frozen[t].nodes,this.edges=this.sectors.frozen[t].edges},e._loadLatestSector=function(){this._switchToSector(this._sector())},e._sector=function(){return this.activeSector[this.activeSector.length-1]},e._previousSector=function(){if(this.activeSector.length>1)return this.activeSector[this.activeSector.length-2];throw new TypeError("there are not enough sectors in the this.activeSector array.")},e._setActiveSector=function(t){this.activeSector.push(t)},e._forgetLastSector=function(){this.activeSector.pop()},e._createNewSector=function(t){this.sectors.active[t]={nodes:{},edges:{},nodeIndices:[],formationScale:this.scale,drawingNode:void 0},this.sectors.active[t].drawingNode=new o({id:t,color:{background:"#eaefef",border:"495c5e"}},{},{},this.constants),this.sectors.active[t].drawingNode.clusterSize=2},e._deleteActiveSector=function(t){delete this.sectors.active[t]},e._deleteFrozenSector=function(t){delete this.sectors.frozen[t]},e._freezeSector=function(t){this.sectors.frozen[t]=this.sectors.active[t],this._deleteActiveSector(t)},e._activateSector=function(t){this.sectors.active[t]=this.sectors.frozen[t],this._deleteFrozenSector(t)},e._mergeThisWithFrozen=function(t){for(var e in this.nodes)this.nodes.hasOwnProperty(e)&&(this.sectors.frozen[t].nodes[e]=this.nodes[e]);for(var i in this.edges)this.edges.hasOwnProperty(i)&&(this.sectors.frozen[t].edges[i]=this.edges[i]);for(var s=0;s1?this[t](o[0],o[1]):this[t](e))}return this._loadLatestSector(),i},e._doInSupportSector=function(t,e){var i=!1;if(void 0===e)this._switchToSupportSector(),i=this[t]();else{this._switchToSupportSector();var s=Array.prototype.splice.call(arguments,1);i=s.length>1?this[t](s[0],s[1]):this[t](e)}return this._loadLatestSector(),i},e._doInAllFrozenSectors=function(t,e){if(void 0===e)for(var i in this.sectors.frozen)this.sectors.frozen.hasOwnProperty(i)&&(this._switchToFrozenSector(i),this[t]());else for(var i in this.sectors.frozen)if(this.sectors.frozen.hasOwnProperty(i)){this._switchToFrozenSector(i);var s=Array.prototype.splice.call(arguments,1);s.length>1?this[t](s[0],s[1]):this[t](e)}this._loadLatestSector()},e._doInAllSectors=function(t,e){var i=Array.prototype.splice.call(arguments,1);void 0===e?(this._doInAllActiveSectors(t),this._doInAllFrozenSectors(t)):i.length>1?(this._doInAllActiveSectors(t,i[0],i[1]),this._doInAllFrozenSectors(t,i[0],i[1])):(this._doInAllActiveSectors(t,e),this._doInAllFrozenSectors(t,e))},e._clearNodeIndexList=function(){var t=this._sector();this.sectors.active[t].nodeIndices=[],this.nodeIndices=this.sectors.active[t].nodeIndices},e._drawSectorNodes=function(t,e){var i,s=1e9,o=-1e9,n=1e9,r=-1e9;for(var a in this.sectors[e])if(this.sectors[e].hasOwnProperty(a)&&void 0!==this.sectors[e][a].drawingNode){this._switchToSector(a,e),s=1e9,o=-1e9,n=1e9,r=-1e9;for(var h in this.nodes)this.nodes.hasOwnProperty(h)&&(i=this.nodes[h],i.resize(t),n>i.x-.5*i.width&&(n=i.x-.5*i.width),ri.y-.5*i.height&&(s=i.y-.5*i.height),o0?this.nodes[i[i.length-1]]:null},e._getEdgesOverlappingWith=function(t,e){var i=this.edges;for(var s in i)i.hasOwnProperty(s)&&i[s].isOverlappingWith(t)&&e.push(s)},e._getAllEdgesOverlappingWith=function(t){var e=[];return this._doInAllActiveSectors("_getEdgesOverlappingWith",t,e),e},e._getEdgeAt=function(t){var e=this._pointerToPositionObject(t),i=this._getAllEdgesOverlappingWith(e);return i.length>0?this.edges[i[i.length-1]]:null},e._addToSelection=function(t){t instanceof s?this.selectionObj.nodes[t.id]=t:this.selectionObj.edges[t.id]=t},e._addToHover=function(t){t instanceof s?this.hoverObj.nodes[t.id]=t:this.hoverObj.edges[t.id]=t},e._removeFromSelection=function(t){t instanceof s?delete this.selectionObj.nodes[t.id]:delete this.selectionObj.edges[t.id]},e._unselectAll=function(t){void 0===t&&(t=!1);for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&this.selectionObj.nodes[e].unselect();for(var i in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(i)&&this.selectionObj.edges[i].unselect();this.selectionObj={nodes:{},edges:{}},0==t&&this.emit("select",this.getSelection())},e._unselectClusters=function(t){void 0===t&&(t=!1);for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&this.selectionObj.nodes[e].clusterSize>1&&(this.selectionObj.nodes[e].unselect(),this._removeFromSelection(this.selectionObj.nodes[e]));0==t&&this.emit("select",this.getSelection())},e._getSelectedNodeCount=function(){var t=0;for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&(t+=1);return t},e._getSelectedNode=function(){for(var t in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(t))return this.selectionObj.nodes[t];return null},e._getSelectedEdge=function(){for(var t in this.selectionObj.edges)if(this.selectionObj.edges.hasOwnProperty(t))return this.selectionObj.edges[t];return null},e._getSelectedEdgeCount=function(){var t=0;for(var e in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(e)&&(t+=1);return t},e._getSelectedObjectCount=function(){var t=0;for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&(t+=1);for(var i in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(i)&&(t+=1);return t},e._selectionIsEmpty=function(){for(var t in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(t))return!1;for(var e in this.selectionObj.edges)if(this.selectionObj.edges.hasOwnProperty(e))return!1;return!0},e._clusterInSelection=function(){for(var t in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(t)&&this.selectionObj.nodes[t].clusterSize>1)return!0;return!1},e._selectConnectedEdges=function(t){for(var e=0;ei;i++){o=t[i];var n=this.nodes[o];if(!n)throw new RangeError('Node with id "'+o+'" not found');this._selectObject(n,!0,!0,e,!0)}this.redraw()},e.selectEdges=function(t){var e,i,s;if(!t||void 0==t.length)throw"Selection must be an array with ids";for(this._unselectAll(!0),e=0,i=t.length;i>e;e++){s=t[e];var o=this.edges[s];if(!o)throw new RangeError('Edge with id "'+s+'" not found');this._selectObject(o,!0,!0,!1,!0)}this.redraw()},e._updateSelection=function(){for(var t in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(t)&&(this.nodes.hasOwnProperty(t)||delete this.selectionObj.nodes[t]);for(var e in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(e)&&(this.edges.hasOwnProperty(e)||delete this.selectionObj.edges[e])}},function(t,e,i){var s=i(1),o=i(40),n=i(37);e._clearManipulatorBar=function(){this._recursiveDOMDelete(this.manipulationDiv),this.manipulationDOM={},this._manipulationReleaseOverload=function(){},delete this.sectors.support.nodes.targetNode,delete this.sectors.support.nodes.targetViaNode,this.controlNodesActive=!1,this.freezeSimulation=!1},e._restoreOverloadedFunctions=function(){for(var t in this.cachedFunctions)this.cachedFunctions.hasOwnProperty(t)&&(this[t]=this.cachedFunctions[t],delete this.cachedFunctions[t])},e._toggleEditMode=function(){this.editMode=!this.editMode;var t=this.manipulationDiv,e=this.closeDiv,i=this.editModeDiv;1==this.editMode?(t.style.display="block",e.style.display="block",i.style.display="none",e.onclick=this._toggleEditMode.bind(this)):(t.style.display="none",e.style.display="none",i.style.display="block",e.onclick=null),this._createManipulatorBar()},e._createManipulatorBar=function(){this.boundFunction&&this.off("select",this.boundFunction);var t=this.constants.locales[this.constants.locale];if(void 0!==this.edgeBeingEdited&&(this.edgeBeingEdited._disableControlNodes(),this.edgeBeingEdited=void 0,this.selectedControlNode=null,this.controlNodesActive=!1,this._redraw()),this._restoreOverloadedFunctions(),this.freezeSimulation=!1,this.blockConnectingEdgeSelection=!1,this.forceAppendSelection=!1,this.manipulationDOM={},1==this.editMode){for(;this.manipulationDiv.hasChildNodes();)this.manipulationDiv.removeChild(this.manipulationDiv.firstChild);this.manipulationDOM.addNodeSpan=document.createElement("span"),this.manipulationDOM.addNodeSpan.className="network-manipulationUI add",this.manipulationDOM.addNodeLabelSpan=document.createElement("span"),this.manipulationDOM.addNodeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.addNodeLabelSpan.innerHTML=t.addNode,this.manipulationDOM.addNodeSpan.appendChild(this.manipulationDOM.addNodeLabelSpan),this.manipulationDOM.seperatorLineDiv1=document.createElement("div"),this.manipulationDOM.seperatorLineDiv1.className="network-seperatorLine",this.manipulationDOM.addEdgeSpan=document.createElement("span"),this.manipulationDOM.addEdgeSpan.className="network-manipulationUI connect",this.manipulationDOM.addEdgeLabelSpan=document.createElement("span"),this.manipulationDOM.addEdgeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.addEdgeLabelSpan.innerHTML=t.addEdge,this.manipulationDOM.addEdgeSpan.appendChild(this.manipulationDOM.addEdgeLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.addNodeSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv1),this.manipulationDiv.appendChild(this.manipulationDOM.addEdgeSpan),1==this._getSelectedNodeCount()&&this.triggerFunctions.edit?(this.manipulationDOM.seperatorLineDiv2=document.createElement("div"),this.manipulationDOM.seperatorLineDiv2.className="network-seperatorLine",this.manipulationDOM.editNodeSpan=document.createElement("span"),this.manipulationDOM.editNodeSpan.className="network-manipulationUI edit",this.manipulationDOM.editNodeLabelSpan=document.createElement("span"),this.manipulationDOM.editNodeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.editNodeLabelSpan.innerHTML=t.editNode,this.manipulationDOM.editNodeSpan.appendChild(this.manipulationDOM.editNodeLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv2),this.manipulationDiv.appendChild(this.manipulationDOM.editNodeSpan)):1==this._getSelectedEdgeCount()&&0==this._getSelectedNodeCount()&&(this.manipulationDOM.seperatorLineDiv3=document.createElement("div"),this.manipulationDOM.seperatorLineDiv3.className="network-seperatorLine",this.manipulationDOM.editEdgeSpan=document.createElement("span"),this.manipulationDOM.editEdgeSpan.className="network-manipulationUI edit",this.manipulationDOM.editEdgeLabelSpan=document.createElement("span"),this.manipulationDOM.editEdgeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.editEdgeLabelSpan.innerHTML=t.editEdge,this.manipulationDOM.editEdgeSpan.appendChild(this.manipulationDOM.editEdgeLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv3),this.manipulationDiv.appendChild(this.manipulationDOM.editEdgeSpan)),0==this._selectionIsEmpty()&&(this.manipulationDOM.seperatorLineDiv4=document.createElement("div"),this.manipulationDOM.seperatorLineDiv4.className="network-seperatorLine",this.manipulationDOM.deleteSpan=document.createElement("span"),this.manipulationDOM.deleteSpan.className="network-manipulationUI delete",this.manipulationDOM.deleteLabelSpan=document.createElement("span"),this.manipulationDOM.deleteLabelSpan.className="network-manipulationLabel",this.manipulationDOM.deleteLabelSpan.innerHTML=t.del,this.manipulationDOM.deleteSpan.appendChild(this.manipulationDOM.deleteLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv4),this.manipulationDiv.appendChild(this.manipulationDOM.deleteSpan)),this.manipulationDOM.addNodeSpan.onclick=this._createAddNodeToolbar.bind(this),this.manipulationDOM.addEdgeSpan.onclick=this._createAddEdgeToolbar.bind(this),1==this._getSelectedNodeCount()&&this.triggerFunctions.edit?this.manipulationDOM.editNodeSpan.onclick=this._editNode.bind(this):1==this._getSelectedEdgeCount()&&0==this._getSelectedNodeCount()&&(this.manipulationDOM.editEdgeSpan.onclick=this._createEditEdgeToolbar.bind(this)),0==this._selectionIsEmpty()&&(this.manipulationDOM.deleteSpan.onclick=this._deleteSelected.bind(this)),this.closeDiv.onclick=this._toggleEditMode.bind(this); -var e=this;this.boundFunction=e._createManipulatorBar,this.on("select",this.boundFunction)}else{for(;this.editModeDiv.hasChildNodes();)this.editModeDiv.removeChild(this.editModeDiv.firstChild);this.manipulationDOM.editModeSpan=document.createElement("span"),this.manipulationDOM.editModeSpan.className="network-manipulationUI edit editmode",this.manipulationDOM.editModeLabelSpan=document.createElement("span"),this.manipulationDOM.editModeLabelSpan.className="network-manipulationLabel",this.manipulationDOM.editModeLabelSpan.innerHTML=t.edit,this.manipulationDOM.editModeSpan.appendChild(this.manipulationDOM.editModeLabelSpan),this.editModeDiv.appendChild(this.manipulationDOM.editModeSpan),this.manipulationDOM.editModeSpan.onclick=this._toggleEditMode.bind(this)}},e._createAddNodeToolbar=function(){this._clearManipulatorBar(),this.boundFunction&&this.off("select",this.boundFunction);var t=this.constants.locales[this.constants.locale];this.manipulationDOM={},this.manipulationDOM.backSpan=document.createElement("span"),this.manipulationDOM.backSpan.className="network-manipulationUI back",this.manipulationDOM.backLabelSpan=document.createElement("span"),this.manipulationDOM.backLabelSpan.className="network-manipulationLabel",this.manipulationDOM.backLabelSpan.innerHTML=t.back,this.manipulationDOM.backSpan.appendChild(this.manipulationDOM.backLabelSpan),this.manipulationDOM.seperatorLineDiv1=document.createElement("div"),this.manipulationDOM.seperatorLineDiv1.className="network-seperatorLine",this.manipulationDOM.descriptionSpan=document.createElement("span"),this.manipulationDOM.descriptionSpan.className="network-manipulationUI none",this.manipulationDOM.descriptionLabelSpan=document.createElement("span"),this.manipulationDOM.descriptionLabelSpan.className="network-manipulationLabel",this.manipulationDOM.descriptionLabelSpan.innerHTML=t.addDescription,this.manipulationDOM.descriptionSpan.appendChild(this.manipulationDOM.descriptionLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.backSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv1),this.manipulationDiv.appendChild(this.manipulationDOM.descriptionSpan),this.manipulationDOM.backSpan.onclick=this._createManipulatorBar.bind(this);var e=this;this.boundFunction=e._addNode,this.on("select",this.boundFunction)},e._createAddEdgeToolbar=function(){this._clearManipulatorBar(),this._unselectAll(!0),this.freezeSimulation=!0,this.boundFunction&&this.off("select",this.boundFunction);var t=this.constants.locales[this.constants.locale];this._unselectAll(),this.forceAppendSelection=!1,this.blockConnectingEdgeSelection=!0,this.manipulationDOM={},this.manipulationDOM.backSpan=document.createElement("span"),this.manipulationDOM.backSpan.className="network-manipulationUI back",this.manipulationDOM.backLabelSpan=document.createElement("span"),this.manipulationDOM.backLabelSpan.className="network-manipulationLabel",this.manipulationDOM.backLabelSpan.innerHTML=t.back,this.manipulationDOM.backSpan.appendChild(this.manipulationDOM.backLabelSpan),this.manipulationDOM.seperatorLineDiv1=document.createElement("div"),this.manipulationDOM.seperatorLineDiv1.className="network-seperatorLine",this.manipulationDOM.descriptionSpan=document.createElement("span"),this.manipulationDOM.descriptionSpan.className="network-manipulationUI none",this.manipulationDOM.descriptionLabelSpan=document.createElement("span"),this.manipulationDOM.descriptionLabelSpan.className="network-manipulationLabel",this.manipulationDOM.descriptionLabelSpan.innerHTML=t.edgeDescription,this.manipulationDOM.descriptionSpan.appendChild(this.manipulationDOM.descriptionLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.backSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv1),this.manipulationDiv.appendChild(this.manipulationDOM.descriptionSpan),this.manipulationDOM.backSpan.onclick=this._createManipulatorBar.bind(this);var e=this;this.boundFunction=e._handleConnect,this.on("select",this.boundFunction),this.cachedFunctions._handleTouch=this._handleTouch,this.cachedFunctions._manipulationReleaseOverload=this._manipulationReleaseOverload,this.cachedFunctions._handleDragStart=this._handleDragStart,this.cachedFunctions._handleDragEnd=this._handleDragEnd,this._handleTouch=this._handleConnect,this._manipulationReleaseOverload=function(){},this._handleDragStart=function(){},this._handleDragEnd=this._finishConnect,this._redraw()},e._createEditEdgeToolbar=function(){this._clearManipulatorBar(),this.controlNodesActive=!0,this.boundFunction&&this.off("select",this.boundFunction),this.edgeBeingEdited=this._getSelectedEdge(),this.edgeBeingEdited._enableControlNodes();var t=this.constants.locales[this.constants.locale];this.manipulationDOM={},this.manipulationDOM.backSpan=document.createElement("span"),this.manipulationDOM.backSpan.className="network-manipulationUI back",this.manipulationDOM.backLabelSpan=document.createElement("span"),this.manipulationDOM.backLabelSpan.className="network-manipulationLabel",this.manipulationDOM.backLabelSpan.innerHTML=t.back,this.manipulationDOM.backSpan.appendChild(this.manipulationDOM.backLabelSpan),this.manipulationDOM.seperatorLineDiv1=document.createElement("div"),this.manipulationDOM.seperatorLineDiv1.className="network-seperatorLine",this.manipulationDOM.descriptionSpan=document.createElement("span"),this.manipulationDOM.descriptionSpan.className="network-manipulationUI none",this.manipulationDOM.descriptionLabelSpan=document.createElement("span"),this.manipulationDOM.descriptionLabelSpan.className="network-manipulationLabel",this.manipulationDOM.descriptionLabelSpan.innerHTML=t.editEdgeDescription,this.manipulationDOM.descriptionSpan.appendChild(this.manipulationDOM.descriptionLabelSpan),this.manipulationDiv.appendChild(this.manipulationDOM.backSpan),this.manipulationDiv.appendChild(this.manipulationDOM.seperatorLineDiv1),this.manipulationDiv.appendChild(this.manipulationDOM.descriptionSpan),this.manipulationDOM.backSpan.onclick=this._createManipulatorBar.bind(this),this.cachedFunctions._handleTouch=this._handleTouch,this.cachedFunctions._manipulationReleaseOverload=this._manipulationReleaseOverload,this.cachedFunctions._handleTap=this._handleTap,this.cachedFunctions._handleDragStart=this._handleDragStart,this.cachedFunctions._handleOnDrag=this._handleOnDrag,this._handleTouch=this._selectControlNode,this._handleTap=function(){},this._handleOnDrag=this._controlNodeDrag,this._handleDragStart=function(){},this._manipulationReleaseOverload=this._releaseControlNode,this._redraw()},e._selectControlNode=function(t){this.edgeBeingEdited.controlNodes.from.unselect(),this.edgeBeingEdited.controlNodes.to.unselect(),this.selectedControlNode=this.edgeBeingEdited._getSelectedControlNode(this._XconvertDOMtoCanvas(t.x),this._YconvertDOMtoCanvas(t.y)),null!==this.selectedControlNode&&(this.selectedControlNode.select(),this.freezeSimulation=!0),this._redraw()},e._controlNodeDrag=function(t){var e=this._getPointer(t.center);null!==this.selectedControlNode&&void 0!==this.selectedControlNode&&(this.selectedControlNode.x=this._XconvertDOMtoCanvas(e.x),this.selectedControlNode.y=this._YconvertDOMtoCanvas(e.y)),this._redraw()},e._releaseControlNode=function(t){var e=this._getNodeAt(t);null!==e?(1==this.edgeBeingEdited.controlNodes.from.selected&&(this.edgeBeingEdited._restoreControlNodes(),this._editEdge(e.id,this.edgeBeingEdited.to.id),this.edgeBeingEdited.controlNodes.from.unselect()),1==this.edgeBeingEdited.controlNodes.to.selected&&(this.edgeBeingEdited._restoreControlNodes(),this._editEdge(this.edgeBeingEdited.from.id,e.id),this.edgeBeingEdited.controlNodes.to.unselect())):this.edgeBeingEdited._restoreControlNodes(),this.freezeSimulation=!1,this._redraw()},e._handleConnect=function(t){if(0==this._getSelectedNodeCount()){var e=this._getNodeAt(t);if(null!=e)if(e.clusterSize>1)alert(this.constants.locales[this.constants.locale].createEdgeError);else{this._selectObject(e,!1);var i=this.sectors.support.nodes;i.targetNode=new o({id:"targetNode"},{},{},this.constants);var s=i.targetNode;s.x=e.x,s.y=e.y,this.edges.connectionEdge=new n({id:"connectionEdge",from:e.id,to:s.id},this,this.constants);var r=this.edges.connectionEdge;r.from=e,r.connected=!0,r.options.smoothCurves={enabled:!0,dynamic:!1,type:"continuous",roundness:.5},r.selected=!0,r.to=s,this.cachedFunctions._handleOnDrag=this._handleOnDrag,this._handleOnDrag=function(t){var e=this._getPointer(t.center),i=this.edges.connectionEdge;i.to.x=this._XconvertDOMtoCanvas(e.x),i.to.y=this._YconvertDOMtoCanvas(e.y)},this.moving=!0,this.start()}}},e._finishConnect=function(t){if(1==this._getSelectedNodeCount()){var e=this._getPointer(t.center);this._handleOnDrag=this.cachedFunctions._handleOnDrag,delete this.cachedFunctions._handleOnDrag;var i=this.edges.connectionEdge.fromId;delete this.edges.connectionEdge,delete this.sectors.support.nodes.targetNode,delete this.sectors.support.nodes.targetViaNode;var s=this._getNodeAt(e);null!=s&&(s.clusterSize>1?alert(this.constants.locales[this.constants.locale].createEdgeError):(this._createEdge(i,s.id),this._createManipulatorBar())),this._unselectAll()}},e._addNode=function(){if(this._selectionIsEmpty()&&1==this.editMode){var t=this._pointerToPositionObject(this.pointerPosition),e={id:s.randomUUID(),x:t.left,y:t.top,label:"new",allowedToMoveX:!0,allowedToMoveY:!0};if(this.triggerFunctions.add){if(2!=this.triggerFunctions.add.length)throw new Error("The function for add does not support two arguments (data,callback)");var i=this;this.triggerFunctions.add(e,function(t){i.nodesData.add(t),i._createManipulatorBar(),i.moving=!0,i.start()})}else this.nodesData.add(e),this._createManipulatorBar(),this.moving=!0,this.start()}},e._createEdge=function(t,e){if(1==this.editMode){var i={from:t,to:e};if(this.triggerFunctions.connect){if(2!=this.triggerFunctions.connect.length)throw new Error("The function for connect does not support two arguments (data,callback)");var s=this;this.triggerFunctions.connect(i,function(t){s.edgesData.add(t),s.moving=!0,s.start()})}else this.edgesData.add(i),this.moving=!0,this.start()}},e._editEdge=function(t,e){if(1==this.editMode){var i={id:this.edgeBeingEdited.id,from:t,to:e};if(this.triggerFunctions.editEdge){if(2!=this.triggerFunctions.editEdge.length)throw new Error("The function for edit does not support two arguments (data, callback)");var s=this;this.triggerFunctions.editEdge(i,function(t){s.edgesData.update(t),s.moving=!0,s.start()})}else this.edgesData.update(i),this.moving=!0,this.start()}},e._editNode=function(){if(!this.triggerFunctions.edit||1!=this.editMode)throw new Error("No edit function has been bound to this button");var t=this._getSelectedNode(),e={id:t.id,label:t.label,group:t.options.group,shape:t.options.shape,color:{background:t.options.color.background,border:t.options.color.border,highlight:{background:t.options.color.highlight.background,border:t.options.color.highlight.border}}};if(2!=this.triggerFunctions.edit.length)throw new Error("The function for edit does not support two arguments (data, callback)");var i=this;this.triggerFunctions.edit(e,function(t){i.nodesData.update(t),i._createManipulatorBar(),i.moving=!0,i.start()})},e._deleteSelected=function(){if(!this._selectionIsEmpty()&&1==this.editMode)if(this._clusterInSelection())alert(this.constants.locales[this.constants.locale].deleteClusterError);else{var t=this.getSelectedNodes(),e=this.getSelectedEdges();if(this.triggerFunctions.del){var i=this,s={nodes:t,edges:e};if(2!=this.triggerFunctions.del.length)throw new Error("The function for delete does not support two arguments (data, callback)");this.triggerFunctions.del(s,function(t){i.edgesData.remove(t.edges),i.nodesData.remove(t.nodes),i._unselectAll(),i.moving=!0,i.start()})}else this.edgesData.remove(e),this.nodesData.remove(t),this._unselectAll(),this.moving=!0,this.start()}}},function(t,e,i){var s=(i(1),i(47)),o=i(45);e._cleanNavigation=function(){if(0!=this.navigationHammers.existing.length){for(var t=0;t0){var t,e,i=0,s=!1,o=!1;for(e in this.nodes)this.nodes.hasOwnProperty(e)&&(t=this.nodes[e],-1!=t.level?s=!0:o=!0,is&&(n.xFixed=!1,n.x=i[n.level].minPos,r=!0):n.yFixed&&n.level>s&&(n.yFixed=!1,n.y=i[n.level].minPos,r=!0),1==r&&(i[n.level].minPos+=i[n.level].nodeSpacing,n.edges.length>1&&this._placeBranchNodes(n.edges,n.id,i,n.level))}},e._setLevel=function(t,e,i){for(var s=0;st)&&(o.level=t,o.edges.length>1&&this._setLevel(t+1,o.edges,o.id))}},e._setLevelDirected=function(t,e,i){this.nodes[i].hierarchyEnumerated=!0;for(var s,o,n=0;n1&&s.hierarchyEnumerated===!1&&this._setLevelDirected(s.level,s.edges,s.id)},e._restoreNodes=function(){for(var t in this.nodes)this.nodes.hasOwnProperty(t)&&(this.nodes[t].xFixed=!1,this.nodes[t].yFixed=!1)}},function(t,e,i){function s(){this.constants.smoothCurves.enabled=!this.constants.smoothCurves.enabled;var t=document.getElementById("graph_toggleSmooth");t.style.background=1==this.constants.smoothCurves.enabled?"#A4FF56":"#FF8532",this._configureSmoothCurves(!1)}function o(){for(var t in this.calculationNodes)this.calculationNodes.hasOwnProperty(t)&&(this.calculationNodes[t].vx=0,this.calculationNodes[t].vy=0,this.calculationNodes[t].fx=0,this.calculationNodes[t].fy=0);1==this.constants.hierarchicalLayout.enabled?(this._setupHierarchicalLayout(),a.call(this,"graph_H_nd",1,"physics_hierarchicalRepulsion_nodeDistance"),a.call(this,"graph_H_cg",1,"physics_centralGravity"),a.call(this,"graph_H_sc",1,"physics_springConstant"),a.call(this,"graph_H_sl",1,"physics_springLength"),a.call(this,"graph_H_damp",1,"physics_damping")):this.repositionNodes(),this.moving=!0,this.start()}function n(){var t="No options are required, default values used.",e=[],i=document.getElementById("graph_physicsMethod1"),s=document.getElementById("graph_physicsMethod2");if(1==i.checked){if(this.constants.physics.barnesHut.gravitationalConstant!=this.backupConstants.physics.barnesHut.gravitationalConstant&&e.push("gravitationalConstant: "+this.constants.physics.barnesHut.gravitationalConstant),this.constants.physics.centralGravity!=this.backupConstants.physics.barnesHut.centralGravity&&e.push("centralGravity: "+this.constants.physics.centralGravity),this.constants.physics.springLength!=this.backupConstants.physics.barnesHut.springLength&&e.push("springLength: "+this.constants.physics.springLength),this.constants.physics.springConstant!=this.backupConstants.physics.barnesHut.springConstant&&e.push("springConstant: "+this.constants.physics.springConstant),this.constants.physics.damping!=this.backupConstants.physics.barnesHut.damping&&e.push("damping: "+this.constants.physics.damping),0!=e.length){t="var options = {",t+="physics: {barnesHut: {";for(var o=0;othis.constants.clustering.clusterThreshold&&1==this.constants.clustering.enabled&&this.clusterToFit(this.constants.clustering.reduceToNodes,!1),this._calculateForces())},e._calculateForces=function(){this._calculateGravitationalForces(),this._calculateNodeForces(),this.constants.physics.springConstant>0&&(1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic?this._calculateSpringForcesWithSupport():1==this.constants.physics.hierarchicalRepulsion.enabled?this._calculateHierarchicalSpringForces():this._calculateSpringForces())},e._updateCalculationNodes=function(){if(1==this.constants.smoothCurves.enabled&&1==this.constants.smoothCurves.dynamic){this.calculationNodes={},this.calculationNodeIndices=[];for(var t in this.nodes)this.nodes.hasOwnProperty(t)&&(this.calculationNodes[t]=this.nodes[t]);var e=this.sectors.support.nodes;for(var i in e)e.hasOwnProperty(i)&&(this.edges.hasOwnProperty(e[i].parentEdgeId)?this.calculationNodes[i]=e[i]:e[i]._setForce(0,0));for(var s in this.calculationNodes)this.calculationNodes.hasOwnProperty(s)&&this.calculationNodeIndices.push(s)}else this.calculationNodes=this.nodes,this.calculationNodeIndices=this.nodeIndices},e._calculateGravitationalForces=function(){var t,e,i,s,o,n=this.calculationNodes,r=this.constants.physics.centralGravity,a=0;for(o=0;oSimulation Mode:Barnes HutRepulsionHierarchical
Options:
',this.containerElement.parentElement.insertBefore(this.physicsConfiguration,this.containerElement),this.optionsDiv=document.createElement("div"),this.optionsDiv.style.fontSize="14px",this.optionsDiv.style.fontFamily="verdana",this.containerElement.parentElement.insertBefore(this.optionsDiv,this.containerElement); -var e;e=document.getElementById("graph_BH_gc"),e.onchange=a.bind(this,"graph_BH_gc",-1,"physics_barnesHut_gravitationalConstant"),e=document.getElementById("graph_BH_cg"),e.onchange=a.bind(this,"graph_BH_cg",1,"physics_centralGravity"),e=document.getElementById("graph_BH_sc"),e.onchange=a.bind(this,"graph_BH_sc",1,"physics_springConstant"),e=document.getElementById("graph_BH_sl"),e.onchange=a.bind(this,"graph_BH_sl",1,"physics_springLength"),e=document.getElementById("graph_BH_damp"),e.onchange=a.bind(this,"graph_BH_damp",1,"physics_damping"),e=document.getElementById("graph_R_nd"),e.onchange=a.bind(this,"graph_R_nd",1,"physics_repulsion_nodeDistance"),e=document.getElementById("graph_R_cg"),e.onchange=a.bind(this,"graph_R_cg",1,"physics_centralGravity"),e=document.getElementById("graph_R_sc"),e.onchange=a.bind(this,"graph_R_sc",1,"physics_springConstant"),e=document.getElementById("graph_R_sl"),e.onchange=a.bind(this,"graph_R_sl",1,"physics_springLength"),e=document.getElementById("graph_R_damp"),e.onchange=a.bind(this,"graph_R_damp",1,"physics_damping"),e=document.getElementById("graph_H_nd"),e.onchange=a.bind(this,"graph_H_nd",1,"physics_hierarchicalRepulsion_nodeDistance"),e=document.getElementById("graph_H_cg"),e.onchange=a.bind(this,"graph_H_cg",1,"physics_centralGravity"),e=document.getElementById("graph_H_sc"),e.onchange=a.bind(this,"graph_H_sc",1,"physics_springConstant"),e=document.getElementById("graph_H_sl"),e.onchange=a.bind(this,"graph_H_sl",1,"physics_springLength"),e=document.getElementById("graph_H_damp"),e.onchange=a.bind(this,"graph_H_damp",1,"physics_damping"),e=document.getElementById("graph_H_direction"),e.onchange=a.bind(this,"graph_H_direction",t,"hierarchicalLayout_direction"),e=document.getElementById("graph_H_levsep"),e.onchange=a.bind(this,"graph_H_levsep",1,"hierarchicalLayout_levelSeparation"),e=document.getElementById("graph_H_nspac"),e.onchange=a.bind(this,"graph_H_nspac",1,"hierarchicalLayout_nodeSpacing");var i=document.getElementById("graph_physicsMethod1"),d=document.getElementById("graph_physicsMethod2"),l=document.getElementById("graph_physicsMethod3");d.checked=!0,this.constants.physics.barnesHut.enabled&&(i.checked=!0),this.constants.hierarchicalLayout.enabled&&(l.checked=!0);var c=document.getElementById("graph_toggleSmooth"),p=document.getElementById("graph_repositionNodes"),u=document.getElementById("graph_generateOptions");c.onclick=s.bind(this),p.onclick=o.bind(this),u.onclick=n.bind(this),c.style.background=1==this.constants.smoothCurves&&0==this.constants.dynamicSmoothCurves?"#A4FF56":"#FF8532",r.apply(this),i.onchange=r.bind(this),d.onchange=r.bind(this),l.onchange=r.bind(this)}},e._overWriteGraphConstants=function(t,e){var i=t.split("_");1==i.length?this.constants[i[0]]=e:2==i.length?this.constants[i[0]][i[1]]=e:3==i.length&&(this.constants[i[0]][i[1]][i[2]]=e)}},function(t){function e(t){throw new Error("Cannot find module '"+t+"'.")}e.keys=function(){return[]},e.resolve=e,t.exports=e,e.id=68},function(t,e){e._calculateNodeForces=function(){var t,e,i,s,o,n,r,a,h,d,l,c=this.calculationNodes,p=this.calculationNodeIndices,u=-2/3,m=4/3,f=this.constants.physics.repulsion.nodeDistance,g=f;for(d=0;di&&(r=.5*g>i?1:v*i+m,r*=0==n?1:1+n*this.constants.clustering.forceAmplification,r/=Math.max(i,.01*g),s=t*r,o=e*r,a.fx-=s,a.fy-=o,h.fx+=s,h.fy+=o)}}},function(t,e){e._calculateNodeForces=function(){var t,e,i,s,o,n,r,a,h,d,l=this.calculationNodes,c=this.calculationNodeIndices,p=this.constants.physics.hierarchicalRepulsion.nodeDistance;for(h=0;hi?-Math.pow(u*i,2)+Math.pow(u*p,2):0,0==i?i=.01:n/=i,s=t*n,o=e*n,r.fx-=s,r.fy-=o,a.fx+=s,a.fy+=o}},e._calculateHierarchicalSpringForces=function(){for(var t,e,i,s,o,n,r,a,h,d=this.edges,l=this.calculationNodes,c=this.calculationNodeIndices,p=0;pn;n++)t=e[i[n]],t.options.mass>0&&(this._getForceContribution(o.root.children.NW,t),this._getForceContribution(o.root.children.NE,t),this._getForceContribution(o.root.children.SW,t),this._getForceContribution(o.root.children.SE,t))}},e._getForceContribution=function(t,e){if(t.childrenCount>0){var i,s,o;if(i=t.centerOfMass.x-e.x,s=t.centerOfMass.y-e.y,o=Math.sqrt(i*i+s*s),o*t.calcSize>this.constants.physics.barnesHut.thetaInverted){0==o&&(o=.1*Math.random(),i=o);var n=this.constants.physics.barnesHut.gravitationalConstant*t.mass*e.options.mass/(o*o*o),r=i*n,a=s*n;e.fx+=r,e.fy+=a}else if(4==t.childrenCount)this._getForceContribution(t.children.NW,e),this._getForceContribution(t.children.NE,e),this._getForceContribution(t.children.SW,e),this._getForceContribution(t.children.SE,e);else if(t.children.data.id!=e.id){0==o&&(o=.5*Math.random(),i=o);var n=this.constants.physics.barnesHut.gravitationalConstant*t.mass*e.options.mass/(o*o*o),r=i*n,a=s*n;e.fx+=r,e.fy+=a}}},e._formBarnesHutTree=function(t,e){for(var i,s=e.length,o=Number.MAX_VALUE,n=Number.MAX_VALUE,r=-Number.MAX_VALUE,a=-Number.MAX_VALUE,h=0;s>h;h++){var d=t[e[h]].x,l=t[e[h]].y;t[e[h]].options.mass>0&&(o>d&&(o=d),d>r&&(r=d),n>l&&(n=l),l>a&&(a=l))}var c=Math.abs(r-o)-Math.abs(a-n);c>0?(n-=.5*c,a+=.5*c):(o+=.5*c,r-=.5*c);var p=1e-5,u=Math.max(p,Math.abs(r-o)),m=.5*u,f=.5*(o+r),g=.5*(n+a),v={root:{centerOfMass:{x:0,y:0},mass:0,range:{minX:f-m,maxX:f+m,minY:g-m,maxY:g+m},size:u,calcSize:1/u,children:{data:null},maxWidth:0,level:0,childrenCount:4}};for(this._splitBranch(v.root),h=0;s>h;h++)i=t[e[h]],i.options.mass>0&&this._placeInTree(v.root,i);this.barnesHutTree=v},e._updateBranchMass=function(t,e){var i=t.mass+e.options.mass,s=1/i;t.centerOfMass.x=t.centerOfMass.x*t.mass+e.x*e.options.mass,t.centerOfMass.x*=s,t.centerOfMass.y=t.centerOfMass.y*t.mass+e.y*e.options.mass,t.centerOfMass.y*=s,t.mass=i;var o=Math.max(Math.max(e.height,e.radius),e.width);t.maxWidth=t.maxWidthe.x?t.children.NW.range.maxY>e.y?this._placeInRegion(t,e,"NW"):this._placeInRegion(t,e,"SW"):t.children.NW.range.maxY>e.y?this._placeInRegion(t,e,"NE"):this._placeInRegion(t,e,"SE")},e._placeInRegion=function(t,e,i){switch(t.children[i].childrenCount){case 0:t.children[i].children.data=e,t.children[i].childrenCount=1,this._updateBranchMass(t.children[i],e);break;case 1:t.children[i].children.data.x==e.x&&t.children[i].children.data.y==e.y?(e.x+=Math.random(),e.y+=Math.random()):(this._splitBranch(t.children[i]),this._placeInTree(t.children[i],e));break;case 4:this._placeInTree(t.children[i],e)}},e._splitBranch=function(t){var e=null;1==t.childrenCount&&(e=t.children.data,t.mass=0,t.centerOfMass.x=0,t.centerOfMass.y=0),t.childrenCount=4,t.children.data=null,this._insertRegion(t,"NW"),this._insertRegion(t,"NE"),this._insertRegion(t,"SW"),this._insertRegion(t,"SE"),null!=e&&this._placeInTree(t,e)},e._insertRegion=function(t,e){var i,s,o,n,r=.5*t.size;switch(e){case"NW":i=t.range.minX,s=t.range.minX+r,o=t.range.minY,n=t.range.minY+r;break;case"NE":i=t.range.minX+r,s=t.range.maxX,o=t.range.minY,n=t.range.minY+r;break;case"SW":i=t.range.minX,s=t.range.minX+r,o=t.range.minY+r,n=t.range.maxY;break;case"SE":i=t.range.minX+r,s=t.range.maxX,o=t.range.minY+r,n=t.range.maxY}t.children[e]={centerOfMass:{x:0,y:0},mass:0,range:{minX:i,maxX:s,minY:o,maxY:n},size:.5*t.size,calcSize:2*t.calcSize,children:{data:null},maxWidth:0,level:t.level+1,childrenCount:0}},e._drawTree=function(t,e){void 0!==this.barnesHutTree&&(t.lineWidth=1,this._drawBranch(this.barnesHutTree.root,t,e))},e._drawBranch=function(t,e,i){void 0===i&&(i="#FF0000"),4==t.childrenCount&&(this._drawBranch(t.children.NW,e),this._drawBranch(t.children.NE,e),this._drawBranch(t.children.SE,e),this._drawBranch(t.children.SW,e)),e.strokeStyle=i,e.beginPath(),e.moveTo(t.range.minX,t.range.minY),e.lineTo(t.range.maxX,t.range.minY),e.stroke(),e.beginPath(),e.moveTo(t.range.maxX,t.range.minY),e.lineTo(t.range.maxX,t.range.maxY),e.stroke(),e.beginPath(),e.moveTo(t.range.maxX,t.range.maxY),e.lineTo(t.range.minX,t.range.maxY),e.stroke(),e.beginPath(),e.moveTo(t.range.minX,t.range.maxY),e.lineTo(t.range.minX,t.range.minY),e.stroke()}},function(t){t.exports=function(t){return t.webpackPolyfill||(t.deprecate=function(){},t.paths=[],t.children=[],t.webpackPolyfill=1),t}},function(t,e){(function(e){t.exports=e}).call(e,{})}])}); +"use strict";!function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.vis=e():t.vis=e()}(this,function(){return function(t){function e(o){if(i[o])return i[o].exports;var n=i[o]={exports:{},id:o,loaded:!1};return t[o].call(n.exports,n,n.exports,e),n.loaded=!0,n.exports}var i={};return e.m=t,e.c=i,e.p="",e(0)}([function(t,e,i){var o=i(1);o.extend(e,i(7)),o.extend(e,i(24)),o.extend(e,i(60))},function(t,e,i){var o="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},n=i(2),s=i(6);e.isNumber=function(t){return t instanceof Number||"number"==typeof t},e.recursiveDOMDelete=function(t){if(t)for(;t.hasChildNodes()===!0;)e.recursiveDOMDelete(t.firstChild),t.removeChild(t.firstChild)},e.giveRange=function(t,e,i,o){if(e==t)return.5;var n=1/(e-t);return Math.max(0,(o-t)*n)},e.isString=function(t){return t instanceof String||"string"==typeof t},e.isDate=function(t){if(t instanceof Date)return!0;if(e.isString(t)){var i=r.exec(t);if(i)return!0;if(!isNaN(Date.parse(t)))return!0}return!1},e.randomUUID=function(){return s.v4()},e.assignAllKeys=function(t,e){for(var i in t)t.hasOwnProperty(i)&&"object"!==o(t[i])&&(t[i]=e)},e.fillIfDefined=function(t,i){var n=arguments.length<=2||void 0===arguments[2]?!1:arguments[2];for(var s in t)void 0!==i[s]&&("object"!==o(i[s])?void 0!==i[s]&&null!==i[s]||void 0===t[s]||n!==!0?t[s]=i[s]:delete t[s]:"object"===o(t[s])&&e.fillIfDefined(t[s],i[s],n))},e.protoExtend=function(t,e){for(var i=1;ii;i++)if(t[i]!=e[i])return!1;return!0},e.convert=function(t,i){var o;if(void 0!==t){if(null===t)return null;if(!i)return t;if("string"!=typeof i&&!(i instanceof String))throw new Error("Type must be a string");switch(i){case"boolean":case"Boolean":return Boolean(t);case"number":case"Number":return Number(t.valueOf());case"string":case"String":return String(t);case"Date":if(e.isNumber(t))return new Date(t);if(t instanceof Date)return new Date(t.valueOf());if(n.isMoment(t))return new Date(t.valueOf());if(e.isString(t))return o=r.exec(t),o?new Date(Number(o[1])):n(t).toDate();throw new Error("Cannot convert object of type "+e.getType(t)+" to type Date");case"Moment":if(e.isNumber(t))return n(t);if(t instanceof Date)return n(t.valueOf());if(n.isMoment(t))return n(t);if(e.isString(t))return o=r.exec(t),n(o?Number(o[1]):t);throw new Error("Cannot convert object of type "+e.getType(t)+" to type Date");case"ISODate":if(e.isNumber(t))return new Date(t);if(t instanceof Date)return t.toISOString();if(n.isMoment(t))return t.toDate().toISOString();if(e.isString(t))return o=r.exec(t),o?new Date(Number(o[1])).toISOString():new Date(t).toISOString();throw new Error("Cannot convert object of type "+e.getType(t)+" to type ISODate");case"ASPDate":if(e.isNumber(t))return"/Date("+t+")/";if(t instanceof Date)return"/Date("+t.valueOf()+")/";if(e.isString(t)){o=r.exec(t);var s;return s=o?new Date(Number(o[1])).valueOf():new Date(t).valueOf(),"/Date("+s+")/"}throw new Error("Cannot convert object of type "+e.getType(t)+" to type ASPDate");default:throw new Error('Unknown type "'+i+'"')}}};var r=/^\/?Date\((\-?\d+)/i;e.getType=function(t){var e="undefined"==typeof t?"undefined":o(t);return"object"==e?null===t?"null":t instanceof Boolean?"Boolean":t instanceof Number?"Number":t instanceof String?"String":Array.isArray(t)?"Array":t instanceof Date?"Date":"Object":"number"==e?"Number":"boolean"==e?"Boolean":"string"==e?"String":void 0===e?"undefined":e},e.copyAndExtendArray=function(t,e){for(var i=[],o=0;oi;i++)e(t[i],i,t);else for(i in t)t.hasOwnProperty(i)&&e(t[i],i,t)},e.toArray=function(t){var e=[];for(var i in t)t.hasOwnProperty(i)&&e.push(t[i]);return e},e.updateProperty=function(t,e,i){return t[e]!==i?(t[e]=i,!0):!1},e.throttle=function(t,e){var i=null,o=!1;return function n(){i?o=!0:(o=!1,t(),i=setTimeout(function(){i=null,o&&n()},e))}},e.addEventListener=function(t,e,i,o){t.addEventListener?(void 0===o&&(o=!1),"mousewheel"===e&&navigator.userAgent.indexOf("Firefox")>=0&&(e="DOMMouseScroll"),t.addEventListener(e,i,o)):t.attachEvent("on"+e,i)},e.removeEventListener=function(t,e,i,o){t.removeEventListener?(void 0===o&&(o=!1),"mousewheel"===e&&navigator.userAgent.indexOf("Firefox")>=0&&(e="DOMMouseScroll"),t.removeEventListener(e,i,o)):t.detachEvent("on"+e,i)},e.preventDefault=function(t){t||(t=window.event),t.preventDefault?t.preventDefault():t.returnValue=!1},e.getTarget=function(t){t||(t=window.event);var e;return t.target?e=t.target:t.srcElement&&(e=t.srcElement),void 0!=e.nodeType&&3==e.nodeType&&(e=e.parentNode),e},e.hasParent=function(t,e){for(var i=t;i;){if(i===e)return!0;i=i.parentNode}return!1},e.option={},e.option.asBoolean=function(t,e){return"function"==typeof t&&(t=t()),null!=t?0!=t:e||null},e.option.asNumber=function(t,e){return"function"==typeof t&&(t=t()),null!=t?Number(t)||e||null:e||null},e.option.asString=function(t,e){return"function"==typeof t&&(t=t()),null!=t?String(t):e||null},e.option.asSize=function(t,i){return"function"==typeof t&&(t=t()),e.isString(t)?t:e.isNumber(t)?t+"px":i||null},e.option.asElement=function(t,e){return"function"==typeof t&&(t=t()),t||e||null},e.hexToRGB=function(t){var e=/^#?([a-f\d])([a-f\d])([a-f\d])$/i;t=t.replace(e,function(t,e,i,o){return e+e+i+i+o+o});var i=/^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(t);return i?{r:parseInt(i[1],16),g:parseInt(i[2],16),b:parseInt(i[3],16)}:null},e.overrideOpacity=function(t,i){if(-1!=t.indexOf("rgba"))return t;if(-1!=t.indexOf("rgb")){var o=t.substr(t.indexOf("(")+1).replace(")","").split(",");return"rgba("+o[0]+","+o[1]+","+o[2]+","+i+")"}var o=e.hexToRGB(t);return null==o?t:"rgba("+o.r+","+o.g+","+o.b+","+i+")"},e.RGBToHex=function(t,e,i){return"#"+((1<<24)+(t<<16)+(e<<8)+i).toString(16).slice(1)},e.parseColor=function(t){var i;if(e.isString(t)===!0){if(e.isValidRGB(t)===!0){var o=t.substr(4).substr(0,t.length-5).split(",").map(function(t){return parseInt(t)});t=e.RGBToHex(o[0],o[1],o[2])}if(e.isValidHex(t)===!0){var n=e.hexToHSV(t),s={h:n.h,s:.8*n.s,v:Math.min(1,1.02*n.v)},r={h:n.h,s:Math.min(1,1.25*n.s),v:.8*n.v},a=e.HSVToHex(r.h,r.s,r.v),h=e.HSVToHex(s.h,s.s,s.v);i={background:t,border:a,highlight:{background:h,border:a},hover:{background:h,border:a}}}else i={background:t,border:t,highlight:{background:t,border:t},hover:{background:t,border:t}}}else i={},i.background=t.background||void 0,i.border=t.border||void 0,e.isString(t.highlight)?i.highlight={border:t.highlight,background:t.highlight}:(i.highlight={},i.highlight.background=t.highlight&&t.highlight.background||void 0,i.highlight.border=t.highlight&&t.highlight.border||void 0),e.isString(t.hover)?i.hover={border:t.hover,background:t.hover}:(i.hover={},i.hover.background=t.hover&&t.hover.background||void 0,i.hover.border=t.hover&&t.hover.border||void 0);return i},e.RGBToHSV=function(t,e,i){t/=255,e/=255,i/=255;var o=Math.min(t,Math.min(e,i)),n=Math.max(t,Math.max(e,i));if(o==n)return{h:0,s:0,v:o};var s=t==o?e-i:i==o?t-e:i-t,r=t==o?3:i==o?1:5,a=60*(r-s/(n-o))/360,h=(n-o)/n,d=n;return{h:a,s:h,v:d}};var a={split:function(t){var e={};return t.split(";").forEach(function(t){if(""!=t.trim()){var i=t.split(":"),o=i[0].trim(),n=i[1].trim();e[o]=n}}),e},join:function(t){return Object.keys(t).map(function(e){return e+": "+t[e]}).join("; ")}};e.addCssText=function(t,i){var o=a.split(t.style.cssText),n=a.split(i),s=e.extend(o,n);t.style.cssText=a.join(s)},e.removeCssText=function(t,e){var i=a.split(t.style.cssText),o=a.split(e);for(var n in o)o.hasOwnProperty(n)&&delete i[n];t.style.cssText=a.join(i)},e.HSVToRGB=function(t,e,i){var o,n,s,r=Math.floor(6*t),a=6*t-r,h=i*(1-e),d=i*(1-a*e),l=i*(1-(1-a)*e);switch(r%6){case 0:o=i,n=l,s=h;break;case 1:o=d,n=i,s=h;break;case 2:o=h,n=i,s=l;break;case 3:o=h,n=d,s=i;break;case 4:o=l,n=h,s=i;break;case 5:o=i,n=h,s=d}return{r:Math.floor(255*o),g:Math.floor(255*n),b:Math.floor(255*s)}},e.HSVToHex=function(t,i,o){var n=e.HSVToRGB(t,i,o);return e.RGBToHex(n.r,n.g,n.b)},e.hexToHSV=function(t){var i=e.hexToRGB(t);return e.RGBToHSV(i.r,i.g,i.b)},e.isValidHex=function(t){var e=/(^#[0-9A-F]{6}$)|(^#[0-9A-F]{3}$)/i.test(t);return e},e.isValidRGB=function(t){t=t.replace(" ","");var e=/rgb\((\d{1,3}),(\d{1,3}),(\d{1,3})\)/i.test(t);return e},e.isValidRGBA=function(t){t=t.replace(" ","");var e=/rgba\((\d{1,3}),(\d{1,3}),(\d{1,3}),(.{1,3})\)/i.test(t);return e},e.selectiveBridgeObject=function(t,i){if("object"==("undefined"==typeof i?"undefined":o(i))){for(var n=Object.create(i),s=0;s0&&e(o,t[n-1])<0;n--)t[n]=t[n-1];t[n]=o}return t},e.mergeOptions=function(t,e,i){var o=(arguments.length<=3||void 0===arguments[3]?!1:arguments[3],arguments.length<=4||void 0===arguments[4]?{}:arguments[4]);if(null===e[i])t[i]=Object.create(o[i]);else if(void 0!==e[i])if("boolean"==typeof e[i])t[i].enabled=e[i];else{void 0===e[i].enabled&&(t[i].enabled=!0);for(var n in e[i])e[i].hasOwnProperty(n)&&(t[i][n]=e[i][n])}},e.binarySearchCustom=function(t,e,i,o){for(var n=1e4,s=0,r=0,a=t.length-1;a>=r&&n>s;){var h=Math.floor((r+a)/2),d=t[h],l=void 0===o?d[i]:d[i][o],c=e(l);if(0==c)return h;-1==c?r=h+1:a=h-1,s++}return-1},e.binarySearchValue=function(t,e,i,o,n){for(var s,r,a,h,d=1e4,l=0,c=0,u=t.length-1,n=void 0!=n?n:function(t,e){return t==e?0:e>t?-1:1};u>=c&&d>l;){if(h=Math.floor(.5*(u+c)),s=t[Math.max(0,h-1)][i],r=t[h][i],a=t[Math.min(t.length-1,h+1)][i],0==n(r,e))return h;if(n(s,e)<0&&n(r,e)>0)return"before"==o?Math.max(0,h-1):h;if(n(r,e)<0&&n(a,e)>0)return"before"==o?h:Math.min(t.length-1,h+1);n(r,e)<0?c=h+1:u=h-1,l++}return-1},e.easingFunctions={linear:function(t){return t},easeInQuad:function(t){return t*t},easeOutQuad:function(t){return t*(2-t)},easeInOutQuad:function(t){return.5>t?2*t*t:-1+(4-2*t)*t},easeInCubic:function(t){return t*t*t},easeOutCubic:function(t){return--t*t*t+1},easeInOutCubic:function(t){return.5>t?4*t*t*t:(t-1)*(2*t-2)*(2*t-2)+1},easeInQuart:function(t){return t*t*t*t},easeOutQuart:function(t){return 1- --t*t*t*t},easeInOutQuart:function(t){return.5>t?8*t*t*t*t:1-8*--t*t*t*t},easeInQuint:function(t){return t*t*t*t*t},easeOutQuint:function(t){return 1+--t*t*t*t*t},easeInOutQuint:function(t){return.5>t?16*t*t*t*t*t:1+16*--t*t*t*t*t}}},function(t,e,i){t.exports="undefined"!=typeof window&&window.moment||i(3)},function(t,e,i){(function(t){!function(e,i){t.exports=i()}(this,function(){function e(){return ro.apply(null,arguments)}function i(t){ro=t}function o(t){return t instanceof Array||"[object Array]"===Object.prototype.toString.call(t)}function n(t){return t instanceof Date||"[object Date]"===Object.prototype.toString.call(t)}function s(t,e){var i,o=[];for(i=0;i0)for(i in ho)o=ho[i],n=e[o],p(n)||(t[o]=n);return t}function m(t){f(this,t),this._d=new Date(null!=t._d?t._d.getTime():NaN),lo===!1&&(lo=!0,e.updateOffset(this),lo=!1)}function v(t){return t instanceof m||null!=t&&null!=t._isAMomentObject}function g(t){return 0>t?Math.ceil(t):Math.floor(t)}function y(t){var e=+t,i=0;return 0!==e&&isFinite(e)&&(i=g(e)),i}function b(t,e,i){var o,n=Math.min(t.length,e.length),s=Math.abs(t.length-e.length),r=0;for(o=0;n>o;o++)(i&&t[o]!==e[o]||!i&&y(t[o])!==y(e[o]))&&r++;return r+s}function w(t){e.suppressDeprecationWarnings===!1&&"undefined"!=typeof console&&console.warn&&console.warn("Deprecation warning: "+t)}function _(t,i){var o=!0;return a(function(){return null!=e.deprecationHandler&&e.deprecationHandler(null,t),o&&(w(t+"\nArguments: "+Array.prototype.slice.call(arguments).join(", ")+"\n"+(new Error).stack),o=!1),i.apply(this,arguments)},i)}function x(t,i){null!=e.deprecationHandler&&e.deprecationHandler(t,i),co[t]||(w(i),co[t]=!0)}function k(t){return t instanceof Function||"[object Function]"===Object.prototype.toString.call(t)}function O(t){return"[object Object]"===Object.prototype.toString.call(t)}function M(t){var e,i;for(i in t)e=t[i],k(e)?this[i]=e:this["_"+i]=e;this._config=t,this._ordinalParseLenient=new RegExp(this._ordinalParse.source+"|"+/\d{1,2}/.source)}function D(t,e){var i,o=a({},t);for(i in e)r(e,i)&&(O(t[i])&&O(e[i])?(o[i]={},a(o[i],t[i]),a(o[i],e[i])):null!=e[i]?o[i]=e[i]:delete o[i]);return o}function S(t){null!=t&&this.set(t)}function C(t){return t?t.toLowerCase().replace("_","-"):t}function T(t){for(var e,i,o,n,s=0;s0;){if(o=E(n.slice(0,e).join("-")))return o;if(i&&i.length>=e&&b(n,i,!0)>=e-1)break;e--}s++}return null}function E(e){var i=null;if(!mo[e]&&"undefined"!=typeof t&&t&&t.exports)try{i=po._abbr,!function(){var t=new Error('Cannot find module "./locale"');throw t.code="MODULE_NOT_FOUND",t}(),P(i)}catch(o){}return mo[e]}function P(t,e){var i;return t&&(i=p(e)?R(t):I(t,e),i&&(po=i)),po._abbr}function I(t,e){return null!==e?(e.abbr=t,null!=mo[t]?(x("defineLocaleOverride","use moment.updateLocale(localeName, config) to change an existing locale. moment.defineLocale(localeName, config) should only be used for creating a new locale"),e=D(mo[t]._config,e)):null!=e.parentLocale&&(null!=mo[e.parentLocale]?e=D(mo[e.parentLocale]._config,e):x("parentLocaleUndefined","specified parentLocale is not defined yet")),mo[t]=new S(e),P(t),mo[t]):(delete mo[t],null)}function N(t,e){if(null!=e){var i;null!=mo[t]&&(e=D(mo[t]._config,e)),i=new S(e),i.parentLocale=mo[t],mo[t]=i,P(t)}else null!=mo[t]&&(null!=mo[t].parentLocale?mo[t]=mo[t].parentLocale:null!=mo[t]&&delete mo[t]);return mo[t]}function R(t){var e;if(t&&t._locale&&t._locale._abbr&&(t=t._locale._abbr),!t)return po;if(!o(t)){if(e=E(t))return e;t=[t]}return T(t)}function z(){return uo(mo)}function L(t,e){var i=t.toLowerCase();vo[i]=vo[i+"s"]=vo[e]=t}function A(t){return"string"==typeof t?vo[t]||vo[t.toLowerCase()]:void 0}function B(t){var e,i,o={};for(i in t)r(t,i)&&(e=A(i),e&&(o[e]=t[i]));return o}function F(t,i){return function(o){return null!=o?(H(this,t,o),e.updateOffset(this,i),this):j(this,t)}}function j(t,e){return t.isValid()?t._d["get"+(t._isUTC?"UTC":"")+e]():NaN}function H(t,e,i){t.isValid()&&t._d["set"+(t._isUTC?"UTC":"")+e](i)}function W(t,e){var i;if("object"==typeof t)for(i in t)this.set(i,t[i]);else if(t=A(t),k(this[t]))return this[t](e);return this}function Y(t,e,i){var o=""+Math.abs(t),n=e-o.length,s=t>=0;return(s?i?"+":"":"-")+Math.pow(10,Math.max(0,n)).toString().substr(1)+o}function G(t,e,i,o){var n=o;"string"==typeof o&&(n=function(){return this[o]()}),t&&(wo[t]=n),e&&(wo[e[0]]=function(){return Y(n.apply(this,arguments),e[1],e[2])}),i&&(wo[i]=function(){return this.localeData().ordinal(n.apply(this,arguments),t)})}function V(t){return t.match(/\[[\s\S]/)?t.replace(/^\[|\]$/g,""):t.replace(/\\/g,"")}function U(t){var e,i,o=t.match(go);for(e=0,i=o.length;i>e;e++)wo[o[e]]?o[e]=wo[o[e]]:o[e]=V(o[e]);return function(e){var n,s="";for(n=0;i>n;n++)s+=o[n]instanceof Function?o[n].call(e,t):o[n];return s}}function q(t,e){return t.isValid()?(e=X(e,t.localeData()),bo[e]=bo[e]||U(e),bo[e](t)):t.localeData().invalidDate()}function X(t,e){function i(t){return e.longDateFormat(t)||t}var o=5;for(yo.lastIndex=0;o>=0&&yo.test(t);)t=t.replace(yo,i),yo.lastIndex=0,o-=1;return t}function Z(t,e,i){Bo[t]=k(e)?e:function(t,o){return t&&i?i:e}}function K(t,e){return r(Bo,t)?Bo[t](e._strict,e._locale):new RegExp(J(t))}function J(t){return Q(t.replace("\\","").replace(/\\(\[)|\\(\])|\[([^\]\[]*)\]|\\(.)/g,function(t,e,i,o,n){return e||i||o||n}))}function Q(t){return t.replace(/[-\/\\^$*+?.()|[\]{}]/g,"\\$&")}function $(t,e){var i,o=e;for("string"==typeof t&&(t=[t]),"number"==typeof e&&(o=function(t,i){i[e]=y(t)}),i=0;io;++o)s=h([2e3,o]),this._shortMonthsParse[o]=this.monthsShort(s,"").toLocaleLowerCase(),this._longMonthsParse[o]=this.months(s,"").toLocaleLowerCase();return i?"MMM"===e?(n=fo.call(this._shortMonthsParse,r),-1!==n?n:null):(n=fo.call(this._longMonthsParse,r),-1!==n?n:null):"MMM"===e?(n=fo.call(this._shortMonthsParse,r),-1!==n?n:(n=fo.call(this._longMonthsParse,r),-1!==n?n:null)):(n=fo.call(this._longMonthsParse,r),-1!==n?n:(n=fo.call(this._shortMonthsParse,r),-1!==n?n:null))}function rt(t,e,i){var o,n,s;if(this._monthsParseExact)return st.call(this,t,e,i);for(this._monthsParse||(this._monthsParse=[],this._longMonthsParse=[],this._shortMonthsParse=[]),o=0;12>o;o++){if(n=h([2e3,o]),i&&!this._longMonthsParse[o]&&(this._longMonthsParse[o]=new RegExp("^"+this.months(n,"").replace(".","")+"$","i"),this._shortMonthsParse[o]=new RegExp("^"+this.monthsShort(n,"").replace(".","")+"$","i")),i||this._monthsParse[o]||(s="^"+this.months(n,"")+"|^"+this.monthsShort(n,""),this._monthsParse[o]=new RegExp(s.replace(".",""),"i")),i&&"MMMM"===e&&this._longMonthsParse[o].test(t))return o;if(i&&"MMM"===e&&this._shortMonthsParse[o].test(t))return o;if(!i&&this._monthsParse[o].test(t))return o}}function at(t,e){var i;if(!t.isValid())return t;if("string"==typeof e)if(/^\d+$/.test(e))e=y(e);else if(e=t.localeData().monthsParse(e),"number"!=typeof e)return t;return i=Math.min(t.date(),it(t.year(),e)),t._d["set"+(t._isUTC?"UTC":"")+"Month"](e,i),t}function ht(t){return null!=t?(at(this,t),e.updateOffset(this,!0),this):j(this,"Month")}function dt(){return it(this.year(),this.month())}function lt(t){return this._monthsParseExact?(r(this,"_monthsRegex")||ut.call(this),t?this._monthsShortStrictRegex:this._monthsShortRegex):this._monthsShortStrictRegex&&t?this._monthsShortStrictRegex:this._monthsShortRegex}function ct(t){return this._monthsParseExact?(r(this,"_monthsRegex")||ut.call(this),t?this._monthsStrictRegex:this._monthsRegex):this._monthsStrictRegex&&t?this._monthsStrictRegex:this._monthsRegex}function ut(){function t(t,e){return e.length-t.length}var e,i,o=[],n=[],s=[];for(e=0;12>e;e++)i=h([2e3,e]),o.push(this.monthsShort(i,"")),n.push(this.months(i,"")),s.push(this.months(i,"")),s.push(this.monthsShort(i,""));for(o.sort(t),n.sort(t),s.sort(t),e=0;12>e;e++)o[e]=Q(o[e]),n[e]=Q(n[e]),s[e]=Q(s[e]);this._monthsRegex=new RegExp("^("+s.join("|")+")","i"),this._monthsShortRegex=this._monthsRegex,this._monthsStrictRegex=new RegExp("^("+n.join("|")+")","i"),this._monthsShortStrictRegex=new RegExp("^("+o.join("|")+")","i")}function pt(t){var e,i=t._a;return i&&-2===l(t).overflow&&(e=i[Ho]<0||i[Ho]>11?Ho:i[Wo]<1||i[Wo]>it(i[jo],i[Ho])?Wo:i[Yo]<0||i[Yo]>24||24===i[Yo]&&(0!==i[Go]||0!==i[Vo]||0!==i[Uo])?Yo:i[Go]<0||i[Go]>59?Go:i[Vo]<0||i[Vo]>59?Vo:i[Uo]<0||i[Uo]>999?Uo:-1,l(t)._overflowDayOfYear&&(jo>e||e>Wo)&&(e=Wo),l(t)._overflowWeeks&&-1===e&&(e=qo),l(t)._overflowWeekday&&-1===e&&(e=Xo),l(t).overflow=e),t}function ft(t){var e,i,o,n,s,r,a=t._i,h=tn.exec(a)||en.exec(a);if(h){for(l(t).iso=!0,e=0,i=nn.length;i>e;e++)if(nn[e][1].exec(h[1])){n=nn[e][0],o=nn[e][2]!==!1;break}if(null==n)return void(t._isValid=!1);if(h[3]){for(e=0,i=sn.length;i>e;e++)if(sn[e][1].exec(h[3])){s=(h[2]||" ")+sn[e][0];break}if(null==s)return void(t._isValid=!1)}if(!o&&null!=s)return void(t._isValid=!1);if(h[4]){if(!on.exec(h[4]))return void(t._isValid=!1);r="Z"}t._f=n+(s||"")+(r||""),Tt(t)}else t._isValid=!1}function mt(t){var i=rn.exec(t._i);return null!==i?void(t._d=new Date(+i[1])):(ft(t),void(t._isValid===!1&&(delete t._isValid,e.createFromInputFallback(t))))}function vt(t,e,i,o,n,s,r){var a=new Date(t,e,i,o,n,s,r);return 100>t&&t>=0&&isFinite(a.getFullYear())&&a.setFullYear(t),a}function gt(t){var e=new Date(Date.UTC.apply(null,arguments));return 100>t&&t>=0&&isFinite(e.getUTCFullYear())&&e.setUTCFullYear(t),e}function yt(t){return bt(t)?366:365}function bt(t){return t%4===0&&t%100!==0||t%400===0}function wt(){return bt(this.year())}function _t(t,e,i){var o=7+e-i,n=(7+gt(t,0,o).getUTCDay()-e)%7;return-n+o-1}function xt(t,e,i,o,n){var s,r,a=(7+i-o)%7,h=_t(t,o,n),d=1+7*(e-1)+a+h;return 0>=d?(s=t-1,r=yt(s)+d):d>yt(t)?(s=t+1,r=d-yt(t)):(s=t,r=d),{year:s,dayOfYear:r}}function kt(t,e,i){var o,n,s=_t(t.year(),e,i),r=Math.floor((t.dayOfYear()-s-1)/7)+1;return 1>r?(n=t.year()-1,o=r+Ot(n,e,i)):r>Ot(t.year(),e,i)?(o=r-Ot(t.year(),e,i),n=t.year()+1):(n=t.year(),o=r),{week:o,year:n}}function Ot(t,e,i){var o=_t(t,e,i),n=_t(t+1,e,i);return(yt(t)-o+n)/7}function Mt(t,e,i){return null!=t?t:null!=e?e:i}function Dt(t){var i=new Date(e.now());return t._useUTC?[i.getUTCFullYear(),i.getUTCMonth(),i.getUTCDate()]:[i.getFullYear(),i.getMonth(),i.getDate()]}function St(t){var e,i,o,n,s=[];if(!t._d){for(o=Dt(t),t._w&&null==t._a[Wo]&&null==t._a[Ho]&&Ct(t),t._dayOfYear&&(n=Mt(t._a[jo],o[jo]),t._dayOfYear>yt(n)&&(l(t)._overflowDayOfYear=!0),i=gt(n,0,t._dayOfYear),t._a[Ho]=i.getUTCMonth(),t._a[Wo]=i.getUTCDate()),e=0;3>e&&null==t._a[e];++e)t._a[e]=s[e]=o[e];for(;7>e;e++)t._a[e]=s[e]=null==t._a[e]?2===e?1:0:t._a[e];24===t._a[Yo]&&0===t._a[Go]&&0===t._a[Vo]&&0===t._a[Uo]&&(t._nextDay=!0,t._a[Yo]=0),t._d=(t._useUTC?gt:vt).apply(null,s),null!=t._tzm&&t._d.setUTCMinutes(t._d.getUTCMinutes()-t._tzm),t._nextDay&&(t._a[Yo]=24)}}function Ct(t){var e,i,o,n,s,r,a,h;e=t._w,null!=e.GG||null!=e.W||null!=e.E?(s=1,r=4,i=Mt(e.GG,t._a[jo],kt(At(),1,4).year),o=Mt(e.W,1),n=Mt(e.E,1),(1>n||n>7)&&(h=!0)):(s=t._locale._week.dow,r=t._locale._week.doy,i=Mt(e.gg,t._a[jo],kt(At(),s,r).year),o=Mt(e.w,1),null!=e.d?(n=e.d,(0>n||n>6)&&(h=!0)):null!=e.e?(n=e.e+s,(e.e<0||e.e>6)&&(h=!0)):n=s),1>o||o>Ot(i,s,r)?l(t)._overflowWeeks=!0:null!=h?l(t)._overflowWeekday=!0:(a=xt(i,o,n,s,r),t._a[jo]=a.year,t._dayOfYear=a.dayOfYear)}function Tt(t){if(t._f===e.ISO_8601)return void ft(t);t._a=[],l(t).empty=!0;var i,o,n,s,r,a=""+t._i,h=a.length,d=0;for(n=X(t._f,t._locale).match(go)||[],i=0;i0&&l(t).unusedInput.push(r),a=a.slice(a.indexOf(o)+o.length),d+=o.length),wo[s]?(o?l(t).empty=!1:l(t).unusedTokens.push(s),et(s,o,t)):t._strict&&!o&&l(t).unusedTokens.push(s);l(t).charsLeftOver=h-d,a.length>0&&l(t).unusedInput.push(a),l(t).bigHour===!0&&t._a[Yo]<=12&&t._a[Yo]>0&&(l(t).bigHour=void 0),l(t).parsedDateParts=t._a.slice(0),l(t).meridiem=t._meridiem,t._a[Yo]=Et(t._locale,t._a[Yo],t._meridiem),St(t),pt(t)}function Et(t,e,i){var o;return null==i?e:null!=t.meridiemHour?t.meridiemHour(e,i):null!=t.isPM?(o=t.isPM(i),o&&12>e&&(e+=12),o||12!==e||(e=0),e):e}function Pt(t){var e,i,o,n,s;if(0===t._f.length)return l(t).invalidFormat=!0,void(t._d=new Date(NaN));for(n=0;ns)&&(o=s,i=e));a(t,i||e)}function It(t){if(!t._d){var e=B(t._i);t._a=s([e.year,e.month,e.day||e.date,e.hour,e.minute,e.second,e.millisecond],function(t){return t&&parseInt(t,10)}),St(t)}}function Nt(t){var e=new m(pt(Rt(t)));return e._nextDay&&(e.add(1,"d"),e._nextDay=void 0),e}function Rt(t){var e=t._i,i=t._f;return t._locale=t._locale||R(t._l),null===e||void 0===i&&""===e?u({nullInput:!0}):("string"==typeof e&&(t._i=e=t._locale.preparse(e)),v(e)?new m(pt(e)):(o(i)?Pt(t):i?Tt(t):n(e)?t._d=e:zt(t),c(t)||(t._d=null),t))}function zt(t){var i=t._i;void 0===i?t._d=new Date(e.now()):n(i)?t._d=new Date(i.valueOf()):"string"==typeof i?mt(t):o(i)?(t._a=s(i.slice(0),function(t){return parseInt(t,10)}),St(t)):"object"==typeof i?It(t):"number"==typeof i?t._d=new Date(i):e.createFromInputFallback(t)}function Lt(t,e,i,o,n){var s={};return"boolean"==typeof i&&(o=i,i=void 0),s._isAMomentObject=!0,s._useUTC=s._isUTC=n,s._l=i,s._i=t,s._f=e,s._strict=o,Nt(s)}function At(t,e,i,o){return Lt(t,e,i,o,!1)}function Bt(t,e){var i,n;if(1===e.length&&o(e[0])&&(e=e[0]),!e.length)return At();for(i=e[0],n=1;nt&&(t=-t,i="-"),i+Y(~~(t/60),2)+e+Y(~~t%60,2)})}function Gt(t,e){var i=(e||"").match(t)||[],o=i[i.length-1]||[],n=(o+"").match(cn)||["-",0,0],s=+(60*n[1])+y(n[2]);return"+"===n[0]?s:-s}function Vt(t,i){var o,s;return i._isUTC?(o=i.clone(),s=(v(t)||n(t)?t.valueOf():At(t).valueOf())-o.valueOf(),o._d.setTime(o._d.valueOf()+s),e.updateOffset(o,!1),o):At(t).local()}function Ut(t){return 15*-Math.round(t._d.getTimezoneOffset()/15)}function qt(t,i){var o,n=this._offset||0;return this.isValid()?null!=t?("string"==typeof t?t=Gt(zo,t):Math.abs(t)<16&&(t=60*t),!this._isUTC&&i&&(o=Ut(this)),this._offset=t,this._isUTC=!0,null!=o&&this.add(o,"m"),n!==t&&(!i||this._changeInProgress?le(this,ne(t-n,"m"),1,!1):this._changeInProgress||(this._changeInProgress=!0,e.updateOffset(this,!0),this._changeInProgress=null)),this):this._isUTC?n:Ut(this):null!=t?this:NaN}function Xt(t,e){return null!=t?("string"!=typeof t&&(t=-t),this.utcOffset(t,e),this):-this.utcOffset()}function Zt(t){return this.utcOffset(0,t)}function Kt(t){return this._isUTC&&(this.utcOffset(0,t),this._isUTC=!1,t&&this.subtract(Ut(this),"m")),this}function Jt(){return this._tzm?this.utcOffset(this._tzm):"string"==typeof this._i&&this.utcOffset(Gt(Ro,this._i)),this}function Qt(t){return this.isValid()?(t=t?At(t).utcOffset():0,(this.utcOffset()-t)%60===0):!1}function $t(){return this.utcOffset()>this.clone().month(0).utcOffset()||this.utcOffset()>this.clone().month(5).utcOffset()}function te(){if(!p(this._isDSTShifted))return this._isDSTShifted;var t={};if(f(t,this),t=Rt(t),t._a){var e=t._isUTC?h(t._a):At(t._a);this._isDSTShifted=this.isValid()&&b(t._a,e.toArray())>0}else this._isDSTShifted=!1;return this._isDSTShifted}function ee(){return this.isValid()?!this._isUTC:!1}function ie(){return this.isValid()?this._isUTC:!1}function oe(){return this.isValid()?this._isUTC&&0===this._offset:!1}function ne(t,e){var i,o,n,s=t,a=null;return Wt(t)?s={ms:t._milliseconds,d:t._days,M:t._months}:"number"==typeof t?(s={},e?s[e]=t:s.milliseconds=t):(a=un.exec(t))?(i="-"===a[1]?-1:1,s={y:0,d:y(a[Wo])*i,h:y(a[Yo])*i,m:y(a[Go])*i,s:y(a[Vo])*i,ms:y(a[Uo])*i}):(a=pn.exec(t))?(i="-"===a[1]?-1:1,s={y:se(a[2],i),M:se(a[3],i),w:se(a[4],i),d:se(a[5],i),h:se(a[6],i),m:se(a[7],i),s:se(a[8],i)}):null==s?s={}:"object"==typeof s&&("from"in s||"to"in s)&&(n=ae(At(s.from),At(s.to)),s={},s.ms=n.milliseconds,s.M=n.months),o=new Ht(s),Wt(t)&&r(t,"_locale")&&(o._locale=t._locale),o}function se(t,e){var i=t&&parseFloat(t.replace(",","."));return(isNaN(i)?0:i)*e}function re(t,e){var i={milliseconds:0,months:0};return i.months=e.month()-t.month()+12*(e.year()-t.year()),t.clone().add(i.months,"M").isAfter(e)&&--i.months, +i.milliseconds=+e-+t.clone().add(i.months,"M"),i}function ae(t,e){var i;return t.isValid()&&e.isValid()?(e=Vt(e,t),t.isBefore(e)?i=re(t,e):(i=re(e,t),i.milliseconds=-i.milliseconds,i.months=-i.months),i):{milliseconds:0,months:0}}function he(t){return 0>t?-1*Math.round(-1*t):Math.round(t)}function de(t,e){return function(i,o){var n,s;return null===o||isNaN(+o)||(x(e,"moment()."+e+"(period, number) is deprecated. Please use moment()."+e+"(number, period)."),s=i,i=o,o=s),i="string"==typeof i?+i:i,n=ne(i,o),le(this,n,t),this}}function le(t,i,o,n){var s=i._milliseconds,r=he(i._days),a=he(i._months);t.isValid()&&(n=null==n?!0:n,s&&t._d.setTime(t._d.valueOf()+s*o),r&&H(t,"Date",j(t,"Date")+r*o),a&&at(t,j(t,"Month")+a*o),n&&e.updateOffset(t,r||a))}function ce(t,e){var i=t||At(),o=Vt(i,this).startOf("day"),n=this.diff(o,"days",!0),s=-6>n?"sameElse":-1>n?"lastWeek":0>n?"lastDay":1>n?"sameDay":2>n?"nextDay":7>n?"nextWeek":"sameElse",r=e&&(k(e[s])?e[s]():e[s]);return this.format(r||this.localeData().calendar(s,this,At(i)))}function ue(){return new m(this)}function pe(t,e){var i=v(t)?t:At(t);return this.isValid()&&i.isValid()?(e=A(p(e)?"millisecond":e),"millisecond"===e?this.valueOf()>i.valueOf():i.valueOf()e-s?(i=t.clone().add(n-1,"months"),o=(e-s)/(s-i)):(i=t.clone().add(n+1,"months"),o=(e-s)/(i-s)),-(n+o)||0}function _e(){return this.clone().locale("en").format("ddd MMM DD YYYY HH:mm:ss [GMT]ZZ")}function xe(){var t=this.clone().utc();return 0s&&(e=s),Xe.call(this,t,e,i,o,n))}function Xe(t,e,i,o,n){var s=xt(t,e,i,o,n),r=gt(s.year,0,s.dayOfYear);return this.year(r.getUTCFullYear()),this.month(r.getUTCMonth()),this.date(r.getUTCDate()),this}function Ze(t){return null==t?Math.ceil((this.month()+1)/3):this.month(3*(t-1)+this.month()%3)}function Ke(t){return kt(t,this._week.dow,this._week.doy).week}function Je(){return this._week.dow}function Qe(){return this._week.doy}function $e(t){var e=this.localeData().week(this);return null==t?e:this.add(7*(t-e),"d")}function ti(t){var e=kt(this,1,4).week;return null==t?e:this.add(7*(t-e),"d")}function ei(t,e){return"string"!=typeof t?t:isNaN(t)?(t=e.weekdaysParse(t),"number"==typeof t?t:null):parseInt(t,10)}function ii(t,e){return o(this._weekdays)?this._weekdays[t.day()]:this._weekdays[this._weekdays.isFormat.test(e)?"format":"standalone"][t.day()]}function oi(t){return this._weekdaysShort[t.day()]}function ni(t){return this._weekdaysMin[t.day()]}function si(t,e,i){var o,n,s,r=t.toLocaleLowerCase();if(!this._weekdaysParse)for(this._weekdaysParse=[],this._shortWeekdaysParse=[],this._minWeekdaysParse=[],o=0;7>o;++o)s=h([2e3,1]).day(o),this._minWeekdaysParse[o]=this.weekdaysMin(s,"").toLocaleLowerCase(),this._shortWeekdaysParse[o]=this.weekdaysShort(s,"").toLocaleLowerCase(),this._weekdaysParse[o]=this.weekdays(s,"").toLocaleLowerCase();return i?"dddd"===e?(n=fo.call(this._weekdaysParse,r),-1!==n?n:null):"ddd"===e?(n=fo.call(this._shortWeekdaysParse,r),-1!==n?n:null):(n=fo.call(this._minWeekdaysParse,r),-1!==n?n:null):"dddd"===e?(n=fo.call(this._weekdaysParse,r),-1!==n?n:(n=fo.call(this._shortWeekdaysParse,r),-1!==n?n:(n=fo.call(this._minWeekdaysParse,r),-1!==n?n:null))):"ddd"===e?(n=fo.call(this._shortWeekdaysParse,r),-1!==n?n:(n=fo.call(this._weekdaysParse,r),-1!==n?n:(n=fo.call(this._minWeekdaysParse,r),-1!==n?n:null))):(n=fo.call(this._minWeekdaysParse,r),-1!==n?n:(n=fo.call(this._weekdaysParse,r),-1!==n?n:(n=fo.call(this._shortWeekdaysParse,r),-1!==n?n:null)))}function ri(t,e,i){var o,n,s;if(this._weekdaysParseExact)return si.call(this,t,e,i);for(this._weekdaysParse||(this._weekdaysParse=[],this._minWeekdaysParse=[],this._shortWeekdaysParse=[],this._fullWeekdaysParse=[]),o=0;7>o;o++){if(n=h([2e3,1]).day(o),i&&!this._fullWeekdaysParse[o]&&(this._fullWeekdaysParse[o]=new RegExp("^"+this.weekdays(n,"").replace(".",".?")+"$","i"),this._shortWeekdaysParse[o]=new RegExp("^"+this.weekdaysShort(n,"").replace(".",".?")+"$","i"),this._minWeekdaysParse[o]=new RegExp("^"+this.weekdaysMin(n,"").replace(".",".?")+"$","i")),this._weekdaysParse[o]||(s="^"+this.weekdays(n,"")+"|^"+this.weekdaysShort(n,"")+"|^"+this.weekdaysMin(n,""),this._weekdaysParse[o]=new RegExp(s.replace(".",""),"i")),i&&"dddd"===e&&this._fullWeekdaysParse[o].test(t))return o;if(i&&"ddd"===e&&this._shortWeekdaysParse[o].test(t))return o;if(i&&"dd"===e&&this._minWeekdaysParse[o].test(t))return o;if(!i&&this._weekdaysParse[o].test(t))return o}}function ai(t){if(!this.isValid())return null!=t?this:NaN;var e=this._isUTC?this._d.getUTCDay():this._d.getDay();return null!=t?(t=ei(t,this.localeData()),this.add(t-e,"d")):e}function hi(t){if(!this.isValid())return null!=t?this:NaN;var e=(this.day()+7-this.localeData()._week.dow)%7;return null==t?e:this.add(t-e,"d")}function di(t){return this.isValid()?null==t?this.day()||7:this.day(this.day()%7?t:t-7):null!=t?this:NaN}function li(t){return this._weekdaysParseExact?(r(this,"_weekdaysRegex")||pi.call(this),t?this._weekdaysStrictRegex:this._weekdaysRegex):this._weekdaysStrictRegex&&t?this._weekdaysStrictRegex:this._weekdaysRegex}function ci(t){return this._weekdaysParseExact?(r(this,"_weekdaysRegex")||pi.call(this),t?this._weekdaysShortStrictRegex:this._weekdaysShortRegex):this._weekdaysShortStrictRegex&&t?this._weekdaysShortStrictRegex:this._weekdaysShortRegex}function ui(t){return this._weekdaysParseExact?(r(this,"_weekdaysRegex")||pi.call(this),t?this._weekdaysMinStrictRegex:this._weekdaysMinRegex):this._weekdaysMinStrictRegex&&t?this._weekdaysMinStrictRegex:this._weekdaysMinRegex}function pi(){function t(t,e){return e.length-t.length}var e,i,o,n,s,r=[],a=[],d=[],l=[];for(e=0;7>e;e++)i=h([2e3,1]).day(e),o=this.weekdaysMin(i,""),n=this.weekdaysShort(i,""),s=this.weekdays(i,""),r.push(o),a.push(n),d.push(s),l.push(o),l.push(n),l.push(s);for(r.sort(t),a.sort(t),d.sort(t),l.sort(t),e=0;7>e;e++)a[e]=Q(a[e]),d[e]=Q(d[e]),l[e]=Q(l[e]);this._weekdaysRegex=new RegExp("^("+l.join("|")+")","i"),this._weekdaysShortRegex=this._weekdaysRegex,this._weekdaysMinRegex=this._weekdaysRegex,this._weekdaysStrictRegex=new RegExp("^("+d.join("|")+")","i"),this._weekdaysShortStrictRegex=new RegExp("^("+a.join("|")+")","i"),this._weekdaysMinStrictRegex=new RegExp("^("+r.join("|")+")","i")}function fi(t){var e=Math.round((this.clone().startOf("day")-this.clone().startOf("year"))/864e5)+1;return null==t?e:this.add(t-e,"d")}function mi(){return this.hours()%12||12}function vi(){return this.hours()||24}function gi(t,e){G(t,0,0,function(){return this.localeData().meridiem(this.hours(),this.minutes(),e)})}function yi(t,e){return e._meridiemParse}function bi(t){return"p"===(t+"").toLowerCase().charAt(0)}function wi(t,e,i){return t>11?i?"pm":"PM":i?"am":"AM"}function _i(t,e){e[Uo]=y(1e3*("0."+t))}function xi(){return this._isUTC?"UTC":""}function ki(){return this._isUTC?"Coordinated Universal Time":""}function Oi(t){return At(1e3*t)}function Mi(){return At.apply(null,arguments).parseZone()}function Di(t,e,i){var o=this._calendar[t];return k(o)?o.call(e,i):o}function Si(t){var e=this._longDateFormat[t],i=this._longDateFormat[t.toUpperCase()];return e||!i?e:(this._longDateFormat[t]=i.replace(/MMMM|MM|DD|dddd/g,function(t){return t.slice(1)}),this._longDateFormat[t])}function Ci(){return this._invalidDate}function Ti(t){return this._ordinal.replace("%d",t)}function Ei(t){return t}function Pi(t,e,i,o){var n=this._relativeTime[i];return k(n)?n(t,e,i,o):n.replace(/%d/i,t)}function Ii(t,e){var i=this._relativeTime[t>0?"future":"past"];return k(i)?i(e):i.replace(/%s/i,e)}function Ni(t,e,i,o){var n=R(),s=h().set(o,e);return n[i](s,t)}function Ri(t,e,i){if("number"==typeof t&&(e=t,t=void 0),t=t||"",null!=e)return Ni(t,e,i,"month");var o,n=[];for(o=0;12>o;o++)n[o]=Ni(t,o,i,"month");return n}function zi(t,e,i,o){"boolean"==typeof t?("number"==typeof e&&(i=e,e=void 0),e=e||""):(e=t,i=e,t=!1,"number"==typeof e&&(i=e,e=void 0),e=e||"");var n=R(),s=t?n._week.dow:0;if(null!=i)return Ni(e,(i+s)%7,o,"day");var r,a=[];for(r=0;7>r;r++)a[r]=Ni(e,(r+s)%7,o,"day");return a}function Li(t,e){return Ri(t,e,"months")}function Ai(t,e){return Ri(t,e,"monthsShort")}function Bi(t,e,i){return zi(t,e,i,"weekdays")}function Fi(t,e,i){return zi(t,e,i,"weekdaysShort")}function ji(t,e,i){return zi(t,e,i,"weekdaysMin")}function Hi(){var t=this._data;return this._milliseconds=jn(this._milliseconds),this._days=jn(this._days),this._months=jn(this._months),t.milliseconds=jn(t.milliseconds),t.seconds=jn(t.seconds),t.minutes=jn(t.minutes),t.hours=jn(t.hours),t.months=jn(t.months),t.years=jn(t.years),this}function Wi(t,e,i,o){var n=ne(e,i);return t._milliseconds+=o*n._milliseconds,t._days+=o*n._days,t._months+=o*n._months,t._bubble()}function Yi(t,e){return Wi(this,t,e,1)}function Gi(t,e){return Wi(this,t,e,-1)}function Vi(t){return 0>t?Math.floor(t):Math.ceil(t)}function Ui(){var t,e,i,o,n,s=this._milliseconds,r=this._days,a=this._months,h=this._data;return s>=0&&r>=0&&a>=0||0>=s&&0>=r&&0>=a||(s+=864e5*Vi(Xi(a)+r),r=0,a=0),h.milliseconds=s%1e3,t=g(s/1e3),h.seconds=t%60,e=g(t/60),h.minutes=e%60,i=g(e/60),h.hours=i%24,r+=g(i/24),n=g(qi(r)),a+=n,r-=Vi(Xi(n)),o=g(a/12),a%=12,h.days=r,h.months=a,h.years=o,this}function qi(t){return 4800*t/146097}function Xi(t){return 146097*t/4800}function Zi(t){var e,i,o=this._milliseconds;if(t=A(t),"month"===t||"year"===t)return e=this._days+o/864e5,i=this._months+qi(e),"month"===t?i:i/12;switch(e=this._days+Math.round(Xi(this._months)),t){case"week":return e/7+o/6048e5;case"day":return e+o/864e5;case"hour":return 24*e+o/36e5;case"minute":return 1440*e+o/6e4;case"second":return 86400*e+o/1e3;case"millisecond":return Math.floor(864e5*e)+o;default:throw new Error("Unknown unit "+t)}}function Ki(){return this._milliseconds+864e5*this._days+this._months%12*2592e6+31536e6*y(this._months/12)}function Ji(t){return function(){return this.as(t)}}function Qi(t){return t=A(t),this[t+"s"]()}function $i(t){return function(){return this._data[t]}}function to(){return g(this.days()/7)}function eo(t,e,i,o,n){return n.relativeTime(e||1,!!i,t,o)}function io(t,e,i){var o=ne(t).abs(),n=is(o.as("s")),s=is(o.as("m")),r=is(o.as("h")),a=is(o.as("d")),h=is(o.as("M")),d=is(o.as("y")),l=n=s&&["m"]||s=r&&["h"]||r=a&&["d"]||a=h&&["M"]||h=d&&["y"]||["yy",d];return l[2]=e,l[3]=+t>0,l[4]=i,eo.apply(null,l)}function oo(t,e){return void 0===os[t]?!1:void 0===e?os[t]:(os[t]=e,!0)}function no(t){var e=this.localeData(),i=io(this,!t,e);return t&&(i=e.pastFuture(+this,i)),e.postformat(i)}function so(){var t,e,i,o=ns(this._milliseconds)/1e3,n=ns(this._days),s=ns(this._months);t=g(o/60),e=g(t/60),o%=60,t%=60,i=g(s/12),s%=12;var r=i,a=s,h=n,d=e,l=t,c=o,u=this.asSeconds();return u?(0>u?"-":"")+"P"+(r?r+"Y":"")+(a?a+"M":"")+(h?h+"D":"")+(d||l||c?"T":"")+(d?d+"H":"")+(l?l+"M":"")+(c?c+"S":""):"P0D"}var ro,ao;ao=Array.prototype.some?Array.prototype.some:function(t){for(var e=Object(this),i=e.length>>>0,o=0;i>o;o++)if(o in e&&t.call(this,e[o],o,e))return!0;return!1};var ho=e.momentProperties=[],lo=!1,co={};e.suppressDeprecationWarnings=!1,e.deprecationHandler=null;var uo;uo=Object.keys?Object.keys:function(t){var e,i=[];for(e in t)r(t,e)&&i.push(e);return i};var po,fo,mo={},vo={},go=/(\[[^\[]*\])|(\\)?([Hh]mm(ss)?|Mo|MM?M?M?|Do|DDDo|DD?D?D?|ddd?d?|do?|w[o|w]?|W[o|W]?|Qo?|YYYYYY|YYYYY|YYYY|YY|gg(ggg?)?|GG(GGG?)?|e|E|a|A|hh?|HH?|kk?|mm?|ss?|S{1,9}|x|X|zz?|ZZ?|.)/g,yo=/(\[[^\[]*\])|(\\)?(LTS|LT|LL?L?L?|l{1,4})/g,bo={},wo={},_o=/\d/,xo=/\d\d/,ko=/\d{3}/,Oo=/\d{4}/,Mo=/[+-]?\d{6}/,Do=/\d\d?/,So=/\d\d\d\d?/,Co=/\d\d\d\d\d\d?/,To=/\d{1,3}/,Eo=/\d{1,4}/,Po=/[+-]?\d{1,6}/,Io=/\d+/,No=/[+-]?\d+/,Ro=/Z|[+-]\d\d:?\d\d/gi,zo=/Z|[+-]\d\d(?::?\d\d)?/gi,Lo=/[+-]?\d+(\.\d{1,3})?/,Ao=/[0-9]*['a-z\u00A0-\u05FF\u0700-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF]+|[\u0600-\u06FF\/]+(\s*?[\u0600-\u06FF]+){1,2}/i,Bo={},Fo={},jo=0,Ho=1,Wo=2,Yo=3,Go=4,Vo=5,Uo=6,qo=7,Xo=8;fo=Array.prototype.indexOf?Array.prototype.indexOf:function(t){var e;for(e=0;e=t?""+t:"+"+t}),G(0,["YY",2],0,function(){return this.year()%100}),G(0,["YYYY",4],0,"year"),G(0,["YYYYY",5],0,"year"),G(0,["YYYYYY",6,!0],0,"year"),L("year","y"),Z("Y",No),Z("YY",Do,xo),Z("YYYY",Eo,Oo),Z("YYYYY",Po,Mo),Z("YYYYYY",Po,Mo),$(["YYYYY","YYYYYY"],jo),$("YYYY",function(t,i){i[jo]=2===t.length?e.parseTwoDigitYear(t):y(t)}),$("YY",function(t,i){i[jo]=e.parseTwoDigitYear(t)}),$("Y",function(t,e){e[jo]=parseInt(t,10)}),e.parseTwoDigitYear=function(t){return y(t)+(y(t)>68?1900:2e3)};var an=F("FullYear",!0);e.ISO_8601=function(){};var hn=_("moment().min is deprecated, use moment.max instead. https://github.com/moment/moment/issues/1548",function(){var t=At.apply(null,arguments);return this.isValid()&&t.isValid()?this>t?this:t:u()}),dn=_("moment().max is deprecated, use moment.min instead. https://github.com/moment/moment/issues/1548",function(){var t=At.apply(null,arguments);return this.isValid()&&t.isValid()?t>this?this:t:u()}),ln=function(){return Date.now?Date.now():+new Date};Yt("Z",":"),Yt("ZZ",""),Z("Z",zo),Z("ZZ",zo),$(["Z","ZZ"],function(t,e,i){i._useUTC=!0,i._tzm=Gt(zo,t)});var cn=/([\+\-]|\d\d)/gi;e.updateOffset=function(){};var un=/^(\-)?(?:(\d*)[. ])?(\d+)\:(\d+)(?:\:(\d+)\.?(\d{3})?\d*)?$/,pn=/^(-)?P(?:(-?[0-9,.]*)Y)?(?:(-?[0-9,.]*)M)?(?:(-?[0-9,.]*)W)?(?:(-?[0-9,.]*)D)?(?:T(?:(-?[0-9,.]*)H)?(?:(-?[0-9,.]*)M)?(?:(-?[0-9,.]*)S)?)?$/;ne.fn=Ht.prototype;var fn=de(1,"add"),mn=de(-1,"subtract");e.defaultFormat="YYYY-MM-DDTHH:mm:ssZ",e.defaultFormatUtc="YYYY-MM-DDTHH:mm:ss[Z]";var vn=_("moment().lang() is deprecated. Instead, use moment().localeData() to get the language configuration. Use moment().locale() to change languages.",function(t){return void 0===t?this.localeData():this.locale(t)});G(0,["gg",2],0,function(){return this.weekYear()%100}),G(0,["GG",2],0,function(){return this.isoWeekYear()%100}),We("gggg","weekYear"),We("ggggg","weekYear"),We("GGGG","isoWeekYear"),We("GGGGG","isoWeekYear"),L("weekYear","gg"),L("isoWeekYear","GG"),Z("G",No),Z("g",No),Z("GG",Do,xo),Z("gg",Do,xo),Z("GGGG",Eo,Oo),Z("gggg",Eo,Oo),Z("GGGGG",Po,Mo),Z("ggggg",Po,Mo),tt(["gggg","ggggg","GGGG","GGGGG"],function(t,e,i,o){e[o.substr(0,2)]=y(t)}),tt(["gg","GG"],function(t,i,o,n){i[n]=e.parseTwoDigitYear(t)}),G("Q",0,"Qo","quarter"),L("quarter","Q"),Z("Q",_o),$("Q",function(t,e){e[Ho]=3*(y(t)-1)}),G("w",["ww",2],"wo","week"),G("W",["WW",2],"Wo","isoWeek"),L("week","w"),L("isoWeek","W"),Z("w",Do),Z("ww",Do,xo),Z("W",Do),Z("WW",Do,xo),tt(["w","ww","W","WW"],function(t,e,i,o){e[o.substr(0,1)]=y(t)});var gn={dow:0,doy:6};G("D",["DD",2],"Do","date"),L("date","D"),Z("D",Do),Z("DD",Do,xo),Z("Do",function(t,e){return t?e._ordinalParse:e._ordinalParseLenient}),$(["D","DD"],Wo),$("Do",function(t,e){e[Wo]=y(t.match(Do)[0],10)});var yn=F("Date",!0);G("d",0,"do","day"),G("dd",0,0,function(t){return this.localeData().weekdaysMin(this,t)}),G("ddd",0,0,function(t){return this.localeData().weekdaysShort(this,t)}),G("dddd",0,0,function(t){return this.localeData().weekdays(this,t)}),G("e",0,0,"weekday"),G("E",0,0,"isoWeekday"),L("day","d"),L("weekday","e"),L("isoWeekday","E"),Z("d",Do),Z("e",Do),Z("E",Do),Z("dd",function(t,e){return e.weekdaysMinRegex(t)}),Z("ddd",function(t,e){return e.weekdaysShortRegex(t)}),Z("dddd",function(t,e){return e.weekdaysRegex(t)}),tt(["dd","ddd","dddd"],function(t,e,i,o){var n=i._locale.weekdaysParse(t,o,i._strict);null!=n?e.d=n:l(i).invalidWeekday=t}),tt(["d","e","E"],function(t,e,i,o){e[o]=y(t)});var bn="Sunday_Monday_Tuesday_Wednesday_Thursday_Friday_Saturday".split("_"),wn="Sun_Mon_Tue_Wed_Thu_Fri_Sat".split("_"),_n="Su_Mo_Tu_We_Th_Fr_Sa".split("_"),xn=Ao,kn=Ao,On=Ao;G("DDD",["DDDD",3],"DDDo","dayOfYear"),L("dayOfYear","DDD"),Z("DDD",To),Z("DDDD",ko),$(["DDD","DDDD"],function(t,e,i){i._dayOfYear=y(t)}),G("H",["HH",2],0,"hour"),G("h",["hh",2],0,mi),G("k",["kk",2],0,vi),G("hmm",0,0,function(){return""+mi.apply(this)+Y(this.minutes(),2)}),G("hmmss",0,0,function(){return""+mi.apply(this)+Y(this.minutes(),2)+Y(this.seconds(),2)}),G("Hmm",0,0,function(){return""+this.hours()+Y(this.minutes(),2)}),G("Hmmss",0,0,function(){return""+this.hours()+Y(this.minutes(),2)+Y(this.seconds(),2)}),gi("a",!0),gi("A",!1),L("hour","h"),Z("a",yi),Z("A",yi),Z("H",Do),Z("h",Do),Z("HH",Do,xo),Z("hh",Do,xo),Z("hmm",So),Z("hmmss",Co),Z("Hmm",So),Z("Hmmss",Co),$(["H","HH"],Yo),$(["a","A"],function(t,e,i){i._isPm=i._locale.isPM(t),i._meridiem=t}),$(["h","hh"],function(t,e,i){e[Yo]=y(t),l(i).bigHour=!0}),$("hmm",function(t,e,i){var o=t.length-2;e[Yo]=y(t.substr(0,o)),e[Go]=y(t.substr(o)),l(i).bigHour=!0}),$("hmmss",function(t,e,i){var o=t.length-4,n=t.length-2;e[Yo]=y(t.substr(0,o)),e[Go]=y(t.substr(o,2)),e[Vo]=y(t.substr(n)),l(i).bigHour=!0}),$("Hmm",function(t,e,i){var o=t.length-2;e[Yo]=y(t.substr(0,o)),e[Go]=y(t.substr(o))}),$("Hmmss",function(t,e,i){var o=t.length-4,n=t.length-2;e[Yo]=y(t.substr(0,o)),e[Go]=y(t.substr(o,2)),e[Vo]=y(t.substr(n))});var Mn=/[ap]\.?m?\.?/i,Dn=F("Hours",!0);G("m",["mm",2],0,"minute"),L("minute","m"),Z("m",Do),Z("mm",Do,xo),$(["m","mm"],Go);var Sn=F("Minutes",!1);G("s",["ss",2],0,"second"),L("second","s"),Z("s",Do),Z("ss",Do,xo),$(["s","ss"],Vo);var Cn=F("Seconds",!1);G("S",0,0,function(){return~~(this.millisecond()/100)}),G(0,["SS",2],0,function(){return~~(this.millisecond()/10)}),G(0,["SSS",3],0,"millisecond"),G(0,["SSSS",4],0,function(){return 10*this.millisecond()}),G(0,["SSSSS",5],0,function(){return 100*this.millisecond()}),G(0,["SSSSSS",6],0,function(){return 1e3*this.millisecond()}),G(0,["SSSSSSS",7],0,function(){return 1e4*this.millisecond()}),G(0,["SSSSSSSS",8],0,function(){return 1e5*this.millisecond()}),G(0,["SSSSSSSSS",9],0,function(){return 1e6*this.millisecond()}),L("millisecond","ms"),Z("S",To,_o),Z("SS",To,xo),Z("SSS",To,ko);var Tn;for(Tn="SSSS";Tn.length<=9;Tn+="S")Z(Tn,Io);for(Tn="S";Tn.length<=9;Tn+="S")$(Tn,_i);var En=F("Milliseconds",!1);G("z",0,0,"zoneAbbr"),G("zz",0,0,"zoneName");var Pn=m.prototype;Pn.add=fn,Pn.calendar=ce,Pn.clone=ue,Pn.diff=be,Pn.endOf=Pe,Pn.format=ke,Pn.from=Oe,Pn.fromNow=Me,Pn.to=De,Pn.toNow=Se,Pn.get=W,Pn.invalidAt=je,Pn.isAfter=pe,Pn.isBefore=fe,Pn.isBetween=me,Pn.isSame=ve,Pn.isSameOrAfter=ge,Pn.isSameOrBefore=ye,Pn.isValid=Be,Pn.lang=vn,Pn.locale=Ce,Pn.localeData=Te,Pn.max=dn,Pn.min=hn,Pn.parsingFlags=Fe,Pn.set=W,Pn.startOf=Ee,Pn.subtract=mn,Pn.toArray=ze,Pn.toObject=Le,Pn.toDate=Re,Pn.toISOString=xe,Pn.toJSON=Ae,Pn.toString=_e,Pn.unix=Ne,Pn.valueOf=Ie,Pn.creationData=He,Pn.year=an,Pn.isLeapYear=wt,Pn.weekYear=Ye,Pn.isoWeekYear=Ge,Pn.quarter=Pn.quarters=Ze,Pn.month=ht,Pn.daysInMonth=dt,Pn.week=Pn.weeks=$e,Pn.isoWeek=Pn.isoWeeks=ti,Pn.weeksInYear=Ue,Pn.isoWeeksInYear=Ve,Pn.date=yn,Pn.day=Pn.days=ai,Pn.weekday=hi,Pn.isoWeekday=di,Pn.dayOfYear=fi,Pn.hour=Pn.hours=Dn,Pn.minute=Pn.minutes=Sn,Pn.second=Pn.seconds=Cn,Pn.millisecond=Pn.milliseconds=En,Pn.utcOffset=qt,Pn.utc=Zt,Pn.local=Kt,Pn.parseZone=Jt,Pn.hasAlignedHourOffset=Qt,Pn.isDST=$t,Pn.isDSTShifted=te,Pn.isLocal=ee,Pn.isUtcOffset=ie,Pn.isUtc=oe,Pn.isUTC=oe,Pn.zoneAbbr=xi,Pn.zoneName=ki,Pn.dates=_("dates accessor is deprecated. Use date instead.",yn),Pn.months=_("months accessor is deprecated. Use month instead",ht),Pn.years=_("years accessor is deprecated. Use year instead",an),Pn.zone=_("moment().zone is deprecated, use moment().utcOffset instead. https://github.com/moment/moment/issues/1779",Xt);var In=Pn,Nn={sameDay:"[Today at] LT",nextDay:"[Tomorrow at] LT",nextWeek:"dddd [at] LT",lastDay:"[Yesterday at] LT",lastWeek:"[Last] dddd [at] LT",sameElse:"L"},Rn={LTS:"h:mm:ss A",LT:"h:mm A",L:"MM/DD/YYYY",LL:"MMMM D, YYYY",LLL:"MMMM D, YYYY h:mm A",LLLL:"dddd, MMMM D, YYYY h:mm A"},zn="Invalid date",Ln="%d",An=/\d{1,2}/,Bn={future:"in %s",past:"%s ago",s:"a few seconds",m:"a minute",mm:"%d minutes",h:"an hour",hh:"%d hours",d:"a day",dd:"%d days",M:"a month",MM:"%d months",y:"a year",yy:"%d years"},Fn=S.prototype;Fn._calendar=Nn,Fn.calendar=Di,Fn._longDateFormat=Rn,Fn.longDateFormat=Si,Fn._invalidDate=zn,Fn.invalidDate=Ci,Fn._ordinal=Ln,Fn.ordinal=Ti,Fn._ordinalParse=An,Fn.preparse=Ei,Fn.postformat=Ei,Fn._relativeTime=Bn,Fn.relativeTime=Pi,Fn.pastFuture=Ii,Fn.set=M,Fn.months=ot,Fn._months=Ko,Fn.monthsShort=nt,Fn._monthsShort=Jo,Fn.monthsParse=rt,Fn._monthsRegex=$o,Fn.monthsRegex=ct,Fn._monthsShortRegex=Qo,Fn.monthsShortRegex=lt,Fn.week=Ke,Fn._week=gn,Fn.firstDayOfYear=Qe,Fn.firstDayOfWeek=Je,Fn.weekdays=ii,Fn._weekdays=bn,Fn.weekdaysMin=ni,Fn._weekdaysMin=_n,Fn.weekdaysShort=oi,Fn._weekdaysShort=wn,Fn.weekdaysParse=ri,Fn._weekdaysRegex=xn,Fn.weekdaysRegex=li,Fn._weekdaysShortRegex=kn,Fn.weekdaysShortRegex=ci,Fn._weekdaysMinRegex=On,Fn.weekdaysMinRegex=ui,Fn.isPM=bi,Fn._meridiemParse=Mn,Fn.meridiem=wi,P("en",{ordinalParse:/\d{1,2}(th|st|nd|rd)/,ordinal:function(t){var e=t%10,i=1===y(t%100/10)?"th":1===e?"st":2===e?"nd":3===e?"rd":"th";return t+i}}),e.lang=_("moment.lang is deprecated. Use moment.locale instead.",P),e.langData=_("moment.langData is deprecated. Use moment.localeData instead.",R);var jn=Math.abs,Hn=Ji("ms"),Wn=Ji("s"),Yn=Ji("m"),Gn=Ji("h"),Vn=Ji("d"),Un=Ji("w"),qn=Ji("M"),Xn=Ji("y"),Zn=$i("milliseconds"),Kn=$i("seconds"),Jn=$i("minutes"),Qn=$i("hours"),$n=$i("days"),ts=$i("months"),es=$i("years"),is=Math.round,os={s:45,m:45,h:22,d:26,M:11},ns=Math.abs,ss=Ht.prototype;ss.abs=Hi,ss.add=Yi,ss.subtract=Gi,ss.as=Zi,ss.asMilliseconds=Hn,ss.asSeconds=Wn,ss.asMinutes=Yn,ss.asHours=Gn,ss.asDays=Vn,ss.asWeeks=Un,ss.asMonths=qn,ss.asYears=Xn,ss.valueOf=Ki,ss._bubble=Ui,ss.get=Qi,ss.milliseconds=Zn,ss.seconds=Kn,ss.minutes=Jn,ss.hours=Qn,ss.days=$n,ss.weeks=to,ss.months=ts,ss.years=es,ss.humanize=no,ss.toISOString=so,ss.toString=so,ss.toJSON=so,ss.locale=Ce,ss.localeData=Te,ss.toIsoString=_("toIsoString() is deprecated. Please use toISOString() instead (notice the capitals)",so),ss.lang=vn,G("X",0,0,"unix"),G("x",0,0,"valueOf"),Z("x",No),Z("X",Lo),$("X",function(t,e,i){i._d=new Date(1e3*parseFloat(t,10))}),$("x",function(t,e,i){i._d=new Date(y(t))}),e.version="2.13.0",i(At),e.fn=In,e.min=Ft,e.max=jt,e.now=ln,e.utc=h,e.unix=Oi,e.months=Li,e.isDate=n,e.locale=P,e.invalid=u,e.duration=ne,e.isMoment=v,e.weekdays=Bi,e.parseZone=Mi,e.localeData=R,e.isDuration=Wt,e.monthsShort=Ai,e.weekdaysMin=ji,e.defineLocale=I,e.updateLocale=N,e.locales=z,e.weekdaysShort=Fi,e.normalizeUnits=A,e.relativeTimeThreshold=oo,e.prototype=In;var rs=e;return rs})}).call(e,i(4)(t))},function(t,e){t.exports=function(t){return t.webpackPolyfill||(t.deprecate=function(){},t.paths=[],t.children=[],t.webpackPolyfill=1),t}},function(t,e){function i(t){throw new Error("Cannot find module '"+t+"'.")}i.keys=function(){return[]},i.resolve=i,t.exports=i,i.id=5},function(t,e){(function(e){function i(t,e,i){var o=e&&i||0,n=0;for(e=e||[],t.toLowerCase().replace(/[0-9a-f]{2}/g,function(t){16>n&&(e[o+n++]=c[t])});16>n;)e[o+n++]=0;return e}function o(t,e){var i=e||0,o=l;return o[t[i++]]+o[t[i++]]+o[t[i++]]+o[t[i++]]+"-"+o[t[i++]]+o[t[i++]]+"-"+o[t[i++]]+o[t[i++]]+"-"+o[t[i++]]+o[t[i++]]+"-"+o[t[i++]]+o[t[i++]]+o[t[i++]]+o[t[i++]]+o[t[i++]]+o[t[i++]]}function n(t,e,i){var n=e&&i||0,s=e||[];t=t||{};var r=void 0!==t.clockseq?t.clockseq:m,a=void 0!==t.msecs?t.msecs:(new Date).getTime(),h=void 0!==t.nsecs?t.nsecs:g+1,d=a-v+(h-g)/1e4;if(0>d&&void 0===t.clockseq&&(r=r+1&16383),(0>d||a>v)&&void 0===t.nsecs&&(h=0),h>=1e4)throw new Error("uuid.v1(): Can't create more than 10M uuids/sec");v=a,g=h,m=r,a+=122192928e5;var l=(1e4*(268435455&a)+h)%4294967296;s[n++]=l>>>24&255,s[n++]=l>>>16&255,s[n++]=l>>>8&255,s[n++]=255&l;var c=a/4294967296*1e4&268435455;s[n++]=c>>>8&255,s[n++]=255&c,s[n++]=c>>>24&15|16,s[n++]=c>>>16&255,s[n++]=r>>>8|128,s[n++]=255&r;for(var u=t.node||f,p=0;6>p;p++)s[n+p]=u[p];return e?e:o(s)}function s(t,e,i){var n=e&&i||0;"string"==typeof t&&(e="binary"==t?new Array(16):null,t=null),t=t||{};var s=t.random||(t.rng||r)();if(s[6]=15&s[6]|64,s[8]=63&s[8]|128,e)for(var a=0;16>a;a++)e[n+a]=s[a];return e||o(s)}var r,a="undefined"!=typeof window?window:"undefined"!=typeof e?e:null;if(a&&a.crypto&&crypto.getRandomValues){var h=new Uint8Array(16);r=function(){return crypto.getRandomValues(h),h}}if(!r){var d=new Array(16);r=function(){for(var t,e=0;16>e;e++)0===(3&e)&&(t=4294967296*Math.random()),d[e]=t>>>((3&e)<<3)&255;return d}}for(var l=[],c={},u=0;256>u;u++)l[u]=(u+256).toString(16).substr(1),c[l[u]]=u;var p=r(),f=[1|p[0],p[1],p[2],p[3],p[4],p[5]],m=16383&(p[6]<<8|p[7]),v=0,g=0,y=s;y.v1=n,y.v4=s,y.parse=i,y.unparse=o,t.exports=y}).call(e,function(){return this}())},function(t,e,i){e.util=i(1),e.DOMutil=i(8),e.DataSet=i(9),e.DataView=i(11),e.Queue=i(10),e.Graph3d=i(12),e.graph3d={Camera:i(16),Filter:i(17),Point2d:i(15),Point3d:i(14),Slider:i(18),StepNumber:i(19)},e.moment=i(2),e.Hammer=i(20),e.keycharm=i(23)},function(t,e){e.prepareElements=function(t){for(var e in t)t.hasOwnProperty(e)&&(t[e].redundant=t[e].used,t[e].used=[])},e.cleanupElements=function(t){for(var e in t)if(t.hasOwnProperty(e)&&t[e].redundant){for(var i=0;i0?(o=e[t].redundant[0],e[t].redundant.shift()):(o=document.createElementNS("http://www.w3.org/2000/svg",t),i.appendChild(o)):(o=document.createElementNS("http://www.w3.org/2000/svg",t),e[t]={used:[],redundant:[]},i.appendChild(o)),e[t].used.push(o),o},e.getDOMElement=function(t,e,i,o){var n;return e.hasOwnProperty(t)?e[t].redundant.length>0?(n=e[t].redundant[0],e[t].redundant.shift()):(n=document.createElement(t),void 0!==o?i.insertBefore(n,o):i.appendChild(n)):(n=document.createElement(t),e[t]={used:[],redundant:[]},void 0!==o?i.insertBefore(n,o):i.appendChild(n)),e[t].used.push(n),n},e.drawPoint=function(t,i,o,n,s,r){var a;if("circle"==o.style?(a=e.getSVGElement("circle",n,s),a.setAttributeNS(null,"cx",t),a.setAttributeNS(null,"cy",i),a.setAttributeNS(null,"r",.5*o.size)):(a=e.getSVGElement("rect",n,s),a.setAttributeNS(null,"x",t-.5*o.size),a.setAttributeNS(null,"y",i-.5*o.size),a.setAttributeNS(null,"width",o.size),a.setAttributeNS(null,"height",o.size)),void 0!==o.styles&&a.setAttributeNS(null,"style",o.styles),a.setAttributeNS(null,"class",o.className+" vis-point"),r){var h=e.getSVGElement("text",n,s); +r.xOffset&&(t+=r.xOffset),r.yOffset&&(i+=r.yOffset),r.content&&(h.textContent=r.content),r.className&&h.setAttributeNS(null,"class",r.className+" vis-label"),h.setAttributeNS(null,"x",t),h.setAttributeNS(null,"y",i)}return a},e.drawBar=function(t,i,o,n,s,r,a,h){if(0!=n){0>n&&(n*=-1,i-=n);var d=e.getSVGElement("rect",r,a);d.setAttributeNS(null,"x",t-.5*o),d.setAttributeNS(null,"y",i),d.setAttributeNS(null,"width",o),d.setAttributeNS(null,"height",n),d.setAttributeNS(null,"class",s),h&&d.setAttributeNS(null,"style",h)}}},function(t,e,i){function o(t,e){if(t&&!Array.isArray(t)&&(e=t,t=null),this._options=e||{},this._data={},this.length=0,this._fieldId=this._options.fieldId||"id",this._type={},this._options.type)for(var i=Object.keys(this._options.type),o=0,n=i.length;n>o;o++){var s=i[o],r=this._options.type[s];"Date"==r||"ISODate"==r||"ASPDate"==r?this._type[s]="Date":this._type[s]=r}if(this._options.convert)throw new Error('Option "convert" is deprecated. Use "type" instead.');this._subscribers={},t&&this.add(t),this.setOptions(e)}var n="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},s=i(1),r=i(10);o.prototype.setOptions=function(t){t&&void 0!==t.queue&&(t.queue===!1?this._queue&&(this._queue.destroy(),delete this._queue):(this._queue||(this._queue=r.extend(this,{replace:["add","update","remove"]})),"object"===n(t.queue)&&this._queue.setOptions(t.queue)))},o.prototype.on=function(t,e){var i=this._subscribers[t];i||(i=[],this._subscribers[t]=i),i.push({callback:e})},o.prototype.subscribe=function(){throw new Error("DataSet.subscribe is deprecated. Use DataSet.on instead.")},o.prototype.off=function(t,e){var i=this._subscribers[t];i&&(this._subscribers[t]=i.filter(function(t){return t.callback!=e}))},o.prototype.unsubscribe=function(){throw new Error("DataSet.unsubscribe is deprecated. Use DataSet.off instead.")},o.prototype._trigger=function(t,e,i){if("*"==t)throw new Error("Cannot trigger event *");var o=[];t in this._subscribers&&(o=o.concat(this._subscribers[t])),"*"in this._subscribers&&(o=o.concat(this._subscribers["*"]));for(var n=0,s=o.length;s>n;n++){var r=o[n];r.callback&&r.callback(t,e,i||null)}},o.prototype.add=function(t,e){var i,o=[],n=this;if(Array.isArray(t))for(var s=0,r=t.length;r>s;s++)i=n._addItem(t[s]),o.push(i);else{if(!(t instanceof Object))throw new Error("Unknown dataType");i=n._addItem(t),o.push(i)}return o.length&&this._trigger("add",{items:o},e),o},o.prototype.update=function(t,e){var i=[],o=[],n=[],r=[],a=this,h=a._fieldId,d=function(t){var e=t[h];if(a._data[e]){var d=s.extend({},a._data[e]);e=a._updateItem(t),o.push(e),r.push(t),n.push(d)}else e=a._addItem(t),i.push(e)};if(Array.isArray(t))for(var l=0,c=t.length;c>l;l++)t[l]instanceof Object?d(t[l]):console.warn("Ignoring input item, which is not an object at index "+l);else{if(!(t instanceof Object))throw new Error("Unknown dataType");d(t)}if(i.length&&this._trigger("add",{items:i},e),o.length){var u={items:o,oldData:n,data:r};this._trigger("update",u,e)}return i.concat(o)},o.prototype.get=function(t){var e,i,o,n=this,r=s.getType(arguments[0]);"String"==r||"Number"==r?(e=arguments[0],o=arguments[1]):"Array"==r?(i=arguments[0],o=arguments[1]):o=arguments[0];var a;if(o&&o.returnType){var h=["Array","Object"];a=-1==h.indexOf(o.returnType)?"Array":o.returnType}else a="Array";var d,l,c,u,p,f=o&&o.type||this._options.type,m=o&&o.filter,v=[];if(void 0!=e)d=n._getItem(e,f),d&&m&&!m(d)&&(d=null);else if(void 0!=i)for(u=0,p=i.length;p>u;u++)d=n._getItem(i[u],f),m&&!m(d)||v.push(d);else for(l=Object.keys(this._data),u=0,p=l.length;p>u;u++)c=l[u],d=n._getItem(c,f),m&&!m(d)||v.push(d);if(o&&o.order&&void 0==e&&this._sort(v,o.order),o&&o.fields){var g=o.fields;if(void 0!=e)d=this._filterFields(d,g);else for(u=0,p=v.length;p>u;u++)v[u]=this._filterFields(v[u],g)}if("Object"==a){var y,b={};for(u=0,p=v.length;p>u;u++)y=v[u],b[y.id]=y;return b}return void 0!=e?d:v},o.prototype.getIds=function(t){var e,i,o,n,s,r=this._data,a=t&&t.filter,h=t&&t.order,d=t&&t.type||this._options.type,l=Object.keys(r),c=[];if(a)if(h){for(s=[],e=0,i=l.length;i>e;e++)o=l[e],n=this._getItem(o,d),a(n)&&s.push(n);for(this._sort(s,h),e=0,i=s.length;i>e;e++)c.push(s[e][this._fieldId])}else for(e=0,i=l.length;i>e;e++)o=l[e],n=this._getItem(o,d),a(n)&&c.push(n[this._fieldId]);else if(h){for(s=[],e=0,i=l.length;i>e;e++)o=l[e],s.push(r[o]);for(this._sort(s,h),e=0,i=s.length;i>e;e++)c.push(s[e][this._fieldId])}else for(e=0,i=l.length;i>e;e++)o=l[e],n=r[o],c.push(n[this._fieldId]);return c},o.prototype.getDataSet=function(){return this},o.prototype.forEach=function(t,e){var i,o,n,s,r=e&&e.filter,a=e&&e.type||this._options.type,h=this._data,d=Object.keys(h);if(e&&e.order){var l=this.get(e);for(i=0,o=l.length;o>i;i++)n=l[i],s=n[this._fieldId],t(n,s)}else for(i=0,o=d.length;o>i;i++)s=d[i],n=this._getItem(s,a),r&&!r(n)||t(n,s)},o.prototype.map=function(t,e){var i,o,n,s,r=e&&e.filter,a=e&&e.type||this._options.type,h=[],d=this._data,l=Object.keys(d);for(i=0,o=l.length;o>i;i++)n=l[i],s=this._getItem(n,a),r&&!r(s)||h.push(t(s,n));return e&&e.order&&this._sort(h,e.order),h},o.prototype._filterFields=function(t,e){if(!t)return t;var i,o,n={},s=Object.keys(t),r=s.length;if(Array.isArray(e))for(i=0;r>i;i++)o=s[i],-1!=e.indexOf(o)&&(n[o]=t[o]);else for(i=0;r>i;i++)o=s[i],e.hasOwnProperty(o)&&(n[e[o]]=t[o]);return n},o.prototype._sort=function(t,e){if(s.isString(e)){var i=e;t.sort(function(t,e){var o=t[i],n=e[i];return o>n?1:n>o?-1:0})}else{if("function"!=typeof e)throw new TypeError("Order must be a function or a string");t.sort(e)}},o.prototype.remove=function(t,e){var i,o,n,s=[];if(Array.isArray(t))for(i=0,o=t.length;o>i;i++)n=this._remove(t[i]),null!=n&&s.push(n);else n=this._remove(t),null!=n&&s.push(n);return s.length&&this._trigger("remove",{items:s},e),s},o.prototype._remove=function(t){if(s.isNumber(t)||s.isString(t)){if(this._data[t])return delete this._data[t],this.length--,t}else if(t instanceof Object){var e=t[this._fieldId];if(void 0!==e&&this._data[e])return delete this._data[e],this.length--,e}return null},o.prototype.clear=function(t){var e=Object.keys(this._data);return this._data={},this.length=0,this._trigger("remove",{items:e},t),e},o.prototype.max=function(t){var e,i,o=this._data,n=Object.keys(o),s=null,r=null;for(e=0,i=n.length;i>e;e++){var a=n[e],h=o[a],d=h[t];null!=d&&(!s||d>r)&&(s=h,r=d)}return s},o.prototype.min=function(t){var e,i,o=this._data,n=Object.keys(o),s=null,r=null;for(e=0,i=n.length;i>e;e++){var a=n[e],h=o[a],d=h[t];null!=d&&(!s||r>d)&&(s=h,r=d)}return s},o.prototype.distinct=function(t){var e,i,o,n=this._data,r=Object.keys(n),a=[],h=this._options.type&&this._options.type[t]||null,d=0;for(e=0,o=r.length;o>e;e++){var l=r[e],c=n[l],u=c[t],p=!1;for(i=0;d>i;i++)if(a[i]==u){p=!0;break}p||void 0===u||(a[d]=u,d++)}if(h)for(e=0,o=a.length;o>e;e++)a[e]=s.convert(a[e],h);return a},o.prototype._addItem=function(t){var e=t[this._fieldId];if(void 0!=e){if(this._data[e])throw new Error("Cannot add item: item with id "+e+" already exists")}else e=s.randomUUID(),t[this._fieldId]=e;var i,o,n={},r=Object.keys(t);for(i=0,o=r.length;o>i;i++){var a=r[i],h=this._type[a];n[a]=s.convert(t[a],h)}return this._data[e]=n,this.length++,e},o.prototype._getItem=function(t,e){var i,o,n,r,a=this._data[t];if(!a)return null;var h={},d=Object.keys(a);if(e)for(n=0,r=d.length;r>n;n++)i=d[n],o=a[i],h[i]=s.convert(o,e[i]);else for(n=0,r=d.length;r>n;n++)i=d[n],o=a[i],h[i]=o;return h},o.prototype._updateItem=function(t){var e=t[this._fieldId];if(void 0==e)throw new Error("Cannot update item: item has no id (item: "+JSON.stringify(t)+")");var i=this._data[e];if(!i)throw new Error("Cannot update item: no item with id "+e+" found");for(var o=Object.keys(t),n=0,r=o.length;r>n;n++){var a=o[n],h=this._type[a];i[a]=s.convert(t[a],h)}return e},t.exports=o},function(t,e){function i(t){this.delay=null,this.max=1/0,this._queue=[],this._timeout=null,this._extended=null,this.setOptions(t)}i.prototype.setOptions=function(t){t&&"undefined"!=typeof t.delay&&(this.delay=t.delay),t&&"undefined"!=typeof t.max&&(this.max=t.max),this._flushIfNeeded()},i.extend=function(t,e){var o=new i(e);if(void 0!==t.flush)throw new Error("Target object already has a property flush");t.flush=function(){o.flush()};var n=[{name:"flush",original:void 0}];if(e&&e.replace)for(var s=0;sthis.max&&this.flush(),clearTimeout(this._timeout),this.queue.length>0&&"number"==typeof this.delay){var t=this;this._timeout=setTimeout(function(){t.flush()},this.delay)}},i.prototype.flush=function(){for(;this._queue.length>0;){var t=this._queue.shift();t.fn.apply(t.context||t.fn,t.args||[])}},t.exports=i},function(t,e,i){function o(t,e){this._data=null,this._ids={},this.length=0,this._options=e||{},this._fieldId="id",this._subscribers={};var i=this;this.listener=function(){i._onEvent.apply(i,arguments)},this.setData(t)}var n=i(1),s=i(9);o.prototype.setData=function(t){var e,i,o,n;if(this._data&&(this._data.off&&this._data.off("*",this.listener),e=Object.keys(this._ids),this._ids={},this.length=0,this._trigger("remove",{items:e})),this._data=t,this._data){for(this._fieldId=this._options.fieldId||this._data&&this._data.options&&this._data.options.fieldId||"id",e=this._data.getIds({filter:this._options&&this._options.filter}),o=0,n=e.length;n>o;o++)i=e[o],this._ids[i]=!0;this.length=e.length,this._trigger("add",{items:e}),this._data.on&&this._data.on("*",this.listener)}},o.prototype.refresh=function(){var t,e,i,o=this._data.getIds({filter:this._options&&this._options.filter}),n=Object.keys(this._ids),s={},r=[],a=[];for(e=0,i=o.length;i>e;e++)t=o[e],s[t]=!0,this._ids[t]||(r.push(t),this._ids[t]=!0);for(e=0,i=n.length;i>e;e++)t=n[e],s[t]||(a.push(t),delete this._ids[t]);this.length+=r.length-a.length,r.length&&this._trigger("add",{items:r}),a.length&&this._trigger("remove",{items:a})},o.prototype.get=function(t){var e,i,o,s=this,r=n.getType(arguments[0]);"String"==r||"Number"==r||"Array"==r?(e=arguments[0],i=arguments[1],o=arguments[2]):(i=arguments[0],o=arguments[1]);var a=n.extend({},this._options,i);this._options.filter&&i&&i.filter&&(a.filter=function(t){return s._options.filter(t)&&i.filter(t)});var h=[];return void 0!=e&&h.push(e),h.push(a),h.push(o),this._data&&this._data.get.apply(this._data,h)},o.prototype.getIds=function(t){var e;if(this._data){var i,o=this._options.filter;i=t&&t.filter?o?function(e){return o(e)&&t.filter(e)}:t.filter:o,e=this._data.getIds({filter:i,order:t&&t.order})}else e=[];return e},o.prototype.map=function(t,e){var i=[];if(this._data){var o,n=this._options.filter;o=e&&e.filter?n?function(t){return n(t)&&e.filter(t)}:e.filter:n,i=this._data.map(t,{filter:o,order:e&&e.order})}else i=[];return i},o.prototype.getDataSet=function(){for(var t=this;t instanceof o;)t=t._data;return t||null},o.prototype._onEvent=function(t,e,i){var o,n,s,r,a=e&&e.items,h=this._data,d=[],l=[],c=[],u=[];if(a&&h){switch(t){case"add":for(o=0,n=a.length;n>o;o++)s=a[o],r=this.get(s),r&&(this._ids[s]=!0,l.push(s));break;case"update":for(o=0,n=a.length;n>o;o++)s=a[o],r=this.get(s),r?this._ids[s]?(c.push(s),d.push(e.data[o])):(this._ids[s]=!0,l.push(s)):this._ids[s]&&(delete this._ids[s],u.push(s));break;case"remove":for(o=0,n=a.length;n>o;o++)s=a[o],this._ids[s]&&(delete this._ids[s],u.push(s))}this.length+=l.length-u.length,l.length&&this._trigger("add",{items:l},i),c.length&&this._trigger("update",{items:c,data:d},i),u.length&&this._trigger("remove",{items:u},i)}},o.prototype.on=s.prototype.on,o.prototype.off=s.prototype.off,o.prototype._trigger=s.prototype._trigger,o.prototype.subscribe=o.prototype.on,o.prototype.unsubscribe=o.prototype.off,t.exports=o},function(t,e,i){function o(t,e,i){if(!(this instanceof o))throw new SyntaxError("Constructor must be called with the new operator");this.containerElement=t,this.width="400px",this.height="400px",this.margin=10,this.defaultXCenter="55%",this.defaultYCenter="50%",this.xLabel="x",this.yLabel="y",this.zLabel="z";var n=function(t){return t};this.xValueLabel=n,this.yValueLabel=n,this.zValueLabel=n,this.filterLabel="time",this.legendLabel="value",this.style=o.STYLE.DOT,this.showPerspective=!0,this.showGrid=!0,this.keepAspectRatio=!0,this.showShadow=!1,this.showGrayBottom=!1,this.showTooltip=!1,this.verticalRatio=.5,this.animationInterval=1e3,this.animationPreload=!1,this.camera=new p,this.camera.setArmRotation(1,.5),this.camera.setArmLength(1.7),this.eye=new c(0,0,-1),this.dataTable=null,this.dataPoints=null,this.colX=void 0,this.colY=void 0,this.colZ=void 0,this.colValue=void 0,this.colFilter=void 0,this.xMin=0,this.xStep=void 0,this.xMax=1,this.yMin=0,this.yStep=void 0,this.yMax=1,this.zMin=0,this.zStep=void 0,this.zMax=1,this.valueMin=0,this.valueMax=1,this.xBarWidth=1,this.yBarWidth=1,this.axisColor="#4D4D4D",this.gridColor="#D3D3D3",this.dataColor={fill:"#7DC1FF",stroke:"#3267D2",strokeWidth:1},this.dotSizeRatio=.02,this.create(),this.setOptions(i),e&&this.setData(e)}function n(t){return"clientX"in t?t.clientX:t.targetTouches[0]&&t.targetTouches[0].clientX||0}function s(t){return"clientY"in t?t.clientY:t.targetTouches[0]&&t.targetTouches[0].clientY||0}var r="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},a=i(13),h=i(9),d=i(11),l=i(1),c=i(14),u=i(15),p=i(16),f=i(17),m=i(18),v=i(19);a(o.prototype),o.prototype._setScale=function(){this.scale=new c(1/(this.xMax-this.xMin),1/(this.yMax-this.yMin),1/(this.zMax-this.zMin)),this.keepAspectRatio&&(this.scale.x3&&(this.colFilter=3);else{if(this.style!==o.STYLE.DOTCOLOR&&this.style!==o.STYLE.DOTSIZE&&this.style!==o.STYLE.BARCOLOR&&this.style!==o.STYLE.BARSIZE)throw'Unknown style "'+this.style+'"';this.colX=0,this.colY=1,this.colZ=2,this.colValue=3,t.getNumberOfColumns()>4&&(this.colFilter=4)}},o.prototype.getNumberOfRows=function(t){return t.length},o.prototype.getNumberOfColumns=function(t){var e=0;for(var i in t[0])t[0].hasOwnProperty(i)&&e++;return e},o.prototype.getDistinctValues=function(t,e){for(var i=[],o=0;ot[o][e]&&(i.min=t[o][e]),i.maxt;t++){var f=(t-u)/(p-u),m=240*f,g=this._hsv2rgb(m,1,1);c.strokeStyle=g,c.beginPath(),c.moveTo(h,r+t),c.lineTo(a,r+t),c.stroke()}c.strokeStyle=this.axisColor,c.strokeRect(h,r,i,s)}if(this.style===o.STYLE.DOTSIZE&&(c.strokeStyle=this.axisColor,c.fillStyle=this.dataColor.fill,c.beginPath(),c.moveTo(h,r),c.lineTo(a,r),c.lineTo(a-i+e,d),c.lineTo(h,d),c.closePath(),c.fill(),c.stroke()),this.style===o.STYLE.DOTCOLOR||this.style===o.STYLE.DOTSIZE){var y=5,b=new v(this.valueMin,this.valueMax,(this.valueMax-this.valueMin)/5,!0);for(b.start(),b.getCurrent()0?this.yMin:this.yMax,n=this._convert3Dto2D(new c(_,r,this.zMin)),Math.cos(2*w)>0?(m.textAlign="center",m.textBaseline="top",n.y+=b):Math.sin(2*w)<0?(m.textAlign="right",m.textBaseline="middle"):(m.textAlign="left",m.textBaseline="middle"),m.fillStyle=this.axisColor,m.fillText(" "+this.xValueLabel(i.getCurrent())+" ",n.x,n.y),i.next()}for(m.lineWidth=1,o=void 0===this.defaultYStep,i=new v(this.yMin,this.yMax,this.yStep,o),i.start(),i.getCurrent()0?this.xMin:this.xMax,n=this._convert3Dto2D(new c(s,i.getCurrent(),this.zMin)),Math.cos(2*w)<0?(m.textAlign="center",m.textBaseline="top",n.y+=b):Math.sin(2*w)>0?(m.textAlign="right",m.textBaseline="middle"):(m.textAlign="left",m.textBaseline="middle"),m.fillStyle=this.axisColor,m.fillText(" "+this.yValueLabel(i.getCurrent())+" ",n.x,n.y),i.next();for(m.lineWidth=1,o=void 0===this.defaultZStep,i=new v(this.zMin,this.zMax,this.zStep,o),i.start(),i.getCurrent()0?this.xMin:this.xMax,r=Math.sin(w)<0?this.yMin:this.yMax;!i.end();)t=this._convert3Dto2D(new c(s,r,i.getCurrent())),m.strokeStyle=this.axisColor,m.beginPath(),m.moveTo(t.x,t.y),m.lineTo(t.x-b,t.y),m.stroke(),m.textAlign="right",m.textBaseline="middle",m.fillStyle=this.axisColor,m.fillText(this.zValueLabel(i.getCurrent())+" ",t.x-5,t.y),i.next();m.lineWidth=1,t=this._convert3Dto2D(new c(s,r,this.zMin)),e=this._convert3Dto2D(new c(s,r,this.zMax)),m.strokeStyle=this.axisColor,m.beginPath(),m.moveTo(t.x,t.y),m.lineTo(e.x,e.y),m.stroke(),m.lineWidth=1,u=this._convert3Dto2D(new c(this.xMin,this.yMin,this.zMin)),p=this._convert3Dto2D(new c(this.xMax,this.yMin,this.zMin)),m.strokeStyle=this.axisColor,m.beginPath(),m.moveTo(u.x,u.y),m.lineTo(p.x,p.y),m.stroke(),u=this._convert3Dto2D(new c(this.xMin,this.yMax,this.zMin)),p=this._convert3Dto2D(new c(this.xMax,this.yMax,this.zMin)),m.strokeStyle=this.axisColor,m.beginPath(),m.moveTo(u.x,u.y),m.lineTo(p.x,p.y),m.stroke(),m.lineWidth=1,t=this._convert3Dto2D(new c(this.xMin,this.yMin,this.zMin)),e=this._convert3Dto2D(new c(this.xMin,this.yMax,this.zMin)),m.strokeStyle=this.axisColor,m.beginPath(),m.moveTo(t.x,t.y),m.lineTo(e.x,e.y),m.stroke(),t=this._convert3Dto2D(new c(this.xMax,this.yMin,this.zMin)),e=this._convert3Dto2D(new c(this.xMax,this.yMax,this.zMin)),m.strokeStyle=this.axisColor,m.beginPath(),m.moveTo(t.x,t.y),m.lineTo(e.x,e.y),m.stroke();var x=this.xLabel;x.length>0&&(l=.1/this.scale.y,s=(this.xMin+this.xMax)/2,r=Math.cos(w)>0?this.yMin-l:this.yMax+l,n=this._convert3Dto2D(new c(s,r,this.zMin)),Math.cos(2*w)>0?(m.textAlign="center",m.textBaseline="top"):Math.sin(2*w)<0?(m.textAlign="right",m.textBaseline="middle"):(m.textAlign="left",m.textBaseline="middle"),m.fillStyle=this.axisColor,m.fillText(x,n.x,n.y));var k=this.yLabel;k.length>0&&(d=.1/this.scale.x,s=Math.sin(w)>0?this.xMin-d:this.xMax+d,r=(this.yMin+this.yMax)/2,n=this._convert3Dto2D(new c(s,r,this.zMin)),Math.cos(2*w)<0?(m.textAlign="center",m.textBaseline="top"):Math.sin(2*w)>0?(m.textAlign="right",m.textBaseline="middle"):(m.textAlign="left",m.textBaseline="middle"),m.fillStyle=this.axisColor,m.fillText(k,n.x,n.y));var O=this.zLabel;O.length>0&&(h=30,s=Math.cos(w)>0?this.xMin:this.xMax,r=Math.sin(w)<0?this.yMin:this.yMax,a=(this.zMin+this.zMax)/2,n=this._convert3Dto2D(new c(s,r,a)),m.textAlign="right",m.textBaseline="middle",m.fillStyle=this.axisColor,m.fillText(O,n.x-h,n.y))},o.prototype._hsv2rgb=function(t,e,i){var o,n,s,r,a,h;switch(r=i*e,a=Math.floor(t/60),h=r*(1-Math.abs(t/60%2-1)),a){case 0:o=r,n=h,s=0;break;case 1:o=h,n=r,s=0;break;case 2:o=0,n=r,s=h;break;case 3:o=0,n=h,s=r;break;case 4:o=h,n=0,s=r;break;case 5:o=r,n=0,s=h;break;default:o=0,n=0,s=0}return"RGB("+parseInt(255*o)+","+parseInt(255*n)+","+parseInt(255*s)+")"},o.prototype._redrawDataGrid=function(){var t,e,i,n,s,r,a,h,d,l,u,p,f=this.frame.canvas,m=f.getContext("2d");if(m.lineJoin="round",m.lineCap="round",!(void 0===this.dataPoints||this.dataPoints.length<=0)){for(s=0;s0}else r=!0;r?(p=(t.point.z+e.point.z+i.point.z+n.point.z)/4,d=240*(1-(p-this.zMin)*this.scale.z/this.verticalRatio),l=1,this.showShadow?(u=Math.min(1+x.x/k/2,1),a=this._hsv2rgb(d,l,u),h=a):(u=1,a=this._hsv2rgb(d,l,u),h=this.axisColor)):(a="gray",h=this.axisColor),m.lineWidth=this._getStrokeWidth(t),m.fillStyle=a,m.strokeStyle=h,m.beginPath(),m.moveTo(t.screen.x,t.screen.y),m.lineTo(e.screen.x,e.screen.y),m.lineTo(n.screen.x,n.screen.y),m.lineTo(i.screen.x,i.screen.y),m.closePath(),m.fill(),m.stroke()}}else for(s=0;su&&(u=0);var p,f,m;this.style===o.STYLE.DOTCOLOR?(p=240*(1-(d.point.value-this.valueMin)*this.scale.value),f=this._hsv2rgb(p,1,1),m=this._hsv2rgb(p,1,.8)):this.style===o.STYLE.DOTSIZE?(f=this.dataColor.fill,m=this.dataColor.stroke):(p=240*(1-(d.point.z-this.zMin)*this.scale.z/this.verticalRatio),f=this._hsv2rgb(p,1,1),m=this._hsv2rgb(p,1,.8)),i.lineWidth=this._getStrokeWidth(d),i.strokeStyle=m,i.fillStyle=f,i.beginPath(),i.arc(d.screen.x,d.screen.y,u,0,2*Math.PI,!0),i.fill(),i.stroke()}}},o.prototype._redrawDataBar=function(){var t,e,i,n,s=this.frame.canvas,r=s.getContext("2d");if(!(void 0===this.dataPoints||this.dataPoints.length<=0)){for(t=0;t0){for(t=this.dataPoints[0],o.lineWidth=this._getStrokeWidth(t),o.lineJoin="round",o.lineCap="round",o.strokeStyle=this.dataColor.stroke,o.beginPath(),o.moveTo(t.screen.x,t.screen.y),e=1;e0?1:0>t?-1:0}var o=e[0],n=e[1],s=e[2],r=i((n.x-o.x)*(t.y-o.y)-(n.y-o.y)*(t.x-o.x)),a=i((s.x-n.x)*(t.y-n.y)-(s.y-n.y)*(t.x-n.x)),h=i((o.x-s.x)*(t.y-s.y)-(o.y-s.y)*(t.x-s.x));return!(0!=r&&0!=a&&r!=a||0!=a&&0!=h&&a!=h||0!=r&&0!=h&&r!=h)},o.prototype._dataPointFromXY=function(t,e){var i,n=100,s=null,r=null,a=null,h=new u(t,e);if(this.style===o.STYLE.BAR||this.style===o.STYLE.BARCOLOR||this.style===o.STYLE.BARSIZE)for(i=this.dataPoints.length-1;i>=0;i--){s=this.dataPoints[i];var d=s.surfaces;if(d)for(var l=d.length-1;l>=0;l--){var c=d[l],p=c.corners,f=[p[0].screen,p[1].screen,p[2].screen],m=[p[2].screen,p[3].screen,p[0].screen];if(this._insideTriangle(h,f)||this._insideTriangle(h,m))return s}}else for(i=0;ib)&&n>b&&(a=b,r=s)}}return r},o.prototype._showTooltip=function(t){var e,i,o;this.tooltip?(e=this.tooltip.dom.content,i=this.tooltip.dom.line,o=this.tooltip.dom.dot):(e=document.createElement("div"),e.style.position="absolute",e.style.padding="10px",e.style.border="1px solid #4d4d4d",e.style.color="#1a1a1a",e.style.background="rgba(255,255,255,0.7)",e.style.borderRadius="2px",e.style.boxShadow="5px 5px 10px rgba(128,128,128,0.5)",i=document.createElement("div"),i.style.position="absolute",i.style.height="40px",i.style.width="0",i.style.borderLeft="1px solid #4d4d4d",o=document.createElement("div"),o.style.position="absolute",o.style.height="0",o.style.width="0",o.style.border="5px solid #4d4d4d",o.style.borderRadius="5px",this.tooltip={dataPoint:null,dom:{content:e,line:i,dot:o}}),this._hideTooltip(),this.tooltip.dataPoint=t,"function"==typeof this.showTooltip?e.innerHTML=this.showTooltip(t.point):e.innerHTML="
"+this.xLabel+":"+t.point.x+"
"+this.yLabel+":"+t.point.y+"
"+this.zLabel+":"+t.point.z+"
",e.style.left="0",e.style.top="0",this.frame.appendChild(e),this.frame.appendChild(i),this.frame.appendChild(o);var n=e.offsetWidth,s=e.offsetHeight,r=i.offsetHeight,a=o.offsetWidth,h=o.offsetHeight,d=t.screen.x-n/2;d=Math.min(Math.max(d,10),this.frame.clientWidth-10-n),i.style.left=t.screen.x+"px",i.style.top=t.screen.y-r+"px",e.style.left=d+"px",e.style.top=t.screen.y-r-s+"px",o.style.left=t.screen.x-a/2+"px",o.style.top=t.screen.y-h/2+"px"},o.prototype._hideTooltip=function(){if(this.tooltip){this.tooltip.dataPoint=null;for(var t in this.tooltip.dom)if(this.tooltip.dom.hasOwnProperty(t)){var e=this.tooltip.dom[t];e&&e.parentNode&&e.parentNode.removeChild(e)}}},t.exports=o},function(t,e){function i(t){return t?o(t):void 0}function o(t){for(var e in i.prototype)t[e]=i.prototype[e];return t}t.exports=i,i.prototype.on=i.prototype.addEventListener=function(t,e){return this._callbacks=this._callbacks||{},(this._callbacks[t]=this._callbacks[t]||[]).push(e),this},i.prototype.once=function(t,e){function i(){o.off(t,i),e.apply(this,arguments)}var o=this;return this._callbacks=this._callbacks||{},i.fn=e,this.on(t,i),this},i.prototype.off=i.prototype.removeListener=i.prototype.removeAllListeners=i.prototype.removeEventListener=function(t,e){if(this._callbacks=this._callbacks||{},0==arguments.length)return this._callbacks={},this;var i=this._callbacks[t];if(!i)return this;if(1==arguments.length)return delete this._callbacks[t],this;for(var o,n=0;no;++o)i[o].apply(this,e)}return this},i.prototype.listeners=function(t){return this._callbacks=this._callbacks||{},this._callbacks[t]||[]},i.prototype.hasListeners=function(t){return!!this.listeners(t).length}},function(t,e){function i(t,e,i){this.x=void 0!==t?t:0,this.y=void 0!==e?e:0,this.z=void 0!==i?i:0}i.subtract=function(t,e){var o=new i;return o.x=t.x-e.x,o.y=t.y-e.y,o.z=t.z-e.z,o},i.add=function(t,e){var o=new i;return o.x=t.x+e.x,o.y=t.y+e.y,o.z=t.z+e.z,o},i.avg=function(t,e){return new i((t.x+e.x)/2,(t.y+e.y)/2,(t.z+e.z)/2)},i.crossProduct=function(t,e){var o=new i;return o.x=t.y*e.z-t.z*e.y,o.y=t.z*e.x-t.x*e.z,o.z=t.x*e.y-t.y*e.x,o},i.prototype.length=function(){return Math.sqrt(this.x*this.x+this.y*this.y+this.z*this.z)},t.exports=i},function(t,e){function i(t,e){this.x=void 0!==t?t:0,this.y=void 0!==e?e:0}t.exports=i},function(t,e,i){function o(){this.armLocation=new n,this.armRotation={},this.armRotation.horizontal=0,this.armRotation.vertical=0,this.armLength=1.7,this.cameraLocation=new n,this.cameraRotation=new n(.5*Math.PI,0,0),this.calculateCameraOrientation()}var n=i(14);o.prototype.setArmLocation=function(t,e,i){this.armLocation.x=t,this.armLocation.y=e,this.armLocation.z=i,this.calculateCameraOrientation()},o.prototype.setArmRotation=function(t,e){void 0!==t&&(this.armRotation.horizontal=t),void 0!==e&&(this.armRotation.vertical=e,this.armRotation.vertical<0&&(this.armRotation.vertical=0),this.armRotation.vertical>.5*Math.PI&&(this.armRotation.vertical=.5*Math.PI)),void 0===t&&void 0===e||this.calculateCameraOrientation()},o.prototype.getArmRotation=function(){var t={};return t.horizontal=this.armRotation.horizontal,t.vertical=this.armRotation.vertical,t},o.prototype.setArmLength=function(t){void 0!==t&&(this.armLength=t,this.armLength<.71&&(this.armLength=.71),this.armLength>5&&(this.armLength=5),this.calculateCameraOrientation())},o.prototype.getArmLength=function(){return this.armLength},o.prototype.getCameraLocation=function(){return this.cameraLocation},o.prototype.getCameraRotation=function(){return this.cameraRotation},o.prototype.calculateCameraOrientation=function(){this.cameraLocation.x=this.armLocation.x-this.armLength*Math.sin(this.armRotation.horizontal)*Math.cos(this.armRotation.vertical),this.cameraLocation.y=this.armLocation.y-this.armLength*Math.cos(this.armRotation.horizontal)*Math.cos(this.armRotation.vertical),this.cameraLocation.z=this.armLocation.z+this.armLength*Math.sin(this.armRotation.vertical),this.cameraRotation.x=Math.PI/2-this.armRotation.vertical,this.cameraRotation.y=0,this.cameraRotation.z=-this.armRotation.horizontal},t.exports=o},function(t,e,i){function o(t,e,i){this.data=t,this.column=e,this.graph=i,this.index=void 0,this.value=void 0,this.values=i.getDistinctValues(t.get(),this.column),this.values.sort(function(t,e){return t>e?1:e>t?-1:0}),this.values.length>0&&this.selectValue(0),this.dataPoints=[],this.loaded=!1,this.onLoadCallback=void 0,i.animationPreload?(this.loaded=!1,this.loadInBackground()):this.loaded=!0}var n=i(11);o.prototype.isLoaded=function(){return this.loaded},o.prototype.getLoadedProgress=function(){for(var t=this.values.length,e=0;this.dataPoints[e];)e++;return Math.round(e/t*100)},o.prototype.getLabel=function(){return this.graph.filterLabel},o.prototype.getColumn=function(){return this.column},o.prototype.getSelectedValue=function(){return void 0!==this.index?this.values[this.index]:void 0},o.prototype.getValues=function(){return this.values},o.prototype.getValue=function(t){if(t>=this.values.length)throw"Error: index out of range";return this.values[t]},o.prototype._getDataPoints=function(t){if(void 0===t&&(t=this.index),void 0===t)return[];var e;if(this.dataPoints[t])e=this.dataPoints[t];else{var i={};i.column=this.column,i.value=this.values[t];var o=new n(this.data,{filter:function(t){return t[i.column]==i.value}}).get();e=this.graph._getDataPoints(o),this.dataPoints[t]=e}return e},o.prototype.setOnLoadCallback=function(t){this.onLoadCallback=t},o.prototype.selectValue=function(t){if(t>=this.values.length)throw"Error: index out of range";this.index=t,this.value=this.values[t]},o.prototype.loadInBackground=function(t){void 0===t&&(t=0);var e=this.graph.frame;if(t0&&(t--,this.setIndex(t))},o.prototype.next=function(){var t=this.getIndex();t0?this.setIndex(0):this.index=void 0},o.prototype.setIndex=function(t){if(!(to&&(o=0),o>this.values.length-1&&(o=this.values.length-1),o},o.prototype.indexToLeft=function(t){var e=parseFloat(this.frame.bar.style.width)-this.frame.slide.clientWidth-10,i=t/(this.values.length-1)*e,o=i+3;return o},o.prototype._onMouseMove=function(t){var e=t.clientX-this.startClientX,i=this.startSlideX+e,o=this.leftToIndex(i);this.setIndex(o),n.preventDefault()},o.prototype._onMouseUp=function(t){this.frame.style.cursor="auto",n.removeEventListener(document,"mousemove",this.onmousemove),n.removeEventListener(document,"mouseup",this.onmouseup),n.preventDefault()},t.exports=o},function(t,e){function i(t,e,i,o){this._start=0,this._end=0,this._step=1,this.prettyStep=!0,this.precision=5,this._current=0,this.setRange(t,e,i,o)}i.prototype.setRange=function(t,e,i,o){this._start=t?t:0,this._end=e?e:0,this.setStep(i,o)},i.prototype.setStep=function(t,e){void 0===t||0>=t||(void 0!==e&&(this.prettyStep=e),this.prettyStep===!0?this._step=i.calculatePrettyStep(t):this._step=t)},i.calculatePrettyStep=function(t){var e=function(t){return Math.log(t)/Math.LN10},i=Math.pow(10,Math.round(e(t))),o=2*Math.pow(10,Math.round(e(t/2))),n=5*Math.pow(10,Math.round(e(t/5))),s=i;return Math.abs(o-t)<=Math.abs(s-t)&&(s=o),Math.abs(n-t)<=Math.abs(s-t)&&(s=n),0>=s&&(s=1),s},i.prototype.getCurrent=function(){return parseFloat(this._current.toPrecision(this.precision))},i.prototype.getStep=function(){return this._step},i.prototype.start=function(){this._current=this._start-this._start%this._step},i.prototype.next=function(){this._current+=this._step},i.prototype.end=function(){return this._current>this._end},t.exports=i},function(t,e,i){if("undefined"!=typeof window){var o=i(21),n=window.Hammer||i(22);t.exports=o(n,{preventDefault:"mouse"})}else t.exports=function(){throw Error("hammer.js is only available in a browser, not in node.js.")}},function(t,e,i){var o,n,s;!function(i){n=[],o=i,s="function"==typeof o?o.apply(e,n):o,!(void 0!==s&&(t.exports=s))}(function(){var t=null;return function e(i,o){function n(t){return t.match(/[^ ]+/g)}function s(e){if("hammer.input"!==e.type){if(e.srcEvent._handled||(e.srcEvent._handled={}),e.srcEvent._handled[e.type])return;e.srcEvent._handled[e.type]=!0}var i=!1;e.stopPropagation=function(){i=!0};var o=e.srcEvent.stopPropagation.bind(e.srcEvent);"function"==typeof o&&(e.srcEvent.stopPropagation=function(){o(),e.stopPropagation()}),e.firstTarget=t;for(var n=t;n&&!i;){var s=n.hammer;if(s)for(var r,a=0;a0?d._handlers[t]=o:(i.off(t,s),delete d._handlers[t]))}),d},d.emit=function(e,o){t=o.target,i.emit(e,o)},d.destroy=function(){var t=i.element.hammer,e=t.indexOf(d);-1!==e&&t.splice(e,1),t.length||delete i.element.hammer,d._handlers={},i.destroy()},d}})},function(t,e,i){var o;!function(n,s,r,a){function h(t,e,i){return setTimeout(p(t,i),e)}function d(t,e,i){return Array.isArray(t)?(l(t,i[e],i),!0):!1}function l(t,e,i){var o;if(t)if(t.forEach)t.forEach(e,i);else if(t.length!==a)for(o=0;o\s*\(/gm,"{anonymous}()@"):"Unknown Stack Trace",s=n.console&&(n.console.warn||n.console.log);return s&&s.call(n.console,o,i),t.apply(this,arguments)}}function u(t,e,i){var o,n=e.prototype;o=t.prototype=Object.create(n),o.constructor=t,o._super=n,i&&ct(o,i)}function p(t,e){return function(){return t.apply(e,arguments)}}function f(t,e){return typeof t==ft?t.apply(e?e[0]||a:a,e):t}function m(t,e){return t===a?e:t}function v(t,e,i){l(w(e),function(e){t.addEventListener(e,i,!1)})}function g(t,e,i){l(w(e),function(e){t.removeEventListener(e,i,!1)})}function y(t,e){for(;t;){if(t==e)return!0;t=t.parentNode}return!1}function b(t,e){return t.indexOf(e)>-1}function w(t){return t.trim().split(/\s+/g)}function _(t,e,i){if(t.indexOf&&!i)return t.indexOf(e);for(var o=0;oi[e]}):o.sort()),o}function O(t,e){for(var i,o,n=e[0].toUpperCase()+e.slice(1),s=0;s1&&!i.firstMultiple?i.firstMultiple=N(e):1===n&&(i.firstMultiple=!1);var s=i.firstInput,r=i.firstMultiple,a=r?r.center:s.center,h=e.center=R(o);e.timeStamp=gt(),e.deltaTime=e.timeStamp-s.timeStamp,e.angle=B(a,h),e.distance=A(a,h),P(i,e),e.offsetDirection=L(e.deltaX,e.deltaY);var d=z(e.deltaTime,e.deltaX,e.deltaY);e.overallVelocityX=d.x,e.overallVelocityY=d.y,e.overallVelocity=vt(d.x)>vt(d.y)?d.x:d.y,e.scale=r?j(r.pointers,o):1,e.rotation=r?F(r.pointers,o):0,e.maxPointers=i.prevInput?e.pointers.length>i.prevInput.maxPointers?e.pointers.length:i.prevInput.maxPointers:e.pointers.length,I(i,e);var l=t.element;y(e.srcEvent.target,l)&&(l=e.srcEvent.target),e.target=l}function P(t,e){var i=e.center,o=t.offsetDelta||{},n=t.prevDelta||{},s=t.prevInput||{};e.eventType!==Et&&s.eventType!==It||(n=t.prevDelta={x:s.deltaX||0,y:s.deltaY||0},o=t.offsetDelta={x:i.x,y:i.y}),e.deltaX=n.x+(i.x-o.x),e.deltaY=n.y+(i.y-o.y)}function I(t,e){var i,o,n,s,r=t.lastInterval||e,h=e.timeStamp-r.timeStamp;if(e.eventType!=Nt&&(h>Tt||r.velocity===a)){var d=e.deltaX-r.deltaX,l=e.deltaY-r.deltaY,c=z(h,d,l);o=c.x,n=c.y,i=vt(c.x)>vt(c.y)?c.x:c.y,s=L(d,l),t.lastInterval=e}else i=r.velocity,o=r.velocityX,n=r.velocityY,s=r.direction;e.velocity=i,e.velocityX=o,e.velocityY=n,e.direction=s}function N(t){for(var e=[],i=0;in;)i+=t[n].clientX,o+=t[n].clientY,n++;return{x:mt(i/e),y:mt(o/e)}}function z(t,e,i){return{x:e/t||0,y:i/t||0}}function L(t,e){return t===e?Rt:vt(t)>=vt(e)?0>t?zt:Lt:0>e?At:Bt}function A(t,e,i){i||(i=Wt);var o=e[i[0]]-t[i[0]],n=e[i[1]]-t[i[1]];return Math.sqrt(o*o+n*n)}function B(t,e,i){i||(i=Wt);var o=e[i[0]]-t[i[0]],n=e[i[1]]-t[i[1]];return 180*Math.atan2(n,o)/Math.PI}function F(t,e){return B(e[1],e[0],Yt)+B(t[1],t[0],Yt)}function j(t,e){return A(e[0],e[1],Yt)/A(t[0],t[1],Yt)}function H(){this.evEl=Vt,this.evWin=Ut,this.allow=!0,this.pressed=!1,S.apply(this,arguments)}function W(){this.evEl=Zt,this.evWin=Kt,S.apply(this,arguments),this.store=this.manager.session.pointerEvents=[]}function Y(){this.evTarget=Qt,this.evWin=$t,this.started=!1,S.apply(this,arguments)}function G(t,e){var i=x(t.touches),o=x(t.changedTouches);return e&(It|Nt)&&(i=k(i.concat(o),"identifier",!0)),[i,o]}function V(){this.evTarget=ee,this.targetIds={},S.apply(this,arguments)}function U(t,e){var i=x(t.touches),o=this.targetIds;if(e&(Et|Pt)&&1===i.length)return o[i[0].identifier]=!0,[i,i];var n,s,r=x(t.changedTouches),a=[],h=this.target;if(s=i.filter(function(t){return y(t.target,h)}),e===Et)for(n=0;na&&(e.push(t),a=e.length-1):n&(It|Nt)&&(i=!0),0>a||(e[a]=t,this.callback(this.manager,n,{pointers:e,changedPointers:[t],pointerType:s,srcEvent:t}),i&&e.splice(a,1))}});var Jt={touchstart:Et,touchmove:Pt,touchend:It,touchcancel:Nt},Qt="touchstart",$t="touchstart touchmove touchend touchcancel";u(Y,S,{handler:function(t){var e=Jt[t.type];if(e===Et&&(this.started=!0),this.started){var i=G.call(this,t,e);e&(It|Nt)&&i[0].length-i[1].length===0&&(this.started=!1),this.callback(this.manager,e,{pointers:i[0],changedPointers:i[1],pointerType:Mt,srcEvent:t})}}});var te={touchstart:Et,touchmove:Pt,touchend:It,touchcancel:Nt},ee="touchstart touchmove touchend touchcancel";u(V,S,{handler:function(t){var e=te[t.type],i=U.call(this,t,e);i&&this.callback(this.manager,e,{pointers:i[0],changedPointers:i[1],pointerType:Mt,srcEvent:t})}}),u(q,S,{handler:function(t,e,i){var o=i.pointerType==Mt,n=i.pointerType==St;if(o)this.mouse.allow=!1;else if(n&&!this.mouse.allow)return;e&(It|Nt)&&(this.mouse.allow=!0),this.callback(t,e,i)},destroy:function(){this.touch.destroy(),this.mouse.destroy()}});var ie=O(pt.style,"touchAction"),oe=ie!==a,ne="compute",se="auto",re="manipulation",ae="none",he="pan-x",de="pan-y";X.prototype={set:function(t){t==ne&&(t=this.compute()),oe&&this.manager.element.style&&(this.manager.element.style[ie]=t),this.actions=t.toLowerCase().trim()},update:function(){this.set(this.manager.options.touchAction)},compute:function(){var t=[];return l(this.manager.recognizers,function(e){f(e.options.enable,[e])&&(t=t.concat(e.getTouchAction()))}),Z(t.join(" "))},preventDefaults:function(t){if(!oe){var e=t.srcEvent,i=t.offsetDirection;if(this.manager.session.prevented)return void e.preventDefault();var o=this.actions,n=b(o,ae),s=b(o,de),r=b(o,he);if(n){var a=1===t.pointers.length,h=t.distance<2,d=t.deltaTime<250;if(a&&h&&d)return}if(!r||!s)return n||s&&i&Ft||r&&i&jt?this.preventSrc(e):void 0}},preventSrc:function(t){this.manager.session.prevented=!0,t.preventDefault()}};var le=1,ce=2,ue=4,pe=8,fe=pe,me=16,ve=32;K.prototype={defaults:{},set:function(t){return ct(this.options,t),this.manager&&this.manager.touchAction.update(),this},recognizeWith:function(t){if(d(t,"recognizeWith",this))return this;var e=this.simultaneous;return t=$(t,this),e[t.id]||(e[t.id]=t,t.recognizeWith(this)),this},dropRecognizeWith:function(t){return d(t,"dropRecognizeWith",this)?this:(t=$(t,this),delete this.simultaneous[t.id],this)},requireFailure:function(t){if(d(t,"requireFailure",this))return this;var e=this.requireFail;return t=$(t,this),-1===_(e,t)&&(e.push(t),t.requireFailure(this)),this},dropRequireFailure:function(t){if(d(t,"dropRequireFailure",this))return this;t=$(t,this);var e=_(this.requireFail,t);return e>-1&&this.requireFail.splice(e,1),this},hasRequireFailures:function(){return this.requireFail.length>0},canRecognizeWith:function(t){return!!this.simultaneous[t.id]},emit:function(t){function e(e){i.manager.emit(e,t)}var i=this,o=this.state;pe>o&&e(i.options.event+J(o)),e(i.options.event),t.additionalEvent&&e(t.additionalEvent),o>=pe&&e(i.options.event+J(o))},tryEmit:function(t){return this.canEmit()?this.emit(t):void(this.state=ve)},canEmit:function(){for(var t=0;ts?zt:Lt,i=s!=this.pX,o=Math.abs(t.deltaX)):(n=0===r?Rt:0>r?At:Bt,i=r!=this.pY,o=Math.abs(t.deltaY))),t.direction=n,i&&o>e.threshold&&n&e.direction},attrTest:function(t){return tt.prototype.attrTest.call(this,t)&&(this.state&ce||!(this.state&ce)&&this.directionTest(t))},emit:function(t){this.pX=t.deltaX,this.pY=t.deltaY;var e=Q(t.direction);e&&(t.additionalEvent=this.options.event+e),this._super.emit.call(this,t)}}),u(it,tt,{defaults:{event:"pinch",threshold:0,pointers:2},getTouchAction:function(){return[ae]},attrTest:function(t){return this._super.attrTest.call(this,t)&&(Math.abs(t.scale-1)>this.options.threshold||this.state&ce)},emit:function(t){if(1!==t.scale){var e=t.scale<1?"in":"out";t.additionalEvent=this.options.event+e}this._super.emit.call(this,t)}}),u(ot,K,{defaults:{event:"press",pointers:1,time:251,threshold:9},getTouchAction:function(){return[se]},process:function(t){var e=this.options,i=t.pointers.length===e.pointers,o=t.distancee.time;if(this._input=t,!o||!i||t.eventType&(It|Nt)&&!n)this.reset();else if(t.eventType&Et)this.reset(),this._timer=h(function(){this.state=fe,this.tryEmit()},e.time,this);else if(t.eventType&It)return fe;return ve},reset:function(){clearTimeout(this._timer)},emit:function(t){this.state===fe&&(t&&t.eventType&It?this.manager.emit(this.options.event+"up",t):(this._input.timeStamp=gt(),this.manager.emit(this.options.event,this._input)))}}),u(nt,tt,{defaults:{event:"rotate",threshold:0,pointers:2},getTouchAction:function(){return[ae]},attrTest:function(t){return this._super.attrTest.call(this,t)&&(Math.abs(t.rotation)>this.options.threshold||this.state&ce)}}),u(st,tt,{defaults:{event:"swipe",threshold:10,velocity:.3,direction:Ft|jt,pointers:1},getTouchAction:function(){return et.prototype.getTouchAction.call(this)},attrTest:function(t){var e,i=this.options.direction;return i&(Ft|jt)?e=t.overallVelocity:i&Ft?e=t.overallVelocityX:i&jt&&(e=t.overallVelocityY),this._super.attrTest.call(this,t)&&i&t.offsetDirection&&t.distance>this.options.threshold&&t.maxPointers==this.options.pointers&&vt(e)>this.options.velocity&&t.eventType&It},emit:function(t){var e=Q(t.offsetDirection);e&&this.manager.emit(this.options.event+e,t),this.manager.emit(this.options.event,t)}}),u(rt,K,{defaults:{event:"tap",pointers:1,taps:1,interval:300,time:250,threshold:9,posThreshold:10},getTouchAction:function(){return[re]},process:function(t){var e=this.options,i=t.pointers.length===e.pointers,o=t.distance=e;e++)r[String.fromCharCode(e)]={code:65+(e-97),shift:!1};for(e=65;90>=e;e++)r[String.fromCharCode(e)]={code:e,shift:!0};for(e=0;9>=e;e++)r[""+e]={code:48+e,shift:!1};for(e=1;12>=e;e++)r["F"+e]={code:111+e,shift:!1};for(e=0;9>=e;e++)r["num"+e]={code:96+e,shift:!1};r["num*"]={code:106,shift:!1},r["num+"]={code:107,shift:!1},r["num-"]={code:109,shift:!1},r["num/"]={code:111,shift:!1},r["num."]={code:110,shift:!1},r.left={code:37,shift:!1},r.up={code:38,shift:!1},r.right={code:39,shift:!1},r.down={code:40,shift:!1},r.space={code:32,shift:!1},r.enter={code:13,shift:!1},r.shift={code:16,shift:void 0},r.esc={code:27,shift:!1},r.backspace={code:8,shift:!1},r.tab={code:9,shift:!1},r.ctrl={code:17,shift:!1},r.alt={code:18,shift:!1},r["delete"]={code:46,shift:!1},r.pageup={code:33,shift:!1},r.pagedown={code:34,shift:!1},r["="]={code:187,shift:!1},r["-"]={code:189,shift:!1},r["]"]={code:221,shift:!1},r["["]={code:219,shift:!1};var a=function(t){d(t,"keydown")},h=function(t){d(t,"keyup")},d=function(t,e){if(void 0!==s[e][t.keyCode]){for(var o=s[e][t.keyCode],n=0;ne)&&(n=e),(null===s||i>s)&&(s=i)}),null!==n&&null!==s){var r=(n+s)/2,a=Math.max(this.range.end-this.range.start,1.1*(s-n)),h=e&&void 0!==e.animation?e.animation:!0;this.range.setRange(r-a/2,r+a/2,h)}}},n.prototype.fit=function(t){var e,i=t&&void 0!==t.animation?t.animation:!0,o=this.itemsData&&this.itemsData.getDataSet();1===o.length&&void 0===o.get()[0].end?(e=this.getDataRange(),this.moveTo(e.min.valueOf(),{animation:i})):(e=this.getItemRange(),this.range.setRange(e.min,e.max,i))},n.prototype.getItemRange=function(){var t=this,e=this.getDataRange(),i=null!==e.min?e.min.valueOf():null,o=null!==e.max?e.max.valueOf():null,n=null,s=null;if(null!=i&&null!=o){var r,a,h,d,c;!function(){var e=function(t){return l.convert(t.data.start,"Date").valueOf()},u=function(t){var e=void 0!=t.data.end?t.data.end:t.data.start;return l.convert(e,"Date").valueOf()};r=o-i,0>=r&&(r=10),a=r/t.props.center.width,l.forEach(t.itemSet.items,function(t){t.show(),t.repositionX();var r=e(t),h=u(t);if(this.options.rtl)var d=r-(t.getWidthRight()+10)*a,l=h+(t.getWidthLeft()+10)*a;else var d=r-(t.getWidthLeft()+10)*a,l=h+(t.getWidthRight()+10)*a;i>d&&(i=d,n=t),l>o&&(o=l,s=t)}.bind(t)),n&&s&&(h=n.getWidthLeft()+10,d=s.getWidthRight()+10,c=t.props.center.width-h-d,c>0&&(t.options.rtl?(i=e(n)-d*r/c,o=u(s)+h*r/c):(i=e(n)-h*r/c,o=u(s)+d*r/c)))}()}return{min:null!=i?new Date(i):null,max:null!=o?new Date(o):null}},n.prototype.getDataRange=function(){var t=null,e=null,i=this.itemsData&&this.itemsData.getDataSet();return i&&i.forEach(function(i){var o=l.convert(i.start,"Date").valueOf(),n=l.convert(void 0!=i.end?i.end:i.start,"Date").valueOf();(null===t||t>o)&&(t=o),(null===e||n>e)&&(e=n)}),{min:null!=t?new Date(t):null,max:null!=e?new Date(e):null}},n.prototype.getEventProperties=function(t){var e=t.center?t.center.x:t.clientX,i=t.center?t.center.y:t.clientY;if(this.options.rtl)var o=l.getAbsoluteRight(this.dom.centerContainer)-e;else var o=e-l.getAbsoluteLeft(this.dom.centerContainer);var n=i-l.getAbsoluteTop(this.dom.centerContainer),s=this.itemSet.itemFromTarget(t),r=this.itemSet.groupFromTarget(t),a=g.customTimeFromTarget(t),h=this.itemSet.options.snap||null,d=this.body.util.getScale(),c=this.body.util.getStep(),u=this._toTime(o),p=h?h(u,d,c):u,f=l.getTarget(t),m=null;return null!=s?m="item":null!=a?m="custom-time":l.hasParent(f,this.timeAxis.dom.foreground)?m="axis":this.timeAxis2&&l.hasParent(f,this.timeAxis2.dom.foreground)?m="axis":l.hasParent(f,this.itemSet.dom.labelSet)?m="group-label":l.hasParent(f,this.currentTime.bar)?m="current-time":l.hasParent(f,this.dom.center)&&(m="background"),{event:t,item:s?s.id:null,group:r?r.groupId:null,what:m,pageX:t.srcEvent?t.srcEvent.pageX:t.pageX,pageY:t.srcEvent?t.srcEvent.pageY:t.pageY,x:o,y:n,time:u,snappedTime:p}},t.exports=n},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var s="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},r=function(){function t(t,e){for(var i=0;i0&&this._makeItem([]),this._makeHeader(n),this._handleObject(this.configureOptions[n],[n])),i++);this.options.showButton===!0&&!function(){var e=document.createElement("div");e.className="vis-configuration vis-config-button",e.innerHTML="generate options",e.onclick=function(){t._printOptions()},e.onmouseover=function(){e.className="vis-configuration vis-config-button hover"},e.onmouseout=function(){e.className="vis-configuration vis-config-button"},t.optionsContainer=document.createElement("div"),t.optionsContainer.className="vis-configuration vis-config-option-container",t.domElements.push(t.optionsContainer),t.domElements.push(e)}(),this._push()}},{key:"_push",value:function(){this.wrapper=document.createElement("div"),this.wrapper.className="vis-configuration-wrapper",this.container.appendChild(this.wrapper);for(var t=0;t1?o-1:0),r=1;o>r;r++)n[r-1]=e[r];return n.forEach(function(t){s.appendChild(t)}),i.domElements.push(s),{v:i.domElements.length}}();if("object"===("undefined"==typeof a?"undefined":s(a)))return a.v}return 0}},{key:"_makeHeader",value:function(t){var e=document.createElement("div");e.className="vis-configuration vis-config-header",e.innerHTML=t,this._makeItem([],e)}},{key:"_makeLabel",value:function(t,e){var i=arguments.length<=2||void 0===arguments[2]?!1:arguments[2],o=document.createElement("div");return o.className="vis-configuration vis-config-label vis-config-s"+e.length,i===!0?o.innerHTML=""+t+":":o.innerHTML=t+":",o}},{key:"_makeDropdown",value:function(t,e,i){var o=document.createElement("select");o.className="vis-configuration vis-config-select";var n=0;void 0!==e&&-1!==t.indexOf(e)&&(n=t.indexOf(e));for(var s=0;se&&n>e*c?(a.min=Math.ceil(e*c),l=a.min,d="range increased"):n>e/c&&(a.min=Math.ceil(e/c),l=a.min,d="range increased"),e*c>s&&1!==s&&(a.max=Math.ceil(e*c),l=a.max,d="range increased"),a.value=e}else a.value=o;var u=document.createElement("input");u.className="vis-configuration vis-config-rangeinput",u.value=a.value;var p=this;a.onchange=function(){u.value=this.value,p._update(Number(this.value),i)},a.oninput=function(){u.value=this.value};var f=this._makeLabel(i[i.length-1],i),m=this._makeItem(i,f,a,u);""!==d&&this.popupHistory[m]!==l&&(this.popupHistory[m]=l,this._setupPopup(d,m))}},{key:"_setupPopup",value:function(t,e){var i=this;if(this.initialized===!0&&this.allowCreation===!0&&this.popupCountervar options = "+JSON.stringify(t,null,2)+""}},{key:"getOptions",value:function(){for(var t={},e=0;es;s++)for(r=0;rp?p+1:p;var f=l/this.r,m=a.RGBToHSV(this.color.r,this.color.g,this.color.b);m.h=p,m.s=f;var v=a.HSVToRGB(m.h,m.s,m.v);v.a=this.color.a,this.color=v,this.initialColorDiv.style.backgroundColor="rgba("+this.initialColor.r+","+this.initialColor.g+","+this.initialColor.b+","+this.initialColor.a+")",this.newColorDiv.style.backgroundColor="rgba("+this.color.r+","+this.color.g+","+this.color.b+","+this.color.a+")"}}]),t}();e["default"]=h},function(t,e,i){i(20);e.onTouch=function(t,e){e.inputHandler=function(t){t.isFirst&&e(t)},t.on("hammer.input",e.inputHandler)},e.onRelease=function(t,e){return e.inputHandler=function(t){t.isFinal&&e(t)},t.on("hammer.input",e.inputHandler)},e.offTouch=function(t,e){t.off("hammer.input",e.inputHandler)},e.offRelease=e.offTouch,e.disablePreventDefaultVertically=function(t){var e="pan-y";return t.getTouchAction=function(){return[e]},t}},function(t,e,i){function o(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var n="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},s=function(){function t(t,e){for(var i=0;is.distance?console.log('%cUnknown option detected: "'+e+'" in '+t.printLocation(n.path,e,"")+"Perhaps it was misplaced? Matching option found at: "+t.printLocation(s.path,s.closestMatch,""),d):n.distance<=r?console.log('%cUnknown option detected: "'+e+'". Did you mean "'+n.closestMatch+'"?'+t.printLocation(n.path,e),d):console.log('%cUnknown option detected: "'+e+'". Did you mean one of these: '+t.print(Object.keys(i))+t.printLocation(o,e),d),a=!0}},{key:"findInOptions",value:function(e,i,o){var n=arguments.length<=3||void 0===arguments[3]?!1:arguments[3],s=1e9,a="",h=[],d=e.toLowerCase(),l=void 0;for(var c in i){var u=void 0;if(void 0!==i[c].__type__&&n===!0){var p=t.findInOptions(e,i[c],r.copyAndExtendArray(o,c));s>p.distance&&(a=p.closestMatch,h=p.path,s=p.distance,l=p.indexMatch)}else-1!==c.toLowerCase().indexOf(d)&&(l=c),u=t.levenshteinDistance(e,c),s>u&&(a=c,h=r.copyArray(o),s=u)}return{closestMatch:a,path:h,distance:s,indexMatch:l}}},{key:"printLocation",value:function(t,e){for(var i=arguments.length<=2||void 0===arguments[2]?"Problem value found at: \n":arguments[2],o="\n\n"+i+"options = {\n",n=0;ns;s++)o+=" ";o+=t[n]+": {\n"}for(var r=0;ru,r=s||null===n?n:l+(n-l)*i,p=s||null===a?a:c+(a-c)*i;y=h._applyRange(r,p),d.updateHiddenDates(h.options.moment,h.body,h.options.hiddenDates),v=v||y,y&&h.body.emitter.emit("rangechange",{start:new Date(h.start),end:new Date(h.end),byUser:o}),s?v&&h.body.emitter.emit("rangechanged",{start:new Date(h.start),end:new Date(h.end),byUser:o}):h.animationTimer=setTimeout(w,20)}};return g()}var y=this._applyRange(n,a);if(d.updateHiddenDates(this.options.moment,this.body,this.options.hiddenDates),y){var b={start:new Date(this.start),end:new Date(this.end),byUser:o};this.body.emitter.emit("rangechange",b),this.body.emitter.emit("rangechanged",b)}},o.prototype._cancelAnimation=function(){this.animationTimer&&(clearTimeout(this.animationTimer),this.animationTimer=null)},o.prototype._applyRange=function(t,e){var i,o=null!=t?r.convert(t,"Date").valueOf():this.start,n=null!=e?r.convert(e,"Date").valueOf():this.end,s=null!=this.options.max?r.convert(this.options.max,"Date").valueOf():null,a=null!=this.options.min?r.convert(this.options.min,"Date").valueOf():null;if(isNaN(o)||null===o)throw new Error('Invalid start "'+t+'"');if(isNaN(n)||null===n)throw new Error('Invalid end "'+e+'"');if(o>n&&(n=o),null!==a&&a>o&&(i=a-o,o+=i,n+=i,null!=s&&n>s&&(n=s)),null!==s&&n>s&&(i=n-s,o-=i,n-=i,null!=a&&a>o&&(o=a)),null!==this.options.zoomMin){var h=parseFloat(this.options.zoomMin);0>h&&(h=0),h>n-o&&(this.end-this.start===h&&o>this.start&&nd&&(d=0),n-o>d&&(this.end-this.start===d&&othis.end?(o=this.start,n=this.end):(i=n-o-d,o+=i/2,n-=i/2))}var l=this.start!=o||this.end!=n;return o>=this.start&&o<=this.end||n>=this.start&&n<=this.end||this.start>=o&&this.start<=n||this.end>=o&&this.end<=n||this.body.emitter.emit("checkRangedItems"),this.start=o,this.end=n,l},o.prototype.getRange=function(){return{start:this.start,end:this.end}},o.prototype.conversion=function(t,e){return o.conversion(this.start,this.end,t,e)},o.conversion=function(t,e,i,o){return void 0===o&&(o=0),0!=i&&e-t!=0?{offset:t,scale:i/(e-t-o)}:{offset:0,scale:1}},o.prototype._onDragStart=function(t){this.deltaDifference=0,this.previousDelta=0,this.options.moveable&&this._isInsideRange(t)&&this.props.touch.allowDragging&&(this.props.touch.start=this.start,this.props.touch.end=this.end,this.props.touch.dragging=!0,this.body.dom.root&&(this.body.dom.root.style.cursor="move"))},o.prototype._onDrag=function(t){if(this.props.touch.dragging&&this.options.moveable&&this.props.touch.allowDragging){var e=this.options.direction;n(e);var i="horizontal"==e?t.deltaX:t.deltaY;i-=this.deltaDifference;var o=this.props.touch.end-this.props.touch.start,s=d.getHiddenDurationBetween(this.body.hiddenDates,this.start,this.end);o-=s;var r="horizontal"==e?this.body.domProps.center.width:this.body.domProps.center.height;if(this.options.rtl)var a=i/r*o;else var a=-i/r*o;var h=this.props.touch.start+a,l=this.props.touch.end+a,c=d.snapAwayFromHidden(this.body.hiddenDates,h,this.previousDelta-i,!0),u=d.snapAwayFromHidden(this.body.hiddenDates,l,this.previousDelta-i,!0);if(c!=h||u!=l)return this.deltaDifference+=i,this.props.touch.start=c,this.props.touch.end=u,void this._onDrag(t);this.previousDelta=i,this._applyRange(h,l);var p=new Date(this.start),f=new Date(this.end);this.body.emitter.emit("rangechange",{start:p,end:f,byUser:!0})}},o.prototype._onDragEnd=function(t){this.props.touch.dragging&&this.options.moveable&&this.props.touch.allowDragging&&(this.props.touch.dragging=!1,this.body.dom.root&&(this.body.dom.root.style.cursor="auto"),this.body.emitter.emit("rangechanged",{start:new Date(this.start),end:new Date(this.end),byUser:!0}))},o.prototype._onMouseWheel=function(t){if(this.options.zoomable&&this.options.moveable&&this._isInsideRange(t)&&(!this.options.zoomKey||t[this.options.zoomKey])){var e=0;if(t.wheelDelta?e=t.wheelDelta/120:t.detail&&(e=-t.detail/3),e){var i;i=0>e?1-e/5:1/(1+e/5);var o=this.getPointer({x:t.clientX,y:t.clientY},this.body.dom.center),n=this._pointerToDate(o);this.zoom(i,n,e)}t.preventDefault()}},o.prototype._onTouch=function(t){this.props.touch.start=this.start,this.props.touch.end=this.end,this.props.touch.allowDragging=!0,this.props.touch.center=null,this.scaleOffset=0,this.deltaDifference=0},o.prototype._onPinch=function(t){if(this.options.zoomable&&this.options.moveable){this.props.touch.allowDragging=!1,this.props.touch.center||(this.props.touch.center=this.getPointer(t.center,this.body.dom.center));var e=1/(t.scale+this.scaleOffset),i=this._pointerToDate(this.props.touch.center),o=d.getHiddenDurationBetween(this.body.hiddenDates,this.start,this.end),n=d.getHiddenDurationBefore(this.options.moment,this.body.hiddenDates,this,i),s=o-n,r=i-n+(this.props.touch.start-(i-n))*e,a=i+s+(this.props.touch.end-(i+s))*e; +this.startToFront=0>=1-e,this.endToFront=0>=e-1;var h=d.snapAwayFromHidden(this.body.hiddenDates,r,1-e,!0),l=d.snapAwayFromHidden(this.body.hiddenDates,a,e-1,!0);h==r&&l==a||(this.props.touch.start=h,this.props.touch.end=l,this.scaleOffset=1-t.scale,r=h,a=l),this.setRange(r,a,!1,!0),this.startToFront=!1,this.endToFront=!0}},o.prototype._isInsideRange=function(t){var e=t.center?t.center.x:t.clientX;if(this.options.rtl)var i=e-r.getAbsoluteLeft(this.body.dom.centerContainer);else var i=r.getAbsoluteRight(this.body.dom.centerContainer)-e;var o=this.body.util.toTime(i);return o>=this.start&&o<=this.end},o.prototype._pointerToDate=function(t){var e,i=this.options.direction;if(n(i),"horizontal"==i)return this.body.util.toTime(t.x).valueOf();var o=this.body.domProps.center.height;return e=this.conversion(o),t.y/e.scale+e.offset},o.prototype.getPointer=function(t,e){return this.options.rtl?{x:r.getAbsoluteRight(e)-t.x,y:t.y-r.getAbsoluteTop(e)}:{x:t.x-r.getAbsoluteLeft(e),y:t.y-r.getAbsoluteTop(e)}},o.prototype.zoom=function(t,e,i){null==e&&(e=(this.start+this.end)/2);var o=d.getHiddenDurationBetween(this.body.hiddenDates,this.start,this.end),n=d.getHiddenDurationBefore(this.options.moment,this.body.hiddenDates,this,e),s=o-n,r=e-n+(this.start-(e-n))*t,a=e+s+(this.end-(e+s))*t;this.startToFront=!(i>0),this.endToFront=!(-i>0);var h=d.snapAwayFromHidden(this.body.hiddenDates,r,i,!0),l=d.snapAwayFromHidden(this.body.hiddenDates,a,-i,!0);h==r&&l==a||(r=h,a=l),this.setRange(r,a,!1,!0),this.startToFront=!1,this.endToFront=!0},o.prototype.move=function(t){var e=this.end-this.start,i=this.start+e*t,o=this.end+e*t;this.start=i,this.end=o},o.prototype.moveTo=function(t){var e=(this.start+this.end)/2,i=e-t,o=this.start-i,n=this.end-i;this.setRange(o,n)},t.exports=o},function(t,e){function i(t,e){this.options=null,this.props=null}i.prototype.setOptions=function(t){t&&util.extend(this.options,t)},i.prototype.redraw=function(){return!1},i.prototype.destroy=function(){},i.prototype._isResized=function(){var t=this.props._previousWidth!==this.props.width||this.props._previousHeight!==this.props.height;return this.props._previousWidth=this.props.width,this.props._previousHeight=this.props.height,t},t.exports=i},function(t,e){e.convertHiddenOptions=function(t,i,o){if(o&&!Array.isArray(o))return e.convertHiddenOptions(t,i,[o]);if(i.hiddenDates=[],o&&1==Array.isArray(o)){for(var n=0;n=4*a){var u=0,p=s.clone();switch(o[h].repeat){case"daily":d.day()!=l.day()&&(u=1),d.dayOfYear(n.dayOfYear()),d.year(n.year()),d.subtract(7,"days"),l.dayOfYear(n.dayOfYear()),l.year(n.year()),l.subtract(7-u,"days"),p.add(1,"weeks");break;case"weekly":var f=l.diff(d,"days"),m=d.day();d.date(n.date()),d.month(n.month()),d.year(n.year()),l=d.clone(),d.day(m),l.day(m),l.add(f,"days"),d.subtract(1,"weeks"),l.subtract(1,"weeks"),p.add(1,"weeks");break;case"monthly":d.month()!=l.month()&&(u=1),d.month(n.month()),d.year(n.year()),d.subtract(1,"months"),l.month(n.month()),l.year(n.year()),l.subtract(1,"months"),l.add(u,"months"),p.add(1,"months");break;case"yearly":d.year()!=l.year()&&(u=1),d.year(n.year()),d.subtract(1,"years"),l.year(n.year()),l.subtract(1,"years"),l.add(u,"years"),p.add(1,"years");break;default:return void console.log("Wrong repeat format, allowed are: daily, weekly, monthly, yearly. Given:",o[h].repeat)}for(;p>d;)switch(i.hiddenDates.push({start:d.valueOf(),end:l.valueOf()}),o[h].repeat){case"daily":d.add(1,"days"),l.add(1,"days");break;case"weekly":d.add(1,"weeks"),l.add(1,"weeks");break;case"monthly":d.add(1,"months"),l.add(1,"months");break;case"yearly":d.add(1,"y"),l.add(1,"y");break;default:return void console.log("Wrong repeat format, allowed are: daily, weekly, monthly, yearly. Given:",o[h].repeat)}i.hiddenDates.push({start:d.valueOf(),end:l.valueOf()})}}e.removeDuplicates(i);var v=e.isHidden(i.range.start,i.hiddenDates),g=e.isHidden(i.range.end,i.hiddenDates),y=i.range.start,b=i.range.end;1==v.hidden&&(y=1==i.range.startToFront?v.startDate-1:v.endDate+1),1==g.hidden&&(b=1==i.range.endToFront?g.startDate-1:g.endDate+1),1!=v.hidden&&1!=g.hidden||i.range._applyRange(y,b)}},e.removeDuplicates=function(t){for(var e=t.hiddenDates,i=[],o=0;o=e[o].start&&e[n].end<=e[o].end?e[n].remove=!0:e[n].start>=e[o].start&&e[n].start<=e[o].end?(e[o].end=e[n].end,e[n].remove=!0):e[n].end>=e[o].start&&e[n].end<=e[o].end&&(e[o].start=e[n].start,e[n].remove=!0));for(var o=0;o=r&&a>n){o=!0;break}}if(1==o&&n=e&&i>r&&(o+=r-s)}return o},e.correctTimeForHidden=function(t,i,o,n){return n=t(n).toDate().valueOf(),n-=e.getHiddenDurationBefore(t,i,o,n)},e.getHiddenDurationBefore=function(t,e,i,o){var n=0;o=t(o).toDate().valueOf();for(var s=0;s=i.start&&a=a&&(n+=a-r)}return n},e.getAccumulatedHiddenDuration=function(t,e,i){for(var o=0,n=0,s=e.start,r=0;r=e.start&&h=i)break;o+=h-a}}return o},e.snapAwayFromHidden=function(t,i,o,n){var s=e.isHidden(i,t);return 1==s.hidden?0>o?1==n?s.startDate-(s.endDate-i)-1:s.startDate-1:1==n?s.endDate+(i-s.startDate)+1:s.endDate+1:i},e.isHidden=function(t,e){for(var i=0;i=o&&n>t)return{hidden:!0,startDate:o,endDate:n}}return{hidden:!1,startDate:o,endDate:n}}},function(t,e,i){function o(){}var n="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},s=i(13),r=i(20),a=i(28),h=i(1),d=(i(9),i(11),i(30),i(34),i(44)),l=i(45),c=i(32),u=i(46);s(o.prototype),o.prototype._create=function(t){function e(t){i.isActive()&&i.emit("mousewheel",t)}this.dom={},this.dom.container=t,this.dom.root=document.createElement("div"),this.dom.background=document.createElement("div"),this.dom.backgroundVertical=document.createElement("div"),this.dom.backgroundHorizontal=document.createElement("div"),this.dom.centerContainer=document.createElement("div"),this.dom.leftContainer=document.createElement("div"),this.dom.rightContainer=document.createElement("div"),this.dom.center=document.createElement("div"),this.dom.left=document.createElement("div"),this.dom.right=document.createElement("div"),this.dom.top=document.createElement("div"),this.dom.bottom=document.createElement("div"),this.dom.shadowTop=document.createElement("div"),this.dom.shadowBottom=document.createElement("div"),this.dom.shadowTopLeft=document.createElement("div"),this.dom.shadowBottomLeft=document.createElement("div"),this.dom.shadowTopRight=document.createElement("div"),this.dom.shadowBottomRight=document.createElement("div"),this.dom.root.className="vis-timeline",this.dom.background.className="vis-panel vis-background",this.dom.backgroundVertical.className="vis-panel vis-background vis-vertical",this.dom.backgroundHorizontal.className="vis-panel vis-background vis-horizontal",this.dom.centerContainer.className="vis-panel vis-center",this.dom.leftContainer.className="vis-panel vis-left",this.dom.rightContainer.className="vis-panel vis-right",this.dom.top.className="vis-panel vis-top",this.dom.bottom.className="vis-panel vis-bottom",this.dom.left.className="vis-content",this.dom.center.className="vis-content",this.dom.right.className="vis-content",this.dom.shadowTop.className="vis-shadow vis-top",this.dom.shadowBottom.className="vis-shadow vis-bottom",this.dom.shadowTopLeft.className="vis-shadow vis-top",this.dom.shadowBottomLeft.className="vis-shadow vis-bottom",this.dom.shadowTopRight.className="vis-shadow vis-top",this.dom.shadowBottomRight.className="vis-shadow vis-bottom",this.dom.root.appendChild(this.dom.background),this.dom.root.appendChild(this.dom.backgroundVertical),this.dom.root.appendChild(this.dom.backgroundHorizontal),this.dom.root.appendChild(this.dom.centerContainer),this.dom.root.appendChild(this.dom.leftContainer),this.dom.root.appendChild(this.dom.rightContainer),this.dom.root.appendChild(this.dom.top),this.dom.root.appendChild(this.dom.bottom),this.dom.centerContainer.appendChild(this.dom.center),this.dom.leftContainer.appendChild(this.dom.left),this.dom.rightContainer.appendChild(this.dom.right),this.dom.centerContainer.appendChild(this.dom.shadowTop),this.dom.centerContainer.appendChild(this.dom.shadowBottom),this.dom.leftContainer.appendChild(this.dom.shadowTopLeft),this.dom.leftContainer.appendChild(this.dom.shadowBottomLeft),this.dom.rightContainer.appendChild(this.dom.shadowTopRight),this.dom.rightContainer.appendChild(this.dom.shadowBottomRight),this.on("rangechange",function(){this.initialDrawDone===!0&&this._redraw()}.bind(this)),this.on("touch",this._onTouch.bind(this)),this.on("pan",this._onDrag.bind(this));var i=this;this.on("_change",function(t){t&&1==t.queue?i._redrawTimer||(i._redrawTimer=setTimeout(function(){i._redrawTimer=null,i._redraw()},0)):i._redraw()}),this.hammer=new r(this.dom.root);var o=this.hammer.get("pinch").set({enable:!0});a.disablePreventDefaultVertically(o),this.hammer.get("pan").set({threshold:5,direction:r.DIRECTION_HORIZONTAL}),this.listeners={};var n=["tap","doubletap","press","pinch","pan","panstart","panmove","panend"];if(n.forEach(function(t){var e=function(e){i.isActive()&&i.emit(t,e)};i.hammer.on(t,e),i.listeners[t]=e}),a.onTouch(this.hammer,function(t){i.emit("touch",t)}.bind(this)),a.onRelease(this.hammer,function(t){i.emit("release",t)}.bind(this)),this.dom.root.addEventListener("mousewheel",e),this.dom.root.addEventListener("DOMMouseScroll",e),this.props={root:{},background:{},centerContainer:{},leftContainer:{},rightContainer:{},center:{},left:{},right:{},top:{},bottom:{},border:{},scrollTop:0,scrollTopMin:0},this.customTimes=[],this.touch={},this.redrawCount=0,this.initialDrawDone=!1,!t)throw new Error("No container provided");t.appendChild(this.dom.root)},o.prototype.setOptions=function(t){if(t){var e=["width","height","minHeight","maxHeight","autoResize","start","end","clickToUse","dataAttributes","hiddenDates","locale","locales","moment","rtl","throttleRedraw"];if(h.selectiveExtend(e,this.options,t),this.options.rtl){var i=this.dom.leftContainer;this.dom.leftContainer=this.dom.rightContainer,this.dom.rightContainer=i,this.dom.container.style.direction="rtl",this.dom.backgroundVertical.className="vis-panel vis-background vis-vertical-rtl"}if(this.options.orientation={item:void 0,axis:void 0},"orientation"in t&&("string"==typeof t.orientation?this.options.orientation={item:t.orientation,axis:t.orientation}:"object"===n(t.orientation)&&("item"in t.orientation&&(this.options.orientation.item=t.orientation.item),"axis"in t.orientation&&(this.options.orientation.axis=t.orientation.axis))),"both"===this.options.orientation.axis){if(!this.timeAxis2){var o=this.timeAxis2=new d(this.body);o.setOptions=function(t){var e=t?h.extend({},t):{};e.orientation="top",d.prototype.setOptions.call(o,e)},this.components.push(o)}}else if(this.timeAxis2){var s=this.components.indexOf(this.timeAxis2);-1!==s&&this.components.splice(s,1),this.timeAxis2.destroy(),this.timeAxis2=null}if("function"==typeof t.drawPoints&&(t.drawPoints={onRender:t.drawPoints}),"hiddenDates"in this.options&&c.convertHiddenOptions(this.options.moment,this.body,this.options.hiddenDates),"clickToUse"in t&&(t.clickToUse?this.activator||(this.activator=new l(this.dom.root)):this.activator&&(this.activator.destroy(),delete this.activator)),"showCustomTime"in t)throw new Error("Option `showCustomTime` is deprecated. Create a custom time bar via timeline.addCustomTime(time [, id])");this._initAutoResize()}if(this.components.forEach(function(e){return e.setOptions(t)}),"configure"in t){this.configurator||(this.configurator=this._createConfigurator()),this.configurator.setOptions(t.configure);var r=h.deepExtend({},this.options);this.components.forEach(function(t){h.deepExtend(r,t.options)}),this.configurator.setModuleOptions({global:r})}this._origRedraw?this._redraw():(this._origRedraw=this._redraw.bind(this),this._redraw=h.throttle(this._origRedraw,this.options.throttleRedraw))},o.prototype.isActive=function(){return!this.activator||this.activator.active},o.prototype.destroy=function(){this.setItems(null),this.setGroups(null),this.off(),this._stopAutoResize(),this.dom.root.parentNode&&this.dom.root.parentNode.removeChild(this.dom.root),this.dom=null,this.activator&&(this.activator.destroy(),delete this.activator);for(var t in this.listeners)this.listeners.hasOwnProperty(t)&&delete this.listeners[t];this.listeners=null,this.hammer=null,this.components.forEach(function(t){return t.destroy()}),this.body=null},o.prototype.setCustomTime=function(t,e){var i=this.customTimes.filter(function(t){return e===t.options.id});if(0===i.length)throw new Error("No custom time bar found with id "+JSON.stringify(e));i.length>0&&i[0].setCustomTime(t)},o.prototype.getCustomTime=function(t){var e=this.customTimes.filter(function(e){return e.options.id===t});if(0===e.length)throw new Error("No custom time bar found with id "+JSON.stringify(t));return e[0].getCustomTime()},o.prototype.setCustomTimeTitle=function(t,e){var i=this.customTimes.filter(function(t){return t.options.id===e});if(0===i.length)throw new Error("No custom time bar found with id "+JSON.stringify(e));return i.length>0?i[0].setCustomTitle(t):void 0},o.prototype.getEventProperties=function(t){return{event:t}},o.prototype.addCustomTime=function(t,e){var i=void 0!==t?h.convert(t,"Date").valueOf():new Date,o=this.customTimes.some(function(t){return t.options.id===e});if(o)throw new Error("A custom time with id "+JSON.stringify(e)+" already exists");var n=new u(this.body,h.extend({},this.options,{time:i,id:e}));return this.customTimes.push(n),this.components.push(n),this._redraw(),e},o.prototype.removeCustomTime=function(t){var e=this.customTimes.filter(function(e){return e.options.id===t});if(0===e.length)throw new Error("No custom time bar found with id "+JSON.stringify(t));e.forEach(function(t){this.customTimes.splice(this.customTimes.indexOf(t),1),this.components.splice(this.components.indexOf(t),1),t.destroy()}.bind(this))},o.prototype.getVisibleItems=function(){return this.itemSet&&this.itemSet.getVisibleItems()||[]},o.prototype.fit=function(t){var e=this.getDataRange();if(null!==e.min||null!==e.max){var i=e.max-e.min,o=new Date(e.min.valueOf()-.01*i),n=new Date(e.max.valueOf()+.01*i),s=t&&void 0!==t.animation?t.animation:!0;this.range.setRange(o,n,s)}},o.prototype.getDataRange=function(){throw new Error("Cannot invoke abstract method getDataRange")},o.prototype.setWindow=function(t,e,i){var o;if(1==arguments.length){var n=arguments[0];o=void 0!==n.animation?n.animation:!0,this.range.setRange(n.start,n.end,o)}else o=i&&void 0!==i.animation?i.animation:!0,this.range.setRange(t,e,o)},o.prototype.moveTo=function(t,e){var i=this.range.end-this.range.start,o=h.convert(t,"Date").valueOf(),n=o-i/2,s=o+i/2,r=e&&void 0!==e.animation?e.animation:!0;this.range.setRange(n,s,r)},o.prototype.getWindow=function(){var t=this.range.getRange();return{start:new Date(t.start),end:new Date(t.end)}},o.prototype.redraw=function(){this._redraw()},o.prototype._redraw=function(){this.redrawCount++;var t=!1,e=this.options,i=this.props,o=this.dom;if(o&&o.container&&0!=o.root.offsetWidth){c.updateHiddenDates(this.options.moment,this.body,this.options.hiddenDates),"top"==e.orientation?(h.addClassName(o.root,"vis-top"),h.removeClassName(o.root,"vis-bottom")):(h.removeClassName(o.root,"vis-top"),h.addClassName(o.root,"vis-bottom")),o.root.style.maxHeight=h.option.asSize(e.maxHeight,""),o.root.style.minHeight=h.option.asSize(e.minHeight,""),o.root.style.width=h.option.asSize(e.width,""),i.border.left=(o.centerContainer.offsetWidth-o.centerContainer.clientWidth)/2,i.border.right=i.border.left,i.border.top=(o.centerContainer.offsetHeight-o.centerContainer.clientHeight)/2,i.border.bottom=i.border.top;var n=o.root.offsetHeight-o.root.clientHeight,s=o.root.offsetWidth-o.root.clientWidth;0===o.centerContainer.clientHeight&&(i.border.left=i.border.top,i.border.right=i.border.left),0===o.root.clientHeight&&(s=n),i.center.height=o.center.offsetHeight,i.left.height=o.left.offsetHeight,i.right.height=o.right.offsetHeight,i.top.height=o.top.clientHeight||-i.border.top,i.bottom.height=o.bottom.clientHeight||-i.border.bottom;var a=Math.max(i.left.height,i.center.height,i.right.height),d=i.top.height+a+i.bottom.height+n+i.border.top+i.border.bottom;o.root.style.height=h.option.asSize(e.height,d+"px"),i.root.height=o.root.offsetHeight,i.background.height=i.root.height-n;var l=i.root.height-i.top.height-i.bottom.height-n;i.centerContainer.height=l,i.leftContainer.height=l,i.rightContainer.height=i.leftContainer.height,i.root.width=o.root.offsetWidth,i.background.width=i.root.width-s,i.left.width=o.leftContainer.clientWidth||-i.border.left,i.leftContainer.width=i.left.width,i.right.width=o.rightContainer.clientWidth||-i.border.right,i.rightContainer.width=i.right.width;var u=i.root.width-i.left.width-i.right.width-s;i.center.width=u,i.centerContainer.width=u,i.top.width=u,i.bottom.width=u,o.background.style.height=i.background.height+"px",o.backgroundVertical.style.height=i.background.height+"px",o.backgroundHorizontal.style.height=i.centerContainer.height+"px",o.centerContainer.style.height=i.centerContainer.height+"px",o.leftContainer.style.height=i.leftContainer.height+"px",o.rightContainer.style.height=i.rightContainer.height+"px",o.background.style.width=i.background.width+"px",o.backgroundVertical.style.width=i.centerContainer.width+"px",o.backgroundHorizontal.style.width=i.background.width+"px",o.centerContainer.style.width=i.center.width+"px",o.top.style.width=i.top.width+"px",o.bottom.style.width=i.bottom.width+"px",o.background.style.left="0",o.background.style.top="0",o.backgroundVertical.style.left=i.left.width+i.border.left+"px",o.backgroundVertical.style.top="0",o.backgroundHorizontal.style.left="0",o.backgroundHorizontal.style.top=i.top.height+"px",o.centerContainer.style.left=i.left.width+"px",o.centerContainer.style.top=i.top.height+"px",o.leftContainer.style.left="0",o.leftContainer.style.top=i.top.height+"px",o.rightContainer.style.left=i.left.width+i.center.width+"px",o.rightContainer.style.top=i.top.height+"px",o.top.style.left=i.left.width+"px",o.top.style.top="0",o.bottom.style.left=i.left.width+"px",o.bottom.style.top=i.top.height+i.centerContainer.height+"px",this._updateScrollTop();var p=this.props.scrollTop;"top"!=e.orientation.item&&(p+=Math.max(this.props.centerContainer.height-this.props.center.height-this.props.border.top-this.props.border.bottom,0)),o.center.style.left="0",o.center.style.top=p+"px",o.left.style.left="0",o.left.style.top=p+"px",o.right.style.left="0",o.right.style.top=p+"px";var f=0==this.props.scrollTop?"hidden":"",m=this.props.scrollTop==this.props.scrollTopMin?"hidden":"";o.shadowTop.style.visibility=f,o.shadowBottom.style.visibility=m,o.shadowTopLeft.style.visibility=f,o.shadowBottomLeft.style.visibility=m,o.shadowTopRight.style.visibility=f,o.shadowBottomRight.style.visibility=m;var v=this.props.center.height>this.props.centerContainer.height;this.hammer.get("pan").set({direction:v?r.DIRECTION_ALL:r.DIRECTION_HORIZONTAL}),this.components.forEach(function(e){t=e.redraw()||t});var g=5;if(t){if(this.redrawCount0&&(this.props.scrollTop=0),this.props.scrollTope;e++)o=this.selection[e],n=this.items[o],n&&n.unselect();for(this.selection=[],e=0,i=t.length;i>e;e++)o=t[e],n=this.items[o],n&&(this.selection.push(o),n.select())},o.prototype.getSelection=function(){return this.selection.concat([])},o.prototype.getVisibleItems=function(){var t=this.body.range.getRange();if(this.options.rtl)var e=this.body.util.toScreen(t.start),i=this.body.util.toScreen(t.end);else var i=this.body.util.toScreen(t.start),e=this.body.util.toScreen(t.end);var o=[];for(var n in this.groups)if(this.groups.hasOwnProperty(n))for(var s=this.groups[n],r=s.visibleItems,a=0;ae&&o.push(h.id):h.lefti&&o.push(h.id)}return o},o.prototype._deselect=function(t){for(var e=this.selection,i=0,o=e.length;o>i;i++)if(e[i]==t){e.splice(i,1);break}},o.prototype.redraw=function(){var t=this.options.margin,e=this.body.range,i=r.option.asSize,o=this.options,n=o.orientation.item,s=!1,a=this.dom.frame;this.props.top=this.body.domProps.top.height+this.body.domProps.border.top,this.options.rtl?this.props.right=this.body.domProps.right.width+this.body.domProps.border.right:this.props.left=this.body.domProps.left.width+this.body.domProps.border.left,a.className="vis-itemset",s=this._orderGroups()||s;var h=e.end-e.start,d=h!=this.lastVisibleInterval||this.props.width!=this.props.lastWidth;d&&(this.stackDirty=!0), +this.lastVisibleInterval=h,this.props.lastWidth=this.props.width;var l=this.stackDirty,c=this._firstGroup(),u={item:t.item,axis:t.axis},p={item:t.item,axis:t.item.vertical/2},f=0,m=t.axis+t.item.vertical;return this.groups[y].redraw(e,p,l),r.forEach(this.groups,function(t){var i=t==c?u:p,o=t.redraw(e,i,l);s=o||s,f+=t.height}),f=Math.max(f,m),this.stackDirty=!1,a.style.height=i(f),this.props.width=a.offsetWidth,this.props.height=f,this.dom.axis.style.top=i("top"==n?this.body.domProps.top.height+this.body.domProps.border.top:this.body.domProps.top.height+this.body.domProps.centerContainer.height),this.options.rtl?this.dom.axis.style.right="0":this.dom.axis.style.left="0",s=this._isResized()||s},o.prototype._firstGroup=function(){var t="top"==this.options.orientation.item?0:this.groupIds.length-1,e=this.groupIds[t],i=this.groups[e]||this.groups[g];return i||null},o.prototype._updateUngrouped=function(){var t,e,i=this.groups[g];this.groups[y];if(this.groupsData){if(i){i.hide(),delete this.groups[g];for(e in this.items)if(this.items.hasOwnProperty(e)){t=this.items[e],t.parent&&t.parent.remove(t);var o=this._getGroupId(t.data),n=this.groups[o];n&&n.add(t)||t.hide()}}}else if(!i){var s=null,r=null;i=new c(s,r,this),this.groups[g]=i;for(e in this.items)this.items.hasOwnProperty(e)&&(t=this.items[e],i.add(t));i.show()}},o.prototype.getLabelSet=function(){return this.dom.labelSet},o.prototype.setItems=function(t){var e,i=this,o=this.itemsData;if(t){if(!(t instanceof a||t instanceof h))throw new TypeError("Data must be an instance of DataSet or DataView");this.itemsData=t}else this.itemsData=null;if(o&&(r.forEach(this.itemListeners,function(t,e){o.off(e,t)}),e=o.getIds(),this._onRemove(e)),this.itemsData){var n=this.id;r.forEach(this.itemListeners,function(t,e){i.itemsData.on(e,t,n)}),e=this.itemsData.getIds(),this._onAdd(e),this._updateUngrouped()}this.body.emitter.emit("_change",{queue:!0})},o.prototype.getItems=function(){return this.itemsData},o.prototype.setGroups=function(t){var e,i=this;if(this.groupsData&&(r.forEach(this.groupListeners,function(t,e){i.groupsData.off(e,t)}),e=this.groupsData.getIds(),this.groupsData=null,this._onRemoveGroups(e)),t){if(!(t instanceof a||t instanceof h))throw new TypeError("Data must be an instance of DataSet or DataView");this.groupsData=t}else this.groupsData=null;if(this.groupsData){var o=this.id;r.forEach(this.groupListeners,function(t,e){i.groupsData.on(e,t,o)}),e=this.groupsData.getIds(),this._onAddGroups(e)}this._updateUngrouped(),this._order(),this.body.emitter.emit("_change",{queue:!0})},o.prototype.getGroups=function(){return this.groupsData},o.prototype.removeItem=function(t){var e=this.itemsData.get(t),i=this.itemsData.getDataSet();e&&this.options.onRemove(e,function(e){e&&i.remove(t)})},o.prototype._getType=function(t){return t.type||this.options.type||(t.end?"range":"box")},o.prototype._getGroupId=function(t){var e=this._getType(t);return"background"==e&&void 0==t.group?y:this.groupsData?t.group:g},o.prototype._onUpdate=function(t){var e=this;t.forEach(function(t){var i,n=e.itemsData.get(t,e.itemOptions),s=e.items[t],r=e._getType(n),a=o.types[r];if(s&&(a&&s instanceof a?e._updateItem(s,n):(i=s.selected,e._removeItem(s),s=null)),!s){if(!a)throw"rangeoverflow"==r?new TypeError('Item type "rangeoverflow" is deprecated. Use css styling instead: .vis-item.vis-range .vis-item-content {overflow: visible;}'):new TypeError('Unknown item type "'+r+'"');s=new a(n,e.conversion,e.options),s.id=t,e._addItem(s),i&&(this.selection.push(t),s.select())}}.bind(this)),this._order(),this.stackDirty=!0,this.body.emitter.emit("_change",{queue:!0})},o.prototype._onAdd=o.prototype._onUpdate,o.prototype._onRemove=function(t){var e=0,i=this;t.forEach(function(t){var o=i.items[t];o&&(e++,i._removeItem(o))}),e&&(this._order(),this.stackDirty=!0,this.body.emitter.emit("_change",{queue:!0}))},o.prototype._order=function(){r.forEach(this.groups,function(t){t.order()})},o.prototype._onUpdateGroups=function(t){this._onAddGroups(t)},o.prototype._onAddGroups=function(t){var e=this;t.forEach(function(t){var i=e.groupsData.get(t),o=e.groups[t];if(o)o.setData(i);else{if(t==g||t==y)throw new Error("Illegal group id. "+t+" is a reserved id.");var n=Object.create(e.options);r.extend(n,{height:null}),o=new c(t,i,e),e.groups[t]=o;for(var s in e.items)if(e.items.hasOwnProperty(s)){var a=e.items[s];a.data.group==t&&o.add(a)}o.order(),o.show()}}),this.body.emitter.emit("_change",{queue:!0})},o.prototype._onRemoveGroups=function(t){var e=this.groups;t.forEach(function(t){var i=e[t];i&&(i.hide(),delete e[t])}),this.markDirty(),this.body.emitter.emit("_change",{queue:!0})},o.prototype._orderGroups=function(){if(this.groupsData){var t=this.groupsData.getIds({order:this.options.groupOrder}),e=!r.equalArray(t,this.groupIds);if(e){var i=this.groups;t.forEach(function(t){i[t].hide()}),t.forEach(function(t){i[t].show()}),this.groupIds=t}return e}return!1},o.prototype._addItem=function(t){this.items[t.id]=t;var e=this._getGroupId(t.data),i=this.groups[e];i&&i.add(t)},o.prototype._updateItem=function(t,e){var i=t.data.group,o=t.data.subgroup;if(t.setData(e),i!=t.data.group||o!=t.data.subgroup){var n=this.groups[i];n&&n.remove(t);var s=this._getGroupId(t.data),r=this.groups[s];r&&r.add(t)}},o.prototype._removeItem=function(t){t.hide(),delete this.items[t.id];var e=this.selection.indexOf(t.id);-1!=e&&this.selection.splice(e,1),t.parent&&t.parent.remove(t)},o.prototype._constructByEndArray=function(t){for(var e=[],i=0;in+s)return}else{var a=e.height;if(n+a-s>o)return}}if(e&&e!=this.groupTouchParams.group){var h=this.groupsData,d=h.get(e.groupId),l=h.get(this.groupTouchParams.group.groupId);l&&d&&(this.options.groupOrderSwap(l,d,this.groupsData),this.groupsData.update(l),this.groupsData.update(d));var c=this.groupsData.getIds({order:this.options.groupOrder});if(!r.equalArray(c,this.groupTouchParams.originalOrder))for(var h=this.groupsData,u=this.groupTouchParams.originalOrder,p=this.groupTouchParams.group.groupId,f=Math.min(u.length,c.length),m=0,v=0,g=0;f>m;){for(;f>m+v&&f>m+g&&c[m+v]==u[m+g];)m++;if(m+v>=f)break;if(c[m+v]!=p)if(u[m+g]!=p){var y=c.indexOf(u[m+g]),b=h.get(c[m+v]),w=h.get(u[m+g]);this.options.groupOrderSwap(b,w,h),h.update(b),h.update(w);var _=c[m+v];c[m+v]=u[m+g],c[y]=_,m++}else g=1;else v=1}}}},o.prototype._onGroupDragEnd=function(t){if(this.options.groupEditable.order&&this.groupTouchParams.group){t.stopPropagation();var e=this,i=e.groupTouchParams.group.groupId,o=e.groupsData.getDataSet(),n=r.extend({},o.get(i));e.options.onMoveGroup(n,function(t){if(t)t[o._fieldId]=i,o.update(t);else{var n=o.getIds({order:e.options.groupOrder});if(!r.equalArray(n,e.groupTouchParams.originalOrder))for(var s=e.groupTouchParams.originalOrder,a=Math.min(s.length,n.length),h=0;a>h;){for(;a>h&&n[h]==s[h];)h++;if(h>=a)break;var d=n.indexOf(s[h]),l=o.get(n[h]),c=o.get(s[h]);e.options.groupOrderSwap(l,c,o),groupsData.update(l),groupsData.update(c);var u=n[h];n[h]=s[h],n[d]=u,h++}}}),e.body.emitter.emit("groupDragged",{groupId:i})}},o.prototype._onSelectItem=function(t){if(this.options.selectable){var e=t.srcEvent&&(t.srcEvent.ctrlKey||t.srcEvent.metaKey),i=t.srcEvent&&t.srcEvent.shiftKey;if(e||i)return void this._onMultiSelectItem(t);var o=this.getSelection(),n=this.itemFromTarget(t),s=n?[n.id]:[];this.setSelection(s);var r=this.getSelection();(r.length>0||o.length>0)&&this.body.emitter.emit("select",{items:r,event:t})}},o.prototype._onAddItem=function(t){if(this.options.selectable&&this.options.editable.add){var e=this,i=this.options.snap||null,o=this.itemFromTarget(t);if(o){var n=e.itemsData.get(o.id);this.options.onUpdate(n,function(t){t&&e.itemsData.getDataSet().update(t)})}else{if(this.options.rtl)var s=r.getAbsoluteRight(this.dom.frame),a=s-t.center.x;else var s=r.getAbsoluteLeft(this.dom.frame),a=t.center.x-s;var h=this.body.util.toTime(a),d=this.body.util.getScale(),l=this.body.util.getStep(),c={start:i?i(h,d,l):h,content:"new item"};if("range"===this.options.type){var u=this.body.util.toTime(a+this.props.width/5);c.end=i?i(u,d,l):u}c[this.itemsData._fieldId]=r.randomUUID();var p=this.groupFromTarget(t);p&&(c.group=p.groupId),c=this._cloneItemData(c),this.options.onAdd(c,function(t){t&&e.itemsData.getDataSet().add(t)})}}},o.prototype._onMultiSelectItem=function(t){if(this.options.selectable){var e=this.itemFromTarget(t);if(e){var i=this.options.multiselect?this.getSelection():[],n=t.srcEvent&&t.srcEvent.shiftKey||!1;if(n&&this.options.multiselect){var s=this.itemsData.get(e.id).group,r=void 0;this.options.multiselectPerGroup&&i.length>0&&(r=this.itemsData.get(i[0]).group),this.options.multiselectPerGroup&&void 0!=r&&r!=s||i.push(e.id);var a=o._getItemRange(this.itemsData.get(i,this.itemOptions));if(!this.options.multiselectPerGroup||r==s){i=[];for(var h in this.items)if(this.items.hasOwnProperty(h)){var d=this.items[h],l=d.data.start,c=void 0!==d.data.end?d.data.end:l;!(l>=a.min&&c<=a.max)||this.options.multiselectPerGroup&&r!=this.itemsData.get(d.id).group||d instanceof v||i.push(d.id)}}}else{var u=i.indexOf(e.id);-1==u?i.push(e.id):i.splice(u,1)}this.setSelection(i),this.body.emitter.emit("select",{items:this.getSelection(),event:t})}}},o._getItemRange=function(t){var e=null,i=null;return t.forEach(function(t){(null==i||t.starte)&&(e=t.end):(null==e||t.start>e)&&(e=t.start)}),{min:i,max:e}},o.prototype.itemFromTarget=function(t){for(var e=t.target;e;){if(e.hasOwnProperty("timeline-item"))return e["timeline-item"];e=e.parentNode}return null},o.prototype.groupFromTarget=function(t){for(var e=t.center?t.center.y:t.clientY,i=0;ia&&ea)return n}else if(0===i&&e0?t.step:1,this.autoScale=!1)},o.prototype.setAutoScale=function(t){this.autoScale=t},o.prototype.setMinimumStep=function(t){if(void 0!=t){var e=31104e6,i=2592e6,o=864e5,n=36e5,s=6e4,r=1e3,a=1;1e3*e>t&&(this.scale="year",this.step=1e3),500*e>t&&(this.scale="year",this.step=500),100*e>t&&(this.scale="year",this.step=100),50*e>t&&(this.scale="year",this.step=50),10*e>t&&(this.scale="year",this.step=10),5*e>t&&(this.scale="year",this.step=5),e>t&&(this.scale="year",this.step=1),3*i>t&&(this.scale="month",this.step=3),i>t&&(this.scale="month",this.step=1),5*o>t&&(this.scale="day",this.step=5),2*o>t&&(this.scale="day",this.step=2),o>t&&(this.scale="day",this.step=1),o/2>t&&(this.scale="weekday",this.step=1),4*n>t&&(this.scale="hour",this.step=4),n>t&&(this.scale="hour",this.step=1),15*s>t&&(this.scale="minute",this.step=15),10*s>t&&(this.scale="minute",this.step=10),5*s>t&&(this.scale="minute",this.step=5),s>t&&(this.scale="minute",this.step=1),15*r>t&&(this.scale="second",this.step=15),10*r>t&&(this.scale="second",this.step=10),5*r>t&&(this.scale="second",this.step=5),r>t&&(this.scale="second",this.step=1),200*a>t&&(this.scale="millisecond",this.step=200),100*a>t&&(this.scale="millisecond",this.step=100),50*a>t&&(this.scale="millisecond",this.step=50),10*a>t&&(this.scale="millisecond",this.step=10),5*a>t&&(this.scale="millisecond",this.step=5),a>t&&(this.scale="millisecond",this.step=1)}},o.snap=function(t,e,i){var o=n(t);if("year"==e){var s=o.year()+Math.round(o.month()/12);o.year(Math.round(s/i)*i),o.month(0),o.date(0),o.hours(0),o.minutes(0),o.seconds(0),o.milliseconds(0)}else if("month"==e)o.date()>15?(o.date(1),o.add(1,"month")):o.date(1),o.hours(0),o.minutes(0),o.seconds(0),o.milliseconds(0);else if("day"==e){switch(i){case 5:case 2:o.hours(24*Math.round(o.hours()/24));break;default:o.hours(12*Math.round(o.hours()/12))}o.minutes(0),o.seconds(0),o.milliseconds(0)}else if("weekday"==e){switch(i){case 5:case 2:o.hours(12*Math.round(o.hours()/12));break;default:o.hours(6*Math.round(o.hours()/6))}o.minutes(0),o.seconds(0),o.milliseconds(0)}else if("hour"==e){switch(i){case 4:o.minutes(60*Math.round(o.minutes()/60));break;default:o.minutes(30*Math.round(o.minutes()/30))}o.seconds(0),o.milliseconds(0)}else if("minute"==e){switch(i){case 15:case 10:o.minutes(5*Math.round(o.minutes()/5)),o.seconds(0);break;case 5:o.seconds(60*Math.round(o.seconds()/60));break;default:o.seconds(30*Math.round(o.seconds()/30))}o.milliseconds(0)}else if("second"==e)switch(i){case 15:case 10:o.seconds(5*Math.round(o.seconds()/5)),o.milliseconds(0);break;case 5:o.milliseconds(1e3*Math.round(o.milliseconds()/1e3));break;default:o.milliseconds(500*Math.round(o.milliseconds()/500))}else if("millisecond"==e){var r=i>5?i/2:1;o.milliseconds(Math.round(o.milliseconds()/r)*r)}return o},o.prototype.isMajor=function(){if(1==this.switchedYear)switch(this.switchedYear=!1,this.scale){case"year":case"month":case"weekday":case"day":case"hour":case"minute":case"second":case"millisecond":return!0;default:return!1}else if(1==this.switchedMonth)switch(this.switchedMonth=!1,this.scale){case"weekday":case"day":case"hour":case"minute":case"second":case"millisecond":return!0;default:return!1}else if(1==this.switchedDay)switch(this.switchedDay=!1,this.scale){case"millisecond":case"second":case"minute":case"hour":return!0;default:return!1}var t=this.moment(this.current);switch(this.scale){case"millisecond":return 0==t.milliseconds();case"second":return 0==t.seconds();case"minute":return 0==t.hours()&&0==t.minutes();case"hour":return 0==t.hours();case"weekday":case"day":return 1==t.date();case"month":return 0==t.month();case"year":return!1;default:return!1}},o.prototype.getLabelMinor=function(t){void 0==t&&(t=this.current);var e=this.format.minorLabels[this.scale];return e&&e.length>0?this.moment(t).format(e):""},o.prototype.getLabelMajor=function(t){void 0==t&&(t=this.current);var e=this.format.majorLabels[this.scale];return e&&e.length>0?this.moment(t).format(e):""},o.prototype.getClassName=function(){function t(t){return t/h%2==0?" vis-even":" vis-odd"}function e(t){return t.isSame(new Date,"day")?" vis-today":t.isSame(s().add(1,"day"),"day")?" vis-tomorrow":t.isSame(s().add(-1,"day"),"day")?" vis-yesterday":""}function i(t){return t.isSame(new Date,"week")?" vis-current-week":""}function o(t){return t.isSame(new Date,"month")?" vis-current-month":""}function n(t){return t.isSame(new Date,"year")?" vis-current-year":""}var s=this.moment,r=this.moment(this.current),a=r.locale?r.locale("en"):r.lang("en"),h=this.step;switch(this.scale){case"millisecond":return t(a.milliseconds()).trim();case"second":return t(a.seconds()).trim();case"minute":return t(a.minutes()).trim();case"hour":var d=a.hours();return 4==this.step&&(d=d+"-h"+(d+4)),"vis-h"+d+e(a)+t(a.hours());case"weekday":return"vis-"+a.format("dddd").toLowerCase()+e(a)+i(a)+t(a.date());case"day":var l=a.date(),c=a.format("MMMM").toLowerCase();return"vis-day"+l+" vis-"+c+o(a)+t(l-1);case"month":return"vis-"+a.format("MMMM").toLowerCase()+o(a)+t(a.month());case"year":var u=a.year();return"vis-year"+u+n(a)+t(u);default:return""}},t.exports=o},function(t,e,i){function o(t,e,i){this.groupId=t,this.subgroups={},this.subgroupIndex=0,this.subgroupOrderer=e&&e.subgroupOrder,this.itemSet=i,this.dom={},this.props={label:{width:0,height:0}},this.className=null,this.items={},this.visibleItems=[],this.orderedItems={byStart:[],byEnd:[]},this.checkRangedItems=!1;var o=this;this.itemSet.body.emitter.on("checkRangedItems",function(){o.checkRangedItems=!0}),this._create(),this.setData(e)}var n=i(1),s=i(37);i(38);o.prototype._create=function(){var t=document.createElement("div");this.itemSet.options.groupEditable.order?t.className="vis-label draggable":t.className="vis-label",this.dom.label=t;var e=document.createElement("div");e.className="vis-inner",t.appendChild(e),this.dom.inner=e;var i=document.createElement("div");i.className="vis-group",i["timeline-group"]=this,this.dom.foreground=i,this.dom.background=document.createElement("div"),this.dom.background.className="vis-group",this.dom.axis=document.createElement("div"),this.dom.axis.className="vis-group",this.dom.marker=document.createElement("div"),this.dom.marker.style.visibility="hidden",this.dom.marker.innerHTML="?",this.dom.background.appendChild(this.dom.marker)},o.prototype.setData=function(t){var e;if(e=this.itemSet.options&&this.itemSet.options.groupTemplate?this.itemSet.options.groupTemplate(t):t&&t.content,e instanceof Element){for(this.dom.inner.appendChild(e);this.dom.inner.firstChild;)this.dom.inner.removeChild(this.dom.inner.firstChild);this.dom.inner.appendChild(e)}else void 0!==e&&null!==e?this.dom.inner.innerHTML=e:this.dom.inner.innerHTML=this.groupId||"";this.dom.label.title=t&&t.title||"",this.dom.inner.firstChild?n.removeClassName(this.dom.inner,"vis-hidden"):n.addClassName(this.dom.inner,"vis-hidden");var i=t&&t.className||null;i!=this.className&&(this.className&&(n.removeClassName(this.dom.label,this.className),n.removeClassName(this.dom.foreground,this.className),n.removeClassName(this.dom.background,this.className),n.removeClassName(this.dom.axis,this.className)),n.addClassName(this.dom.label,i),n.addClassName(this.dom.foreground,i),n.addClassName(this.dom.background,i),n.addClassName(this.dom.axis,i),this.className=i),this.style&&(n.removeCssText(this.dom.label,this.style),this.style=null),t&&t.style&&(n.addCssText(this.dom.label,t.style),this.style=t.style)},o.prototype.getLabelWidth=function(){return this.props.label.width},o.prototype.redraw=function(t,e,i){var o=!1,r=this.dom.marker.clientHeight;if(r!=this.lastMarkerHeight&&(this.lastMarkerHeight=r,n.forEach(this.items,function(t){t.dirty=!0,t.displayed&&t.redraw()}),i=!0),this._calculateSubGroupHeights(),"function"==typeof this.itemSet.options.order){if(i){var a=this,h=!1;n.forEach(this.items,function(t){t.displayed||(t.redraw(),a.visibleItems.push(t)),t.repositionX(h)});var d=this.orderedItems.byStart.slice().sort(function(t,e){return a.itemSet.options.order(t.data,e.data)});s.stack(d,e,!0)}this.visibleItems=this._updateVisibleItems(this.orderedItems,this.visibleItems,t)}else this.visibleItems=this._updateVisibleItems(this.orderedItems,this.visibleItems,t),this.itemSet.options.stack?s.stack(this.visibleItems,e,i):s.nostack(this.visibleItems,e,this.subgroups);var l=this._calculateHeight(e),c=this.dom.foreground;this.top=c.offsetTop,this.right=c.offsetLeft,this.width=c.offsetWidth,o=n.updateProperty(this,"height",l)||o,o=n.updateProperty(this.props.label,"width",this.dom.inner.clientWidth)||o,o=n.updateProperty(this.props.label,"height",this.dom.inner.clientHeight)||o,this.dom.background.style.height=l+"px",this.dom.foreground.style.height=l+"px",this.dom.label.style.height=l+"px";for(var u=0,p=this.visibleItems.length;p>u;u++){var f=this.visibleItems[u];f.repositionY(e)}return o},o.prototype._calculateSubGroupHeights=function(){if(Object.keys(this.subgroups).length>0){var t=this;this.resetSubgroups(),n.forEach(this.visibleItems,function(e){void 0!==e.data.subgroup&&(t.subgroups[e.data.subgroup].height=Math.max(t.subgroups[e.data.subgroup].height,e.height),t.subgroups[e.data.subgroup].visible=!0)})}},o.prototype._calculateHeight=function(t){var e,i=this.visibleItems;if(i.length>0){var o=i[0].top,s=i[0].top+i[0].height;if(n.forEach(i,function(t){o=Math.min(o,t.top),s=Math.max(s,t.top+t.height)}),o>t.axis){var r=o-t.axis;s-=r,n.forEach(i,function(t){t.top-=r})}e=s+t.item.vertical/2}else e=0;return e=Math.max(e,this.props.label.height)},o.prototype.show=function(){this.dom.label.parentNode||this.itemSet.dom.labelSet.appendChild(this.dom.label),this.dom.foreground.parentNode||this.itemSet.dom.foreground.appendChild(this.dom.foreground),this.dom.background.parentNode||this.itemSet.dom.background.appendChild(this.dom.background),this.dom.axis.parentNode||this.itemSet.dom.axis.appendChild(this.dom.axis)},o.prototype.hide=function(){var t=this.dom.label;t.parentNode&&t.parentNode.removeChild(t);var e=this.dom.foreground;e.parentNode&&e.parentNode.removeChild(e);var i=this.dom.background;i.parentNode&&i.parentNode.removeChild(i);var o=this.dom.axis;o.parentNode&&o.parentNode.removeChild(o)},o.prototype.add=function(t){if(this.items[t.id]=t,t.setParent(this),void 0!==t.data.subgroup&&(void 0===this.subgroups[t.data.subgroup]&&(this.subgroups[t.data.subgroup]={height:0,visible:!1,index:this.subgroupIndex,items:[]},this.subgroupIndex++),this.subgroups[t.data.subgroup].items.push(t)),this.orderSubgroups(),-1==this.visibleItems.indexOf(t)){var e=this.itemSet.body.range;this._checkIfVisible(t,this.visibleItems,e)}},o.prototype.orderSubgroups=function(){if(void 0!==this.subgroupOrderer){var t=[];if("string"==typeof this.subgroupOrderer){for(var e in this.subgroups)t.push({subgroup:e,sortField:this.subgroups[e].items[0].data[this.subgroupOrderer]});t.sort(function(t,e){return t.sortField-e.sortField})}else if("function"==typeof this.subgroupOrderer){for(var e in this.subgroups)t.push(this.subgroups[e].items[0].data);t.sort(this.subgroupOrderer)}if(t.length>0)for(var i=0;it?-1:l>=t?0:1};if(e.length>0)for(s=0;sl}),1==this.checkRangedItems)for(this.checkRangedItems=!1,s=0;sl})}for(s=0;s=0&&(s=e[r],!n(s));r--)void 0===o[s.id]&&(o[s.id]=!0,i.push(s));for(r=t+1;rn;n++)t[n].top=null;for(n=0,s=t.length;s>n;n++){var r=t[n];if(r.stack&&null===r.top){r.top=i.axis;do{for(var a=null,h=0,d=t.length;d>h;h++){var l=t[h];if(null!==l.top&&l!==r&&l.stack&&e.collision(r,l,i.item,l.options.rtl)){a=l;break}}null!=a&&(r.top=a.top+a.height+i.item.vertical)}while(a)}}},e.nostack=function(t,e,i){var o,n,s;for(o=0,n=t.length;n>o;o++)if(void 0!==t[o].data.subgroup){s=e.axis;for(var r in i)i.hasOwnProperty(r)&&1==i[r].visible&&i[r].indexe.right&&t.top-o.vertical+ie.top:t.left-o.horizontal+ie.left&&t.top-o.vertical+ie.top}},function(t,e,i){function o(t,e,i){if(this.props={content:{width:0}},this.overflow=!1,this.options=i,t){if(void 0==t.start)throw new Error('Property "start" missing in item '+t.id);if(void 0==t.end)throw new Error('Property "end" missing in item '+t.id)}n.call(this,t,e,i)}var n=(i(20),i(39));o.prototype=new n(null,null,null),o.prototype.baseClassName="vis-item vis-range",o.prototype.isVisible=function(t){return this.data.startt.start},o.prototype.redraw=function(){var t=this.dom;if(t||(this.dom={},t=this.dom,t.box=document.createElement("div"),t.frame=document.createElement("div"),t.frame.className="vis-item-overflow",t.box.appendChild(t.frame),t.content=document.createElement("div"),t.content.className="vis-item-content",t.frame.appendChild(t.content),t.box["timeline-item"]=this,this.dirty=!0),!this.parent)throw new Error("Cannot redraw item: no parent attached");if(!t.box.parentNode){var e=this.parent.dom.foreground;if(!e)throw new Error("Cannot redraw item: parent has no foreground container element");e.appendChild(t.box)}if(this.displayed=!0,this.dirty){this._updateContents(this.dom.content),this._updateTitle(this.dom.box),this._updateDataAttributes(this.dom.box),this._updateStyle(this.dom.box);var i=(this.options.editable.updateTime||this.options.editable.updateGroup||this.editable===!0)&&this.editable!==!1,o=(this.data.className?" "+this.data.className:"")+(this.selected?" vis-selected":"")+(i?" vis-editable":" vis-readonly");t.box.className=this.baseClassName+o,this.overflow="hidden"!==window.getComputedStyle(t.frame).overflow,this.dom.content.style.maxWidth="none",this.props.content.width=this.dom.content.offsetWidth,this.height=this.dom.box.offsetHeight,this.dom.content.style.maxWidth="",this.dirty=!1}this._repaintDeleteButton(t.box),this._repaintDragLeft(),this._repaintDragRight()},o.prototype.show=function(){this.displayed||this.redraw()},o.prototype.hide=function(){if(this.displayed){var t=this.dom.box;t.parentNode&&t.parentNode.removeChild(t),this.displayed=!1}},o.prototype.repositionX=function(t){var e,i,o=this.parent.width,n=this.conversion.toScreen(this.data.start),s=this.conversion.toScreen(this.data.end);void 0!==t&&t!==!0||(-o>n&&(n=-o),s>2*o&&(s=2*o));var r=Math.max(s-n,1);switch(this.overflow?(this.options.rtl?this.right=n:this.left=n,this.width=r+this.props.content.width,i=this.props.content.width):(this.options.rtl?this.right=n:this.left=n,this.width=r,i=Math.min(s-n,this.props.content.width)),this.options.rtl?this.dom.box.style.right=this.right+"px":this.dom.box.style.left=this.left+"px",this.dom.box.style.width=r+"px",this.options.align){case"left":this.options.rtl?this.dom.content.style.right="0":this.dom.content.style.left="0";break;case"right":this.options.rtl?this.dom.content.style.right=Math.max(r-i,0)+"px":this.dom.content.style.left=Math.max(r-i,0)+"px";break;case"center":this.options.rtl?this.dom.content.style.right=Math.max((r-i)/2,0)+"px":this.dom.content.style.left=Math.max((r-i)/2,0)+"px";break;default:e=this.overflow?s>0?Math.max(-n,0):-i:0>n?-n:0,this.options.rtl?this.dom.content.style.right=e+"px":this.dom.content.style.left=e+"px"}},o.prototype.repositionY=function(){var t=this.options.orientation.item,e=this.dom.box;"top"==t?e.style.top=this.top+"px":e.style.top=this.parent.height-this.top-this.height+"px"},o.prototype._repaintDragLeft=function(){if(this.selected&&this.options.editable.updateTime&&!this.dom.dragLeft){var t=document.createElement("div");t.className="vis-drag-left",t.dragLeftItem=this,this.dom.box.appendChild(t),this.dom.dragLeft=t}else!this.selected&&this.dom.dragLeft&&(this.dom.dragLeft.parentNode&&this.dom.dragLeft.parentNode.removeChild(this.dom.dragLeft),this.dom.dragLeft=null)},o.prototype._repaintDragRight=function(){if(this.selected&&this.options.editable.updateTime&&!this.dom.dragRight){var t=document.createElement("div");t.className="vis-drag-right",t.dragRightItem=this,this.dom.box.appendChild(t),this.dom.dragRight=t}else!this.selected&&this.dom.dragRight&&(this.dom.dragRight.parentNode&&this.dom.dragRight.parentNode.removeChild(this.dom.dragRight),this.dom.dragRight=null)},t.exports=o},function(t,e,i){function o(t,e,i){this.id=null,this.parent=null,this.data=t,this.dom=null,this.conversion=e||{},this.options=i||{},this.selected=!1,this.displayed=!1,this.dirty=!0,this.top=null,this.right=null,this.left=null,this.width=null,this.height=null,this.editable=null,this.data&&this.data.hasOwnProperty("editable")&&"boolean"==typeof this.data.editable&&(this.editable=t.editable)}var n=i(20),s=i(1);o.prototype.stack=!0,o.prototype.select=function(){this.selected=!0,this.dirty=!0,this.displayed&&this.redraw()},o.prototype.unselect=function(){this.selected=!1,this.dirty=!0,this.displayed&&this.redraw()},o.prototype.setData=function(t){var e=void 0!=t.group&&this.data.group!=t.group;e&&this.parent.itemSet._moveToGroup(this,t.group),t.hasOwnProperty("editable")&&"boolean"==typeof t.editable&&(this.editable=t.editable),this.data=t,this.dirty=!0,this.displayed&&this.redraw()},o.prototype.setParent=function(t){this.displayed?(this.hide(),this.parent=t,this.parent&&this.show()):this.parent=t},o.prototype.isVisible=function(t){return!1},o.prototype.show=function(){return!1},o.prototype.hide=function(){return!1},o.prototype.redraw=function(){},o.prototype.repositionX=function(){},o.prototype.repositionY=function(){},o.prototype._repaintDeleteButton=function(t){var e=(this.options.editable.remove||this.data.editable===!0)&&this.data.editable!==!1;if(this.selected&&e&&!this.dom.deleteButton){var i=this,o=document.createElement("div");this.options.rtl?o.className="vis-delete-rtl":o.className="vis-delete",o.title="Delete this item",new n(o).on("tap",function(t){t.stopPropagation(),i.parent.removeFromDataSet(i)}),t.appendChild(o),this.dom.deleteButton=o}else!this.selected&&this.dom.deleteButton&&(this.dom.deleteButton.parentNode&&this.dom.deleteButton.parentNode.removeChild(this.dom.deleteButton),this.dom.deleteButton=null)},o.prototype._updateContents=function(t){var e;if(this.options.template){var i=this.parent.itemSet.itemsData.get(this.id);e=this.options.template(i)}else e=this.data.content;var o=this._contentToString(this.content)!==this._contentToString(e);if(o){if(e instanceof Element)t.innerHTML="",t.appendChild(e);else if(void 0!=e)t.innerHTML=e;else if("background"!=this.data.type||void 0!==this.data.content)throw new Error('Property "content" missing in item '+this.id);this.content=e}},o.prototype._updateTitle=function(t){null!=this.data.title?t.title=this.data.title||"":t.removeAttribute("vis-title")},o.prototype._updateDataAttributes=function(t){if(this.options.dataAttributes&&this.options.dataAttributes.length>0){var e=[];if(Array.isArray(this.options.dataAttributes))e=this.options.dataAttributes;else{if("all"!=this.options.dataAttributes)return;e=Object.keys(this.data)}for(var i=0;in;n++){var r=this.visibleItems[n];r.repositionY(e)}return o},o.prototype.show=function(){this.dom.background.parentNode||this.itemSet.dom.background.appendChild(this.dom.background)},t.exports=o},function(t,e,i){function o(t,e,i){if(this.props={dot:{width:0,height:0},line:{width:0,height:0}},this.options=i,t&&void 0==t.start)throw new Error('Property "start" missing in item '+t);n.call(this,t,e,i)}var n=i(39);i(1);o.prototype=new n(null,null,null),o.prototype.isVisible=function(t){var e=(t.end-t.start)/4;return this.data.start>t.start-e&&this.data.startt.start-e&&this.data.startt.start},o.prototype.redraw=function(){var t=this.dom;if(t||(this.dom={},t=this.dom,t.box=document.createElement("div"),t.frame=document.createElement("div"),t.frame.className="vis-item-overflow",t.box.appendChild(t.frame),t.content=document.createElement("div"),t.content.className="vis-item-content",t.frame.appendChild(t.content),this.dirty=!0),!this.parent)throw new Error("Cannot redraw item: no parent attached");if(!t.box.parentNode){var e=this.parent.dom.background;if(!e)throw new Error("Cannot redraw item: parent has no background container element");e.appendChild(t.box)}if(this.displayed=!0,this.dirty){this._updateContents(this.dom.content),this._updateTitle(this.dom.content),this._updateDataAttributes(this.dom.content),this._updateStyle(this.dom.box);var i=(this.data.className?" "+this.data.className:"")+(this.selected?" vis-selected":"");t.box.className=this.baseClassName+i,this.overflow="hidden"!==window.getComputedStyle(t.content).overflow,this.props.content.width=this.dom.content.offsetWidth,this.height=0,this.dirty=!1}},o.prototype.show=r.prototype.show,o.prototype.hide=r.prototype.hide,o.prototype.repositionX=r.prototype.repositionX,o.prototype.repositionY=function(t){var e="top"===this.options.orientation.item;this.dom.content.style.top=e?"":"0",this.dom.content.style.bottom=e?"0":"";var i;if(void 0!==this.data.subgroup){var o=this.data.subgroup,n=this.parent.subgroups,r=n[o].index;if(1==e){i=this.parent.subgroups[o].height+t.item.vertical,i+=0==r?t.axis-.5*t.item.vertical:0;var a=this.parent.top;for(var h in n)n.hasOwnProperty(h)&&1==n[h].visible&&n[h].indexr&&(a+=l)}i=this.parent.subgroups[o].height+t.item.vertical,this.dom.box.style.top=this.parent.height-d+a+"px",this.dom.box.style.bottom=""}}else this.parent instanceof s?(i=Math.max(this.parent.height,this.parent.itemSet.body.domProps.center.height,this.parent.itemSet.body.domProps.centerContainer.height),this.dom.box.style.top=e?"0":"",this.dom.box.style.bottom=e?"":"0"):(i=this.parent.height,this.dom.box.style.top=this.parent.top+"px",this.dom.box.style.bottom="");this.dom.box.style.height=i+"px"},t.exports=o},function(t,e,i){function o(t,e){this.dom={foreground:null,lines:[],majorTexts:[],minorTexts:[],redundant:{lines:[],majorTexts:[],minorTexts:[]}},this.props={range:{start:0,end:0,minimumStep:0},lineTop:0},this.defaultOptions={orientation:{axis:"bottom"},showMinorLabels:!0,showMajorLabels:!0,maxMinorChars:7,format:a.FORMAT,moment:d,timeAxis:null},this.options=s.extend({},this.defaultOptions),this.body=t,this._create(),this.setOptions(e)}var n="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},s=i(1),r=i(31),a=i(35),h=i(32),d=i(2);o.prototype=new r,o.prototype.setOptions=function(t){t&&(s.selectiveExtend(["showMinorLabels","showMajorLabels","maxMinorChars","hiddenDates","timeAxis","moment","rtl"],this.options,t),s.selectiveDeepExtend(["format"],this.options,t),"orientation"in t&&("string"==typeof t.orientation?this.options.orientation.axis=t.orientation:"object"===n(t.orientation)&&"axis"in t.orientation&&(this.options.orientation.axis=t.orientation.axis)),"locale"in t&&("function"==typeof d.locale?d.locale(t.locale):d.lang(t.locale)))},o.prototype._create=function(){this.dom.foreground=document.createElement("div"),this.dom.background=document.createElement("div"),this.dom.foreground.className="vis-time-axis vis-foreground",this.dom.background.className="vis-time-axis vis-background"},o.prototype.destroy=function(){this.dom.foreground.parentNode&&this.dom.foreground.parentNode.removeChild(this.dom.foreground),this.dom.background.parentNode&&this.dom.background.parentNode.removeChild(this.dom.background),this.body=null},o.prototype.redraw=function(){var t=this.props,e=this.dom.foreground,i=this.dom.background,o="top"==this.options.orientation.axis?this.body.dom.top:this.body.dom.bottom,n=e.parentNode!==o;this._calculateCharSize();var s=this.options.showMinorLabels&&"none"!==this.options.orientation.axis,r=this.options.showMajorLabels&&"none"!==this.options.orientation.axis;t.minorLabelHeight=s?t.minorCharHeight:0,t.majorLabelHeight=r?t.majorCharHeight:0,t.height=t.minorLabelHeight+t.majorLabelHeight,t.width=e.offsetWidth,t.minorLineHeight=this.body.domProps.root.height-t.majorLabelHeight-("top"==this.options.orientation.axis?this.body.domProps.bottom.height:this.body.domProps.top.height),t.minorLineWidth=1,t.majorLineHeight=t.minorLineHeight+t.majorLabelHeight,t.majorLineWidth=1;var a=e.nextSibling,h=i.nextSibling;return e.parentNode&&e.parentNode.removeChild(e),i.parentNode&&i.parentNode.removeChild(i),e.style.height=this.props.height+"px",this._repaintLabels(),a?o.insertBefore(e,a):o.appendChild(e),h?this.body.dom.backgroundVertical.insertBefore(i,h):this.body.dom.backgroundVertical.appendChild(i),this._isResized()||n},o.prototype._repaintLabels=function(){var t=this.options.orientation.axis,e=s.convert(this.body.range.start,"Number"),i=s.convert(this.body.range.end,"Number"),o=this.body.util.toTime((this.props.minorCharWidth||10)*this.options.maxMinorChars).valueOf(),n=o-h.getHiddenDurationBefore(this.options.moment,this.body.hiddenDates,this.body.range,o);n-=this.body.util.toTime(0).valueOf();var r=new a(new Date(e),new Date(i),n,this.body.hiddenDates);r.setMoment(this.options.moment),this.options.format&&r.setFormat(this.options.format),this.options.timeAxis&&r.setScale(this.options.timeAxis),this.step=r;var d=this.dom;d.redundant.lines=d.lines,d.redundant.majorTexts=d.majorTexts,d.redundant.minorTexts=d.minorTexts,d.lines=[],d.majorTexts=[],d.minorTexts=[];var c,u,p,f,m,v,g,y,b,w,_=0,x=void 0,k=0,O=1e3;for(r.start(),u=r.getCurrent(),f=this.body.util.toScreen(u);r.hasNext()&&O>k;){k++,m=r.isMajor(),w=r.getClassName(),b=r.getLabelMinor(),c=u,p=f,r.next(),u=r.getCurrent(),v=r.isMajor(),f=this.body.util.toScreen(u),g=_,_=f-p;var M=_>=.4*g;if(this.options.showMinorLabels&&M){var D=this._repaintMinorText(p,b,t,w);D.style.width=_+"px"}m&&this.options.showMajorLabels?(p>0&&(void 0==x&&(x=p),D=this._repaintMajorText(p,r.getLabelMajor(),t,w)),y=this._repaintMajorLine(p,_,t,w)):M?y=this._repaintMinorLine(p,_,t,w):y&&(y.style.width=parseInt(y.style.width)+_+"px")}if(k!==O||l||(console.warn("Something is wrong with the Timeline scale. Limited drawing of grid lines to "+O+" lines."),l=!0),this.options.showMajorLabels){var S=this.body.util.toTime(0),C=r.getLabelMajor(S),T=C.length*(this.props.majorCharWidth||10)+10;(void 0==x||x>T)&&this._repaintMajorText(0,C,t,w)}s.forEach(this.dom.redundant,function(t){for(;t.length;){var e=t.pop();e&&e.parentNode&&e.parentNode.removeChild(e)}})},o.prototype._repaintMinorText=function(t,e,i,o){var n=this.dom.redundant.minorTexts.shift();if(!n){var s=document.createTextNode("");n=document.createElement("div"),n.appendChild(s),this.dom.foreground.appendChild(n)}return this.dom.minorTexts.push(n),n.childNodes[0].nodeValue=e,n.style.top="top"==i?this.props.majorLabelHeight+"px":"0",this.options.rtl?(n.style.left="",n.style.right=t+"px"):n.style.left=t+"px",n.className="vis-text vis-minor "+o,n},o.prototype._repaintMajorText=function(t,e,i,o){var n=this.dom.redundant.majorTexts.shift();if(!n){var s=document.createTextNode(e);n=document.createElement("div"),n.appendChild(s),this.dom.foreground.appendChild(n)}return this.dom.majorTexts.push(n),n.childNodes[0].nodeValue=e,n.className="vis-text vis-major "+o,n.style.top="top"==i?"0":this.props.minorLabelHeight+"px",this.options.rtl?(n.style.left="",n.style.right=t+"px"):n.style.left=t+"px",n},o.prototype._repaintMinorLine=function(t,e,i,o){var n=this.dom.redundant.lines.shift();n||(n=document.createElement("div"),this.dom.background.appendChild(n)),this.dom.lines.push(n);var s=this.props;return"top"==i?n.style.top=s.majorLabelHeight+"px":n.style.top=this.body.domProps.top.height+"px",n.style.height=s.minorLineHeight+"px",this.options.rtl?(n.style.left="",n.style.right=t-s.minorLineWidth/2+"px",n.className="vis-grid vis-vertical-rtl vis-minor "+o):(n.style.left=t-s.minorLineWidth/2+"px",n.className="vis-grid vis-vertical vis-minor "+o),n.style.width=e+"px",n},o.prototype._repaintMajorLine=function(t,e,i,o){var n=this.dom.redundant.lines.shift();n||(n=document.createElement("div"),this.dom.background.appendChild(n)),this.dom.lines.push(n);var s=this.props;return"top"==i?n.style.top="0":n.style.top=this.body.domProps.top.height+"px",this.options.rtl?(n.style.left="",n.style.right=t-s.majorLineWidth/2+"px",n.className="vis-grid vis-vertical-rtl vis-major "+o):(n.style.left=t-s.majorLineWidth/2+"px",n.className="vis-grid vis-vertical vis-major "+o),n.style.height=s.majorLineHeight+"px",n.style.width=e+"px",n},o.prototype._calculateCharSize=function(){this.dom.measureCharMinor||(this.dom.measureCharMinor=document.createElement("DIV"),this.dom.measureCharMinor.className="vis-text vis-minor vis-measure",this.dom.measureCharMinor.style.position="absolute",this.dom.measureCharMinor.appendChild(document.createTextNode("0")),this.dom.foreground.appendChild(this.dom.measureCharMinor)),this.props.minorCharHeight=this.dom.measureCharMinor.clientHeight,this.props.minorCharWidth=this.dom.measureCharMinor.clientWidth,this.dom.measureCharMajor||(this.dom.measureCharMajor=document.createElement("DIV"),this.dom.measureCharMajor.className="vis-text vis-major vis-measure",this.dom.measureCharMajor.style.position="absolute",this.dom.measureCharMajor.appendChild(document.createTextNode("0")),this.dom.foreground.appendChild(this.dom.measureCharMajor)),this.props.majorCharHeight=this.dom.measureCharMajor.clientHeight,this.props.majorCharWidth=this.dom.measureCharMajor.clientWidth};var l=!1;t.exports=o},function(t,e,i){function o(t){this.active=!1,this.dom={container:t},this.dom.overlay=document.createElement("div"),this.dom.overlay.className="vis-overlay",this.dom.container.appendChild(this.dom.overlay),this.hammer=a(this.dom.overlay),this.hammer.on("tap",this._onTapOverlay.bind(this));var e=this,i=["tap","doubletap","press","pinch","pan","panstart","panmove","panend"];i.forEach(function(t){e.hammer.on(t,function(t){t.stopPropagation()})}),document&&document.body&&(this.onClick=function(i){n(i.target,t)||e.deactivate()},document.body.addEventListener("click",this.onClick)),void 0!==this.keycharm&&this.keycharm.destroy(),this.keycharm=s(),this.escListener=this.deactivate.bind(this)}function n(t,e){for(;t;){if(t===e)return!0;t=t.parentNode}return!1}var s=i(23),r=i(13),a=i(20),h=i(1);r(o.prototype),o.current=null,o.prototype.destroy=function(){this.deactivate(),this.dom.overlay.parentNode.removeChild(this.dom.overlay),this.onClick&&document.body.removeEventListener("click",this.onClick),this.hammer.destroy(),this.hammer=null},o.prototype.activate=function(){o.current&&o.current.deactivate(),o.current=this,this.active=!0,this.dom.overlay.style.display="none",h.addClassName(this.dom.container,"vis-active"),this.emit("change"),this.emit("activate"),this.keycharm.bind("esc",this.escListener)},o.prototype.deactivate=function(){this.active=!1,this.dom.overlay.style.display="",h.removeClassName(this.dom.container,"vis-active"),this.keycharm.unbind("esc",this.escListener),this.emit("change"),this.emit("deactivate")},o.prototype._onTapOverlay=function(t){this.activate(),t.stopPropagation()},t.exports=o},function(t,e,i){function o(t,e){this.body=t,this.defaultOptions={moment:a,locales:h,locale:"en",id:void 0,title:void 0},this.options=s.extend({},this.defaultOptions),e&&e.time?this.customTime=e.time:this.customTime=new Date,this.eventParams={},this.setOptions(e),this._create()}var n=i(20),s=i(1),r=i(31),a=i(2),h=i(47);o.prototype=new r,o.prototype.setOptions=function(t){t&&s.selectiveExtend(["moment","locale","locales","id"],this.options,t)},o.prototype._create=function(){var t=document.createElement("div");t["custom-time"]=this,t.className="vis-custom-time "+(this.options.id||""),t.style.position="absolute",t.style.top="0px",t.style.height="100%",this.bar=t;var e=document.createElement("div");e.style.position="relative",e.style.top="0px",e.style.left="-10px",e.style.height="100%",e.style.width="20px",t.appendChild(e),this.hammer=new n(e),this.hammer.on("panstart",this._onDragStart.bind(this)),this.hammer.on("panmove",this._onDrag.bind(this)),this.hammer.on("panend",this._onDragEnd.bind(this)),this.hammer.get("pan").set({threshold:5,direction:n.DIRECTION_HORIZONTAL})},o.prototype.destroy=function(){this.hide(),this.hammer.destroy(),this.hammer=null,this.body=null},o.prototype.redraw=function(){var t=this.body.dom.backgroundVertical;this.bar.parentNode!=t&&(this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar),t.appendChild(this.bar));var e=this.body.util.toScreen(this.customTime),i=this.options.locales[this.options.locale];i||(this.warned||(console.log("WARNING: options.locales['"+this.options.locale+"'] not found. See http://visjs.org/docs/timeline.html#Localization"),this.warned=!0),i=this.options.locales.en);var o=this.options.title;return void 0===o&&(o=i.time+": "+this.options.moment(this.customTime).format("dddd, MMMM Do YYYY, H:mm:ss"),o=o.charAt(0).toUpperCase()+o.substring(1)),this.bar.style.left=e+"px",this.bar.title=o,!1},o.prototype.hide=function(){this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar)},o.prototype.setCustomTime=function(t){this.customTime=s.convert(t,"Date"),this.redraw()},o.prototype.getCustomTime=function(){return new Date(this.customTime.valueOf())},o.prototype.setCustomTitle=function(t){this.options.title=t},o.prototype._onDragStart=function(t){this.eventParams.dragging=!0,this.eventParams.customTime=this.customTime,t.stopPropagation()},o.prototype._onDrag=function(t){if(this.eventParams.dragging){var e=this.body.util.toScreen(this.eventParams.customTime)+t.deltaX,i=this.body.util.toTime(e);this.setCustomTime(i),this.body.emitter.emit("timechange",{id:this.options.id,time:new Date(this.customTime.valueOf()) +}),t.stopPropagation()}},o.prototype._onDragEnd=function(t){this.eventParams.dragging&&(this.body.emitter.emit("timechanged",{id:this.options.id,time:new Date(this.customTime.valueOf())}),t.stopPropagation())},o.customTimeFromTarget=function(t){for(var e=t.target;e;){if(e.hasOwnProperty("custom-time"))return e["custom-time"];e=e.parentNode}return null},t.exports=o},function(t,e){e.en={current:"current",time:"time"},e.en_EN=e.en,e.en_US=e.en,e.nl={current:"huidige",time:"tijd"},e.nl_NL=e.nl,e.nl_BE=e.nl},function(t,e,i){function o(t,e){this.body=t,this.defaultOptions={rtl:!1,showCurrentTime:!0,moment:r,locales:a,locale:"en"},this.options=n.extend({},this.defaultOptions),this.offset=0,this._create(),this.setOptions(e)}var n=i(1),s=i(31),r=i(2),a=i(47);o.prototype=new s,o.prototype._create=function(){var t=document.createElement("div");t.className="vis-current-time",t.style.position="absolute",t.style.top="0px",t.style.height="100%",this.bar=t},o.prototype.destroy=function(){this.options.showCurrentTime=!1,this.redraw(),this.body=null},o.prototype.setOptions=function(t){t&&n.selectiveExtend(["rtl","showCurrentTime","moment","locale","locales"],this.options,t)},o.prototype.redraw=function(){if(this.options.showCurrentTime){var t=this.body.dom.backgroundVertical;this.bar.parentNode!=t&&(this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar),t.appendChild(this.bar),this.start());var e=this.options.moment((new Date).valueOf()+this.offset),i=this.body.util.toScreen(e),o=this.options.locales[this.options.locale];o||(this.warned||(console.log("WARNING: options.locales['"+this.options.locale+"'] not found. See http://visjs.org/docs/timeline/#Localization"),this.warned=!0),o=this.options.locales.en);var n=o.current+" "+o.time+": "+e.format("dddd, MMMM Do YYYY, H:mm:ss");n=n.charAt(0).toUpperCase()+n.substring(1),this.options.rtl?this.bar.style.right=i+"px":this.bar.style.left=i+"px",this.bar.title=n}else this.bar.parentNode&&this.bar.parentNode.removeChild(this.bar),this.stop();return!1},o.prototype.start=function(){function t(){e.stop();var i=e.body.range.conversion(e.body.domProps.center.width).scale,o=1/i/10;30>o&&(o=30),o>1e3&&(o=1e3),e.redraw(),e.body.emitter.emit("currentTimeTick"),e.currentTimeTimer=setTimeout(t,o)}var e=this;t()},o.prototype.stop=function(){void 0!==this.currentTimeTimer&&(clearTimeout(this.currentTimeTimer),delete this.currentTimeTimer)},o.prototype.setCurrentTime=function(t){var e=n.convert(t,"Date").valueOf(),i=(new Date).valueOf();this.offset=e-i,this.redraw()},o.prototype.getCurrentTime=function(){return new Date((new Date).valueOf()+this.offset)},t.exports=o},function(t,e){Object.defineProperty(e,"__esModule",{value:!0});var i="string",o="boolean",n="number",s="array",r="date",a="object",h="dom",d="moment",l="any",c={configure:{enabled:{"boolean":o},filter:{"boolean":o,"function":"function"},container:{dom:h},__type__:{object:a,"boolean":o,"function":"function"}},align:{string:i},rtl:{"boolean":o,undefined:"undefined"},autoResize:{"boolean":o},throttleRedraw:{number:n},clickToUse:{"boolean":o},dataAttributes:{string:i,array:s},editable:{add:{"boolean":o,undefined:"undefined"},remove:{"boolean":o,undefined:"undefined"},updateGroup:{"boolean":o,undefined:"undefined"},updateTime:{"boolean":o,undefined:"undefined"},__type__:{"boolean":o,object:a}},end:{number:n,date:r,string:i,moment:d},format:{minorLabels:{millisecond:{string:i,undefined:"undefined"},second:{string:i,undefined:"undefined"},minute:{string:i,undefined:"undefined"},hour:{string:i,undefined:"undefined"},weekday:{string:i,undefined:"undefined"},day:{string:i,undefined:"undefined"},month:{string:i,undefined:"undefined"},year:{string:i,undefined:"undefined"},__type__:{object:a}},majorLabels:{millisecond:{string:i,undefined:"undefined"},second:{string:i,undefined:"undefined"},minute:{string:i,undefined:"undefined"},hour:{string:i,undefined:"undefined"},weekday:{string:i,undefined:"undefined"},day:{string:i,undefined:"undefined"},month:{string:i,undefined:"undefined"},year:{string:i,undefined:"undefined"},__type__:{object:a}},__type__:{object:a}},moment:{"function":"function"},groupOrder:{string:i,"function":"function"},groupEditable:{add:{"boolean":o,undefined:"undefined"},remove:{"boolean":o,undefined:"undefined"},order:{"boolean":o,undefined:"undefined"},__type__:{"boolean":o,object:a}},groupOrderSwap:{"function":"function"},height:{string:i,number:n},hiddenDates:{start:{date:r,number:n,string:i,moment:d},end:{date:r,number:n,string:i,moment:d},repeat:{string:i},__type__:{object:a,array:s}},itemsAlwaysDraggable:{"boolean":o},locale:{string:i},locales:{__any__:{any:l},__type__:{object:a}},margin:{axis:{number:n},item:{horizontal:{number:n,undefined:"undefined"},vertical:{number:n,undefined:"undefined"},__type__:{object:a,number:n}},__type__:{object:a,number:n}},max:{date:r,number:n,string:i,moment:d},maxHeight:{number:n,string:i},maxMinorChars:{number:n},min:{date:r,number:n,string:i,moment:d},minHeight:{number:n,string:i},moveable:{"boolean":o},multiselect:{"boolean":o},multiselectPerGroup:{"boolean":o},onAdd:{"function":"function"},onUpdate:{"function":"function"},onMove:{"function":"function"},onMoving:{"function":"function"},onRemove:{"function":"function"},onAddGroup:{"function":"function"},onMoveGroup:{"function":"function"},onRemoveGroup:{"function":"function"},order:{"function":"function"},orientation:{axis:{string:i,undefined:"undefined"},item:{string:i,undefined:"undefined"},__type__:{string:i,object:a}},selectable:{"boolean":o},showCurrentTime:{"boolean":o},showMajorLabels:{"boolean":o},showMinorLabels:{"boolean":o},stack:{"boolean":o},snap:{"function":"function","null":"null"},start:{date:r,number:n,string:i,moment:d},template:{"function":"function"},groupTemplate:{"function":"function"},timeAxis:{scale:{string:i,undefined:"undefined"},step:{number:n,undefined:"undefined"},__type__:{object:a}},type:{string:i},width:{string:i,number:n},zoomable:{"boolean":o},zoomKey:{string:["ctrlKey","altKey","metaKey",""]},zoomMax:{number:n},zoomMin:{number:n},__type__:{object:a}},u={global:{align:["center","left","right"],direction:!1,autoResize:!0,throttleRedraw:[10,0,1e3,10],clickToUse:!1,editable:{add:!1,remove:!1,updateGroup:!1,updateTime:!1},end:"",format:{minorLabels:{millisecond:"SSS",second:"s",minute:"HH:mm",hour:"HH:mm",weekday:"ddd D",day:"D",month:"MMM",year:"YYYY"},majorLabels:{millisecond:"HH:mm:ss",second:"D MMMM HH:mm",minute:"ddd D MMMM",hour:"ddd D MMMM",weekday:"MMMM YYYY",day:"MMMM YYYY",month:"YYYY",year:""}},groupsDraggable:!1,height:"",locale:"",margin:{axis:[20,0,100,1],item:{horizontal:[10,0,100,1],vertical:[10,0,100,1]}},max:"",maxHeight:"",maxMinorChars:[7,0,20,1],min:"",minHeight:"",moveable:!1,multiselect:!1,multiselectPerGroup:!1,orientation:{axis:["both","bottom","top"],item:["bottom","top"]},selectable:!0,showCurrentTime:!1,showMajorLabels:!0,showMinorLabels:!0,stack:!0,start:"",type:["box","point","range","background"],width:"100%",zoomable:!0,zoomKey:["ctrlKey","altKey","metaKey",""],zoomMax:[31536e10,10,31536e10,1],zoomMin:[10,10,31536e10,1]}};e.allOptions=c,e.configureOptions=u},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e,i,o){if(!(Array.isArray(i)||i instanceof c||i instanceof u)&&i instanceof Object){var n=o;o=i,i=n}var s=this;this.defaultOptions={start:null,end:null,autoResize:!0,orientation:{axis:"bottom",item:"bottom"},moment:d,width:null,height:null,maxHeight:null,minHeight:null},this.options=l.deepExtend({},this.defaultOptions),this._create(t),this.components=[],this.body={dom:this.dom,domProps:this.props,emitter:{on:this.on.bind(this),off:this.off.bind(this),emit:this.emit.bind(this)},hiddenDates:[],util:{toScreen:s._toScreen.bind(s),toGlobalScreen:s._toGlobalScreen.bind(s),toTime:s._toTime.bind(s),toGlobalTime:s._toGlobalTime.bind(s)}},this.range=new p(this.body),this.components.push(this.range),this.body.range=this.range,this.timeAxis=new m(this.body),this.components.push(this.timeAxis),this.currentTime=new v(this.body),this.components.push(this.currentTime),this.linegraph=new y(this.body),this.components.push(this.linegraph),this.itemsData=null,this.groupsData=null,this.on("tap",function(t){s.emit("click",s.getEventProperties(t))}),this.on("doubletap",function(t){s.emit("doubleClick",s.getEventProperties(t))}),this.dom.root.oncontextmenu=function(t){s.emit("contextmenu",s.getEventProperties(t))},o&&this.setOptions(o),i&&this.setGroups(i),e&&this.setItems(e),this._redraw()}var s=i(26),r=o(s),a=i(29),h=o(a),d=(i(13),i(20),i(2)),l=i(1),c=i(9),u=i(11),p=i(30),f=i(33),m=i(44),v=i(48),g=i(46),y=i(51),b=i(29).printStyle,w=i(59).allOptions,_=i(59).configureOptions;n.prototype=new f,n.prototype.setOptions=function(t){var e=h["default"].validate(t,w);e===!0&&console.log("%cErrors have been found in the supplied options object.",b),f.prototype.setOptions.call(this,t)},n.prototype.setItems=function(t){var e,i=null==this.itemsData;if(e=t?t instanceof c||t instanceof u?t:new c(t,{type:{start:"Date",end:"Date"}}):null,this.itemsData=e,this.linegraph&&this.linegraph.setItems(e),i)if(void 0!=this.options.start||void 0!=this.options.end){var o=void 0!=this.options.start?this.options.start:null,n=void 0!=this.options.end?this.options.end:null;this.setWindow(o,n,{animation:!1})}else this.fit({animation:!1})},n.prototype.setGroups=function(t){var e;e=t?t instanceof c||t instanceof u?t:new c(t):null,this.groupsData=e,this.linegraph.setGroups(e)},n.prototype.getLegend=function(t,e,i){return void 0===e&&(e=15),void 0===i&&(i=15),void 0!==this.linegraph.groups[t]?this.linegraph.groups[t].getLegend(e,i):"cannot find group:'"+t+"'"},n.prototype.isGroupVisible=function(t){return void 0!==this.linegraph.groups[t]?this.linegraph.groups[t].visible&&(void 0===this.linegraph.options.groups.visibility[t]||1==this.linegraph.options.groups.visibility[t]):!1},n.prototype.getDataRange=function(){var t=null,e=null;for(var i in this.linegraph.groups)if(this.linegraph.groups.hasOwnProperty(i)&&1==this.linegraph.groups[i].visible)for(var o=0;os?s:t,e=null==e?s:s>e?s:e}return{min:null!=t?new Date(t):null,max:null!=e?new Date(e):null}},n.prototype.getEventProperties=function(t){var e=t.center?t.center.x:t.clientX,i=t.center?t.center.y:t.clientY,o=e-l.getAbsoluteLeft(this.dom.centerContainer),n=i-l.getAbsoluteTop(this.dom.centerContainer),s=this._toTime(o),r=g.customTimeFromTarget(t),a=l.getTarget(t),h=null;l.hasParent(a,this.timeAxis.dom.foreground)?h="axis":this.timeAxis2&&l.hasParent(a,this.timeAxis2.dom.foreground)?h="axis":l.hasParent(a,this.linegraph.yAxisLeft.dom.frame)?h="data-axis":l.hasParent(a,this.linegraph.yAxisRight.dom.frame)?h="data-axis":l.hasParent(a,this.linegraph.legendLeft.dom.frame)?h="legend":l.hasParent(a,this.linegraph.legendRight.dom.frame)?h="legend":null!=r?h="custom-time":l.hasParent(a,this.currentTime.bar)?h="current-time":l.hasParent(a,this.dom.center)&&(h="background");var d=[],c=this.linegraph.yAxisLeft,u=this.linegraph.yAxisRight;return c.hidden||d.push(c.screenToValue(n)),u.hidden||d.push(u.screenToValue(n)),{event:t,what:h,pageX:t.srcEvent?t.srcEvent.pageX:t.pageX,pageY:t.srcEvent?t.srcEvent.pageY:t.pageY,x:o,y:n,time:s,value:d}},n.prototype._createConfigurator=function(){return new r["default"](this,this.dom.container,_)},t.exports=n},function(t,e,i){function o(t,e){this.id=s.randomUUID(),this.body=t,this.defaultOptions={yAxisOrientation:"left",defaultGroup:"default",sort:!0,sampling:!0,stack:!1,graphHeight:"400px",shaded:{enabled:!1,orientation:"bottom"},style:"line",barChart:{width:50,sideBySide:!1,align:"center"},interpolation:{enabled:!0,parametrization:"centripetal",alpha:.5},drawPoints:{enabled:!0,size:6,style:"square"},dataAxis:{},legend:{},groups:{visibility:{}}},this.options=s.extend({},this.defaultOptions),this.dom={},this.props={},this.hammer=null,this.groups={},this.abortedGraphUpdate=!1,this.updateSVGheight=!1,this.updateSVGheightOnResize=!1,this.forceGraphUpdate=!0;var i=this;this.itemsData=null,this.groupsData=null,this.itemListeners={add:function(t,e,o){i._onAdd(e.items)},update:function(t,e,o){i._onUpdate(e.items)},remove:function(t,e,o){i._onRemove(e.items)}},this.groupListeners={add:function(t,e,o){i._onAddGroups(e.items)},update:function(t,e,o){i._onUpdateGroups(e.items)},remove:function(t,e,o){i._onRemoveGroups(e.items)}},this.items={},this.selection=[],this.lastStart=this.body.range.start,this.touchParams={},this.svgElements={},this.setOptions(e),this.groupsUsingDefaultStyles=[0],this.body.emitter.on("rangechanged",function(){i.lastStart=i.body.range.start,i.svg.style.left=s.option.asSize(-i.props.width),i.forceGraphUpdate=!0,i.redraw.call(i)}),this._create(),this.framework={svg:this.svg,svgElements:this.svgElements,options:this.options,groups:this.groups}}var n="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},s=i(1),r=i(8),a=i(9),h=i(11),d=i(31),l=i(52),c=i(54),u=i(58),p=i(55),f=i(57),m=i(56),v="__ungrouped__";o.prototype=new d,o.prototype._create=function(){var t=document.createElement("div");t.className="vis-line-graph",this.dom.frame=t,this.svg=document.createElementNS("http://www.w3.org/2000/svg","svg"),this.svg.style.position="relative",this.svg.style.height=(""+this.options.graphHeight).replace("px","")+"px",this.svg.style.display="block",t.appendChild(this.svg),this.options.dataAxis.orientation="left",this.yAxisLeft=new l(this.body,this.options.dataAxis,this.svg,this.options.groups),this.options.dataAxis.orientation="right",this.yAxisRight=new l(this.body,this.options.dataAxis,this.svg,this.options.groups),delete this.options.dataAxis.orientation,this.legendLeft=new u(this.body,this.options.legend,"left",this.options.groups),this.legendRight=new u(this.body,this.options.legend,"right",this.options.groups),this.show()},o.prototype.setOptions=function(t){if(t){var e=["sampling","defaultGroup","stack","height","graphHeight","yAxisOrientation","style","barChart","dataAxis","sort","groups"];void 0===t.graphHeight&&void 0!==t.height?(this.updateSVGheight=!0,this.updateSVGheightOnResize=!0):void 0!==this.body.domProps.centerContainer.height&&void 0!==t.graphHeight&&parseInt((t.graphHeight+"").replace("px",""))i?-1:1});for(var o=new Array(t.length),n=0;n0){var h={};for(this._getRelevantData(a,h,n,s),this._applySampling(a,h),e=0;e0)switch(t.options.style){case"line":l.hasOwnProperty(a[e])||(l[a[e]]=f.calcPath(h[a[e]],t)),f.draw(l[a[e]],t,this.framework);case"point":case"points":"point"!=t.options.style&&"points"!=t.options.style&&1!=t.options.drawPoints.enabled||m.draw(h[a[e]],t,this.framework);break;case"bar":}}}return r.cleanupElements(this.svgElements),!1},o.prototype._stack=function(t,e){var i,o,n,s,r;i=0;for(var a=0;at[a].x){r=e[h],s=0==h?r:e[h-1],i=h;break}}void 0===r&&(s=e[e.length-1],r=e[e.length-1]),o=r.x-s.x,n=r.y-s.y,0==o?t[a].y=t[a].orginalY+r.y:t[a].y=t[a].orginalY+n/o*(t[a].x-s.x)+s.y}},o.prototype._getRelevantData=function(t,e,i,o){var n,r,a,h;if(t.length>0)for(r=0;rt?-1:1},c=Math.max(0,s.binarySearchValue(d,i,"x","before",l)),u=Math.min(d.length,s.binarySearchValue(d,o,"x","after",l)+1);0>=u&&(u=d.length);var p=new Array(u-c);for(a=c;u>a;a++)h=n.itemsData[a],p[a-c]=h;e[t[r]]=p}else e[t[r]]=n.itemsData}},o.prototype._applySampling=function(t,e){var i;if(t.length>0)for(var o=0;o0){var s=1,r=n.length,a=this.body.util.toGlobalScreen(n[n.length-1].x)-this.body.util.toGlobalScreen(n[0].x),h=r/a;s=Math.min(Math.ceil(.2*r),Math.max(1,Math.round(h)));for(var d=new Array(r),l=0;r>l;l+=s){var c=Math.round(l/s);d[c]=n[l]}e[t[o]]=d.splice(0,Math.round(r/s))}}},o.prototype._getYRanges=function(t,e,i){var o,n,s,r,a=[],h=[];if(t.length>0){for(s=0;s0&&(n=this.groups[t[s]],r.stack===!0&&"bar"===r.style?"left"===r.yAxisOrientation?a=a.concat(n.getItems()):h=h.concat(n.getItems()):i[t[s]]=n.getYRange(o,t[s]));p.getStackedYRange(a,i,t,"__barStackLeft","left"),p.getStackedYRange(h,i,t,"__barStackRight","right")}},o.prototype._updateYAxis=function(t,e){var i,o,n=!1,s=!1,r=!1,a=1e9,h=1e9,d=-1e9,l=-1e9;if(t.length>0){for(var c=0;ci?i:a,d=o>d?o:d):(r=!0,h=h>i?i:h,l=o>l?o:l));1==s&&this.yAxisLeft.setRange(a,d),1==r&&this.yAxisRight.setRange(h,l)}n=this._toggleAxisVisiblity(s,this.yAxisLeft)||n,n=this._toggleAxisVisiblity(r,this.yAxisRight)||n,1==r&&1==s?(this.yAxisLeft.drawIcons=!0,this.yAxisRight.drawIcons=!0):(this.yAxisLeft.drawIcons=!1,this.yAxisRight.drawIcons=!1),this.yAxisRight.master=!s,this.yAxisRight.masterAxis=this.yAxisLeft,0==this.yAxisRight.master?(1==r?this.yAxisLeft.lineOffset=this.yAxisRight.width:this.yAxisLeft.lineOffset=0,n=this.yAxisLeft.redraw()||n,n=this.yAxisRight.redraw()||n):n=this.yAxisRight.redraw()||n;for(var p=["__barStackLeft","__barStackRight","__lineStackLeft","__lineStackRight"],c=0;ct?-1:1});for(var a=0;a=0&&t._redrawLabel(o-2,e.val,i,"vis-y-axis vis-major",t.props.majorCharHeight),t.master===!0&&(n?t._redrawLine(o,i,"vis-grid vis-horizontal vis-major",t.options.majorLinesOffset,t.props.majorLineWidth):t._redrawLine(o,i,"vis-grid vis-horizontal vis-minor",t.options.minorLinesOffset,t.props.minorLineWidth))});var d=0;void 0!==this.options[i].title&&void 0!==this.options[i].title.text&&(d=this.props.titleCharHeight);var l=this.options.icons===!0?Math.max(this.options.iconWidth,d)+this.options.labelOffsetX+15:d+this.options.labelOffsetX+15;return this.maxLabelSize>this.width-l&&this.options.visible===!0?(this.width=this.maxLabelSize+l,this.options.width=this.width+"px",s.cleanupElements(this.DOMelements.lines),s.cleanupElements(this.DOMelements.labels),this.redraw(),e=!0):this.maxLabelSizethis.minWidth?(this.width=Math.max(this.minWidth,this.maxLabelSize+l),this.options.width=this.width+"px",s.cleanupElements(this.DOMelements.lines),s.cleanupElements(this.DOMelements.labels),this.redraw(),e=!0):(s.cleanupElements(this.DOMelements.lines),s.cleanupElements(this.DOMelements.labels),e=!1),e},o.prototype.convertValue=function(t){return this.scale.convertValue(t)},o.prototype.screenToValue=function(t){return this.scale.screenToValue(t)},o.prototype._redrawLabel=function(t,e,i,o,n){var r=s.getDOMElement("div",this.DOMelements.labels,this.dom.frame);r.className=o,r.innerHTML=e,"left"===i?(r.style.left="-"+this.options.labelOffsetX+"px",r.style.textAlign="right"):(r.style.right="-"+this.options.labelOffsetX+"px",r.style.textAlign="left"),r.style.top=t-.5*n+this.options.labelOffsetY+"px",e+="";var a=Math.max(this.props.majorCharWidth,this.props.minorCharWidth);this.maxLabelSize.5*(h.magnitudefactor*h.minorSteps[h.minorStepIdx])?e+h.magnitudefactor*h.minorSteps[h.minorStepIdx]:e};i&&(this._start-=2*this.magnitudefactor*this.minorSteps[this.minorStepIdx],this._start=d(this._start)),o&&(this._end+=this.magnitudefactor*this.minorSteps[this.minorStepIdx],this._end=d(this._end)),this.determineScale()}}i.prototype.setCharHeight=function(t){this.majorCharHeight=t},i.prototype.setHeight=function(t){this.containerHeight=t},i.prototype.determineScale=function(){var t=this._end-this._start;this.scale=this.containerHeight/t;var e=this.majorCharHeight/this.scale,i=t>0?Math.round(Math.log(t)/Math.LN10):0;this.minorStepIdx=-1,this.magnitudefactor=Math.pow(10,i);var o=0;0>i&&(o=i);for(var n=!1,s=o;Math.abs(s)<=Math.abs(i);s++){this.magnitudefactor=Math.pow(10,s);for(var r=0;r=e){n=!0,this.minorStepIdx=r;break}}if(n===!0)break}},i.prototype.is_major=function(t){return t%(this.magnitudefactor*this.majorSteps[this.minorStepIdx])===0},i.prototype.getStep=function(){return this.magnitudefactor*this.minorSteps[this.minorStepIdx]},i.prototype.getFirstMajor=function(){var t=this.magnitudefactor*this.majorSteps[this.minorStepIdx];return this.convertValue(this._start+(t-this._start%t)%t)},i.prototype.formatValue=function(t){var e=t.toPrecision(5);return"function"==typeof this.formattingFunction&&(e=this.formattingFunction(t)),"number"==typeof e?""+e:"string"==typeof e?e:t.toPrecision(5)},i.prototype.getLines=function(){for(var t=[],e=this.getStep(),i=(e-this._start%e)%e,o=this._start+i;this._end-o>1e-5;o+=e)o!=this._start&&t.push({major:this.is_major(o),y:this.convertValue(o),val:this.formatValue(o)});return t},i.prototype.followScale=function(t){var e=this.minorStepIdx,i=this._start,o=this._end,n=this,s=function(){n.magnitudefactor*=2},r=function(){n.magnitudefactor/=2};t.minorStepIdx<=1&&this.minorStepIdx<=1||t.minorStepIdx>1&&this.minorStepIdx>1||(t.minorStepIdxo+1e-5)r(),d=!1;else{if(!this.autoScaleStart&&this._start=0)){r(),d=!1;continue}console.warn("Can't adhere to given 'min' range, due to zeroalign")}this.autoScaleStart&&this.autoScaleEnd&&o-i>c?(s(),d=!1):d=!0}}},i.prototype.convertValue=function(t){return this.containerHeight-(t-this._start)*this.scale},i.prototype.screenToValue=function(t){return(this.containerHeight-t)/this.scale+this._start},t.exports=i},function(t,e,i){function o(t,e,i,o){this.id=e;var n=["sampling","style","sort","yAxisOrientation","barChart","drawPoints","shaded","interpolation","zIndex","excludeFromStacking","excludeFromLegend"];this.options=s.selectiveBridgeObject(n,i),this.usingDefaultStyle=void 0===t.className,this.groupsUsingDefaultStyles=o,this.zeroPosition=0,this.update(t),1==this.usingDefaultStyle&&(this.groupsUsingDefaultStyles[0]+=1),this.itemsData=[],this.visible=void 0===t.visible?!0:t.visible}var n="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},s=i(1),r=(i(8),i(55)),a=i(57),h=i(56);o.prototype.setItems=function(t){null!=t?(this.itemsData=t,1==this.options.sort&&s.insertSort(this.itemsData,function(t,e){return t.x>e.x?1:-1})):this.itemsData=[]},o.prototype.getItems=function(){return this.itemsData},o.prototype.setZeroPosition=function(t){this.zeroPosition=t},o.prototype.setOptions=function(t){if(void 0!==t){var e=["sampling","style","sort","yAxisOrientation","barChart","zIndex","excludeFromStacking","excludeFromLegend"];s.selectiveDeepExtend(e,this.options,t),"function"==typeof t.drawPoints&&(t.drawPoints={onRender:t.drawPoints}),s.mergeOptions(this.options,t,"interpolation"),s.mergeOptions(this.options,t,"drawPoints"),s.mergeOptions(this.options,t,"shaded"),t.interpolation&&"object"==n(t.interpolation)&&t.interpolation.parametrization&&("uniform"==t.interpolation.parametrization?this.options.interpolation.alpha=0:"chordal"==t.interpolation.parametrization?this.options.interpolation.alpha=1:(this.options.interpolation.parametrization="centripetal",this.options.interpolation.alpha=.5))}},o.prototype.update=function(t){this.group=t,this.content=t.content||"graph",this.className=t.className||this.className||"vis-graph-group"+this.groupsUsingDefaultStyles[0]%10,this.visible=void 0===t.visible?!0:t.visible,this.style=t.style,this.setOptions(t.options)},o.prototype.getLegend=function(t,e,i,o,n){if(void 0==i||null==i){var s=document.createElementNS("http://www.w3.org/2000/svg","svg");i={svg:s,svgElements:{},options:this.options,groups:[this]}}switch(void 0!=o&&null!=o||(o=0),void 0!=n&&null!=n||(n=.5*e),this.options.style){case"line":a.drawIcon(this,o,n,t,e,i);break;case"points":case"point":h.drawIcon(this,o,n,t,e,i);break;case"bar":r.drawIcon(this,o,n,t,e,i)}return{icon:i.svg,label:this.content,orientation:this.options.yAxisOrientation}},o.prototype.getYRange=function(t){for(var e=t[0].y,i=t[0].y,o=0;ot[o].y?t[o].y:e,i=i0&&(i=Math.min(i,Math.abs(e[o-1].screen_x-e[o].screen_x))),0===i&&(void 0===t[e[o].screen_x]&&(t[e[o].screen_x]={amount:0,resolved:0,accumulatedPositive:0,accumulatedNegative:0}),t[e[o].screen_x].amount+=1)},o._getSafeDrawData=function(t,e,i){var o,n;return t0?(o=i>t?i:t,n=0,"left"===e.options.barChart.align?n-=.5*t:"right"===e.options.barChart.align&&(n+=.5*t)):(o=e.options.barChart.width,n=0,"left"===e.options.barChart.align?n-=.5*e.options.barChart.width:"right"===e.options.barChart.align&&(n+=.5*e.options.barChart.width)),{width:o,offset:n}},o.getStackedYRange=function(t,e,i,n,s){if(t.length>0){t.sort(function(t,e){return t.screen_x===e.screen_x?t.groupIde[s].screen_y?e[s].screen_y:o,n=nt[r].accumulatedNegative?t[r].accumulatedNegative:o,o=o>t[r].accumulatedPositive?t[r].accumulatedPositive:o,n=n0){var i=[];return i=1==e.options.interpolation.enabled?o._catmullRom(t,e):o._linear(t)}},o.drawIcon=function(t,e,i,o,s,r){var a,h,d=.5*s,l=n.getSVGElement("rect",r.svgElements,r.svg);if(l.setAttributeNS(null,"x",e),l.setAttributeNS(null,"y",i-d),l.setAttributeNS(null,"width",o),l.setAttributeNS(null,"height",2*d),l.setAttributeNS(null,"class","vis-outline"),a=n.getSVGElement("path",r.svgElements,r.svg),a.setAttributeNS(null,"class",t.className),void 0!==t.style&&a.setAttributeNS(null,"style",t.style),a.setAttributeNS(null,"d","M"+e+","+i+" L"+(e+o)+","+i),1==t.options.shaded.enabled&&(h=n.getSVGElement("path",r.svgElements,r.svg),"top"==t.options.shaded.orientation?h.setAttributeNS(null,"d","M"+e+", "+(i-d)+"L"+e+","+i+" L"+(e+o)+","+i+" L"+(e+o)+","+(i-d)):h.setAttributeNS(null,"d","M"+e+","+i+" L"+e+","+(i+d)+" L"+(e+o)+","+(i+d)+"L"+(e+o)+","+i),h.setAttributeNS(null,"class",t.className+" vis-icon-fill"),void 0!==t.options.shaded.style&&""!==t.options.shaded.style&&h.setAttributeNS(null,"style",t.options.shaded.style)),1==t.options.drawPoints.enabled){var c={style:t.options.drawPoints.style,styles:t.options.drawPoints.styles,size:t.options.drawPoints.size,className:t.className};n.drawPoint(e+.5*o,i,c,r.svgElements,r.svg)}},o.drawShading=function(t,e,i,o){if(1==e.options.shaded.enabled){var s=Number(o.svg.style.height.replace("px","")),r=n.getSVGElement("path",o.svgElements,o.svg),a="L";1==e.options.interpolation.enabled&&(a="C");var h,d=0;d="top"==e.options.shaded.orientation?0:"bottom"==e.options.shaded.orientation?s:Math.min(Math.max(0,e.zeroPosition),s),h="group"==e.options.shaded.orientation&&null!=i&&void 0!=i?"M"+t[0][0]+","+t[0][1]+" "+this.serializePath(t,a,!1)+" L"+i[i.length-1][0]+","+i[i.length-1][1]+" "+this.serializePath(i,a,!0)+i[0][0]+","+i[0][1]+" Z":"M"+t[0][0]+","+t[0][1]+" "+this.serializePath(t,a,!1)+" V"+d+" H"+t[0][0]+" Z",r.setAttributeNS(null,"class",e.className+" vis-fill"),void 0!==e.options.shaded.style&&r.setAttributeNS(null,"style",e.options.shaded.style),r.setAttributeNS(null,"d",h)}},o.draw=function(t,e,i){if(null!=t&&void 0!=t){var o=n.getSVGElement("path",i.svgElements,i.svg);o.setAttributeNS(null,"class",e.className),void 0!==e.style&&o.setAttributeNS(null,"style",e.style);var s="L";1==e.options.interpolation.enabled&&(s="C"),o.setAttributeNS(null,"d","M"+t[0][0]+","+t[0][1]+" "+this.serializePath(t,s,!1))}},o.serializePath=function(t,e,i){if(t.length<2)return"";var o=e;if(i)for(var n=t.length-2;n>0;n--)o+=t[n][0]+","+t[n][1]+" ";else for(var n=1;nl;l++)e=0==l?t[0]:t[l-1],i=t[l],o=t[l+1],n=d>l+2?t[l+2]:o,s={screen_x:(-e.screen_x+6*i.screen_x+o.screen_x)*h,screen_y:(-e.screen_y+6*i.screen_y+o.screen_y)*h},r={screen_x:(i.screen_x+6*o.screen_x-n.screen_x)*h,screen_y:(i.screen_y+6*o.screen_y-n.screen_y)*h},a.push([s.screen_x,s.screen_y]),a.push([r.screen_x,r.screen_y]),a.push([o.screen_x,o.screen_y]);return a},o._catmullRom=function(t,e){var i=e.options.interpolation.alpha;if(0==i||void 0===i)return this._catmullRomUniform(t);var o,n,s,r,a,h,d,l,c,u,p,f,m,v,g,y,b,w,_,x=[];x.push([Math.round(t[0].screen_x),Math.round(t[0].screen_y)]);for(var k=t.length,O=0;k-1>O;O++)o=0==O?t[0]:t[O-1],n=t[O],s=t[O+1],r=k>O+2?t[O+2]:s,d=Math.sqrt(Math.pow(o.screen_x-n.screen_x,2)+Math.pow(o.screen_y-n.screen_y,2)),l=Math.sqrt(Math.pow(n.screen_x-s.screen_x,2)+Math.pow(n.screen_y-s.screen_y,2)),c=Math.sqrt(Math.pow(s.screen_x-r.screen_x,2)+Math.pow(s.screen_y-r.screen_y,2)),v=Math.pow(c,i),y=Math.pow(c,2*i),g=Math.pow(l,i),b=Math.pow(l,2*i),_=Math.pow(d,i),w=Math.pow(d,2*i),u=2*w+3*_*g+b,p=2*y+3*v*g+b,f=3*_*(_+g),f>0&&(f=1/f),m=3*v*(v+g),m>0&&(m=1/m),a={screen_x:(-b*o.screen_x+u*n.screen_x+w*s.screen_x)*f,screen_y:(-b*o.screen_y+u*n.screen_y+w*s.screen_y)*f},h={screen_x:(y*n.screen_x+p*s.screen_x-b*r.screen_x)*m,screen_y:(y*n.screen_y+p*s.screen_y-b*r.screen_y)*m},0==a.screen_x&&0==a.screen_y&&(a=n),0==h.screen_x&&0==h.screen_y&&(h=s),x.push([a.screen_x,a.screen_y]),x.push([h.screen_x,h.screen_y]),x.push([s.screen_x,s.screen_y]);return x},o._linear=function(t){for(var e=[],i=0;it?-1:1});for(var i=0;i")}this.dom.textArea.innerHTML=s,this.dom.textArea.style.lineHeight=.75*this.options.iconSize+this.options.iconSpacing+"px"}},o.prototype.drawLegendIcons=function(){if(this.dom.frame.parentNode){var t=Object.keys(this.groups);t.sort(function(t,e){return e>t?-1:1}),s.resetElements(this.svgElements);var e=window.getComputedStyle(this.dom.frame).paddingTop,i=Number(e.replace("px","")),o=i,n=this.options.iconSize,r=.75*this.options.iconSize,a=i+.5*r+3;this.svg.style.width=n+5+i+"px";for(var h=0;h0){var i=this.groupIndex%this.groupsArray.length;this.groupIndex++,e={},e.color=this.groups[this.groupsArray[i]],this.groups[t]=e}else{var o=this.defaultIndex%this.defaultGroups.length;this.defaultIndex++,e={},e.color=this.defaultGroups[o],this.groups[t]=e}return e}},{key:"add",value:function(t,e){return this.groups[t]=e,this.groupsArray.push(t),e}}]),t}();e["default"]=r},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var s=function(){function t(t,e){for(var i=0;it.left&&this.shape.topt.top}},{key:"isBoundingBoxOverlappingWith",value:function(t){return this.shape.boundingBox.leftt.left&&this.shape.boundingBox.topt.top}}],[{key:"parseOptions",value:function(t,e){var i=arguments.length<=2||void 0===arguments[2]?!1:arguments[2],o=arguments.length<=3||void 0===arguments[3]?{}:arguments[3],n=["color","font","fixed","shadow"];if(A.selectiveNotDeepExtend(n,t,e,i),A.mergeOptions(t,e,"shadow",i,o),void 0!==e.color&&null!==e.color){var s=A.parseColor(e.color);A.fillIfDefined(t.color,s)}else i===!0&&null===e.color&&(t.color=A.bridgeObject(o.color));void 0!==e.fixed&&null!==e.fixed&&("boolean"==typeof e.fixed?(t.fixed.x=e.fixed,t.fixed.y=e.fixed):(void 0!==e.fixed.x&&"boolean"==typeof e.fixed.x&&(t.fixed.x=e.fixed.x),void 0!==e.fixed.y&&"boolean"==typeof e.fixed.y&&(t.fixed.y=e.fixed.y))),void 0!==e.font&&null!==e.font?a["default"].parseOptions(t.font,e):i===!0&&null===e.font&&(t.font=A.bridgeObject(o.font)),void 0!==e.scaling&&A.mergeOptions(t.scaling,e.scaling,"label",i,o.scaling)}}]),t}();e["default"]=B},function(t,e,i){function o(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var n=function(){function t(t,e){var i=[],o=!0,n=!1,s=void 0;try{for(var r,a=t[Symbol.iterator]();!(o=(r=a.next()).done)&&(i.push(r.value),!e||i.length!==e);o=!0);}catch(h){n=!0,s=h}finally{try{!o&&a["return"]&&a["return"]()}finally{if(n)throw s}}return i}return function(e,i){if(Array.isArray(e))return e;if(Symbol.iterator in Object(e))return t(e,i);throw new TypeError("Invalid attempt to destructure non-iterable instance")}}(),s="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},r=function(){function t(t,e){for(var i=0;i=this.nodeOptions.scaling.label.maxVisible&&(r=Number(this.nodeOptions.scaling.label.maxVisible)/this.body.view.scale);var h=this.size.yLine,d=this._getColor(a),l=n(d,2),c=l[0],u=l[1],p=this._setAlignment(t,i,h,s),f=n(p,2);i=f[0],h=f[1],t.font=(e&&this.nodeOptions.labelHighlightBold?"bold ":"")+r+"px "+this.fontOptions.face,t.fillStyle=c,this.isEdgeLabel||"left"!==this.fontOptions.align?t.textAlign="center":(t.textAlign=this.fontOptions.align,i-=.5*this.size.width),this.fontOptions.strokeWidth>0&&(t.lineWidth=this.fontOptions.strokeWidth,t.strokeStyle=u,t.lineJoin="round");for(var m=0;m0&&t.strokeText(this.lines[m],i,h),t.fillText(this.lines[m],i,h),h+=r}},{key:"_setAlignment",value:function(t,e,i,o){if(this.isEdgeLabel&&"horizontal"!==this.fontOptions.align&&this.pointToSelf===!1){e=0,i=0;var n=2;"top"===this.fontOptions.align?(t.textBaseline="alphabetic",i-=2*n):"bottom"===this.fontOptions.align?(t.textBaseline="hanging",i+=2*n):t.textBaseline="middle"}else t.textBaseline=o;return[e,i]}},{key:"_getColor",value:function(t){var e=this.fontOptions.color||"#000000",i=this.fontOptions.strokeColor||"#ffffff";if(t<=this.nodeOptions.scaling.label.drawThreshold){var o=Math.max(0,Math.min(1,1-(this.nodeOptions.scaling.label.drawThreshold-t)));e=a.overrideOpacity(e,o),i=a.overrideOpacity(i,o)}return[e,i]}},{key:"getTextSize",value:function(t){var e=arguments.length<=1||void 0===arguments[1]?!1:arguments[1],i={width:this._processLabel(t,e),height:this.fontOptions.size*this.lineCount,lineCount:this.lineCount};return i}},{key:"calculateLabelSize",value:function(t,e){var i=arguments.length<=2||void 0===arguments[2]?0:arguments[2],o=arguments.length<=3||void 0===arguments[3]?0:arguments[3],n=arguments.length<=4||void 0===arguments[4]?"middle":arguments[4];this.labelDirty===!0&&(this.size.width=this._processLabel(t,e)),this.size.height=this.fontOptions.size*this.lineCount,this.size.left=i-.5*this.size.width,this.size.top=o-.5*this.size.height,this.size.yLine=o+.5*(1-this.lineCount)*this.fontOptions.size,"hanging"===n&&(this.size.top+=.5*this.fontOptions.size,this.size.top+=4,this.size.yLine+=4),this.labelDirty=!1}},{key:"_processLabel",value:function(t,e){var i=0,o=[""],n=0;if(void 0!==this.nodeOptions.label){o=String(this.nodeOptions.label).split("\n"),n=o.length,t.font=(e&&this.nodeOptions.labelHighlightBold?"bold ":"")+this.fontOptions.size+"px "+this.fontOptions.face,i=t.measureText(o[0]).width;for(var s=1;n>s;s++){var r=t.measureText(o[s]).width;i=r>i?r:i}}return this.lines=o,this.lineCount=n,i}}],[{key:"parseOptions",value:function(t,e){var i=arguments.length<=2||void 0===arguments[2]?!1:arguments[2];if("string"==typeof e.font){var o=e.font.split(" ");t.size=o[0].replace("px",""),t.face=o[1],t.color=o[2]}else"object"===s(e.font)&&a.fillIfDefined(t,e.font,i);t.size=Number(t.size)}}]),t}();e["default"]=h},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;i0&&(this.enableBorderDashes(t),t.stroke(),this.disableBorderDashes(t)),t.restore(),this.updateBoundingBox(e,i,t,o),this.labelModule.draw(t,e,i,o)}},{key:"updateBoundingBox",value:function(t,e,i,o){this.resize(i,o),this.left=t-.5*this.width,this.top=e-.5*this.height;var n=this.options.shapeProperties.borderRadius;this.boundingBox.left=this.left-n,this.boundingBox.top=this.top-n,this.boundingBox.bottom=this.top+this.height+n,this.boundingBox.right=this.left+this.width+n}},{key:"distanceToBorder",value:function(t,e){this.resize(t);var i=this.options.borderWidth;return Math.min(Math.abs(this.width/2/Math.cos(e)),Math.abs(this.height/2/Math.sin(e)))+i}}]),e}(d["default"]);e["default"]=l},function(t,e){function i(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var o=function(){function t(t,e){for(var i=0;ithis.imageObj.height?(o=this.imageObj.width/this.imageObj.height,e=2*this.options.size*o||this.imageObj.width,i=2*this.options.size||this.imageObj.height):(o=this.imageObj.width&&this.imageObj.height?this.imageObj.height/this.imageObj.width:1,e=2*this.options.size,i=2*this.options.size*o):(e=this.imageObj.width,i=this.imageObj.height),this.width=e,this.height=i,this.radius=.5*this.width}}},{key:"_drawRawCircle",value:function(t,e,i,o,n,s){var r=this.options.borderWidth,a=this.options.borderWidthSelected||2*this.options.borderWidth,h=(o?a:r)/this.body.view.scale;t.lineWidth=Math.min(this.width,h),t.strokeStyle=o?this.options.color.highlight.border:n?this.options.color.hover.border:this.options.color.border,t.fillStyle=o?this.options.color.highlight.background:n?this.options.color.hover.background:this.options.color.background,t.circle(e,i,s),this.enableShadow(t),t.fill(),this.disableShadow(t),t.save(),h>0&&(this.enableBorderDashes(t),t.stroke(),this.disableBorderDashes(t)),t.restore()}},{key:"_drawImageAtPosition",value:function(t){if(0!=this.imageObj.width){t.globalAlpha=1,this.enableShadow(t);var e=this.imageObj.width/this.width/this.body.view.scale;if(e>2&&this.options.shapeProperties.interpolation===!0){var i=this.imageObj.width,o=this.imageObj.height,n=document.createElement("canvas");n.width=i,n.height=i;var s=n.getContext("2d");e*=.5,i*=.5,o*=.5,s.drawImage(this.imageObj,0,0,i,o);for(var r=0,a=1;e>2&&4>a;)s.drawImage(n,r,0,i,o,r+i,0,i/2,o/2),r+=i,e*=.5,i*=.5,o*=.5,a+=1;t.drawImage(n,r,0,i,o,this.left,this.top,this.width,this.height)}else t.drawImage(this.imageObj,this.left,this.top,this.width,this.height);this.disableShadow(t)}}},{key:"_drawImageLabel",value:function(t,e,i,o){var n,s=0;if(void 0!==this.height){s=.5*this.height;var r=this.labelModule.getTextSize(t);r.lineCount>=1&&(s+=r.height/2)}n=i+s,this.options.label&&(this.labelOffset=s),this.labelModule.draw(t,e,n,o,"hanging")}}]),e}(d["default"]);e["default"]=l},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;i0&&(this.enableBorderDashes(t),t.stroke(),this.disableBorderDashes(t)),t.restore(),this.updateBoundingBox(e,i,t,o),this.labelModule.draw(t,e,i,o)}},{key:"updateBoundingBox",value:function(t,e,i,o){this.resize(i,o),this.left=t-.5*this.width,this.top=e-.5*this.height,this.boundingBox.left=this.left,this.boundingBox.top=this.top,this.boundingBox.bottom=this.top+this.height,this.boundingBox.right=this.left+this.width}},{key:"distanceToBorder",value:function(t,e){return this._distanceToBorder(t,e)}}]),e}(d["default"]);e["default"]=l},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;i0&&(this.enableBorderDashes(t),t.stroke(),this.disableBorderDashes(t)),t.restore(),void 0!==this.options.label){var l=n+.5*this.height+3;this.labelModule.draw(t,o,l,s,"hanging")}this.updateBoundingBox(o,n)}},{key:"updateBoundingBox",value:function(t,e){this.boundingBox.top=e-this.options.size,this.boundingBox.left=t-this.options.size,this.boundingBox.right=t+this.options.size,this.boundingBox.bottom=e+this.options.size,void 0!==this.options.label&&this.labelModule.size.width>0&&(this.boundingBox.left=Math.min(this.boundingBox.left,this.labelModule.size.left),this.boundingBox.right=Math.max(this.boundingBox.right,this.labelModule.size.left+this.labelModule.size.width),this.boundingBox.bottom=Math.max(this.boundingBox.bottom,this.boundingBox.bottom+this.labelModule.size.height+3))}}]),e}(d["default"]);e["default"]=l},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;i0&&(this.enableBorderDashes(t),t.stroke(),this.disableBorderDashes(t)),t.restore(),this.updateBoundingBox(e,i,t,o),this.labelModule.draw(t,e,i,o)}},{key:"updateBoundingBox",value:function(t,e,i,o){this.resize(i,o),this.left=t-.5*this.width,this.top=e-.5*this.height,this.boundingBox.left=this.left,this.boundingBox.top=this.top,this.boundingBox.bottom=this.top+this.height,this.boundingBox.right=this.left+this.width}},{key:"distanceToBorder",value:function(t,e){this.resize(t);var i=.5*this.width,o=.5*this.height,n=Math.sin(e)*i,s=Math.cos(e)*o;return i*o/Math.sqrt(n*n+s*s)}}]),e}(d["default"]);e["default"]=l},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;i0){var i=5;this.boundingBox.left=Math.min(this.boundingBox.left,this.labelModule.size.left),this.boundingBox.right=Math.max(this.boundingBox.right,this.labelModule.size.left+this.labelModule.size.width),this.boundingBox.bottom=Math.max(this.boundingBox.bottom,this.boundingBox.bottom+this.labelModule.size.height+i)}}},{key:"_icon",value:function(t,e,i,o){var n=Number(this.options.icon.size);void 0!==this.options.icon.code?(t.font=(o?"bold ":"")+n+"px "+this.options.icon.face,t.fillStyle=this.options.icon.color||"black",t.textAlign="center",t.textBaseline="middle",this.enableShadow(t),t.fillText(this.options.icon.code,e,i),this.disableShadow(t)):console.error("When using the icon shape, you need to define the code in the icon options object. This can be done per node or globally.")}},{key:"distanceToBorder",value:function(t,e){return this._distanceToBorder(t,e)}}]),e}(d["default"]);e["default"]=l},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;i0&&(this.enableBorderDashes(t),t.stroke(),this.disableBorderDashes(t)),t.restore(),t.closePath()}this._drawImageAtPosition(t),this._drawImageLabel(t,e,i,o||n),this.updateBoundingBox(e,i)}},{key:"updateBoundingBox",value:function(t,e){this.resize(),this.left=t-this.width/2,this.top=e-this.height/2,this.boundingBox.top=this.top,this.boundingBox.left=this.left,this.boundingBox.right=this.left+this.width,this.boundingBox.bottom=this.top+this.height,void 0!==this.options.label&&this.labelModule.size.width>0&&(this.boundingBox.left=Math.min(this.boundingBox.left,this.labelModule.size.left),this.boundingBox.right=Math.max(this.boundingBox.right,this.labelModule.size.left+this.labelModule.size.width),this.boundingBox.bottom=Math.max(this.boundingBox.bottom,this.boundingBox.bottom+this.labelOffset))}},{key:"distanceToBorder", +value:function(t,e){return this._distanceToBorder(t,e)}}]),e}(d["default"]);e["default"]=l},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;ii.shape.height?(r=i.x+.5*i.shape.width,a=i.y-h):(r=i.x+h,a=i.y-.5*i.shape.height),s=this._pointOnCircle(r,a,h,.125),this.labelModule.draw(t,s.x,s.y,n)}}}},{key:"isOverlappingWith",value:function(t){if(this.connected){var e=10,i=this.from.x,o=this.from.y,n=this.to.x,s=this.to.y,r=t.left,a=t.top,h=this.edgeType.getDistanceToEdge(i,o,n,s,r,a);return e>h}return!1}},{key:"_rotateForLabelAlignment",value:function(t){var e=this.from.y-this.to.y,i=this.from.x-this.to.x,o=Math.atan2(e,i);(-1>o&&0>i||o>0&&0>i)&&(o+=Math.PI),t.rotate(o)}},{key:"_pointOnCircle",value:function(t,e,i,o){var n=2*o*Math.PI;return{x:t+i*Math.cos(n),y:e-i*Math.sin(n)}}},{key:"select",value:function(){this.selected=!0}},{key:"unselect",value:function(){this.selected=!1}},{key:"cleanup",value:function(){return this.edgeType.cleanup()}}],[{key:"parseOptions",value:function(t,e){var i=arguments.length<=2||void 0===arguments[2]?!1:arguments[2],o=arguments.length<=3||void 0===arguments[3]?{}:arguments[3],n=["arrowStrikethrough","id","from","hidden","hoverWidth","label","labelHighlightBold","length","line","opacity","physics","scaling","selectionWidth","selfReferenceSize","to","title","value","width"];if(g.selectiveDeepExtend(n,t,e,i),g.mergeOptions(t,e,"smooth",i,o),g.mergeOptions(t,e,"shadow",i,o),void 0!==e.dashes&&null!==e.dashes?t.dashes=e.dashes:i===!0&&null===e.dashes&&(t.dashes=Object.create(o.dashes)),void 0!==e.scaling&&null!==e.scaling?(void 0!==e.scaling.min&&(t.scaling.min=e.scaling.min),void 0!==e.scaling.max&&(t.scaling.max=e.scaling.max),g.mergeOptions(t.scaling,e.scaling,"label",i,o.scaling)):i===!0&&null===e.scaling&&(t.scaling=Object.create(o.scaling)),void 0!==e.arrows&&null!==e.arrows)if("string"==typeof e.arrows){var r=e.arrows.toLowerCase();t.arrows.to.enabled=-1!=r.indexOf("to"),t.arrows.middle.enabled=-1!=r.indexOf("middle"),t.arrows.from.enabled=-1!=r.indexOf("from")}else{if("object"!==s(e.arrows))throw new Error("The arrow newOptions can only be an object or a string. Refer to the documentation. You used:"+JSON.stringify(e.arrows));g.mergeOptions(t.arrows,e.arrows,"to",i,o.arrows),g.mergeOptions(t.arrows,e.arrows,"middle",i,o.arrows),g.mergeOptions(t.arrows,e.arrows,"from",i,o.arrows)}else i===!0&&null===e.arrows&&(t.arrows=Object.create(o.arrows));if(void 0!==e.color&&null!==e.color)if(t.color=g.deepExtend({},t.color,!0),g.isString(e.color))t.color.color=e.color,t.color.highlight=e.color,t.color.hover=e.color,t.color.inherit=!1;else{var a=!1;void 0!==e.color.color&&(t.color.color=e.color.color,a=!0),void 0!==e.color.highlight&&(t.color.highlight=e.color.highlight,a=!0),void 0!==e.color.hover&&(t.color.hover=e.color.hover,a=!0),void 0!==e.color.inherit&&(t.color.inherit=e.color.inherit),void 0!==e.color.opacity&&(t.color.opacity=Math.min(1,Math.max(0,e.color.opacity))),void 0===e.color.inherit&&a===!0&&(t.color.inherit=!1)}else i===!0&&null===e.color&&(t.color=g.bridgeObject(o.color));void 0!==e.font&&null!==e.font?h["default"].parseOptions(t.font,e):i===!0&&null===e.font&&(t.font=g.bridgeObject(o.font))}}]),t}();e["default"]=y},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){var i=[],o=!0,n=!1,s=void 0;try{for(var r,a=t[Symbol.iterator]();!(o=(r=a.next()).done)&&(i.push(r.value),!e||i.length!==e);o=!0);}catch(h){n=!0,s=h}finally{try{!o&&a["return"]&&a["return"]()}finally{if(n)throw s}}return i}return function(e,i){if(Array.isArray(e))return e;if(Symbol.iterator in Object(e))return t(e,i);throw new TypeError("Invalid attempt to destructure non-iterable instance")}}(),h=function(){function t(t,e){for(var i=0;iMath.abs(e)||this.options.smooth.forceDirection===!0||"horizontal"===this.options.smooth.forceDirection)&&"vertical"!==this.options.smooth.forceDirection?(o=this.from.y,s=this.to.y,i=this.from.x-r*t,n=this.to.x+r*t):(o=this.from.y-r*e,s=this.to.y+r*e,i=this.from.x,n=this.to.x),[{x:i,y:o},{x:n,y:s}]}},{key:"getViaNode",value:function(){return this._getViaCoordinates()}},{key:"_findBorderPosition",value:function(t,e){return this._findBorderPositionBezier(t,e)}},{key:"_getDistanceToEdge",value:function(t,e,i,o,n,s){var r=arguments.length<=6||void 0===arguments[6]?this._getViaCoordinates():arguments[6],h=a(r,2),d=h[0],l=h[1];return this._getDistanceToBezierEdge(t,e,i,o,n,s,d,l)}},{key:"getPoint",value:function(t){var e=arguments.length<=1||void 0===arguments[1]?this._getViaCoordinates():arguments[1],i=a(e,2),o=i[0],n=i[1],s=t,r=[];r[0]=Math.pow(1-s,3),r[1]=3*s*Math.pow(1-s,2),r[2]=3*Math.pow(s,2)*(1-s),r[3]=Math.pow(s,3);var h=r[0]*this.fromPoint.x+r[1]*o.x+r[2]*n.x+r[3]*this.toPoint.x,d=r[0]*this.fromPoint.y+r[1]*o.y+r[2]*n.y+r[3]*this.toPoint.y;return{x:h,y:d}}}]),e}(l["default"]);e["default"]=c},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;il;l++)c=.1*l,v[0]=Math.pow(1-c,3),v[1]=3*c*Math.pow(1-c,2),v[2]=3*Math.pow(c,2)*(1-c),v[3]=Math.pow(c,3),u=v[0]*t+v[1]*r.x+v[2]*a.x+v[3]*i,p=v[0]*e+v[1]*r.y+v[2]*a.y+v[3]*o,l>0&&(d=this._getDistanceToLine(f,m,u,p,n,s),h=h>d?d:h),f=u,m=p;return h}}]),e}(d["default"]);e["default"]=l},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;i=l&&h>d;){var m=.5*(l+c);if(i=this.getPoint(m,a),o=Math.atan2(p.y-i.y,p.x-i.x),n=p.distanceToBorder(e,o),s=Math.sqrt(Math.pow(i.x-p.x,2)+Math.pow(i.y-p.y,2)),r=n-s,Math.abs(r)r?f===!1?l=m:c=m:f===!1?c=m:l=m,d++}return i.t=m,i}},{key:"_getDistanceToBezierEdge",value:function(t,e,i,o,n,s,r){var a=1e9,h=void 0,d=void 0,l=void 0,c=void 0,u=void 0,p=t,f=e;for(d=1;10>d;d++)l=.1*d,c=Math.pow(1-l,2)*t+2*l*(1-l)*r.x+Math.pow(l,2)*i,u=Math.pow(1-l,2)*e+2*l*(1-l)*r.y+Math.pow(l,2)*o,d>0&&(h=this._getDistanceToLine(p,f,c,u,n,s),a=a>h?h:a),p=c,f=u;return a}}]),e}(d["default"]);e["default"]=l},function(t,e,i){function o(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var n=function(){function t(t,e){var i=[],o=!0,n=!1,s=void 0;try{for(var r,a=t[Symbol.iterator]();!(o=(r=a.next()).done)&&(i.push(r.value),!e||i.length!==e);o=!0);}catch(h){n=!0,s=h}finally{try{!o&&a["return"]&&a["return"]()}finally{if(n)throw s}}return i}return function(e,i){if(Array.isArray(e))return e;if(Symbol.iterator in Object(e))return t(e,i);throw new TypeError("Invalid attempt to destructure non-iterable instance")}}(),s=function(){function t(t,e){for(var i=0;io.shape.height?(e=o.x+.5*o.shape.width,i=o.y-n):(e=o.x+n,i=o.y-.5*o.shape.height),[e,i,n]}},{key:"_pointOnCircle",value:function(t,e,i,o){var n=2*o*Math.PI;return{x:t+i*Math.cos(n),y:e-i*Math.sin(n)}}},{key:"_findBorderPositionCircle",value:function(t,e,i){for(var o=i.x,n=i.y,s=i.low,r=i.high,a=i.direction,h=10,d=0,l=this.options.selfReferenceSize,c=void 0,u=void 0,p=void 0,f=void 0,m=void 0,v=.05,g=.5*(s+r);r>=s&&h>d&&(g=.5*(s+r),c=this._pointOnCircle(o,n,l,g),u=Math.atan2(t.y-c.y,t.x-c.x),p=t.distanceToBorder(e,u),f=Math.sqrt(Math.pow(c.x-t.x,2)+Math.pow(c.y-t.y,2)), +m=p-f,!(Math.abs(m)0?a>0?s=g:r=g:a>0?r=g:s=g,d++;return c.t=g,c}},{key:"getLineWidth",value:function(t,e){return t===!0?Math.max(this.selectionWidth,.3/this.body.view.scale):e===!0?Math.max(this.hoverWidth,.3/this.body.view.scale):Math.max(this.options.width,.3/this.body.view.scale)}},{key:"getColor",value:function(t,e,i){var o=this.options.color;if(o.inherit!==!1){if("both"===o.inherit&&this.from.id!==this.to.id){var n=t.createLinearGradient(this.from.x,this.from.y,this.to.x,this.to.y),s=void 0,a=void 0;return s=this.from.options.color.highlight.border,a=this.to.options.color.highlight.border,this.from.selected===!1&&this.to.selected===!1?(s=r.overrideOpacity(this.from.options.color.border,this.options.color.opacity),a=r.overrideOpacity(this.to.options.color.border,this.options.color.opacity)):this.from.selected===!0&&this.to.selected===!1?a=this.to.options.color.border:this.from.selected===!1&&this.to.selected===!0&&(s=this.from.options.color.border),n.addColorStop(0,s),n.addColorStop(1,a),n}this.colorDirty===!0&&("to"===o.inherit?(this.color.highlight=this.to.options.color.highlight.border,this.color.hover=this.to.options.color.hover.border,this.color.color=r.overrideOpacity(this.to.options.color.border,o.opacity)):(this.color.highlight=this.from.options.color.highlight.border,this.color.hover=this.from.options.color.hover.border,this.color.color=r.overrideOpacity(this.from.options.color.border,o.opacity)))}else this.colorDirty===!0&&(this.color.highlight=o.highlight,this.color.hover=o.hover,this.color.color=r.overrideOpacity(o.color,o.opacity));return this.colorDirty=!1,e===!0?this.color.highlight:i===!0?this.color.hover:this.color.color}},{key:"_circle",value:function(t,e,i,o){this.enableShadow(t),t.beginPath(),t.arc(e,i,o,0,2*Math.PI,!1),t.stroke(),this.disableShadow(t)}},{key:"getDistanceToEdge",value:function(t,e,i,o,s,r,a){var h=0;if(this.from!=this.to)h=this._getDistanceToEdge(t,e,i,o,s,r,a);else{var d=this._getCircleData(),l=n(d,3),c=l[0],u=l[1],p=l[2],f=c-s,m=u-r;h=Math.abs(Math.sqrt(f*f+m*m)-p)}return this.labelModule.size.lefts&&this.labelModule.size.topr?0:h}},{key:"_getDistanceToLine",value:function(t,e,i,o,n,s){var r=i-t,a=o-e,h=r*r+a*a,d=((n-t)*r+(s-e)*a)/h;d>1?d=1:0>d&&(d=0);var l=t+d*r,c=e+d*a,u=l-n,p=c-s;return Math.sqrt(u*u+p*p)}},{key:"getArrowData",value:function(t,e,i,o,s){var r=void 0,a=void 0,h=void 0,d=void 0,l=void 0,c=void 0,u=this.getLineWidth(o,s);if("from"===e?(h=this.from,d=this.to,l=.1,c=this.options.arrows.from.scaleFactor):"to"===e?(h=this.to,d=this.from,l=-.1,c=this.options.arrows.to.scaleFactor):(h=this.to,d=this.from,c=this.options.arrows.middle.scaleFactor),h!=d)if("middle"!==e)if(this.options.smooth.enabled===!0){a=this.findBorderPosition(h,t,{via:i});var p=this.getPoint(Math.max(0,Math.min(1,a.t+l)),i);r=Math.atan2(a.y-p.y,a.x-p.x)}else r=Math.atan2(h.y-d.y,h.x-d.x),a=this.findBorderPosition(h,t);else r=Math.atan2(h.y-d.y,h.x-d.x),a=this.getPoint(.5,i);else{var f=this._getCircleData(t),m=n(f,3),v=m[0],g=m[1],y=m[2];"from"===e?(a=this.findBorderPosition(this.from,t,{x:v,y:g,low:.25,high:.6,direction:-1}),r=-2*a.t*Math.PI+1.5*Math.PI+.1*Math.PI):"to"===e?(a=this.findBorderPosition(this.from,t,{x:v,y:g,low:.6,high:1,direction:1}),r=-2*a.t*Math.PI+1.5*Math.PI-1.1*Math.PI):(a=this._pointOnCircle(v,g,y,.175),r=3.9269908169872414)}var b=15*c+3*u,w=a.x-.9*b*Math.cos(r),_=a.y-.9*b*Math.sin(r),x={x:w,y:_};return{point:a,core:x,angle:r,length:b}}},{key:"drawArrowHead",value:function(t,e,i,o){t.strokeStyle=this.getColor(t,e,i),t.fillStyle=t.strokeStyle,t.lineWidth=this.getLineWidth(e,i),t.arrow(o.point.x,o.point.y,o.angle,o.length),this.enableShadow(t),t.fill(),this.disableShadow(t)}},{key:"enableShadow",value:function(t){this.options.shadow.enabled===!0&&(t.shadowColor=this.options.shadow.color,t.shadowBlur=this.options.shadow.size,t.shadowOffsetX=this.options.shadow.x,t.shadowOffsetY=this.options.shadow.y)}},{key:"disableShadow",value:function(t){this.options.shadow.enabled===!0&&(t.shadowColor="rgba(0,0,0,0)",t.shadowBlur=0,t.shadowOffsetX=0,t.shadowOffsetY=0)}}]),t}();e["default"]=a},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}function s(t,e){if(!t)throw new ReferenceError("this hasn't been initialised - super() hasn't been called");return!e||"object"!=typeof e&&"function"!=typeof e?t:e}function r(t,e){if("function"!=typeof e&&null!==e)throw new TypeError("Super expression must either be null or a function, not "+typeof e);t.prototype=Object.create(e&&e.prototype,{constructor:{value:t,enumerable:!1,writable:!0,configurable:!0}}),e&&(Object.setPrototypeOf?Object.setPrototypeOf(t,e):t.__proto__=e)}Object.defineProperty(e,"__esModule",{value:!0});var a=function(){function t(t,e){for(var i=0;i=this.to.y?this.from.x<=this.to.x?(t=this.from.x+i*s,e=this.from.y-i*s):this.from.x>this.to.x&&(t=this.from.x-i*s,e=this.from.y-i*s):this.from.ythis.to.x&&(t=this.from.x-i*s,e=this.from.y+i*s)),"discrete"===o&&(t=i*s>n?this.from.x:t)):Math.abs(this.from.x-this.to.x)>Math.abs(this.from.y-this.to.y)&&(this.from.y>=this.to.y?this.from.x<=this.to.x?(t=this.from.x+i*n,e=this.from.y-i*n):this.from.x>this.to.x&&(t=this.from.x-i*n,e=this.from.y-i*n):this.from.ythis.to.x&&(t=this.from.x-i*n,e=this.from.y+i*n)),"discrete"===o&&(e=i*n>s?this.from.y:e));else if("straightCross"===o)Math.abs(this.from.x-this.to.x)<=Math.abs(this.from.y-this.to.y)?(t=this.from.x,e=this.from.yMath.abs(this.from.y-this.to.y)&&(t=this.from.x=this.to.y?this.from.x<=this.to.x?(t=this.from.x+i*s,e=this.from.y-i*s,t=this.to.xthis.to.x&&(t=this.from.x-i*s,e=this.from.y-i*s,t=this.to.x>t?this.to.x:t):this.from.ythis.to.x&&(t=this.from.x-i*s,e=this.from.y+i*s,t=this.to.x>t?this.to.x:t)):Math.abs(this.from.x-this.to.x)>Math.abs(this.from.y-this.to.y)&&(this.from.y>=this.to.y?this.from.x<=this.to.x?(t=this.from.x+i*n,e=this.from.y-i*n,e=this.to.y>e?this.to.y:e):this.from.x>this.to.x&&(t=this.from.x-i*n,e=this.from.y-i*n,e=this.to.y>e?this.to.y:e):this.from.ythis.to.x&&(t=this.from.x-i*n,e=this.from.y+i*n,e=this.to.y1||this.startedStabilization===!0)&&setTimeout(function(){t.body.emitter.emit("stabilized",{iterations:e}),t.startedStabilization=!1,t.stabilizationIterations=0},0)}},{key:"physicsTick",value:function(){if(this.startedStabilization===!1&&(this.body.emitter.emit("startStabilizing"),this.startedStabilization=!0),this.stabilized===!1){if(this.adaptiveTimestep===!0&&this.adaptiveTimestepEnabled===!0){var t=1.2;this.adaptiveCounter%this.adaptiveInterval===0?(this.timestep=2*this.timestep,this.calculateForces(),this.moveNodes(),this.revert(),this.timestep=.5*this.timestep,this.calculateForces(),this.moveNodes(),this.calculateForces(),this.moveNodes(),this._evaluateStepQuality()===!0?this.timestep=t*this.timestep:this.timestep/ts))return!1;return!0}},{key:"moveNodes",value:function(){for(var t=this.physicsBody.physicsNodeIndices,e=this.options.maxVelocity?this.options.maxVelocity:1e9,i=0,o=0,n=5,s=0;se?s[t].x>0?e:-e:s[t].x,i.x+=s[t].x*o}else n[t].x=0,s[t].x=0;if(i.options.fixed.y===!1){var h=this.modelOptions.damping*s[t].y,d=(n[t].y-h)/i.options.mass;s[t].y+=d*o,s[t].y=Math.abs(s[t].y)>e?s[t].y>0?e:-e:s[t].y,i.y+=s[t].y*o}else n[t].y=0,s[t].y=0;var l=Math.sqrt(Math.pow(s[t].x,2)+Math.pow(s[t].y,2));return l}},{key:"calculateForces",value:function(){this.gravitySolver.solve(),this.nodesSolver.solve(),this.edgesSolver.solve()}},{key:"_freezeNodes",value:function(){var t=this.body.nodes;for(var e in t)t.hasOwnProperty(e)&&t[e].x&&t[e].y&&(this.freezeCache[e]={x:t[e].options.fixed.x,y:t[e].options.fixed.y},t[e].options.fixed.x=!0,t[e].options.fixed.y=!0)}},{key:"_restoreFrozenNodes",value:function(){var t=this.body.nodes;for(var e in t)t.hasOwnProperty(e)&&void 0!==this.freezeCache[e]&&(t[e].options.fixed.x=this.freezeCache[e].x,t[e].options.fixed.y=this.freezeCache[e].y);this.freezeCache={}}},{key:"stabilize",value:function(){var t=this,e=arguments.length<=0||void 0===arguments[0]?this.options.stabilization.iterations:arguments[0];return"number"!=typeof e&&(console.log("The stabilize method needs a numeric amount of iterations. Switching to default: ",this.options.stabilization.iterations),e=this.options.stabilization.iterations),0===this.physicsBody.physicsNodeIndices.length?void(this.ready=!0):(this.adaptiveTimestep=this.options.adaptiveTimestep,this.body.emitter.emit("_resizeNodes"),this.stopSimulation(),this.stabilized=!1,this.body.emitter.emit("_blockRedraw"),this.targetIterations=e,this.options.stabilization.onlyDynamicEdges===!0&&this._freezeNodes(),this.stabilizationIterations=0,void setTimeout(function(){return t._stabilizationBatch()},0))}},{key:"_stabilizationBatch",value:function(){this.startedStabilization===!1&&(this.body.emitter.emit("startStabilizing"),this.startedStabilization=!0);for(var t=0;this.stabilized===!1&&t0){var t=void 0,e=this.body.nodes,i=this.physicsBody.physicsNodeIndices,o=i.length,n=this._formBarnesHutTree(e,i);this.barnesHutTree=n;for(var s=0;o>s;s++)t=e[i[s]],t.options.mass>0&&(this._getForceContribution(n.root.children.NW,t),this._getForceContribution(n.root.children.NE,t),this._getForceContribution(n.root.children.SW,t),this._getForceContribution(n.root.children.SE,t))}}},{key:"_getForceContribution",value:function(t,e){if(t.childrenCount>0){var i=void 0,o=void 0,n=void 0;i=t.centerOfMass.x-e.x,o=t.centerOfMass.y-e.y,n=Math.sqrt(i*i+o*o),n*t.calcSize>this.thetaInversed?this._calculateForces(n,i,o,e,t):4===t.childrenCount?(this._getForceContribution(t.children.NW,e),this._getForceContribution(t.children.NE,e),this._getForceContribution(t.children.SW,e),this._getForceContribution(t.children.SE,e)):t.children.data.id!=e.id&&this._calculateForces(n,i,o,e,t)}}},{key:"_calculateForces",value:function(t,e,i,o,n){0===t&&(t=.1,e=t),this.overlapAvoidanceFactor<1&&(t=Math.max(.1+this.overlapAvoidanceFactor*o.shape.radius,t-o.shape.radius));var s=this.options.gravitationalConstant*n.mass*o.options.mass/Math.pow(t,3),r=e*s,a=i*s;this.physicsBody.forces[o.id].x+=r,this.physicsBody.forces[o.id].y+=a}},{key:"_formBarnesHutTree",value:function(t,e){for(var i=void 0,o=e.length,n=t[e[0]].x,s=t[e[0]].y,r=t[e[0]].x,a=t[e[0]].y,h=1;o>h;h++){var d=t[e[h]].x,l=t[e[h]].y;t[e[h]].options.mass>0&&(n>d&&(n=d),d>r&&(r=d),s>l&&(s=l),l>a&&(a=l))}var c=Math.abs(r-n)-Math.abs(a-s);c>0?(s-=.5*c,a+=.5*c):(n+=.5*c,r-=.5*c);var u=1e-5,p=Math.max(u,Math.abs(r-n)),f=.5*p,m=.5*(n+r),v=.5*(s+a),g={root:{centerOfMass:{x:0,y:0},mass:0,range:{minX:m-f,maxX:m+f,minY:v-f,maxY:v+f},size:p,calcSize:1/p,children:{data:null},maxWidth:0,level:0,childrenCount:4}};this._splitBranch(g.root);for(var y=0;o>y;y++)i=t[e[y]],i.options.mass>0&&this._placeInTree(g.root,i);return g}},{key:"_updateBranchMass",value:function(t,e){var i=t.mass+e.options.mass,o=1/i;t.centerOfMass.x=t.centerOfMass.x*t.mass+e.x*e.options.mass,t.centerOfMass.x*=o,t.centerOfMass.y=t.centerOfMass.y*t.mass+e.y*e.options.mass,t.centerOfMass.y*=o,t.mass=i;var n=Math.max(Math.max(e.height,e.radius),e.width);t.maxWidth=t.maxWidthe.x?t.children.NW.range.maxY>e.y?this._placeInRegion(t,e,"NW"):this._placeInRegion(t,e,"SW"):t.children.NW.range.maxY>e.y?this._placeInRegion(t,e,"NE"):this._placeInRegion(t,e,"SE")}},{key:"_placeInRegion",value:function(t,e,i){switch(t.children[i].childrenCount){case 0:t.children[i].children.data=e,t.children[i].childrenCount=1,this._updateBranchMass(t.children[i],e);break;case 1:t.children[i].children.data.x===e.x&&t.children[i].children.data.y===e.y?(e.x+=this.seededRandom(),e.y+=this.seededRandom()):(this._splitBranch(t.children[i]),this._placeInTree(t.children[i],e));break;case 4:this._placeInTree(t.children[i],e)}}},{key:"_splitBranch",value:function(t){var e=null;1===t.childrenCount&&(e=t.children.data,t.mass=0,t.centerOfMass.x=0,t.centerOfMass.y=0),t.childrenCount=4,t.children.data=null,this._insertRegion(t,"NW"),this._insertRegion(t,"NE"),this._insertRegion(t,"SW"),this._insertRegion(t,"SE"),null!=e&&this._placeInTree(t,e)}},{key:"_insertRegion",value:function(t,e){var i=void 0,o=void 0,n=void 0,s=void 0,r=.5*t.size;switch(e){case"NW":i=t.range.minX,o=t.range.minX+r,n=t.range.minY,s=t.range.minY+r;break;case"NE":i=t.range.minX+r,o=t.range.maxX,n=t.range.minY,s=t.range.minY+r;break;case"SW":i=t.range.minX,o=t.range.minX+r,n=t.range.minY+r,s=t.range.maxY;break;case"SE":i=t.range.minX+r,o=t.range.maxX,n=t.range.minY+r,s=t.range.maxY}t.children[e]={centerOfMass:{x:0,y:0},mass:0,range:{minX:i,maxX:o,minY:n,maxY:s},size:.5*t.size,calcSize:2*t.calcSize,children:{data:null},maxWidth:0,level:t.level+1,childrenCount:0}}},{key:"_debug",value:function(t,e){void 0!==this.barnesHutTree&&(t.lineWidth=1,this._drawBranch(this.barnesHutTree.root,t,e))}},{key:"_drawBranch",value:function(t,e,i){void 0===i&&(i="#FF0000"),4===t.childrenCount&&(this._drawBranch(t.children.NW,e),this._drawBranch(t.children.NE,e),this._drawBranch(t.children.SE,e),this._drawBranch(t.children.SW,e)),e.strokeStyle=i,e.beginPath(),e.moveTo(t.range.minX,t.range.minY),e.lineTo(t.range.maxX,t.range.minY),e.stroke(),e.beginPath(),e.moveTo(t.range.maxX,t.range.minY),e.lineTo(t.range.maxX,t.range.maxY),e.stroke(),e.beginPath(),e.moveTo(t.range.maxX,t.range.maxY), +e.lineTo(t.range.minX,t.range.maxY),e.stroke(),e.beginPath(),e.moveTo(t.range.minX,t.range.maxY),e.lineTo(t.range.minX,t.range.minY),e.stroke()}}]),t}();e["default"]=n},function(t,e){function i(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var o=function(){function t(t,e){for(var i=0;ii&&(s=.5*c>i?1:u*i+p,s/=i,o=t*s,n=e*s,l[r.id].x-=o,l[r.id].y-=n,l[a.id].x+=o,l[a.id].y+=n)}}}]),t}();e["default"]=n},function(t,e){function i(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var o=function(){function t(t,e){for(var i=0;ii?-Math.pow(f*i,2)+Math.pow(f*p,2):0,0===i?i=.01:s/=i,o=t*s,n=e*s,u[r.id].x-=o,u[r.id].y-=n,u[a.id].x+=o,u[a.id].y+=n}}}]),t}();e["default"]=n},function(t,e){function i(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var o=function(){function t(t,e){for(var i=0;i0){var s=n.edges.length+1,r=this.options.centralGravity*s*n.options.mass;o[n.id].x=e*r,o[n.id].y=i*r}}}]),e}(d["default"]);e["default"]=l},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var s="function"==typeof Symbol&&"symbol"==typeof Symbol.iterator?function(t){return typeof t}:function(t){return t&&"function"==typeof Symbol&&t.constructor===Symbol?"symbol":typeof t},r=function(){function t(t,e){for(var i=0;i=t&&i.push(n.id)}for(var r=0;r0&&Object.keys(p).length>0&&m===!0&&o.push({nodes:u,edges:p})}}}for(var b=0;bo?r.x:o,n=r.ys?r.y:s;return{x:.5*(i+o),y:.5*(n+s)}}},{key:"openCluster",value:function(t,e){var i=arguments.length<=2||void 0===arguments[2]?!0:arguments[2];if(void 0===t)throw new Error("No clusterNodeId supplied to openCluster.");if(void 0===this.body.nodes[t])throw new Error("The clusterNodeId supplied to openCluster does not exist.");if(void 0===this.body.nodes[t].containedNodes)return void console.log("The node:"+t+" is not a cluster.");var o=this.body.nodes[t],n=o.containedNodes,s=o.containedEdges;if(void 0!==e&&void 0!==e.releaseFunction&&"function"==typeof e.releaseFunction){var r={},a={x:o.x,y:o.y};for(var d in n)if(n.hasOwnProperty(d)){var l=this.body.nodes[d];r[d]={x:l.x,y:l.y}}var u=e.releaseFunction(a,r);for(var p in n)if(n.hasOwnProperty(p)){var f=this.body.nodes[p];void 0!==u[p]&&(f.x=void 0===u[p].x?o.x:u[p].x,f.y=void 0===u[p].y?o.y:u[p].y)}}else for(var m in n)if(n.hasOwnProperty(m)){var v=this.body.nodes[m];v=n[m],v.options.fixed.x===!1&&(v.x=o.x),v.options.fixed.y===!1&&(v.y=o.y)}for(var g in n)if(n.hasOwnProperty(g)){var y=this.body.nodes[g];y.vx=o.vx,y.vy=o.vy,y.setOptions({hidden:!1,physics:!0}),delete this.clusteredNodes[g]}for(var b=[],w=0;wo;)e.push(this.body.nodes[t].id),t=this.clusteredNodes[t].clusterId,o++;return e.push(this.body.nodes[t].id),e.reverse(),e}},{key:"_getConnectedId",value:function(t,e){return t.toId!=e?t.toId:t.fromId!=e?t.fromId:t.fromId}},{key:"_getHubSize",value:function(){for(var t=0,e=0,i=0,o=0,n=0;no&&(o=s.edges.length),t+=s.edges.length,e+=Math.pow(s.edges.length,2),i+=1}t/=i,e/=i;var r=e-Math.pow(t,2),a=Math.sqrt(r),h=Math.floor(t+2*a);return h>o&&(h=o),h}}]),t}();e["default"]=u},function(t,e,i){function o(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var n=function(){function t(t,e){for(var i=0;i0)for(var a=0;ae.shape.boundingBox.left&&(s=e.shape.boundingBox.left),re.shape.boundingBox.top&&(o=e.shape.boundingBox.top),n0)for(var a=0;ae.x&&(s=e.x),re.y&&(o=e.y),n0,t.renderTimer=void 0}),this.body.emitter.on("destroy",function(){t.renderRequests=0,t.allowRedraw=!1,t.renderingActive=!1,t.requiresTimeout===!0?clearTimeout(t.renderTimer):cancelAnimationFrame(t.renderTimer),t.body.emitter.off()})}},{key:"setOptions",value:function(t){if(void 0!==t){var e=["hideEdgesOnDrag","hideNodesOnDrag"];s.selectiveDeepExtend(e,this.options,t)}}},{key:"_startRendering",value:function(){this.renderingActive===!0&&void 0===this.renderTimer&&(this.requiresTimeout===!0?this.renderTimer=window.setTimeout(this._renderStep.bind(this),this.simulationInterval):this.renderTimer=window.requestAnimationFrame(this._renderStep.bind(this)))}},{key:"_renderStep",value:function(){this.renderingActive===!0&&(this.renderTimer=void 0,this.requiresTimeout===!0&&this._startRendering(),this._redraw(),this.requiresTimeout===!1&&this._startRendering())}},{key:"redraw",value:function(){this.body.emitter.emit("setSize"),this._redraw()}},{key:"_requestRedraw",value:function(){var t=this;this.redrawRequested!==!0&&this.renderingActive===!1&&this.allowRedraw===!0&&(this.redrawRequested=!0,this.requiresTimeout===!0?window.setTimeout(function(){t._redraw(!1)},0):window.requestAnimationFrame(function(){t._redraw(!1)}))}},{key:"_redraw",value:function(){var t=arguments.length<=0||void 0===arguments[0]?!1:arguments[0];if(this.allowRedraw===!0){this.body.emitter.emit("initRedraw"),this.redrawRequested=!1;var e=this.canvas.frame.canvas.getContext("2d");0!==this.canvas.frame.canvas.width&&0!==this.canvas.frame.canvas.height||this.canvas.setSize(),this.pixelRatio=(window.devicePixelRatio||1)/(e.webkitBackingStorePixelRatio||e.mozBackingStorePixelRatio||e.msBackingStorePixelRatio||e.oBackingStorePixelRatio||e.backingStorePixelRatio||1),e.setTransform(this.pixelRatio,0,0,this.pixelRatio,0,0);var i=this.canvas.frame.canvas.clientWidth,o=this.canvas.frame.canvas.clientHeight;if(e.clearRect(0,0,i,o),0===this.canvas.frame.clientWidth)return;e.save(),e.translate(this.body.view.translation.x,this.body.view.translation.y),e.scale(this.body.view.scale,this.body.view.scale),e.beginPath(),this.body.emitter.emit("beforeDrawing",e),e.closePath(),t===!1&&(this.dragging===!1||this.dragging===!0&&this.options.hideEdgesOnDrag===!1)&&this._drawEdges(e),(this.dragging===!1||this.dragging===!0&&this.options.hideNodesOnDrag===!1)&&this._drawNodes(e,t),e.beginPath(),this.body.emitter.emit("afterDrawing",e),e.closePath(),e.restore(),t===!0&&e.clearRect(0,0,i,o)}}},{key:"_resizeNodes",value:function(){var t=this.canvas.frame.canvas.getContext("2d");void 0===this.pixelRatio&&(this.pixelRatio=(window.devicePixelRatio||1)/(t.webkitBackingStorePixelRatio||t.mozBackingStorePixelRatio||t.msBackingStorePixelRatio||t.oBackingStorePixelRatio||t.backingStorePixelRatio||1)),t.setTransform(this.pixelRatio,0,0,this.pixelRatio,0,0),t.save(),t.translate(this.body.view.translation.x,this.body.view.translation.y),t.scale(this.body.view.scale,this.body.view.scale);var e=this.body.nodes,i=void 0;for(var o in e)e.hasOwnProperty(o)&&(i=e[o],i.resize(t),i.updateBoundingBox(t,i.selected));t.restore()}},{key:"_drawNodes",value:function(t){for(var e=arguments.length<=1||void 0===arguments[1]?!1:arguments[1],i=this.body.nodes,o=this.body.nodeIndices,n=void 0,s=[],r=20,a=this.canvas.DOMtoCanvas({x:-r,y:-r}),h=this.canvas.DOMtoCanvas({x:this.canvas.frame.canvas.clientWidth+r,y:this.canvas.frame.canvas.clientHeight+r}),d={top:a.y,left:a.x,bottom:h.y,right:h.x},l=0;l0){var t=this.frame.canvas.width/this.pixelRatio/this.cameraState.previousWidth,e=this.frame.canvas.height/this.pixelRatio/this.cameraState.previousHeight,i=this.cameraState.scale;1!=t&&1!=e?i=.5*this.cameraState.scale*(t+e):1!=t?i=this.cameraState.scale*t:1!=e&&(i=this.cameraState.scale*e),this.body.view.scale=i;var o=this.DOMtoCanvas({x:.5*this.frame.canvas.clientWidth,y:.5*this.frame.canvas.clientHeight}),n={x:o.x-this.cameraState.position.x,y:o.y-this.cameraState.position.y};this.body.view.translation.x+=n.x*this.body.view.scale,this.body.view.translation.y+=n.y*this.body.view.scale}}},{key:"_prepareValue",value:function(t){if("number"==typeof t)return t+"px";if("string"==typeof t){if(-1!==t.indexOf("%")||-1!==t.indexOf("px"))return t;if(-1===t.indexOf("%"))return t+"px"}throw new Error("Could not use the value supplied for width or height:"+t)}},{key:"_create",value:function(){for(;this.body.container.hasChildNodes();)this.body.container.removeChild(this.body.container.firstChild);if(this.frame=document.createElement("div"),this.frame.className="vis-network",this.frame.style.position="relative",this.frame.style.overflow="hidden",this.frame.tabIndex=900,this.frame.canvas=document.createElement("canvas"),this.frame.canvas.style.position="relative",this.frame.appendChild(this.frame.canvas),this.frame.canvas.getContext){var t=this.frame.canvas.getContext("2d");this.pixelRatio=(window.devicePixelRatio||1)/(t.webkitBackingStorePixelRatio||t.mozBackingStorePixelRatio||t.msBackingStorePixelRatio||t.oBackingStorePixelRatio||t.backingStorePixelRatio||1), +this.frame.canvas.getContext("2d").setTransform(this.pixelRatio,0,0,this.pixelRatio,0,0)}else{var e=document.createElement("DIV");e.style.color="red",e.style.fontWeight="bold",e.style.padding="10px",e.innerHTML="Error: your browser does not support HTML canvas",this.frame.canvas.appendChild(e)}this.body.container.appendChild(this.frame),this.body.view.scale=1,this.body.view.translation={x:.5*this.frame.canvas.clientWidth,y:.5*this.frame.canvas.clientHeight},this._bindHammer()}},{key:"_bindHammer",value:function(){var t=this;void 0!==this.hammer&&this.hammer.destroy(),this.drag={},this.pinch={},this.hammer=new s(this.frame.canvas),this.hammer.get("pinch").set({enable:!0}),this.hammer.get("pan").set({threshold:5,direction:s.DIRECTION_ALL}),r.onTouch(this.hammer,function(e){t.body.eventListeners.onTouch(e)}),this.hammer.on("tap",function(e){t.body.eventListeners.onTap(e)}),this.hammer.on("doubletap",function(e){t.body.eventListeners.onDoubleTap(e)}),this.hammer.on("press",function(e){t.body.eventListeners.onHold(e)}),this.hammer.on("panstart",function(e){t.body.eventListeners.onDragStart(e)}),this.hammer.on("panmove",function(e){t.body.eventListeners.onDrag(e)}),this.hammer.on("panend",function(e){t.body.eventListeners.onDragEnd(e)}),this.hammer.on("pinch",function(e){t.body.eventListeners.onPinch(e)}),this.frame.canvas.addEventListener("mousewheel",function(e){t.body.eventListeners.onMouseWheel(e)}),this.frame.canvas.addEventListener("DOMMouseScroll",function(e){t.body.eventListeners.onMouseWheel(e)}),this.frame.canvas.addEventListener("mousemove",function(e){t.body.eventListeners.onMouseMove(e)}),this.frame.canvas.addEventListener("contextmenu",function(e){t.body.eventListeners.onContext(e)}),this.hammerFrame=new s(this.frame),r.onRelease(this.hammerFrame,function(e){t.body.eventListeners.onRelease(e)})}},{key:"setSize",value:function(){var t=arguments.length<=0||void 0===arguments[0]?this.options.width:arguments[0],e=arguments.length<=1||void 0===arguments[1]?this.options.height:arguments[1];t=this._prepareValue(t),e=this._prepareValue(e);var i=!1,o=this.frame.canvas.width,n=this.frame.canvas.height,s=this.frame.canvas.getContext("2d"),r=this.pixelRatio;return this.pixelRatio=(window.devicePixelRatio||1)/(s.webkitBackingStorePixelRatio||s.mozBackingStorePixelRatio||s.msBackingStorePixelRatio||s.oBackingStorePixelRatio||s.backingStorePixelRatio||1),t!=this.options.width||e!=this.options.height||this.frame.style.width!=t||this.frame.style.height!=e?(this._getCameraState(r),this.frame.style.width=t,this.frame.style.height=e,this.frame.canvas.style.width="100%",this.frame.canvas.style.height="100%",this.frame.canvas.width=Math.round(this.frame.canvas.clientWidth*this.pixelRatio),this.frame.canvas.height=Math.round(this.frame.canvas.clientHeight*this.pixelRatio),this.options.width=t,this.options.height=e,i=!0):(this.frame.canvas.width==Math.round(this.frame.canvas.clientWidth*this.pixelRatio)&&this.frame.canvas.height==Math.round(this.frame.canvas.clientHeight*this.pixelRatio)||this._getCameraState(r),this.frame.canvas.width!=Math.round(this.frame.canvas.clientWidth*this.pixelRatio)&&(this.frame.canvas.width=Math.round(this.frame.canvas.clientWidth*this.pixelRatio),i=!0),this.frame.canvas.height!=Math.round(this.frame.canvas.clientHeight*this.pixelRatio)&&(this.frame.canvas.height=Math.round(this.frame.canvas.clientHeight*this.pixelRatio),i=!0)),i===!0&&(this.body.emitter.emit("resize",{width:Math.round(this.frame.canvas.width/this.pixelRatio),height:Math.round(this.frame.canvas.height/this.pixelRatio),oldWidth:Math.round(o/this.pixelRatio),oldHeight:Math.round(n/this.pixelRatio)}),this._setCameraState()),this.initialized=!0,i}},{key:"_XconvertDOMtoCanvas",value:function(t){return(t-this.body.view.translation.x)/this.body.view.scale}},{key:"_XconvertCanvasToDOM",value:function(t){return t*this.body.view.scale+this.body.view.translation.x}},{key:"_YconvertDOMtoCanvas",value:function(t){return(t-this.body.view.translation.y)/this.body.view.scale}},{key:"_YconvertCanvasToDOM",value:function(t){return t*this.body.view.scale+this.body.view.translation.y}},{key:"canvasToDOM",value:function(t){return{x:this._XconvertCanvasToDOM(t.x),y:this._YconvertCanvasToDOM(t.y)}}},{key:"DOMtoCanvas",value:function(t){return{x:this._XconvertDOMtoCanvas(t.x),y:this._YconvertDOMtoCanvas(t.y)}}}]),t}();e["default"]=h},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var s=function(){function t(t,e){for(var i=0;i.5*this.body.nodeIndices.length)return void this.fit(t,!1);i=a["default"].getRange(this.body.nodes,t.nodes);var h=this.body.nodeIndices.length;o=12.662/(h+7.4147)+.0964822;var d=Math.min(this.canvas.frame.canvas.clientWidth/600,this.canvas.frame.canvas.clientHeight/600);o*=d}else{this.body.emitter.emit("_resizeNodes"),i=a["default"].getRange(this.body.nodes,t.nodes);var l=1.1*Math.abs(i.maxX-i.minX),c=1.1*Math.abs(i.maxY-i.minY),u=this.canvas.frame.canvas.clientWidth/l,p=this.canvas.frame.canvas.clientHeight/c;o=p>=u?u:p}o>1?o=1:0===o&&(o=1);var f=a["default"].findCenter(i),m={position:f,scale:o,animation:t.animation};this.moveTo(m)}},{key:"focus",value:function(t){var e=arguments.length<=1||void 0===arguments[1]?{}:arguments[1];if(void 0!==this.body.nodes[t]){var i={x:this.body.nodes[t].x,y:this.body.nodes[t].y};e.position=i,e.lockedOnNode=t,this.moveTo(e)}else console.log("Node: "+t+" cannot be found.")}},{key:"moveTo",value:function(t){return void 0===t?void(t={}):(void 0===t.offset&&(t.offset={x:0,y:0}),void 0===t.offset.x&&(t.offset.x=0),void 0===t.offset.y&&(t.offset.y=0),void 0===t.scale&&(t.scale=this.body.view.scale),void 0===t.position&&(t.position=this.getViewPosition()),void 0===t.animation&&(t.animation={duration:0}),t.animation===!1&&(t.animation={duration:0}),t.animation===!0&&(t.animation={}),void 0===t.animation.duration&&(t.animation.duration=1e3),void 0===t.animation.easingFunction&&(t.animation.easingFunction="easeInOutQuad"),void this.animateView(t))}},{key:"animateView",value:function(t){if(void 0!==t){this.animationEasingFunction=t.animation.easingFunction,this.releaseNode(),t.locked===!0&&(this.lockedOnNodeId=t.lockedOnNode,this.lockedOnNodeOffset=t.offset),0!=this.easingTime&&this._transitionRedraw(!0),this.sourceScale=this.body.view.scale,this.sourceTranslation=this.body.view.translation,this.targetScale=t.scale,this.body.view.scale=this.targetScale;var e=this.canvas.DOMtoCanvas({x:.5*this.canvas.frame.canvas.clientWidth,y:.5*this.canvas.frame.canvas.clientHeight}),i={x:e.x-t.position.x,y:e.y-t.position.y};this.targetTranslation={x:this.sourceTranslation.x+i.x*this.targetScale+t.offset.x,y:this.sourceTranslation.y+i.y*this.targetScale+t.offset.y},0===t.animation.duration?void 0!=this.lockedOnNodeId?(this.viewFunction=this._lockedRedraw.bind(this),this.body.emitter.on("initRedraw",this.viewFunction)):(this.body.view.scale=this.targetScale,this.body.view.translation=this.targetTranslation,this.body.emitter.emit("_requestRedraw")):(this.animationSpeed=1/(60*t.animation.duration*.001)||1/60,this.animationEasingFunction=t.animation.easingFunction,this.viewFunction=this._transitionRedraw.bind(this),this.body.emitter.on("initRedraw",this.viewFunction),this.body.emitter.emit("_startRendering"))}}},{key:"_lockedRedraw",value:function(){var t={x:this.body.nodes[this.lockedOnNodeId].x,y:this.body.nodes[this.lockedOnNodeId].y},e=this.canvas.DOMtoCanvas({x:.5*this.canvas.frame.canvas.clientWidth,y:.5*this.canvas.frame.canvas.clientHeight}),i={x:e.x-t.x,y:e.y-t.y},o=this.body.view.translation,n={x:o.x+i.x*this.body.view.scale+this.lockedOnNodeOffset.x,y:o.y+i.y*this.body.view.scale+this.lockedOnNodeOffset.y};this.body.view.translation=n}},{key:"releaseNode",value:function(){void 0!==this.lockedOnNodeId&&void 0!==this.viewFunction&&(this.body.emitter.off("initRedraw",this.viewFunction),this.lockedOnNodeId=void 0,this.lockedOnNodeOffset=void 0)}},{key:"_transitionRedraw",value:function(){var t=arguments.length<=0||void 0===arguments[0]?!1:arguments[0];this.easingTime+=this.animationSpeed,this.easingTime=t===!0?1:this.easingTime;var e=h.easingFunctions[this.animationEasingFunction](this.easingTime);this.body.view.scale=this.sourceScale+(this.targetScale-this.sourceScale)*e,this.body.view.translation={x:this.sourceTranslation.x+(this.targetTranslation.x-this.sourceTranslation.x)*e,y:this.sourceTranslation.y+(this.targetTranslation.y-this.sourceTranslation.y)*e},this.easingTime>=1&&(this.body.emitter.off("initRedraw",this.viewFunction),this.easingTime=0,void 0!=this.lockedOnNodeId&&(this.viewFunction=this._lockedRedraw.bind(this),this.body.emitter.on("initRedraw",this.viewFunction)),this.body.emitter.emit("animationFinished"))}},{key:"getScale",value:function(){return this.body.view.scale}},{key:"getViewPosition",value:function(){return this.canvas.DOMtoCanvas({x:.5*this.canvas.frame.canvas.clientWidth,y:.5*this.canvas.frame.canvas.clientHeight})}}]),t}();e["default"]=d},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var s=function(){function t(t,e){for(var i=0;i50&&(this.drag.pointer=this.getPointer(t.center),this.drag.pinched=!1,this.pinch.scale=this.body.view.scale,this.touchTime=(new Date).valueOf())}},{key:"onTap",value:function(t){var e=this.getPointer(t.center),i=this.selectionHandler.options.multiselect&&(t.changedPointers[0].ctrlKey||t.changedPointers[0].metaKey);this.checkSelectionChanges(e,t,i),this.selectionHandler._generateClickEvent("click",t,e)}},{key:"onDoubleTap",value:function(t){var e=this.getPointer(t.center);this.selectionHandler._generateClickEvent("doubleClick",t,e)}},{key:"onHold",value:function(t){var e=this.getPointer(t.center),i=this.selectionHandler.options.multiselect;this.checkSelectionChanges(e,t,i),this.selectionHandler._generateClickEvent("click",t,e),this.selectionHandler._generateClickEvent("hold",t,e)}},{key:"onRelease",value:function(t){if((new Date).valueOf()-this.touchTime>10){var e=this.getPointer(t.center);this.selectionHandler._generateClickEvent("release",t,e),this.touchTime=(new Date).valueOf()}}},{key:"onContext",value:function(t){var e=this.getPointer({x:t.clientX,y:t.clientY});this.selectionHandler._generateClickEvent("oncontext",t,e)}},{key:"checkSelectionChanges",value:function(t,e){var i=arguments.length<=2||void 0===arguments[2]?!1:arguments[2],o=this.selectionHandler._getSelectedEdgeCount(),n=this.selectionHandler._getSelectedNodeCount(),s=this.selectionHandler.getSelection(),r=void 0;r=i===!0?this.selectionHandler.selectAdditionalOnPoint(t):this.selectionHandler.selectOnPoint(t);var a=this.selectionHandler._getSelectedEdgeCount(),h=this.selectionHandler._getSelectedNodeCount(),d=this.selectionHandler.getSelection(),l=this._determineIfDifferent(s,d),c=l.nodesChanged,u=l.edgesChanged,p=!1;h-n>0?(this.selectionHandler._generateClickEvent("selectNode",e,t),r=!0,p=!0):c===!0&&h>0?(this.selectionHandler._generateClickEvent("deselectNode",e,t,s),this.selectionHandler._generateClickEvent("selectNode",e,t),p=!0,r=!0):0>h-n&&(this.selectionHandler._generateClickEvent("deselectNode",e,t,s),r=!0),a-o>0&&p===!1?(this.selectionHandler._generateClickEvent("selectEdge",e,t),r=!0):a>0&&u===!0?(this.selectionHandler._generateClickEvent("deselectEdge",e,t,s),this.selectionHandler._generateClickEvent("selectEdge",e,t),r=!0):0>a-o&&(this.selectionHandler._generateClickEvent("deselectEdge",e,t,s),r=!0),r===!0&&this.selectionHandler._generateClickEvent("select",e,t)}},{key:"_determineIfDifferent",value:function(t,e){for(var i=!1,o=!1,n=0;nt&&(t=1e-5),t>10&&(t=10);var o=void 0;void 0!==this.drag&&this.drag.dragging===!0&&(o=this.canvas.DOMtoCanvas(this.drag.pointer));var n=this.body.view.translation,s=t/i,r=(1-s)*e.x+n.x*s,a=(1-s)*e.y+n.y*s;if(this.body.view.scale=t,this.body.view.translation={x:r,y:a},void 0!=o){var h=this.canvas.canvasToDOM(o);this.drag.pointer.x=h.x,this.drag.pointer.y=h.y}this.body.emitter.emit("_requestRedraw"),t>i?this.body.emitter.emit("zoom",{direction:"+",scale:this.body.view.scale}):this.body.emitter.emit("zoom",{direction:"-",scale:this.body.view.scale})}}},{key:"onMouseWheel",value:function(t){if(this.options.zoomView===!0){var e=0;if(t.wheelDelta?e=t.wheelDelta/120:t.detail&&(e=-t.detail/3),0!==e){var i=this.body.view.scale,o=e/10;0>e&&(o/=1-o),i*=1+o;var n=this.getPointer({x:t.clientX,y:t.clientY});this.zoom(i,n)}t.preventDefault()}}},{key:"onMouseMove",value:function(t){var e=this,i=this.getPointer({x:t.clientX,y:t.clientY}),o=!1;if(void 0!==this.popup&&(this.popup.hidden===!1&&this._checkHidePopup(i),this.popup.hidden===!1&&(o=!0,this.popup.setPosition(i.x+3,i.y-5),this.popup.show())),this.options.keyboard.bindToWindow===!1&&this.options.keyboard.enabled===!0&&this.canvas.frame.focus(),o===!1&&(void 0!==this.popupTimer&&(clearInterval(this.popupTimer),this.popupTimer=void 0),this.drag.dragging||(this.popupTimer=setTimeout(function(){return e._checkShowPopup(i)},this.options.tooltipDelay))),this.options.hover===!0){var n=this.selectionHandler.getNodeAt(i);void 0===n&&(n=this.selectionHandler.getEdgeAt(i)),this.selectionHandler.hoverObject(n)}}},{key:"_checkShowPopup",value:function(t){var e=this.canvas._XconvertDOMtoCanvas(t.x),i=this.canvas._YconvertDOMtoCanvas(t.y),o={left:e,top:i,right:e,bottom:i},n=void 0===this.popupObj?void 0:this.popupObj.id,s=!1,r="node";if(void 0===this.popupObj){for(var a=this.body.nodeIndices,h=this.body.nodes,l=void 0,c=[],u=0;u0&&(this.popupObj=h[c[c.length-1]],s=!0)}if(void 0===this.popupObj&&s===!1){for(var p=this.body.edgeIndices,f=this.body.edges,m=void 0,v=[],g=0;g0&&(this.popupObj=f[v[v.length-1]],r="edge")}void 0!==this.popupObj?this.popupObj.id!==n&&(void 0===this.popup&&(this.popup=new d["default"](this.canvas.frame)),this.popup.popupTargetType=r,this.popup.popupTargetId=this.popupObj.id,this.popup.setPosition(t.x+3,t.y-5),this.popup.setText(this.popupObj.getTitle()),this.popup.show(),this.body.emitter.emit("showPopup",this.popupObj.id)):void 0!==this.popup&&(this.popup.hide(),this.body.emitter.emit("hidePopup"))}},{key:"_checkHidePopup",value:function(t){var e=this.selectionHandler._pointerToPositionObject(t),i=!1;if("node"===this.popup.popupTargetType){if(void 0!==this.body.nodes[this.popup.popupTargetId]&&(i=this.body.nodes[this.popup.popupTargetId].isOverlappingWith(e),i===!0)){var o=this.selectionHandler.getNodeAt(t);i=o.id===this.popup.popupTargetId}}else void 0===this.selectionHandler.getNodeAt(t)&&void 0!==this.body.edges[this.popup.popupTargetId]&&(i=this.body.edges[this.popup.popupTargetId].isOverlappingWith(e));i===!1&&(this.popupObj=void 0,this.popup.hide(),this.body.emitter.emit("hidePopup"))}}]),t}();e["default"]=c},function(t,e,i){function o(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var n=function(){function t(t,e){for(var i=0;i700&&(this.body.emitter.emit("fit",{duration:700}),this.touchTime=(new Date).valueOf())}},{key:"_stopMovement",value:function(){for(var t in this.boundFunctions)this.boundFunctions.hasOwnProperty(t)&&(this.body.emitter.off("initRedraw",this.boundFunctions[t]),this.body.emitter.emit("_stopRendering"));this.boundFunctions={}}},{key:"_moveUp",value:function(){this.body.view.translation.y+=this.options.keyboard.speed.y}},{key:"_moveDown",value:function(){this.body.view.translation.y-=this.options.keyboard.speed.y}},{key:"_moveLeft",value:function(){this.body.view.translation.x+=this.options.keyboard.speed.x}},{key:"_moveRight",value:function(){this.body.view.translation.x-=this.options.keyboard.speed.x}},{key:"_zoomIn",value:function(){this.body.view.scale*=1+this.options.keyboard.speed.zoom,this.body.emitter.emit("zoom",{direction:"+",scale:this.body.view.scale})}},{key:"_zoomOut",value:function(){this.body.view.scale/=1+this.options.keyboard.speed.zoom,this.body.emitter.emit("zoom",{direction:"-",scale:this.body.view.scale})}},{key:"configureKeyboardBindings",value:function(){var t=this;void 0!==this.keycharm&&this.keycharm.destroy(),this.options.keyboard.enabled===!0&&(this.options.keyboard.bindToWindow===!0?this.keycharm=a({container:window,preventDefault:!0}):this.keycharm=a({container:this.canvas.frame,preventDefault:!0}),this.keycharm.reset(),this.activated===!0&&(this.keycharm.bind("up",function(){t.bindToRedraw("_moveUp")},"keydown"),this.keycharm.bind("down",function(){t.bindToRedraw("_moveDown")},"keydown"),this.keycharm.bind("left",function(){t.bindToRedraw("_moveLeft")},"keydown"),this.keycharm.bind("right",function(){t.bindToRedraw("_moveRight")},"keydown"),this.keycharm.bind("=",function(){t.bindToRedraw("_zoomIn")},"keydown"),this.keycharm.bind("num+",function(){t.bindToRedraw("_zoomIn")},"keydown"),this.keycharm.bind("num-",function(){t.bindToRedraw("_zoomOut")},"keydown"),this.keycharm.bind("-",function(){t.bindToRedraw("_zoomOut")},"keydown"),this.keycharm.bind("[",function(){t.bindToRedraw("_zoomOut")},"keydown"),this.keycharm.bind("]",function(){t.bindToRedraw("_zoomIn")},"keydown"),this.keycharm.bind("pageup",function(){t.bindToRedraw("_zoomIn")},"keydown"),this.keycharm.bind("pagedown",function(){t.bindToRedraw("_zoomOut")},"keydown"),this.keycharm.bind("up",function(){t.unbindFromRedraw("_moveUp")},"keyup"),this.keycharm.bind("down",function(){t.unbindFromRedraw("_moveDown")},"keyup"),this.keycharm.bind("left",function(){t.unbindFromRedraw("_moveLeft")},"keyup"),this.keycharm.bind("right",function(){t.unbindFromRedraw("_moveRight")},"keyup"),this.keycharm.bind("=",function(){t.unbindFromRedraw("_zoomIn")},"keyup"),this.keycharm.bind("num+",function(){t.unbindFromRedraw("_zoomIn")},"keyup"),this.keycharm.bind("num-",function(){t.unbindFromRedraw("_zoomOut")},"keyup"),this.keycharm.bind("-",function(){t.unbindFromRedraw("_zoomOut")},"keyup"),this.keycharm.bind("[",function(){t.unbindFromRedraw("_zoomOut")},"keyup"),this.keycharm.bind("]",function(){t.unbindFromRedraw("_zoomIn")},"keyup"),this.keycharm.bind("pageup",function(){t.unbindFromRedraw("_zoomIn")},"keyup"),this.keycharm.bind("pagedown",function(){t.unbindFromRedraw("_zoomOut")},"keyup")))}}]),t}();e["default"]=h},function(t,e){function i(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var o=function(){function t(t,e){for(var i=0;io&&(s=o-e-this.padding),sn&&(r=n-i-this.padding),r0?e===!0?this.body.nodes[o[o.length-1]]:o[o.length-1]:void 0}},{key:"_getEdgesOverlappingWith",value:function(t,e){for(var i=this.body.edges,o=0;o0?e===!0?this.body.edges[o[o.length-1]]:o[o.length-1]:void 0}},{key:"_addToSelection",value:function(t){t instanceof a["default"]?this.selectionObj.nodes[t.id]=t:this.selectionObj.edges[t.id]=t}},{key:"_addToHover",value:function(t){t instanceof a["default"]?this.hoverObj.nodes[t.id]=t:this.hoverObj.edges[t.id]=t}},{key:"_removeFromSelection",value:function(t){t instanceof a["default"]?(delete this.selectionObj.nodes[t.id],this._unselectConnectedEdges(t)):delete this.selectionObj.edges[t.id]}},{key:"unselectAll",value:function(){for(var t in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(t)&&this.selectionObj.nodes[t].unselect();for(var e in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(e)&&this.selectionObj.edges[e].unselect();this.selectionObj={nodes:{},edges:{}}}},{key:"_getSelectedNodeCount",value:function(){var t=0;for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&(t+=1);return t}},{key:"_getSelectedNode",value:function(){for(var t in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(t))return this.selectionObj.nodes[t]}},{key:"_getSelectedEdge",value:function(){for(var t in this.selectionObj.edges)if(this.selectionObj.edges.hasOwnProperty(t))return this.selectionObj.edges[t]}},{key:"_getSelectedEdgeCount",value:function(){var t=0;for(var e in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(e)&&(t+=1);return t}},{key:"_getSelectedObjectCount",value:function(){var t=0;for(var e in this.selectionObj.nodes)this.selectionObj.nodes.hasOwnProperty(e)&&(t+=1);for(var i in this.selectionObj.edges)this.selectionObj.edges.hasOwnProperty(i)&&(t+=1);return t}},{key:"_selectionIsEmpty",value:function(){for(var t in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(t))return!1;for(var e in this.selectionObj.edges)if(this.selectionObj.edges.hasOwnProperty(e))return!1;return!0}},{key:"_clusterInSelection",value:function(){for(var t in this.selectionObj.nodes)if(this.selectionObj.nodes.hasOwnProperty(t)&&this.selectionObj.nodes[t].clusterSize>1)return!0;return!1}},{key:"_selectConnectedEdges",value:function(t){for(var e=0;e0&&(this.options.hierarchical.levelSeparation*=-1):this.options.hierarchical.levelSeparation<0&&(this.options.hierarchical.levelSeparation*=-1),this.body.emitter.emit("_resetHierarchicalLayout"),this.adaptAllOptionsForHierarchicalLayout(e);if(i===!0)return this.body.emitter.emit("refresh"),l.deepExtend(e,this.optionsBackup)}return e}},{key:"adaptAllOptionsForHierarchicalLayout",value:function(t){if(this.options.hierarchical.enabled===!0){void 0===t.physics||t.physics===!0?(t.physics={enabled:void 0===this.optionsBackup.physics.enabled?!0:this.optionsBackup.physics.enabled,solver:"hierarchicalRepulsion"},this.optionsBackup.physics.enabled=void 0===this.optionsBackup.physics.enabled?!0:this.optionsBackup.physics.enabled,this.optionsBackup.physics.solver=this.optionsBackup.physics.solver||"barnesHut"):"object"===r(t.physics)?(this.optionsBackup.physics.enabled=void 0===t.physics.enabled?!0:t.physics.enabled,this.optionsBackup.physics.solver=t.physics.solver||"barnesHut",t.physics.solver="hierarchicalRepulsion"):t.physics!==!1&&(this.optionsBackup.physics.solver="barnesHut",t.physics={solver:"hierarchicalRepulsion"});var e="horizontal";"RL"!==this.options.hierarchical.direction&&"LR"!==this.options.hierarchical.direction||(e="vertical"),void 0===t.edges?(this.optionsBackup.edges={smooth:{enabled:!0,type:"dynamic"}},t.edges={smooth:!1}):void 0===t.edges.smooth?(this.optionsBackup.edges={smooth:{enabled:!0,type:"dynamic"}},t.edges.smooth=!1):"boolean"==typeof t.edges.smooth?(this.optionsBackup.edges={smooth:t.edges.smooth},t.edges.smooth={enabled:t.edges.smooth,type:e}):(void 0!==t.edges.smooth.type&&"dynamic"!==t.edges.smooth.type&&(e=t.edges.smooth.type),this.optionsBackup.edges={smooth:void 0===t.edges.smooth.enabled?!0:t.edges.smooth.enabled,type:void 0===t.edges.smooth.type?"dynamic":t.edges.smooth.type,roundness:void 0===t.edges.smooth.roundness?.5:t.edges.smooth.roundness,forceDirection:void 0===t.edges.smooth.forceDirection?!1:t.edges.smooth.forceDirection},t.edges.smooth={enabled:void 0===t.edges.smooth.enabled?!0:t.edges.smooth.enabled,type:e,roundness:void 0===t.edges.smooth.roundness?.5:t.edges.smooth.roundness,forceDirection:void 0===t.edges.smooth.forceDirection?!1:t.edges.smooth.forceDirection}),this.body.emitter.emit("_forceDisableDynamicCurves",e)}return t}},{key:"seededRandom",value:function(){var t=1e4*Math.sin(this.randomSeed++);return t-Math.floor(t)}},{key:"positionInitially",value:function(t){if(this.options.hierarchical.enabled!==!0){this.randomSeed=this.initialRandomSeed;for(var e=0;es){for(var r=this.body.nodeIndices.length;this.body.nodeIndices.length>s;){n+=1;var a=this.body.nodeIndices.length;n%3===0?this.body.modules.clustering.clusterBridges():this.body.modules.clustering.clusterOutliers();var h=this.body.nodeIndices.length;if(a==h&&n%3!==0||n>o)return this._declusterAll(),this.body.emitter.emit("_layoutFailed"),void console.info("This network could not be positioned by this version of the improved layout algorithm. Please disable improvedLayout for better performance.")}this.body.modules.kamadaKawai.setOptions({springLength:Math.max(150,2*r)})}this.body.modules.kamadaKawai.solve(this.body.nodeIndices,this.body.edgeIndices,!0),this._shiftToCenter();for(var d=70,l=0;l0){var t=void 0,e=void 0,i=!1,o=!0,n=!1;this.hierarchicalLevels={},this.lastNodeOnLevel={},this.hierarchicalChildrenReference={},this.hierarchicalParentReference={},this.hierarchicalTrees={},this.treeIndex=-1,this.distributionOrdering={},this.distributionIndex={},this.distributionOrderingPresence={};for(e in this.body.nodes)this.body.nodes.hasOwnProperty(e)&&(t=this.body.nodes[e],void 0===t.options.x&&void 0===t.options.y&&(o=!1),void 0!==t.options.level?(i=!0,this.hierarchicalLevels[e]=t.options.level):n=!0);if(n===!0&&i===!0)throw new Error("To use the hierarchical layout, nodes require either no predefined levels or levels have to be defined for all nodes.");n===!0&&("hubsize"===this.options.hierarchical.sortMethod?this._determineLevelsByHubsize():"directed"===this.options.hierarchical.sortMethod?this._determineLevelsDirected():"custom"===this.options.hierarchical.sortMethod&&this._determineLevelsCustomCallback());for(var s in this.body.nodes)this.body.nodes.hasOwnProperty(s)&&void 0===this.hierarchicalLevels[s]&&(this.hierarchicalLevels[s]=0);var r=this._getDistribution();this._generateMap(),this._placeNodesByHierarchy(r),this._condenseHierarchy(),this._shiftToCenter()}}},{key:"_condenseHierarchy",value:function(){var t=this,e=!1,i={},o=function(){for(var e=a(),i=0;i0)for(var n=0;n=l&&(r=Math.min(c,r),a=Math.max(c,a))}return[r,a,o,n]},l=function _(e){var i=t.hierarchicalLevels[e];if(t.hierarchicalChildrenReference[e]){var o=t.hierarchicalChildrenReference[e];if(o.length>0)for(var n=0;n1)for(var a=0;at.options.hierarchical.nodeSpacing){var u={};u[i.id]=!0;var p={};p[o.id]=!0,h(i,u),h(o,p);var f=c(i,o),m=d(u,f),v=s(m,4),g=(v[0],v[1]),y=(v[2],v[3],d(p,f)),b=s(y,4),w=b[0],_=(b[1],b[2]),x=(b[3],Math.abs(g-w));if(x>t.options.hierarchical.nodeSpacing){var k=g-w+t.options.hierarchical.nodeSpacing;k<-_+t.options.hierarchical.nodeSpacing&&(k=-_+t.options.hierarchical.nodeSpacing),0>k&&(t._shiftBlock(o.id,k),e=!0,n===!0&&t._centerParent(o))}}},m=function(o,n){for(var r=n.id,a=n.edges,l=t.hierarchicalLevels[n.id],c=t.options.hierarchical.levelSeparation*t.options.hierarchical.levelSeparation,u={},p=[],f=0;fr;r++){var a=g(o,i),h=y(o,i),d=40,l=Math.max(-d,Math.min(d,Math.round(a/h)));if(o-=l,void 0!==s[o])break;s[o]=r}return o},w=function(o){var r=t._getPositionForHierarchy(n);if(void 0===i[n.id]){var a={};a[n.id]=!0,h(n,a),i[n.id]=a}var l=d(i[n.id]),c=s(l,4),u=(c[0],c[1],c[2]),p=c[3],f=o-r,m=0;f>0?m=Math.min(f,p-t.options.hierarchical.nodeSpacing):0>f&&(m=-Math.min(-f,u-t.options.hierarchical.nodeSpacing)),0!=m&&(t._shiftBlock(n.id,m),e=!0)},_=function(i){var o=t._getPositionForHierarchy(n),r=t._getSpaceAroundNode(n),a=s(r,2),h=a[0],d=a[1],l=i-o,c=o;l>0?c=Math.min(o+(d-t.options.hierarchical.nodeSpacing),i):0>l&&(c=Math.max(o-(h-t.options.hierarchical.nodeSpacing),i)),c!==o&&(t._setPositionForHierarchy(n,c,void 0,!0),e=!0)},x=b(o,p);w(x),x=b(o,a),_(x)},v=function(i){var o=Object.keys(t.distributionOrdering);o=o.reverse();for(var n=0;i>n;n++){e=!1;for(var s=0;sn&&(e=!1,p(f,o,!0),e===!0);n++);},y=function(){for(var e in t.body.nodes)t.body.nodes.hasOwnProperty(e)&&t._centerParent(t.body.nodes[e])},b=function(){var e=Object.keys(t.distributionOrdering);e=e.reverse();for(var i=0;i0)for(var d=0;dg&&Math.abs(g)0&&Math.abs(g)0&&(r=this._getPositionForHierarchy(i[n-1])+this.options.hierarchical.nodeSpacing),this._setPositionForHierarchy(s,r,e),this._validataPositionAndContinue(s,e,r),o++}}}}},{key:"_placeBranchNodes",value:function(t,e){if(void 0!==this.hierarchicalChildrenReference[t]){for(var i=[],o=0;oe&&void 0===this.positionedNodes[s.id]))return;var a=void 0;a=0===n?this._getPositionForHierarchy(this.body.nodes[t]):this._getPositionForHierarchy(i[n-1])+this.options.hierarchical.nodeSpacing,this._setPositionForHierarchy(s,a,r),this._validataPositionAndContinue(s,r,a)}for(var h=1e9,d=-1e9,l=0;l0&&(e=this._getHubSize(),0!==e);)for(var o in this.body.nodes)if(this.body.nodes.hasOwnProperty(o)){var n=this.body.nodes[o];n.edges.length===e&&this._crawlNetwork(i,o)}}},{key:"_determineLevelsCustomCallback",value:function(){var t=this,e=1e5,i=function(t,e,i){},o=function(o,n,s){var r=t.hierarchicalLevels[o.id];void 0===r&&(t.hierarchicalLevels[o.id]=e);var a=i(d["default"].cloneOptions(o,"node"),d["default"].cloneOptions(n,"node"),d["default"].cloneOptions(s,"edge"));t.hierarchicalLevels[n.id]=t.hierarchicalLevels[o.id]+a};this._crawlNetwork(o),this._setMinLevelToZero()}},{key:"_determineLevelsDirected",value:function(){var t=this,e=1e4,i=function(i,o,n){var s=t.hierarchicalLevels[i.id];void 0===s&&(t.hierarchicalLevels[i.id]=e),n.toId==o.id?t.hierarchicalLevels[o.id]=t.hierarchicalLevels[i.id]+1:t.hierarchicalLevels[o.id]=t.hierarchicalLevels[i.id]-1};this._crawlNetwork(i),this._setMinLevelToZero()}},{key:"_setMinLevelToZero",value:function(){var t=1e9;for(var e in this.body.nodes)this.body.nodes.hasOwnProperty(e)&&void 0!==this.hierarchicalLevels[e]&&(t=Math.min(this.hierarchicalLevels[e],t));for(var i in this.body.nodes)this.body.nodes.hasOwnProperty(i)&&void 0!==this.hierarchicalLevels[i]&&(this.hierarchicalLevels[i]-=t)}},{key:"_generateMap",value:function(){var t=this,e=function(e,i){if(t.hierarchicalLevels[i.id]>t.hierarchicalLevels[e.id]){var o=e.id,n=i.id;void 0===t.hierarchicalChildrenReference[o]&&(t.hierarchicalChildrenReference[o]=[]),t.hierarchicalChildrenReference[o].push(n),void 0===t.hierarchicalParentReference[n]&&(t.hierarchicalParentReference[n]=[]),t.hierarchicalParentReference[n].push(o)}};this._crawlNetwork(e)}},{key:"_crawlNetwork",value:function(){var t=this,e=arguments.length<=0||void 0===arguments[0]?function(){}:arguments[0],i=arguments[1],o={},n=0,s=function d(i,n){if(void 0===o[i.id]){void 0===t.hierarchicalTrees[i.id]&&(t.hierarchicalTrees[i.id]=n,t.treeIndex=Math.max(n,t.treeIndex)),o[i.id]=!0;for(var s=void 0,r=0;r1&&("UD"===this.options.hierarchical.direction||"DU"===this.options.hierarchical.direction?t.sort(function(t,e){return t.x-e.x}):t.sort(function(t,e){return t.y-e.y}))}}]),t}();e["default"]=c},function(t,e,i){function o(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var n=function(){function t(t,e){for(var i=0;i0&&this.options.deleteNode!==!1?(n===!0&&this._createSeperator(4),this._createDeleteButton(o)):0===t&&this.options.deleteEdge!==!1&&(n===!0&&this._createSeperator(4),this._createDeleteButton(o))),this._bindHammerToDiv(this.closeDiv,this.toggleEditMode.bind(this)),this._temporaryBindEvent("select",this.showManipulatorToolbar.bind(this))}this.body.emitter.emit("_redraw")}},{key:"addNodeMode",value:function(){if(this.editMode!==!0&&this.enableEditMode(),this._clean(),this.inMode="addNode",this.guiEnabled===!0){var t=this.options.locales[this.options.locale];this.manipulationDOM={},this._createBackButton(t),this._createSeperator(),this._createDescription(t.addDescription||this.options.locales.en.addDescription),this._bindHammerToDiv(this.closeDiv,this.toggleEditMode.bind(this))}this._temporaryBindEvent("click",this._performAddNode.bind(this))}},{key:"editNode",value:function(){var t=this;this.editMode!==!0&&this.enableEditMode(),this._clean();var e=this.selectionHandler._getSelectedNode();if(void 0!==e){if(this.inMode="editNode","function"!=typeof this.options.editNode)throw new Error("No function has been configured to handle the editing of nodes.");if(e.isCluster!==!0){var i=s.deepExtend({},e.options,!1); +if(i.x=e.x,i.y=e.y,2!==this.options.editNode.length)throw new Error("The function for edit does not support two arguments (data, callback)");this.options.editNode(i,function(e){null!==e&&void 0!==e&&"editNode"===t.inMode&&t.body.data.nodes.getDataSet().update(e),t.showManipulatorToolbar()})}else alert(this.options.locales[this.options.locale].editClusterError||this.options.locales.en.editClusterError)}else this.showManipulatorToolbar()}},{key:"addEdgeMode",value:function(){if(this.editMode!==!0&&this.enableEditMode(),this._clean(),this.inMode="addEdge",this.guiEnabled===!0){var t=this.options.locales[this.options.locale];this.manipulationDOM={},this._createBackButton(t),this._createSeperator(),this._createDescription(t.edgeDescription||this.options.locales.en.edgeDescription),this._bindHammerToDiv(this.closeDiv,this.toggleEditMode.bind(this))}this._temporaryBindUI("onTouch",this._handleConnect.bind(this)),this._temporaryBindUI("onDragEnd",this._finishConnect.bind(this)),this._temporaryBindUI("onDrag",this._dragControlNode.bind(this)),this._temporaryBindUI("onRelease",this._finishConnect.bind(this)),this._temporaryBindUI("onDragStart",function(){}),this._temporaryBindUI("onHold",function(){})}},{key:"editEdgeMode",value:function(){var t=this;if(this.editMode!==!0&&this.enableEditMode(),this._clean(),this.inMode="editEdge",this.guiEnabled===!0){var e=this.options.locales[this.options.locale];this.manipulationDOM={},this._createBackButton(e),this._createSeperator(),this._createDescription(e.editEdgeDescription||this.options.locales.en.editEdgeDescription),this._bindHammerToDiv(this.closeDiv,this.toggleEditMode.bind(this))}this.edgeBeingEditedId=this.selectionHandler.getSelectedEdges()[0],void 0!==this.edgeBeingEditedId?!function(){var e=t.body.edges[t.edgeBeingEditedId],i=t._getNewTargetNode(e.from.x,e.from.y),o=t._getNewTargetNode(e.to.x,e.to.y);t.temporaryIds.nodes.push(i.id),t.temporaryIds.nodes.push(o.id),t.body.nodes[i.id]=i,t.body.nodeIndices.push(i.id),t.body.nodes[o.id]=o,t.body.nodeIndices.push(o.id),t._temporaryBindUI("onTouch",t._controlNodeTouch.bind(t)),t._temporaryBindUI("onTap",function(){}),t._temporaryBindUI("onHold",function(){}),t._temporaryBindUI("onDragStart",t._controlNodeDragStart.bind(t)),t._temporaryBindUI("onDrag",t._controlNodeDrag.bind(t)),t._temporaryBindUI("onDragEnd",t._controlNodeDragEnd.bind(t)),t._temporaryBindUI("onMouseMove",function(){}),t._temporaryBindEvent("beforeDrawing",function(t){var n=e.edgeType.findBorderPositions(t);i.selected===!1&&(i.x=n.from.x,i.y=n.from.y),o.selected===!1&&(o.x=n.to.x,o.y=n.to.y)}),t.body.emitter.emit("_redraw")}():this.showManipulatorToolbar()}},{key:"deleteSelected",value:function(){var t=this;this.editMode!==!0&&this.enableEditMode(),this._clean(),this.inMode="delete";var e=this.selectionHandler.getSelectedNodes(),i=this.selectionHandler.getSelectedEdges(),o=void 0;if(e.length>0){for(var n=0;n0&&"function"==typeof this.options.deleteEdge&&(o=this.options.deleteEdge);if("function"==typeof o){var s={nodes:e,edges:i};if(2!==o.length)throw new Error("The function for delete does not support two arguments (data, callback)");o(s,function(e){null!==e&&void 0!==e&&"delete"===t.inMode?(t.body.data.edges.getDataSet().remove(e.edges),t.body.data.nodes.getDataSet().remove(e.nodes),t.body.emitter.emit("startSimulation"),t.showManipulatorToolbar()):(t.body.emitter.emit("startSimulation"),t.showManipulatorToolbar())})}else this.body.data.edges.getDataSet().remove(i),this.body.data.nodes.getDataSet().remove(e),this.body.emitter.emit("startSimulation"),this.showManipulatorToolbar()}},{key:"_setup",value:function(){this.options.enabled===!0?(this.guiEnabled=!0,this._createWrappers(),this.editMode===!1?this._createEditButton():this.showManipulatorToolbar()):(this._removeManipulationDOM(),this.guiEnabled=!1)}},{key:"_createWrappers",value:function(){void 0===this.manipulationDiv&&(this.manipulationDiv=document.createElement("div"),this.manipulationDiv.className="vis-manipulation",this.editMode===!0?this.manipulationDiv.style.display="block":this.manipulationDiv.style.display="none",this.canvas.frame.appendChild(this.manipulationDiv)),void 0===this.editModeDiv&&(this.editModeDiv=document.createElement("div"),this.editModeDiv.className="vis-edit-mode",this.editMode===!0?this.editModeDiv.style.display="none":this.editModeDiv.style.display="block",this.canvas.frame.appendChild(this.editModeDiv)),void 0===this.closeDiv&&(this.closeDiv=document.createElement("div"),this.closeDiv.className="vis-close",this.closeDiv.style.display=this.manipulationDiv.style.display,this.canvas.frame.appendChild(this.closeDiv))}},{key:"_getNewTargetNode",value:function(t,e){var i=s.deepExtend({},this.options.controlNodeStyle);i.id="targetNode"+s.randomUUID(),i.hidden=!1,i.physics=!1,i.x=t,i.y=e;var o=this.body.functions.createNode(i);return o.shape.boundingBox={left:t,right:t,top:e,bottom:e},o}},{key:"_createEditButton",value:function(){this._clean(),this.manipulationDOM={},s.recursiveDOMDelete(this.editModeDiv);var t=this.options.locales[this.options.locale],e=this._createButton("editMode","vis-button vis-edit vis-edit-mode",t.edit||this.options.locales.en.edit);this.editModeDiv.appendChild(e),this._bindHammerToDiv(e,this.toggleEditMode.bind(this))}},{key:"_clean",value:function(){this.inMode=!1,this.guiEnabled===!0&&(s.recursiveDOMDelete(this.editModeDiv),s.recursiveDOMDelete(this.manipulationDiv),this._cleanManipulatorHammers()),this._cleanupTemporaryNodesAndEdges(),this._unbindTemporaryUIs(),this._unbindTemporaryEvents(),this.body.emitter.emit("restorePhysics")}},{key:"_cleanManipulatorHammers",value:function(){if(0!=this.manipulationHammers.length){for(var t=0;t=0;r--)if(n[r]!==this.selectedControlNode.id){s=this.body.nodes[n[r]];break}if(void 0!==s&&void 0!==this.selectedControlNode)if(s.isCluster===!0)alert(this.options.locales[this.options.locale].createEdgeError||this.options.locales.en.createEdgeError);else{var a=this.body.nodes[this.temporaryIds.nodes[0]];this.selectedControlNode.id===a.id?this._performEditEdge(s.id,o.to.id):this._performEditEdge(o.from.id,s.id)}else o.updateEdgeType(),this.body.emitter.emit("restorePhysics");this.body.emitter.emit("_redraw")}}},{key:"_handleConnect",value:function(t){if((new Date).valueOf()-this.touchTime>100){this.lastTouch=this.body.functions.getPointer(t.center),this.lastTouch.translation=s.extend({},this.body.view.translation);var e=this.lastTouch,i=this.selectionHandler.getNodeAt(e);if(void 0!==i)if(i.isCluster===!0)alert(this.options.locales[this.options.locale].createEdgeError||this.options.locales.en.createEdgeError);else{var o=this._getNewTargetNode(i.x,i.y);this.body.nodes[o.id]=o,this.body.nodeIndices.push(o.id);var n=this.body.functions.createEdge({id:"connectionEdge"+s.randomUUID(),from:i.id,to:o.id,physics:!1,smooth:{enabled:!0,type:"continuous",roundness:.5}});this.body.edges[n.id]=n,this.body.edgeIndices.push(n.id),this.temporaryIds.nodes.push(o.id),this.temporaryIds.edges.push(n.id)}this.touchTime=(new Date).valueOf()}}},{key:"_dragControlNode",value:function(t){var e=this.body.functions.getPointer(t.center);if(void 0!==this.temporaryIds.nodes[0]){var i=this.body.nodes[this.temporaryIds.nodes[0]];i.x=this.canvas._XconvertDOMtoCanvas(e.x),i.y=this.canvas._YconvertDOMtoCanvas(e.y),this.body.emitter.emit("_redraw")}else{var o=e.x-this.lastTouch.x,n=e.y-this.lastTouch.y;this.body.view.translation={x:this.lastTouch.translation.x+o,y:this.lastTouch.translation.y+n}}}},{key:"_finishConnect",value:function(t){var e=this.body.functions.getPointer(t.center),i=this.selectionHandler._pointerToPositionObject(e),o=void 0;void 0!==this.temporaryIds.edges[0]&&(o=this.body.edges[this.temporaryIds.edges[0]].fromId);for(var n=this.selectionHandler._getAllNodesOverlappingWith(i),s=void 0,r=n.length-1;r>=0;r--)if(-1===this.temporaryIds.nodes.indexOf(n[r])){s=this.body.nodes[n[r]];break}this._cleanupTemporaryNodesAndEdges(),void 0!==s&&(s.isCluster===!0?alert(this.options.locales[this.options.locale].createEdgeError||this.options.locales.en.createEdgeError):void 0!==this.body.nodes[o]&&void 0!==this.body.nodes[s.id]&&this._performAddEdge(o,s.id)),this.body.emitter.emit("_redraw")}},{key:"_performAddNode",value:function(t){var e=this,i={id:s.randomUUID(),x:t.pointer.canvas.x,y:t.pointer.canvas.y,label:"new"};if("function"==typeof this.options.addNode){if(2!==this.options.addNode.length)throw new Error("The function for add does not support two arguments (data,callback)");this.options.addNode(i,function(t){null!==t&&void 0!==t&&"addNode"===e.inMode&&(e.body.data.nodes.getDataSet().add(t),e.showManipulatorToolbar())})}else this.body.data.nodes.getDataSet().add(i),this.showManipulatorToolbar()}},{key:"_performAddEdge",value:function(t,e){var i=this,o={from:t,to:e};if("function"==typeof this.options.addEdge){if(2!==this.options.addEdge.length)throw new Error("The function for connect does not support two arguments (data,callback)");this.options.addEdge(o,function(t){null!==t&&void 0!==t&&"addEdge"===i.inMode&&(i.body.data.edges.getDataSet().add(t),i.selectionHandler.unselectAll(),i.showManipulatorToolbar())})}else this.body.data.edges.getDataSet().add(o),this.selectionHandler.unselectAll(),this.showManipulatorToolbar()}},{key:"_performEditEdge",value:function(t,e){var i=this,o={id:this.edgeBeingEditedId,from:t,to:e};if("function"==typeof this.options.editEdge){if(2!==this.options.editEdge.length)throw new Error("The function for edit does not support two arguments (data, callback)");this.options.editEdge(o,function(t){null===t||void 0===t||"editEdge"!==i.inMode?(i.body.edges[o.id].updateEdgeType(),i.body.emitter.emit("_redraw")):(i.body.data.edges.getDataSet().update(t),i.selectionHandler.unselectAll(),i.showManipulatorToolbar())})}else this.body.data.edges.getDataSet().update(o),this.selectionHandler.unselectAll(),this.showManipulatorToolbar()}}]),t}();e["default"]=h},function(t,e){Object.defineProperty(e,"__esModule",{value:!0});var i="string",o="boolean",n="number",s="array",r="object",a="dom",h="any",d={configure:{enabled:{"boolean":o},filter:{"boolean":o,string:i,array:s,"function":"function"},container:{dom:a},showButton:{"boolean":o},__type__:{object:r,"boolean":o,string:i,array:s,"function":"function"}},edges:{arrows:{to:{enabled:{"boolean":o},scaleFactor:{number:n},__type__:{object:r,"boolean":o}},middle:{enabled:{"boolean":o},scaleFactor:{number:n},__type__:{object:r,"boolean":o}},from:{enabled:{"boolean":o},scaleFactor:{number:n},__type__:{object:r,"boolean":o}},__type__:{string:["from","to","middle"],object:r}},arrowStrikethrough:{"boolean":o},color:{color:{string:i},highlight:{string:i},hover:{string:i},inherit:{string:["from","to","both"],"boolean":o},opacity:{number:n},__type__:{object:r,string:i}},dashes:{"boolean":o,array:s},font:{color:{string:i},size:{number:n},face:{string:i},background:{string:i},strokeWidth:{number:n},strokeColor:{string:i},align:{string:["horizontal","top","middle","bottom"]},__type__:{object:r,string:i}},hidden:{"boolean":o},hoverWidth:{"function":"function",number:n},label:{string:i,undefined:"undefined"},labelHighlightBold:{"boolean":o},length:{number:n,undefined:"undefined"},physics:{"boolean":o},scaling:{min:{number:n},max:{number:n},label:{enabled:{"boolean":o},min:{number:n},max:{number:n},maxVisible:{number:n},drawThreshold:{number:n},__type__:{object:r,"boolean":o}},customScalingFunction:{"function":"function"},__type__:{object:r}},selectionWidth:{"function":"function",number:n},selfReferenceSize:{number:n},shadow:{enabled:{"boolean":o},color:{string:i},size:{number:n},x:{number:n},y:{number:n},__type__:{object:r,"boolean":o}},smooth:{enabled:{"boolean":o},type:{string:["dynamic","continuous","discrete","diagonalCross","straightCross","horizontal","vertical","curvedCW","curvedCCW","cubicBezier"]},roundness:{number:n},forceDirection:{string:["horizontal","vertical","none"],"boolean":o},__type__:{object:r,"boolean":o}},title:{string:i,undefined:"undefined"},width:{number:n},value:{number:n,undefined:"undefined"},__type__:{object:r}},groups:{useDefaultGroups:{"boolean":o},__any__:"get from nodes, will be overwritten below",__type__:{object:r}},interaction:{dragNodes:{"boolean":o},dragView:{"boolean":o},hideEdgesOnDrag:{"boolean":o},hideNodesOnDrag:{"boolean":o},hover:{"boolean":o},keyboard:{enabled:{"boolean":o},speed:{x:{number:n},y:{number:n},zoom:{number:n},__type__:{object:r}},bindToWindow:{"boolean":o},__type__:{object:r,"boolean":o}},multiselect:{"boolean":o},navigationButtons:{"boolean":o},selectable:{"boolean":o},selectConnectedEdges:{"boolean":o},hoverConnectedEdges:{"boolean":o},tooltipDelay:{number:n},zoomView:{"boolean":o},__type__:{object:r}},layout:{randomSeed:{undefined:"undefined",number:n},improvedLayout:{"boolean":o},hierarchical:{enabled:{"boolean":o},levelSeparation:{number:n},nodeSpacing:{number:n},treeSpacing:{number:n},blockShifting:{"boolean":o},edgeMinimization:{"boolean":o},parentCentralization:{"boolean":o},direction:{string:["UD","DU","LR","RL"]},sortMethod:{string:["hubsize","directed"]},__type__:{object:r,"boolean":o}},__type__:{object:r}},manipulation:{enabled:{"boolean":o},initiallyActive:{"boolean":o},addNode:{"boolean":o,"function":"function"},addEdge:{"boolean":o,"function":"function"},editNode:{"function":"function"},editEdge:{"boolean":o,"function":"function"},deleteNode:{"boolean":o,"function":"function"},deleteEdge:{"boolean":o,"function":"function"},controlNodeStyle:"get from nodes, will be overwritten below",__type__:{object:r,"boolean":o}},nodes:{borderWidth:{number:n},borderWidthSelected:{number:n,undefined:"undefined"},brokenImage:{string:i,undefined:"undefined"},color:{border:{string:i},background:{string:i},highlight:{border:{string:i},background:{string:i},__type__:{object:r,string:i}},hover:{border:{string:i},background:{string:i},__type__:{object:r,string:i}},__type__:{object:r,string:i}},fixed:{x:{"boolean":o},y:{"boolean":o},__type__:{object:r,"boolean":o}},font:{align:{string:i},color:{string:i},size:{number:n},face:{string:i},background:{string:i},strokeWidth:{number:n},strokeColor:{string:i},__type__:{object:r,string:i}},group:{string:i,number:n,undefined:"undefined"},hidden:{"boolean":o},icon:{face:{string:i},code:{string:i},size:{number:n},color:{string:i},__type__:{object:r}},id:{string:i,number:n},image:{string:i,undefined:"undefined"},label:{string:i,undefined:"undefined"},labelHighlightBold:{"boolean":o},level:{number:n,undefined:"undefined"},mass:{number:n},physics:{"boolean":o},scaling:{min:{number:n},max:{number:n},label:{enabled:{"boolean":o},min:{number:n},max:{number:n},maxVisible:{number:n},drawThreshold:{number:n},__type__:{object:r,"boolean":o}},customScalingFunction:{"function":"function"},__type__:{object:r}},shadow:{enabled:{"boolean":o},color:{string:i},size:{number:n},x:{number:n},y:{number:n},__type__:{object:r,"boolean":o}},shape:{string:["ellipse","circle","database","box","text","image","circularImage","diamond","dot","star","triangle","triangleDown","square","icon"]},shapeProperties:{borderDashes:{"boolean":o,array:s},borderRadius:{number:n},interpolation:{"boolean":o},useImageSize:{"boolean":o},useBorderWithImage:{"boolean":o},__type__:{object:r}},size:{number:n},title:{string:i,undefined:"undefined"},value:{number:n,undefined:"undefined"},x:{number:n},y:{number:n},__type__:{object:r}},physics:{enabled:{"boolean":o},barnesHut:{gravitationalConstant:{number:n},centralGravity:{number:n},springLength:{number:n},springConstant:{number:n},damping:{number:n},avoidOverlap:{number:n},__type__:{object:r}},forceAtlas2Based:{gravitationalConstant:{number:n},centralGravity:{number:n},springLength:{number:n},springConstant:{number:n},damping:{number:n},avoidOverlap:{number:n},__type__:{object:r}},repulsion:{centralGravity:{number:n},springLength:{number:n},springConstant:{number:n},nodeDistance:{number:n},damping:{number:n},__type__:{object:r}},hierarchicalRepulsion:{centralGravity:{number:n},springLength:{number:n},springConstant:{number:n},nodeDistance:{number:n},damping:{number:n},__type__:{object:r}},maxVelocity:{number:n},minVelocity:{number:n},solver:{string:["barnesHut","repulsion","hierarchicalRepulsion","forceAtlas2Based"]},stabilization:{enabled:{"boolean":o},iterations:{number:n},updateInterval:{number:n},onlyDynamicEdges:{"boolean":o},fit:{"boolean":o},__type__:{object:r,"boolean":o}},timestep:{number:n},adaptiveTimestep:{"boolean":o},__type__:{object:r,"boolean":o}},autoResize:{"boolean":o},clickToUse:{"boolean":o},locale:{string:i},locales:{__any__:{any:h},__type__:{object:r}},height:{string:i},width:{string:i},__type__:{object:r}};d.groups.__any__=d.nodes,d.manipulation.controlNodeStyle=d.nodes;var l={nodes:{borderWidth:[1,0,10,1],borderWidthSelected:[2,0,10,1],color:{border:["color","#2B7CE9"],background:["color","#97C2FC"],highlight:{border:["color","#2B7CE9"],background:["color","#D2E5FF"]},hover:{border:["color","#2B7CE9"],background:["color","#D2E5FF"]}},fixed:{x:!1,y:!1},font:{color:["color","#343434"],size:[14,0,100,1],face:["arial","verdana","tahoma"],background:["color","none"],strokeWidth:[0,0,50,1],strokeColor:["color","#ffffff"]},hidden:!1,labelHighlightBold:!0,physics:!0,scaling:{min:[10,0,200,1],max:[30,0,200,1],label:{enabled:!1,min:[14,0,200,1],max:[30,0,200,1],maxVisible:[30,0,200,1],drawThreshold:[5,0,20,1]}},shadow:{enabled:!1,color:"rgba(0,0,0,0.5)",size:[10,0,20,1],x:[5,-30,30,1],y:[5,-30,30,1]},shape:["ellipse","box","circle","database","diamond","dot","square","star","text","triangle","triangleDown"],shapeProperties:{borderDashes:!1,borderRadius:[6,0,20,1],interpolation:!0,useImageSize:!1},size:[25,0,200,1]},edges:{arrows:{to:{enabled:!1,scaleFactor:[1,0,3,.05]},middle:{enabled:!1,scaleFactor:[1,0,3,.05]},from:{enabled:!1,scaleFactor:[1,0,3,.05]}},arrowStrikethrough:!0,color:{color:["color","#848484"],highlight:["color","#848484"],hover:["color","#848484"],inherit:["from","to","both",!0,!1],opacity:[1,0,1,.05]},dashes:!1,font:{color:["color","#343434"],size:[14,0,100,1],face:["arial","verdana","tahoma"],background:["color","none"],strokeWidth:[2,0,50,1],strokeColor:["color","#ffffff"],align:["horizontal","top","middle","bottom"]},hidden:!1,hoverWidth:[1.5,0,5,.1],labelHighlightBold:!0,physics:!0,scaling:{min:[1,0,100,1],max:[15,0,100,1],label:{enabled:!0,min:[14,0,200,1],max:[30,0,200,1],maxVisible:[30,0,200,1],drawThreshold:[5,0,20,1]}},selectionWidth:[1.5,0,5,.1],selfReferenceSize:[20,0,200,1],shadow:{enabled:!1,color:"rgba(0,0,0,0.5)",size:[10,0,20,1],x:[5,-30,30,1],y:[5,-30,30,1]},smooth:{enabled:!0,type:["dynamic","continuous","discrete","diagonalCross","straightCross","horizontal","vertical","curvedCW","curvedCCW","cubicBezier"],forceDirection:["horizontal","vertical","none"],roundness:[.5,0,1,.05]},width:[1,0,30,1]},layout:{hierarchical:{enabled:!1,levelSeparation:[150,20,500,5],nodeSpacing:[100,20,500,5],treeSpacing:[200,20,500,5],blockShifting:!0,edgeMinimization:!0,parentCentralization:!0,direction:["UD","DU","LR","RL"],sortMethod:["hubsize","directed"]}},interaction:{dragNodes:!0,dragView:!0,hideEdgesOnDrag:!1,hideNodesOnDrag:!1,hover:!1,keyboard:{enabled:!1,speed:{x:[10,0,40,1],y:[10,0,40,1],zoom:[.02,0,.1,.005]},bindToWindow:!0},multiselect:!1,navigationButtons:!1,selectable:!0,selectConnectedEdges:!0,hoverConnectedEdges:!0,tooltipDelay:[300,0,1e3,25],zoomView:!0},manipulation:{enabled:!1,initiallyActive:!1},physics:{enabled:!0,barnesHut:{gravitationalConstant:[-2e3,-3e4,0,50],centralGravity:[.3,0,10,.05],springLength:[95,0,500,5],springConstant:[.04,0,1.2,.005],damping:[.09,0,1,.01],avoidOverlap:[0,0,1,.01]},forceAtlas2Based:{gravitationalConstant:[-50,-500,0,1],centralGravity:[.01,0,1,.005],springLength:[95,0,500,5],springConstant:[.08,0,1.2,.005],damping:[.4,0,1,.01],avoidOverlap:[0,0,1,.01]},repulsion:{centralGravity:[.2,0,10,.05],springLength:[200,0,500,5],springConstant:[.05,0,1.2,.005],nodeDistance:[100,0,500,5],damping:[.09,0,1,.01]},hierarchicalRepulsion:{centralGravity:[.2,0,10,.05],springLength:[100,0,500,5],springConstant:[.01,0,1.2,.005],nodeDistance:[120,0,500,5],damping:[.09,0,1,.01]},maxVelocity:[50,0,150,1],minVelocity:[.1,.01,.5,.01],solver:["barnesHut","forceAtlas2Based","repulsion","hierarchicalRepulsion"],timestep:[.5,.01,1,.01]},global:{locale:["en","nl"]}};e.allOptions=d,e.configureOptions=l},function(t,e,i){function o(t){return t&&t.__esModule?t:{"default":t}}function n(t,e){if(!(t instanceof e))throw new TypeError("Cannot call a class as a function")}Object.defineProperty(e,"__esModule",{value:!0});var s=function(){function t(t,e){var i=[],o=!0,n=!1,s=void 0;try{for(var r,a=t[Symbol.iterator]();!(o=(r=a.next()).done)&&(i.push(r.value),!e||i.length!==e);o=!0);}catch(h){n=!0,s=h}finally{try{!o&&a["return"]&&a["return"]()}finally{if(n)throw s}}return i}return function(e,i){if(Array.isArray(e))return e;if(Symbol.iterator in Object(e))return t(e,i);throw new TypeError("Invalid attempt to destructure non-iterable instance")}}(),r=function(){function t(t,e){for(var i=0;in&&h>a;){a+=1;var v=this._getHighestEnergyNode(i),g=s(v,4);for(c=g[0],l=g[1],u=g[2],p=g[3],f=l,m=0;f>r&&d>m;){m+=1,this._moveNode(c,u,p);var y=this._getEnergy(c),b=s(y,3);f=b[0],u=b[1],p=b[2]}}}},{key:"_getHighestEnergyNode",value:function(t){for(var e=this.body.nodeIndices,i=this.body.nodes,o=0,n=e[0],r=0,a=0,h=0;ho&&(o=u,n=d,r=p,a=f)}}return[n,o,r,a]}},{key:"_getEnergy",value:function(t){for(var e=this.body.nodeIndices,i=this.body.nodes,o=i[t].x,n=i[t].y,s=0,r=0,a=0;al;l++)for(var c=0;d-1>c;c++)for(var u=c+1;d>u;u++)o[e[c]][e[u]]=Math.min(o[e[c]][e[u]],o[e[c]][e[l]]+o[e[l]][e[u]]),o[e[u]][e[c]]=o[e[c]][e[u]];return o}}]),t}();e["default"]=n},function(t,e){"undefined"!=typeof CanvasRenderingContext2D&&(CanvasRenderingContext2D.prototype.circle=function(t,e,i){this.beginPath(),this.arc(t,e,i,0,2*Math.PI,!1),this.closePath()},CanvasRenderingContext2D.prototype.square=function(t,e,i){this.beginPath(),this.rect(t-i,e-i,2*i,2*i),this.closePath()},CanvasRenderingContext2D.prototype.triangle=function(t,e,i){this.beginPath(),i*=1.15,e+=.275*i;var o=2*i,n=o/2,s=Math.sqrt(3)/6*o,r=Math.sqrt(o*o-n*n);this.moveTo(t,e-(r-s)),this.lineTo(t+n,e+s),this.lineTo(t-n,e+s),this.lineTo(t,e-(r-s)),this.closePath()},CanvasRenderingContext2D.prototype.triangleDown=function(t,e,i){this.beginPath(),i*=1.15,e-=.275*i;var o=2*i,n=o/2,s=Math.sqrt(3)/6*o,r=Math.sqrt(o*o-n*n);this.moveTo(t,e+(r-s)), +this.lineTo(t+n,e-s),this.lineTo(t-n,e-s),this.lineTo(t,e+(r-s)),this.closePath()},CanvasRenderingContext2D.prototype.star=function(t,e,i){this.beginPath(),i*=.82,e+=.1*i;for(var o=0;10>o;o++){var n=o%2===0?1.3*i:.5*i;this.lineTo(t+n*Math.sin(2*o*Math.PI/10),e-n*Math.cos(2*o*Math.PI/10))}this.closePath()},CanvasRenderingContext2D.prototype.diamond=function(t,e,i){this.beginPath(),this.lineTo(t,e+i),this.lineTo(t+i,e),this.lineTo(t,e-i),this.lineTo(t-i,e),this.closePath()},CanvasRenderingContext2D.prototype.roundRect=function(t,e,i,o,n){var s=Math.PI/180;0>i-2*n&&(n=i/2),0>o-2*n&&(n=o/2),this.beginPath(),this.moveTo(t+n,e),this.lineTo(t+i-n,e),this.arc(t+i-n,e+n,n,270*s,360*s,!1),this.lineTo(t+i,e+o-n),this.arc(t+i-n,e+o-n,n,0,90*s,!1),this.lineTo(t+n,e+o),this.arc(t+n,e+o-n,n,90*s,180*s,!1),this.lineTo(t,e+n),this.arc(t+n,e+n,n,180*s,270*s,!1),this.closePath()},CanvasRenderingContext2D.prototype.ellipse=function(t,e,i,o){var n=.5522848,s=i/2*n,r=o/2*n,a=t+i,h=e+o,d=t+i/2,l=e+o/2;this.beginPath(),this.moveTo(t,l),this.bezierCurveTo(t,l-r,d-s,e,d,e),this.bezierCurveTo(d+s,e,a,l-r,a,l),this.bezierCurveTo(a,l+r,d+s,h,d,h),this.bezierCurveTo(d-s,h,t,l+r,t,l),this.closePath()},CanvasRenderingContext2D.prototype.database=function(t,e,i,o){var n=1/3,s=i,r=o*n,a=.5522848,h=s/2*a,d=r/2*a,l=t+s,c=e+r,u=t+s/2,p=e+r/2,f=e+(o-r/2),m=e+o;this.beginPath(),this.moveTo(l,p),this.bezierCurveTo(l,p+d,u+h,c,u,c),this.bezierCurveTo(u-h,c,t,p+d,t,p),this.bezierCurveTo(t,p-d,u-h,e,u,e),this.bezierCurveTo(u+h,e,l,p-d,l,p),this.lineTo(l,f),this.bezierCurveTo(l,f+d,u+h,m,u,m),this.bezierCurveTo(u-h,m,t,f+d,t,f),this.lineTo(t,p)},CanvasRenderingContext2D.prototype.arrow=function(t,e,i,o){var n=t-o*Math.cos(i),s=e-o*Math.sin(i),r=t-.9*o*Math.cos(i),a=e-.9*o*Math.sin(i),h=n+o/3*Math.cos(i+.5*Math.PI),d=s+o/3*Math.sin(i+.5*Math.PI),l=n+o/3*Math.cos(i-.5*Math.PI),c=s+o/3*Math.sin(i-.5*Math.PI);this.beginPath(),this.moveTo(t,e),this.lineTo(h,d),this.lineTo(r,a),this.lineTo(l,c),this.closePath()},CanvasRenderingContext2D.prototype.dashedLine=function(t,e,i,o,n){this.beginPath(),this.moveTo(t,e);for(var s=n.length,r=i-t,a=o-e,h=a/r,d=Math.sqrt(r*r+a*a),l=0,c=!0,u=0,p=n[0];d>=.1;)p=n[l++%s],p>d&&(p=d),u=Math.sqrt(p*p/(1+h*h)),u=0>r?-u:u,t+=u,e+=h*u,c===!0?this.lineTo(t,e):this.moveTo(t,e),d-=p,c=!c})},function(t,e){function i(t){return P=t,p()}function o(){I=0,N=P.charAt(0)}function n(){I++,N=P.charAt(I)}function s(){return P.charAt(I+1)}function r(t){return L.test(t)}function a(t,e){if(t||(t={}),e)for(var i in e)e.hasOwnProperty(i)&&(t[i]=e[i]);return t}function h(t,e,i){for(var o=e.split("."),n=t;o.length;){var s=o.shift();o.length?(n[s]||(n[s]={}),n=n[s]):n[s]=i}}function d(t,e){for(var i,o,n=null,s=[t],r=t;r.parent;)s.push(r.parent),r=r.parent;if(r.nodes)for(i=0,o=r.nodes.length;o>i;i++)if(e.id===r.nodes[i].id){n=r.nodes[i];break}for(n||(n={id:e.id},t.node&&(n.attr=a(n.attr,t.node))),i=s.length-1;i>=0;i--){var h=s[i];h.nodes||(h.nodes=[]),-1===h.nodes.indexOf(n)&&h.nodes.push(n)}e.attr&&(n.attr=a(n.attr,e.attr))}function l(t,e){if(t.edges||(t.edges=[]),t.edges.push(e),t.edge){var i=a({},t.edge);e.attr=a(i,e.attr)}}function c(t,e,i,o,n){var s={from:e,to:i,type:o};return t.edge&&(s.attr=a({},t.edge)),s.attr=a(s.attr||{},n),s}function u(){for(z=T.NULL,R="";" "===N||" "===N||"\n"===N||"\r"===N;)n();do{var t=!1;if("#"===N){for(var e=I-1;" "===P.charAt(e)||" "===P.charAt(e);)e--;if("\n"===P.charAt(e)||""===P.charAt(e)){for(;""!=N&&"\n"!=N;)n();t=!0}}if("/"===N&&"/"===s()){for(;""!=N&&"\n"!=N;)n();t=!0}if("/"===N&&"*"===s()){for(;""!=N;){if("*"===N&&"/"===s()){n(),n();break}n()}t=!0}for(;" "===N||" "===N||"\n"===N||"\r"===N;)n()}while(t);if(""===N)return void(z=T.DELIMITER);var i=N+s();if(E[i])return z=T.DELIMITER,R=i,n(),void n();if(E[N])return z=T.DELIMITER,R=N,void n();if(r(N)||"-"===N){for(R+=N,n();r(N);)R+=N,n();return"false"===R?R=!1:"true"===R?R=!0:isNaN(Number(R))||(R=Number(R)),void(z=T.IDENTIFIER)}if('"'===N){for(n();""!=N&&('"'!=N||'"'===N&&'"'===s());)R+=N,'"'===N&&n(),n();if('"'!=N)throw _('End of string " expected');return n(),void(z=T.IDENTIFIER)}for(z=T.UNKNOWN;""!=N;)R+=N,n();throw new SyntaxError('Syntax error in part "'+x(R,30)+'"')}function p(){var t={};if(o(),u(),"strict"===R&&(t.strict=!0,u()),"graph"!==R&&"digraph"!==R||(t.type=R,u()),z===T.IDENTIFIER&&(t.id=R,u()),"{"!=R)throw _("Angle bracket { expected");if(u(),f(t),"}"!=R)throw _("Angle bracket } expected");if(u(),""!==R)throw _("End of file expected");return u(),delete t.node,delete t.edge,delete t.graph,t}function f(t){for(;""!==R&&"}"!=R;)m(t),";"===R&&u()}function m(t){var e=v(t);if(e)return void b(t,e);var i=g(t);if(!i){if(z!=T.IDENTIFIER)throw _("Identifier expected");var o=R;if(u(),"="===R){if(u(),z!=T.IDENTIFIER)throw _("Identifier expected");t[o]=R,u()}else y(t,o)}}function v(t){var e=null;if("subgraph"===R&&(e={},e.type="subgraph",u(),z===T.IDENTIFIER&&(e.id=R,u())),"{"===R){if(u(),e||(e={}),e.parent=t,e.node=t.node,e.edge=t.edge,e.graph=t.graph,f(e),"}"!=R)throw _("Angle bracket } expected");u(),delete e.node,delete e.edge,delete e.graph,delete e.parent,t.subgraphs||(t.subgraphs=[]),t.subgraphs.push(e)}return e}function g(t){return"node"===R?(u(),t.node=w(),"node"):"edge"===R?(u(),t.edge=w(),"edge"):"graph"===R?(u(),t.graph=w(),"graph"):null}function y(t,e){var i={id:e},o=w();o&&(i.attr=o),d(t,i),b(t,e)}function b(t,e){for(;"->"===R||"--"===R;){var i,o=R;u();var n=v(t);if(n)i=n;else{if(z!=T.IDENTIFIER)throw _("Identifier or subgraph expected");i=R,d(t,{id:i}),u()}var s=w(),r=c(t,e,i,o,s);l(t,r),e=i}}function w(){for(var t=null;"["===R;){for(u(),t={};""!==R&&"]"!=R;){if(z!=T.IDENTIFIER)throw _("Attribute name expected");var e=R;if(u(),"="!=R)throw _("Equal sign = expected");if(u(),z!=T.IDENTIFIER)throw _("Attribute value expected");var i=R;h(t,e,i),u(),","==R&&u()}if("]"!=R)throw _("Bracket ] expected");u()}return t}function _(t){return new SyntaxError(t+', got "'+x(R,30)+'" (char '+I+")")}function x(t,e){return t.length<=e?t:t.substr(0,27)+"..."}function k(t,e,i){Array.isArray(t)?t.forEach(function(t){Array.isArray(e)?e.forEach(function(e){i(t,e)}):i(t,e)}):Array.isArray(e)?e.forEach(function(e){i(t,e)}):i(t,e)}function O(t,e,i){for(var o=e.split("."),n=o.pop(),s=t,r=0;r":!0,"--":!0},P="",I=0,N="",R="",z=T.NULL,L=/[a-zA-Z_0-9.:#]/;e.parseDOT=i,e.DOTToGraph=D},function(t,e){function i(t,e){var i=[],o=[],n={edges:{inheritColor:!1},nodes:{fixed:!1,parseColor:!1}};void 0!==e&&(void 0!==e.fixed&&(n.nodes.fixed=e.fixed),void 0!==e.parseColor&&(n.nodes.parseColor=e.parseColor),void 0!==e.inheritColor&&(n.edges.inheritColor=e.inheritColor));for(var s=t.edges,r=t.nodes,a=0;a val accum = sc.accumulator(0) - * accum: spark.Accumulator[Int] = 0 + * accum: org.apache.spark.Accumulator[Int] = 0 * * scala> sc.parallelize(Array(1, 2, 3, 4)).foreach(x => accum += x) * ... @@ -58,7 +49,8 @@ import org.apache.spark.storage.{BlockId, BlockStatus} * @param name human-readable name associated with this accumulator * @param countFailedValues whether to accumulate values from failed tasks * @tparam T result type - */ +*/ +@deprecated("use AccumulatorV2", "2.0.0") class Accumulator[T] private[spark] ( // SI-8813: This must explicitly be a private val, or else scala 2.11 doesn't compile @transient private val initialValue: T, @@ -75,6 +67,7 @@ class Accumulator[T] private[spark] ( * * @tparam T type of value to accumulate */ +@deprecated("use AccumulatorV2", "2.0.0") trait AccumulatorParam[T] extends AccumulableParam[T, T] { def addAccumulator(t1: T, t2: T): T = { addInPlace(t1, t2) @@ -82,6 +75,7 @@ trait AccumulatorParam[T] extends AccumulableParam[T, T] { } +@deprecated("use AccumulatorV2", "2.0.0") object AccumulatorParam { // The following implicit objects were in SparkContext before 1.2 and users had to @@ -89,21 +83,25 @@ object AccumulatorParam { // them automatically. However, as there are duplicate codes in SparkContext for backward // compatibility, please update them accordingly if you modify the following implicit objects. + @deprecated("use AccumulatorV2", "2.0.0") implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double): Double = 0.0 } + @deprecated("use AccumulatorV2", "2.0.0") implicit object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 def zero(initialValue: Int): Int = 0 } + @deprecated("use AccumulatorV2", "2.0.0") implicit object LongAccumulatorParam extends AccumulatorParam[Long] { def addInPlace(t1: Long, t2: Long): Long = t1 + t2 def zero(initialValue: Long): Long = 0L } + @deprecated("use AccumulatorV2", "2.0.0") implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { def addInPlace(t1: Float, t2: Float): Float = t1 + t2 def zero(initialValue: Float): Float = 0f @@ -112,20 +110,9 @@ object AccumulatorParam { // Note: when merging values, this param just adopts the newer value. This is used only // internally for things that shouldn't really be accumulated across tasks, like input // read method, which should be the same across all tasks in the same stage. + @deprecated("use AccumulatorV2", "2.0.0") private[spark] object StringAccumulatorParam extends AccumulatorParam[String] { def addInPlace(t1: String, t2: String): String = t2 def zero(initialValue: String): String = "" } - - // Note: this is expensive as it makes a copy of the list every time the caller adds an item. - // A better way to use this is to first accumulate the values yourself then them all at once. - private[spark] class ListAccumulatorParam[T] extends AccumulatorParam[Seq[T]] { - def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2 - def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T] - } - - // For the internal metric that records what blocks are updated in a particular task - private[spark] object UpdatedBlockStatusesAccumulatorParam - extends ListAccumulatorParam[(BlockId, BlockStatus)] - } diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 63a00a84af3cd..5678d790e9e76 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, ThreadUtils, Utils} /** * Classes that represent cleaning tasks. @@ -144,7 +144,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { registerForCleanup(rdd, CleanRDD(rdd.id)) } - def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = { + def registerAccumulatorForCleanup(a: AccumulatorV2[_, _]): Unit = { registerForCleanup(a, CleanAccum(a.id)) } @@ -212,7 +212,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - /** Perform shuffle cleanup, asynchronously. */ + /** Perform shuffle cleanup. */ def doCleanupShuffle(shuffleId: Int, blocking: Boolean): Unit = { try { logDebug("Cleaning shuffle " + shuffleId) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 0926d05414ba7..6f320c524201c 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -25,9 +25,10 @@ import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{DYN_ALLOCATION_MAX_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS} import org.apache.spark.metrics.source.Source import org.apache.spark.scheduler._ -import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -87,11 +88,9 @@ private[spark] class ExecutorAllocationManager( import ExecutorAllocationManager._ // Lower and upper bounds on the number of executors. - private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) - private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", - Integer.MAX_VALUE) - private val initialNumExecutors = conf.getInt("spark.dynamicAllocation.initialExecutors", - minNumExecutors) + private val minNumExecutors = conf.get(DYN_ALLOCATION_MIN_EXECUTORS) + private val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS) + private val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf) // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( @@ -231,7 +230,7 @@ private[spark] class ExecutorAllocationManager( } } } - executor.scheduleAtFixedRate(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) + executor.scheduleWithFixedDelay(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 9eac05fdf9f3d..c3764ac671afb 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} +import org.apache.spark.util._ /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} */ private[spark] case class Heartbeat( executorId: String, - accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], // taskId -> accumulator updates + accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], // taskId -> accumulator updates blockManagerId: BlockManagerId) /** @@ -166,7 +166,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) } /** - * Send ExecutorRemoved to the event loop to remove a executor. Only for test. + * Send ExecutorRemoved to the event loop to remove an executor. Only for test. * * @return if HeartbeatReceiver is stopped, return None. Otherwise, return a Some(Future) that * indicate if this operation is successful. diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3a5caa3510eb0..6bd950205fadb 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,13 +18,15 @@ package org.apache.spark import java.io._ -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor} import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.reflect.ClassTag +import scala.util.control.NonFatal +import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus @@ -37,31 +39,18 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) + /** RpcEndpoint class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterEndpoint( override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf) extends RpcEndpoint with Logging { - val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) - val serializedSize = mapOutputStatuses.length - if (serializedSize > maxRpcMessageSize) { - - val msg = s"Map output statuses were $serializedSize bytes which " + - s"exceeds spark.rpc.message.maxSize ($maxRpcMessageSize bytes)." - - /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender. - * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */ - val exception = new SparkException(msg) - logError(msg, exception) - context.sendFailure(exception) - } else { - context.reply(mapOutputStatuses) - } + val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context)) case StopMapOutputTracker => logInfo("MapOutputTrackerMasterEndpoint stopped!") @@ -270,12 +259,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging /** * MapOutputTracker for the driver. */ -private[spark] class MapOutputTrackerMaster(conf: SparkConf) +private[spark] class MapOutputTrackerMaster(conf: SparkConf, + broadcastManager: BroadcastManager, isLocal: Boolean) extends MapOutputTracker(conf) { /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ private var cacheEpoch = epoch + // The size at which we use Broadcast to send the map output statuses to the executors + private val minSizeForBroadcast = + conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt + /** Whether to compute locality preferences for reduce tasks */ private val shuffleLocalityEnabled = conf.getBoolean("spark.shuffle.reduceLocality.enabled", true) @@ -296,10 +290,86 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala + private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + + // Kept in sync with cachedSerializedStatuses explicitly + // This is required so that the Broadcast variable remains in scope until we remove + // the shuffleId explicitly or implicitly. + private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]() + + // This is to prevent multiple serializations of the same shuffle - which happens when + // there is a request storm when shuffle start. + private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]() + + // requests for map output statuses + private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] + + // Thread pool used for handling map output status requests. This is a separate thread pool + // to ensure we don't block the normal dispatcher threads. + private val threadpool: ThreadPoolExecutor = { + val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher") + for (i <- 0 until numThreads) { + pool.execute(new MessageLoop) + } + pool + } + + // Make sure that that we aren't going to exceed the max RPC message size by making sure + // we use broadcast to send large map output statuses. + if (minSizeForBroadcast > maxRpcMessageSize) { + val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " + + s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " + + "message that is to large." + logError(msg) + throw new IllegalArgumentException(msg) + } + + def post(message: GetMapOutputMessage): Unit = { + mapOutputRequests.offer(message) + } + + /** Message loop used for dispatching messages. */ + private class MessageLoop extends Runnable { + override def run(): Unit = { + try { + while (true) { + try { + val data = mapOutputRequests.take() + if (data == PoisonPill) { + // Put PoisonPill back so that other MessageLoops can see it. + mapOutputRequests.offer(PoisonPill) + return + } + val context = data.context + val shuffleId = data.shuffleId + val hostPort = context.senderAddress.hostPort + logDebug("Handling request to send map output locations for shuffle " + shuffleId + + " to " + hostPort) + val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) + context.reply(mapOutputStatuses) + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case ie: InterruptedException => // exit + } + } + } + + /** A poison endpoint that indicates MessageLoop should exit its message loop. */ + private val PoisonPill = new GetMapOutputMessage(-99, null) + + // Exposed for testing + private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size + def registerShuffle(shuffleId: Int, numMaps: Int) { if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } + // add in advance + shuffleIdLocks.putIfAbsent(shuffleId, new Object()) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { @@ -337,6 +407,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) override def unregisterShuffle(shuffleId: Int) { mapStatuses.remove(shuffleId) cachedSerializedStatuses.remove(shuffleId) + cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v)) + shuffleIdLocks.remove(shuffleId) } /** Check if the given shuffle is being tracked */ @@ -428,40 +500,89 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + private def removeBroadcast(bcast: Broadcast[_]): Unit = { + if (null != bcast) { + broadcastManager.unbroadcast(bcast.id, + removeFromDriver = true, blocking = false) + } + } + + private def clearCachedBroadcast(): Unit = { + for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2) + cachedSerializedBroadcast.clear() + } + def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { var statuses: Array[MapStatus] = null + var retBytes: Array[Byte] = null var epochGotten: Long = -1 - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - return bytes - case None => - statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) - epochGotten = epoch + + // Check to see if we have a cached version, returns true if it does + // and has side effect of setting retBytes. If not returns false + // with side effect of setting statuses + def checkCachedStatuses(): Boolean = { + epochLock.synchronized { + if (epoch > cacheEpoch) { + cachedSerializedStatuses.clear() + clearCachedBroadcast() + cacheEpoch = epoch + } + cachedSerializedStatuses.get(shuffleId) match { + case Some(bytes) => + retBytes = bytes + true + case None => + logDebug("cached status not found for : " + shuffleId) + statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) + epochGotten = epoch + false + } } } - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "statuses"; let's serialize and return that - val bytes = MapOutputTracker.serializeMapStatuses(statuses) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes + + if (checkCachedStatuses()) return retBytes + var shuffleIdLock = shuffleIdLocks.get(shuffleId) + if (null == shuffleIdLock) { + val newLock = new Object() + // in general, this condition should be false - but good to be paranoid + val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) + shuffleIdLock = if (null != prevLock) prevLock else newLock + } + // synchronize so we only serialize/broadcast it once since multiple threads call + // in parallel + shuffleIdLock.synchronized { + // double check to make sure someone else didn't serialize and cache the same + // mapstatus while we were waiting on the synchronize + if (checkCachedStatuses()) return retBytes + + // If we got here, we failed to find the serialized locations in the cache, so we pulled + // out a snapshot of the locations as "statuses"; let's serialize and return that + val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, + isLocal, minSizeForBroadcast) + logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) + // Add them into the table only if the epoch hasn't changed while we were working + epochLock.synchronized { + if (epoch == epochGotten) { + cachedSerializedStatuses(shuffleId) = bytes + if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast + } else { + logInfo("Epoch changed, not caching!") + removeBroadcast(bcast) + } } + bytes } - bytes } override def stop() { + mapOutputRequests.offer(PoisonPill) + threadpool.shutdown() sendTracker(StopMapOutputTracker) mapStatuses.clear() trackerEndpoint = null cachedSerializedStatuses.clear() + clearCachedBroadcast() + shuffleIdLocks.clear() } } @@ -477,12 +598,16 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private[spark] object MapOutputTracker extends Logging { val ENDPOINT_NAME = "MapOutputTracker" + private val DIRECT = 0 + private val BROADCAST = 1 // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { + def serializeMapStatuses(statuses: Array[MapStatus], broadcastManager: BroadcastManager, + isLocal: Boolean, minBroadcastSize: Int): (Array[Byte], Broadcast[Array[Byte]]) = { val out = new ByteArrayOutputStream + out.write(DIRECT) val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) Utils.tryWithSafeFinally { // Since statuses can be modified in parallel, sync on it @@ -492,16 +617,51 @@ private[spark] object MapOutputTracker extends Logging { } { objOut.close() } - out.toByteArray + val arr = out.toByteArray + if (arr.length >= minBroadcastSize) { + // Use broadcast instead. + // Important arr(0) is the tag == DIRECT, ignore that while deserializing ! + val bcast = broadcastManager.newBroadcast(arr, isLocal) + // toByteArray creates copy, so we can reuse out + out.reset() + out.write(BROADCAST) + val oos = new ObjectOutputStream(new GZIPOutputStream(out)) + oos.writeObject(bcast) + oos.close() + val outArr = out.toByteArray + logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length) + (outArr, bcast) + } else { + (arr, null) + } } // Opposite of serializeMapStatuses. def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = { - val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - Utils.tryWithSafeFinally { - objIn.readObject().asInstanceOf[Array[MapStatus]] - } { - objIn.close() + assert (bytes.length > 0) + + def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = { + val objIn = new ObjectInputStream(new GZIPInputStream( + new ByteArrayInputStream(arr, off, len))) + Utils.tryWithSafeFinally { + objIn.readObject() + } { + objIn.close() + } + } + + bytes(0) match { + case DIRECT => + deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]] + case BROADCAST => + // deserialize the Broadcast, pull .value array out of it, and then deserialize that + val bcast = deserializeObject(bytes, 1, bytes.length - 1). + asInstanceOf[Broadcast[Array[Byte]]] + logInfo("Broadcast mapstatuses size = " + bytes.length + + ", actual size = " + bcast.value.length) + // Important - ignore the DIRECT tag ! Start from offset 1 + deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]] + case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0)) } } diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 719905a2c901c..be19179b00a49 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -71,7 +71,7 @@ private[spark] case class SSLOptions( keyPassword.foreach(sslContextFactory.setKeyManagerPassword) keyStoreType.foreach(sslContextFactory.setKeyStoreType) if (needClientAuth) { - trustStore.foreach(file => sslContextFactory.setTrustStore(file.getAbsolutePath)) + trustStore.foreach(file => sslContextFactory.setTrustStorePath(file.getAbsolutePath)) trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) trustStoreType.foreach(sslContextFactory.setTrustStoreType) } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index e8f68224d5976..f72c7ded5ea52 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -50,17 +50,19 @@ import org.apache.spark.util.Utils * secure the UI if it has data that other users should not be allowed to see. The javax * servlet filter specified by the user can authenticate the user and then once the user * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.acls.enable' and 'spark.ui.view.acls' - * control the behavior of the acls. Note that the person who started the application - * always has view access to the UI. + * authorized to view the UI. The configs 'spark.acls.enable', 'spark.ui.view.acls' and + * 'spark.ui.view.acls.groups' control the behavior of the acls. Note that the person who + * started the application always has view access to the UI. * - * Spark has a set of modify acls (`spark.modify.acls`) that controls which users have permission - * to modify a single application. This would include things like killing the application. By - * default the person who started the application has modify access. For modify access through - * the UI, you must have a filter that does authentication in place for the modify acls to work - * properly. + * Spark has a set of individual and group modify acls (`spark.modify.acls`) and + * (`spark.modify.acls.groups`) that controls which users and groups have permission to + * modify a single application. This would include things like killing the application. + * By default the person who started the application has modify access. For modify access + * through the UI, you must have a filter that does authentication in place for the modify + * acls to work properly. * - * Spark also has a set of admin acls (`spark.admin.acls`) which is a set of users/administrators + * Spark also has a set of individual and group admin acls (`spark.admin.acls`) and + * (`spark.admin.acls.groups`) which is a set of users/administrators and admin groups * who always have permission to view or modify the Spark application. * * Starting from version 1.3, Spark has partial support for encrypted connections with SSL. @@ -184,6 +186,9 @@ private[spark] class SecurityManager(sparkConf: SparkConf) import SecurityManager._ + // allow all users/groups to have view/modify permissions + private val WILDCARD_ACL = "*" + private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false) // keep spark.ui.acls.enable for backwards compatibility with 1.0 private var aclsOn = @@ -193,12 +198,20 @@ private[spark] class SecurityManager(sparkConf: SparkConf) private var adminAcls: Set[String] = stringToSet(sparkConf.get("spark.admin.acls", "")) + // admin group acls should be set before view or modify group acls + private var adminAclsGroups : Set[String] = + stringToSet(sparkConf.get("spark.admin.acls.groups", "")) + private var viewAcls: Set[String] = _ + private var viewAclsGroups: Set[String] = _ + // list of users who have permission to modify the application. This should // apply to both UI and CLI for things like killing the application. private var modifyAcls: Set[String] = _ + private var modifyAclsGroups: Set[String] = _ + // always add the current user and SPARK_USER to the viewAcls private val defaultAclUsers = Set[String](System.getProperty("user.name", ""), Utils.getCurrentUserName()) @@ -206,11 +219,16 @@ private[spark] class SecurityManager(sparkConf: SparkConf) setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) + setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); + setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + private val secretKey = generateSecretKey() logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + - "; users with view permissions: " + viewAcls.toString() + - "; users with modify permissions: " + modifyAcls.toString()) + "; users with view permissions: " + viewAcls.toString() + + "; groups with view permissions: " + viewAclsGroups.toString() + + "; users with modify permissions: " + modifyAcls.toString() + + "; groups with modify permissions: " + modifyAclsGroups.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. // This is needed by the HTTP client fetching from the HttpServer. Put here so its @@ -302,17 +320,34 @@ private[spark] class SecurityManager(sparkConf: SparkConf) setViewAcls(Set[String](defaultUser), allowedUsers) } + /** + * Admin acls groups should be set before the view or modify acls groups. If you modify the admin + * acls groups you should also set the view and modify acls groups again to pick up the changes. + */ + def setViewAclsGroups(allowedUserGroups: String) { + viewAclsGroups = (adminAclsGroups ++ stringToSet(allowedUserGroups)); + logInfo("Changing view acls groups to: " + viewAclsGroups.mkString(",")) + } + /** * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" */ def getViewAcls: String = { - if (viewAcls.contains("*")) { - "*" + if (viewAcls.contains(WILDCARD_ACL)) { + WILDCARD_ACL } else { viewAcls.mkString(",") } } + def getViewAclsGroups: String = { + if (viewAclsGroups.contains(WILDCARD_ACL)) { + WILDCARD_ACL + } else { + viewAclsGroups.mkString(",") + } + } + /** * Admin acls should be set before the view or modify acls. If you modify the admin * acls you should also set the view and modify acls again to pick up the changes. @@ -322,17 +357,34 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing modify acls to: " + modifyAcls.mkString(",")) } + /** + * Admin acls groups should be set before the view or modify acls groups. If you modify the admin + * acls groups you should also set the view and modify acls groups again to pick up the changes. + */ + def setModifyAclsGroups(allowedUserGroups: String) { + modifyAclsGroups = (adminAclsGroups ++ stringToSet(allowedUserGroups)); + logInfo("Changing modify acls groups to: " + modifyAclsGroups.mkString(",")) + } + /** * Checking the existence of "*" is necessary as YARN can't recognize the "*" in "defaultuser,*" */ def getModifyAcls: String = { - if (modifyAcls.contains("*")) { - "*" + if (modifyAcls.contains(WILDCARD_ACL)) { + WILDCARD_ACL } else { modifyAcls.mkString(",") } } + def getModifyAclsGroups: String = { + if (modifyAclsGroups.contains(WILDCARD_ACL)) { + WILDCARD_ACL + } else { + modifyAclsGroups.mkString(",") + } + } + /** * Admin acls should be set before the view or modify acls. If you modify the admin * acls you should also set the view and modify acls again to pick up the changes. @@ -342,6 +394,15 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing admin acls to: " + adminAcls.mkString(",")) } + /** + * Admin acls groups should be set before the view or modify acls groups. If you modify the admin + * acls groups you should also set the view and modify acls groups again to pick up the changes. + */ + def setAdminAclsGroups(adminUserGroups: String) { + adminAclsGroups = stringToSet(adminUserGroups) + logInfo("Changing admin acls groups to: " + adminAclsGroups.mkString(",")) + } + def setAcls(aclSetting: Boolean) { aclsOn = aclSetting logInfo("Changing acls enabled to: " + aclsOn) @@ -398,36 +459,49 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def aclsEnabled(): Boolean = aclsOn /** - * Checks the given user against the view acl list to see if they have + * Checks the given user against the view acl and groups list to see if they have * authorization to view the UI. If the UI acls are disabled * via spark.acls.enable, all users have view access. If the user is null - * it is assumed authentication is off and all users have access. + * it is assumed authentication is off and all users have access. Also if any one of the + * UI acls or groups specify the WILDCARD(*) then all users have view access. * * @param user to see if is authorized * @return true is the user has permission, otherwise false */ def checkUIViewPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + - viewAcls.mkString(",")) - !aclsEnabled || user == null || viewAcls.contains(user) || viewAcls.contains("*") + viewAcls.mkString(",") + " viewAclsGroups=" + viewAclsGroups.mkString(",")) + if (!aclsEnabled || user == null || viewAcls.contains(user) || + viewAcls.contains(WILDCARD_ACL) || viewAclsGroups.contains(WILDCARD_ACL)) { + return true + } + val currentUserGroups = Utils.getCurrentUserGroups(sparkConf, user) + logDebug("userGroups=" + currentUserGroups.mkString(",")) + viewAclsGroups.exists(currentUserGroups.contains(_)) } /** - * Checks the given user against the modify acl list to see if they have - * authorization to modify the application. If the UI acls are disabled + * Checks the given user against the modify acl and groups list to see if they have + * authorization to modify the application. If the modify acls are disabled * via spark.acls.enable, all users have modify access. If the user is null - * it is assumed authentication isn't turned on and all users have access. + * it is assumed authentication isn't turned on and all users have access. Also if any one + * of the modify acls or groups specify the WILDCARD(*) then all users have modify access. * * @param user to see if is authorized * @return true is the user has permission, otherwise false */ def checkModifyPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + - modifyAcls.mkString(",")) - !aclsEnabled || user == null || modifyAcls.contains(user) || modifyAcls.contains("*") + modifyAcls.mkString(",") + " modifyAclsGroups=" + modifyAclsGroups.mkString(",")) + if (!aclsEnabled || user == null || modifyAcls.contains(user) || + modifyAcls.contains(WILDCARD_ACL) || modifyAclsGroups.contains(WILDCARD_ACL)) { + return true + } + val currentUserGroups = Utils.getCurrentUserGroups(sparkConf, user) + logDebug("userGroups=" + currentUserGroups) + modifyAclsGroups.exists(currentUserGroups.contains(_)) } - /** * Check to see if authentication for the Spark communication protocols is enabled * @return true if authentication is enabled, otherwise false diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 2cb3ed0296a47..e9f9d7242618a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,12 +19,11 @@ package org.apache.spark import java.io._ import java.lang.reflect.Constructor -import java.net.URI -import java.util.{Arrays, Properties, ServiceLoader, UUID} -import java.util.concurrent.ConcurrentMap +import java.net.{MalformedURLException, URI} +import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID} +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} -import scala.annotation.tailrec import scala.collection.JavaConverters._ import scala.collection.Map import scala.collection.generic.Growable @@ -34,13 +33,11 @@ import scala.reflect.{classTag, ClassTag} import scala.util.control.NonFatal import com.google.common.collect.MapMaker -import org.apache.commons.lang.SerializationUtils +import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, - FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} -import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, - TextInputFormat} +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} +import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary @@ -48,18 +45,16 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, - WholeTextFileInputFormat} +import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, WholeTextFileInputFormat} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, - SparkDeploySchedulerBackend} -import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.scheduler.local.LocalBackend +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} +import org.apache.spark.scheduler.cluster.mesos.{MesosCoarseGrainedSchedulerBackend, MesosFineGrainedSchedulerBackend} +import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} @@ -94,7 +89,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] val stopped: AtomicBoolean = new AtomicBoolean(false) - private def assertNotStopped(): Unit = { + private[spark] def assertNotStopped(): Unit = { if (stopped.get()) { val activeContext = SparkContext.activeContext.get() val activeCreationSite = @@ -251,7 +246,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def isStopped: Boolean = stopped.get() // An asynchronous listener bus for Spark events - private[spark] val listenerBus = new LiveListenerBus + private[spark] val listenerBus = new LiveListenerBus(this) // This function allows components created by SparkEnv to be mocked in unit tests: private[spark] def createSparkEnv( @@ -264,8 +259,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] def env: SparkEnv = _env // Used to store a URL for each static file/jar together with the file's local timestamp - private[spark] val addedFiles = HashMap[String, Long]() - private[spark] val addedJars = HashMap[String, Long]() + private[spark] val addedFiles = new ConcurrentHashMap[String, Long]().asScala + private[spark] val addedJars = new ConcurrentHashMap[String, Long]().asScala // Keeps track of all persisted RDDs private[spark] val persistentRdds = { @@ -335,7 +330,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli override protected def childValue(parent: Properties): Properties = { // Note: make a clone such that changes in the parent properties aren't reflected in // the those of the children threads, which has confusing semantics (SPARK-10563). - SerializationUtils.clone(parent).asInstanceOf[Properties] + SerializationUtils.clone(parent) } override protected def initialValue(): Properties = new Properties() } @@ -357,12 +352,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN */ def setLogLevel(logLevel: String) { - val validLevels = Seq("ALL", "DEBUG", "ERROR", "FATAL", "INFO", "OFF", "TRACE", "WARN") - if (!validLevels.contains(logLevel)) { - throw new IllegalArgumentException( - s"Supplied level $logLevel did not match one of: ${validLevels.mkString(",")}") - } - Utils.setLogLevel(org.apache.log4j.Level.toLevel(logLevel)) + // let's allow lowcase or mixed case too + val upperCased = logLevel.toUpperCase(Locale.ENGLISH) + require(SparkContext.VALID_LOG_LEVELS.contains(upperCased), + s"Supplied level $logLevel did not match one of:" + + s" ${SparkContext.VALID_LOG_LEVELS.mkString(",")}") + Utils.setLogLevel(org.apache.log4j.Level.toLevel(upperCased)) } try { @@ -392,7 +387,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - _jars = _conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten + _jars = Utils.getUserJars(_conf) _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.nonEmpty)) .toSeq.flatten @@ -608,6 +603,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * scheduler pool. User-defined properties may also be set here. These properties are propagated * through to worker tasks and can be accessed there via * [[org.apache.spark.TaskContext#getLocalProperty]]. + * + * These properties are inherited by child threads spawned from this thread. This + * may have unexpected consequences when working with thread pools. The standard java + * implementation of thread pools have worker threads spawn other worker threads. + * As a result, local properties may propagate unpredictably. */ def setLocalProperty(key: String, value: String) { if (value == null) { @@ -785,7 +785,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap - new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) + new ParallelCollectionRDD[T](this, seq.map(_._1), math.max(seq.size, 1), indexToPrefs) } /** @@ -957,6 +957,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() + + // This is a hack to enforce loading hdfs-site.xml. + // See SPARK-11227 for details. + FileSystem.getLocal(conf) + // Add necessary security credentials to the JobConf before broadcasting it. SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions) @@ -977,6 +982,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli valueClass: Class[V], minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() + + // This is a hack to enforce loading hdfs-site.xml. + // See SPARK-11227 for details. + FileSystem.getLocal(hadoopConfiguration) + // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) @@ -1061,6 +1071,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = withScope { assertNotStopped() + + // This is a hack to enforce loading hdfs-site.xml. + // See SPARK-11227 for details. + FileSystem.getLocal(hadoopConfiguration) + // The call to NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves val job = NewHadoopJob.getInstance(conf) @@ -1095,6 +1110,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = withScope { assertNotStopped() + + // This is a hack to enforce loading hdfs-site.xml. + // See SPARK-11227 for details. + FileSystem.getLocal(conf) + // Add necessary security credentials to the JobConf. Required to access secure HDFS. val jconf = new JobConf(conf) SparkHadoopUtil.get.addCredentials(jconf) @@ -1219,6 +1239,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `+=` method. Only the driver can access the accumulator's `value`. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = { val acc = new Accumulator(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) @@ -1230,6 +1251,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the * driver can access the accumulator's `value`. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) : Accumulator[T] = { val acc = new Accumulator(initialValue, param, Some(name)) @@ -1239,10 +1261,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values - * with `+=`. Only the driver can access the accumuable's `value`. + * with `+=`. Only the driver can access the accumulable's `value`. * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { val acc = new Accumulable(initialValue, param) @@ -1252,11 +1275,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the - * Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can - * access the accumuable's `value`. + * Spark UI. Tasks can add values to the accumulable using the `+=` operator. Only the driver can + * access the accumulable's `value`. * @tparam R accumulator result type * @tparam T type that can be added to the accumulator */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { val acc = new Accumulable(initialValue, param, Some(name)) @@ -1270,6 +1294,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by * standard mutable collections. So you can use this with mutable Map, Set, etc. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { val param = new GrowableAccumulableParam[R, T] @@ -1282,7 +1307,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register the given accumulator. Note that accumulators must be registered before use, or it * will throw exception. */ - def register(acc: NewAccumulator[_, _]): Unit = { + def register(acc: AccumulatorV2[_, _]): Unit = { acc.register(this) } @@ -1290,12 +1315,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register the given accumulator with given name. Note that accumulators must be registered * before use, or it will throw exception. */ - def register(acc: NewAccumulator[_, _], name: String): Unit = { + def register(acc: AccumulatorV2[_, _], name: String): Unit = { acc.register(this, name = Some(name)) } /** - * Create and register a long accumulator, which starts with 0 and accumulates inputs by `+=`. + * Create and register a long accumulator, which starts with 0 and accumulates inputs by `add`. */ def longAccumulator: LongAccumulator = { val acc = new LongAccumulator @@ -1304,7 +1329,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Create and register a long accumulator, which starts with 0 and accumulates inputs by `+=`. + * Create and register a long accumulator, which starts with 0 and accumulates inputs by `add`. */ def longAccumulator(name: String): LongAccumulator = { val acc = new LongAccumulator @@ -1313,7 +1338,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Create and register a double accumulator, which starts with 0 and accumulates inputs by `+=`. + * Create and register a double accumulator, which starts with 0 and accumulates inputs by `add`. */ def doubleAccumulator: DoubleAccumulator = { val acc = new DoubleAccumulator @@ -1322,7 +1347,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Create and register a double accumulator, which starts with 0 and accumulates inputs by `+=`. + * Create and register a double accumulator, which starts with 0 and accumulates inputs by `add`. */ def doubleAccumulator(name: String): DoubleAccumulator = { val acc = new DoubleAccumulator @@ -1331,43 +1356,21 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Create and register an average accumulator, which accumulates double inputs by recording the - * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be - * returned if no input is added. - */ - def averageAccumulator: AverageAccumulator = { - val acc = new AverageAccumulator - register(acc) - acc - } - - /** - * Create and register an average accumulator, which accumulates double inputs by recording the - * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be - * returned if no input is added. - */ - def averageAccumulator(name: String): AverageAccumulator = { - val acc = new AverageAccumulator - register(acc, name) - acc - } - - /** - * Create and register a list accumulator, which starts with empty list and accumulates inputs - * by adding them into the inner list. + * Create and register a [[CollectionAccumulator]], which starts with empty list and accumulates + * inputs by adding them into the list. */ - def listAccumulator[T]: ListAccumulator[T] = { - val acc = new ListAccumulator[T] + def collectionAccumulator[T]: CollectionAccumulator[T] = { + val acc = new CollectionAccumulator[T] register(acc) acc } /** - * Create and register a list accumulator, which starts with empty list and accumulates inputs - * by adding them into the inner list. + * Create and register a [[CollectionAccumulator]], which starts with empty list and accumulates + * inputs by adding them into the list. */ - def listAccumulator[T](name: String): ListAccumulator[T] = { - val acc = new ListAccumulator[T] + def collectionAccumulator[T](name: String): CollectionAccumulator[T] = { + val acc = new CollectionAccumulator[T] register(acc, name) acc } @@ -1398,6 +1401,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli addFile(path, false) } + /** + * Returns a list of file paths that are added to resources. + */ + def listFiles(): Seq[String] = addedFiles.keySet.toSeq + /** * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported @@ -1408,7 +1416,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * supported for Hadoop-supported filesystems. */ def addFile(path: String, recursive: Boolean): Unit = { - val uri = new URI(path) + val uri = new Path(path).toUri val schemeCorrectedPath = uri.getScheme match { case null | "local" => new File(path).getCanonicalFile.toURI.toString case _ => path @@ -1430,6 +1438,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " + "turned on.") } + } else { + // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies + Utils.validateURL(uri) } val key = if (!isLocal && scheme == "file") { @@ -1438,14 +1449,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli schemeCorrectedPath } val timestamp = System.currentTimeMillis - addedFiles(key) = timestamp - - // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, - hadoopConfiguration, timestamp, useCache = false) - - logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) - postEnvironmentUpdate() + if (addedFiles.putIfAbsent(key, timestamp).isEmpty) { + logInfo(s"Added file $path at $key with timestamp $timestamp") + // Fetch the file locally so that closures which are run on the driver can still use the + // SparkFiles API to access files. + Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, + env.securityManager, hadoopConfiguration, timestamp, useCache = false) + postEnvironmentUpdate() + } } /** @@ -1481,7 +1492,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * This includes running, pending, and completed tasks. * @return whether the request is acknowledged by the cluster manager. */ - private[spark] override def requestTotalExecutors( + @DeveloperApi + override def requestTotalExecutors( numExecutors: Int, localityAwareTasks: Int, hostToLocalTaskCount: scala.collection.immutable.Map[String, Int] @@ -1688,6 +1700,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) + // SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies + Utils.validateURL(uri) key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => @@ -1713,12 +1727,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli case exc: FileNotFoundException => logError(s"Jar not found at $path") null - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null } } // A JAR file which exists locally on every worker node @@ -1729,13 +1737,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } if (key != null) { - addedJars(key) = System.currentTimeMillis - logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) + val timestamp = System.currentTimeMillis + if (addedJars.putIfAbsent(key, timestamp).isEmpty) { + logInfo(s"Added JAR $path at $key with timestamp $timestamp") + postEnvironmentUpdate() + } } } - postEnvironmentUpdate() } + /** + * Returns a list of jar files that are added to resources. + */ + def listJars(): Seq[String] = addedJars.keySet.toSeq + // Shut down the SparkContext. def stop() { if (LiveListenerBus.withinListenerThread.value) { @@ -2141,7 +2156,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - listenerBus.start(this) + listenerBus.start() _listenerBusStarted = true } @@ -2182,6 +2197,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * various Spark features. */ object SparkContext extends Logging { + private val VALID_LOG_LEVELS = + Set("ALL", "DEBUG", "ERROR", "FATAL", "INFO", "OFF", "TRACE", "WARN") /** * Lock that guards access to global variables that track SparkContext construction. @@ -2257,6 +2274,9 @@ object SparkContext extends Logging { if (activeContext.get() == null) { setActiveContext(new SparkContext(config), allowMultipleContexts = false) } + if (config.getAll.nonEmpty) { + logWarning("Use an existing SparkContext, some configuration may not take effect.") + } activeContext.get() } } @@ -2415,7 +2435,6 @@ object SparkContext extends Logging { * Create a task scheduler based on a given master URL. * Return a 2-tuple of the scheduler backend and the task scheduler. */ - @tailrec private def createTaskScheduler( sc: SparkContext, master: String, @@ -2428,7 +2447,7 @@ object SparkContext extends Logging { master match { case "local" => val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(sc.getConf, scheduler, 1) + val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) (backend, scheduler) @@ -2440,7 +2459,7 @@ object SparkContext extends Logging { throw new SparkException(s"Asked to run locally with $threadCount threads") } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(sc.getConf, scheduler, threadCount) + val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) @@ -2450,14 +2469,14 @@ object SparkContext extends Logging { // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(sc.getConf, scheduler, threadCount) + val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) case SPARK_REGEX(sparkUrl) => val scheduler = new TaskSchedulerImpl(sc) val masterUrls = sparkUrl.split(",").map("spark://" + _) - val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls) + val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) (backend, scheduler) @@ -2474,9 +2493,9 @@ object SparkContext extends Logging { val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf) val masterUrls = localCluster.start() - val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls) + val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) - backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { + backend.shutdownCallback = (backend: StandaloneSchedulerBackend) => { localCluster.stop() } (backend, scheduler) @@ -2486,18 +2505,13 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc) val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", defaultValue = true) val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, sc, mesosUrl, sc.env.securityManager) + new MesosCoarseGrainedSchedulerBackend(scheduler, sc, mesosUrl, sc.env.securityManager) } else { - new MesosSchedulerBackend(scheduler, sc, mesosUrl) + new MesosFineGrainedSchedulerBackend(scheduler, sc, mesosUrl) } scheduler.initialize(backend) (backend, scheduler) - case zkUrl if zkUrl.startsWith("zk://") => - logWarning("Master URL for a multi-master Mesos cluster managed by ZooKeeper should be " + - "in the form mesos://zk://host:port. Current Master URL will stop working in Spark 2.0.") - createTaskScheduler(sc, "mesos://" + zkUrl, deployMode) - case masterUrl => val cm = getClusterManager(masterUrl) match { case Some(clusterMgr) => clusterMgr diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 27497e21b829d..af50a6dc2d8d1 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -31,7 +31,6 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager} import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} @@ -61,10 +60,8 @@ class SparkEnv ( val mapOutputTracker: MapOutputTracker, val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, - val blockTransferService: BlockTransferService, val blockManager: BlockManager, val securityManager: SecurityManager, - val sparkFilesDir: String, val metricsSystem: MetricsSystem, val memoryManager: MemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, @@ -77,7 +74,7 @@ class SparkEnv ( // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() - private var driverTmpDirToDelete: Option[String] = None + private[spark] var driverTmpDir: Option[String] = None private[spark] def stop() { @@ -94,13 +91,10 @@ class SparkEnv ( rpcEnv.shutdown() rpcEnv.awaitTermination() - // Note that blockTransferService is stopped by BlockManager since it is started by it. - // If we only stop sc, but the driver process still run as a services then we need to delete // the tmp dir, if not, it will create too many tmp dirs. - // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the - // current working dir in executor which we do not need to delete. - driverTmpDirToDelete match { + // We only need to delete the tmp dir create by driver + driverTmpDir match { case Some(path) => try { Utils.deleteRecursively(new File(path)) @@ -284,8 +278,10 @@ object SparkEnv extends Logging { } } + val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) + val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf) + new MapOutputTrackerMaster(conf, broadcastManager, isLocal) } else { new MapOutputTrackerWorker(conf) } @@ -325,8 +321,6 @@ object SparkEnv extends Logging { serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) - val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -342,15 +336,6 @@ object SparkEnv extends Logging { ms } - // Set the sparkFiles directory, used when downloading dependencies. In local mode, - // this is a temporary directory; in distributed mode, this is the executor's current working - // directory. - val sparkFilesDir: String = if (isDriver) { - Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath - } else { - "." - } - val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf, isDriver) } @@ -367,10 +352,8 @@ object SparkEnv extends Logging { mapOutputTracker, shuffleManager, broadcastManager, - blockTransferService, blockManager, securityManager, - sparkFilesDir, metricsSystem, memoryManager, outputCommitCoordinator, @@ -380,7 +363,8 @@ object SparkEnv extends Logging { // called, and we only need to do it for driver. Because driver may run as a service, and if we // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. if (isDriver) { - envInstance.driverTmpDirToDelete = Some(sparkFilesDir) + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) } envInstance diff --git a/core/src/main/scala/org/apache/spark/SparkFiles.scala b/core/src/main/scala/org/apache/spark/SparkFiles.scala index e85b89fd014ef..44f4444a1fa8d 100644 --- a/core/src/main/scala/org/apache/spark/SparkFiles.scala +++ b/core/src/main/scala/org/apache/spark/SparkFiles.scala @@ -34,6 +34,6 @@ object SparkFiles { * Get the root directory that contains files added through `SparkContext.addFile()`. */ def getRootDirectory(): String = - SparkEnv.get.sparkFilesDir + SparkEnv.get.driverTmpDir.getOrElse(".") } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 9e53257462867..27abccf5ac2a9 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source -import org.apache.spark.util.{TaskCompletionListener, TaskFailureListener} +import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener} object TaskContext { @@ -188,6 +188,6 @@ abstract class TaskContext extends Serializable { * Register an accumulator that belongs to this task. Accumulators must call this method when * deserializing in executors. */ - private[spark] def registerAccumulator(a: NewAccumulator[_, _]): Unit + private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index bc3807f5db180..c904e083911cd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -122,7 +122,7 @@ private[spark] class TaskContextImpl( override def getMetricsSources(sourceName: String): Seq[Source] = metricsSystem.getSourcesByName(sourceName) - private[spark] override def registerAccumulator(a: NewAccumulator[_, _]): Unit = { + private[spark] override def registerAccumulator(a: AccumulatorV2[_, _]): Unit = { taskMetrics.registerAccumulator(a) } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 82ba2d0c274be..42690844f9610 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.Utils +import org.apache.spark.util.{AccumulatorV2, Utils} // ============================================================================================== // NOTE: new task end reasons MUST be accompanied with serialization logic in util.JsonProtocol! @@ -118,7 +118,7 @@ case class ExceptionFailure( fullStackTrace: String, private val exceptionWrapper: Option[ThrowableSerializationWrapper], accumUpdates: Seq[AccumulableInfo] = Seq.empty, - private[spark] var accums: Seq[NewAccumulator[_, _]] = Nil) + private[spark] var accums: Seq[AccumulatorV2[_, _]] = Nil) extends TaskFailedReason { /** @@ -138,7 +138,7 @@ case class ExceptionFailure( this(e, accumUpdates, preserveCause = true) } - private[spark] def withAccums(accums: Seq[NewAccumulator[_, _]]): ExceptionFailure = { + private[spark] def withAccums(accums: Seq[AccumulatorV2[_, _]]): ExceptionFailure = { this.accums = accums this } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 43c89b258f2fa..871b9d1ad575b 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -22,6 +22,7 @@ import java.net.{URI, URL} import java.nio.charset.StandardCharsets import java.nio.file.Paths import java.util.Arrays +import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConverters._ @@ -190,8 +191,14 @@ private[spark] object TestUtils { private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] private val spilledStageIds = new mutable.HashSet[Int] + private val stagesDone = new CountDownLatch(1) - def numSpilledStages: Int = spilledStageIds.size + def numSpilledStages: Int = { + // Long timeout, just in case somehow the job end isn't notified. + // Fails if a timeout occurs + assert(stagesDone.await(10, TimeUnit.SECONDS)) + spilledStageIds.size + } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { stageIdToTaskMetrics.getOrElseUpdate( @@ -206,4 +213,8 @@ private class SpillListener extends SparkListener { spilledStageIds += stageId } } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + stagesDone.countDown() + } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 6f3b8faf03b04..a37c52cbaf210 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -19,7 +19,7 @@ package org.apache.spark.api.java import java.{lang => jl} import java.lang.{Iterable => JIterable} -import java.util.{Comparator, Iterator => JIterator, List => JList} +import java.util.{Comparator, Iterator => JIterator, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -80,7 +80,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ - def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = + def iterator(split: Partition, taskContext: TaskContext): JIterator[T] = rdd.iterator(split, taskContext).asJava // Transformations (return a new RDD) @@ -96,7 +96,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * of the original partition. */ def mapPartitionsWithIndex[R]( - f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], + f: JFunction2[jl.Integer, JIterator[T], JIterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = new JavaRDD(rdd.mapPartitionsWithIndex((a, b) => f.call(a, b.asJava).asScala, preservesPartitioning)(fakeClassTag))(fakeClassTag) @@ -147,7 +147,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { + def mapPartitions[U](f: FlatMapFunction[JIterator[T], U]): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { (x: Iterator[T]) => f.call(x.asJava).asScala } @@ -157,7 +157,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U], + def mapPartitions[U](f: FlatMapFunction[JIterator[T], U], preservesPartitioning: Boolean): JavaRDD[U] = { def fn: (Iterator[T]) => Iterator[U] = { (x: Iterator[T]) => f.call(x.asJava).asScala @@ -169,7 +169,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { + def mapPartitionsToDouble(f: DoubleFlatMapFunction[JIterator[T]]): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { (x: Iterator[T]) => f.call(x.asJava).asScala } @@ -179,7 +179,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): + def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[JIterator[T], K2, V2]): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { (x: Iterator[T]) => f.call(x.asJava).asScala @@ -190,7 +190,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitionsToDouble(f: DoubleFlatMapFunction[java.util.Iterator[T]], + def mapPartitionsToDouble(f: DoubleFlatMapFunction[JIterator[T]], preservesPartitioning: Boolean): JavaDoubleRDD = { def fn: (Iterator[T]) => Iterator[jl.Double] = { (x: Iterator[T]) => f.call(x.asJava).asScala @@ -202,7 +202,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2], + def mapPartitionsToPair[K2, V2](f: PairFlatMapFunction[JIterator[T], K2, V2], preservesPartitioning: Boolean): JavaPairRDD[K2, V2] = { def fn: (Iterator[T]) => Iterator[(K2, V2)] = { (x: Iterator[T]) => f.call(x.asJava).asScala @@ -214,7 +214,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Applies a function f to each partition of this RDD. */ - def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) { + def foreachPartition(f: VoidFunction[JIterator[T]]): Unit = { rdd.foreachPartition(x => f.call(x.asJava)) } @@ -256,19 +256,44 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: String): JavaRDD[String] = rdd.pipe(command) + def pipe(command: String): JavaRDD[String] = { + rdd.pipe(command) + } /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: JList[String]): JavaRDD[String] = + def pipe(command: JList[String]): JavaRDD[String] = { rdd.pipe(command.asScala) + } /** * Return an RDD created by piping elements to a forked external process. */ - def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = + def pipe(command: JList[String], env: JMap[String, String]): JavaRDD[String] = { rdd.pipe(command.asScala, env.asScala) + } + + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: JList[String], + env: JMap[String, String], + separateWorkingDir: Boolean, + bufferSize: Int): JavaRDD[String] = { + rdd.pipe(command.asScala, env.asScala, null, null, separateWorkingDir, bufferSize) + } + + /** + * Return an RDD created by piping elements to a forked external process. + */ + def pipe(command: JList[String], + env: JMap[String, String], + separateWorkingDir: Boolean, + bufferSize: Int, + encoding: String): JavaRDD[String] = { + rdd.pipe(command.asScala, env.asScala, null, null, separateWorkingDir, bufferSize, encoding) + } /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, @@ -288,7 +313,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def zipPartitions[U, V]( other: JavaRDDLike[U, _], - f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = { + f: FlatMapFunction2[JIterator[T], JIterator[U], V]): JavaRDD[V] = { def fn: (Iterator[T], Iterator[U]) => Iterator[V] = { (x: Iterator[T], y: Iterator[U]) => f.call(x.asJava, y.asJava).asScala } @@ -431,6 +456,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. + * + * The confidence is the probability that the error bounds of the result will + * contain the true value. That is, if countApprox were called repeatedly + * with confidence 0.9, we would expect 90% of the results to contain the + * true count. The confidence must be in the range [0,1] or an exception will + * be thrown. + * + * @param timeout maximum time to wait for the job, in milliseconds + * @param confidence the desired statistical confidence in the result + * @return a potentially incomplete result, with error bounds */ def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = rdd.countApprox(timeout, confidence) @@ -438,6 +473,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. + * + * @param timeout maximum time to wait for the job, in milliseconds */ def countApprox(timeout: Long): PartialResult[BoundedDouble] = rdd.countApprox(timeout) @@ -446,22 +483,35 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final * combine step happens locally on the master, equivalent to running a single reduce task. */ - def countByValue(): java.util.Map[T, jl.Long] = - mapAsSerializableJavaMap(rdd.countByValue()).asInstanceOf[java.util.Map[T, jl.Long]] + def countByValue(): JMap[T, jl.Long] = + mapAsSerializableJavaMap(rdd.countByValue()).asInstanceOf[JMap[T, jl.Long]] /** - * (Experimental) Approximate version of countByValue(). + * Approximate version of countByValue(). + * + * The confidence is the probability that the error bounds of the result will + * contain the true value. That is, if countApprox were called repeatedly + * with confidence 0.9, we would expect 90% of the results to contain the + * true count. The confidence must be in the range [0,1] or an exception will + * be thrown. + * + * @param timeout maximum time to wait for the job, in milliseconds + * @param confidence the desired statistical confidence in the result + * @return a potentially incomplete result, with error bounds */ def countByValueApprox( timeout: Long, confidence: Double - ): PartialResult[java.util.Map[T, BoundedDouble]] = + ): PartialResult[JMap[T, BoundedDouble]] = rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** - * (Experimental) Approximate version of countByValue(). + * Approximate version of countByValue(). + * + * @param timeout maximum time to wait for the job, in milliseconds + * @return a potentially incomplete result, with error bounds */ - def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = + def countByValueApprox(timeout: Long): PartialResult[JMap[T, BoundedDouble]] = rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap) /** @@ -596,9 +646,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the maximum element from this RDD as defined by the specified * Comparator[T]. + * * @param comp the comparator that defines ordering * @return the maximum of the RDD - * */ + */ def max(comp: Comparator[T]): T = { rdd.max()(Ordering.comparatorToOrdering(comp)) } @@ -606,9 +657,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Returns the minimum element from this RDD as defined by the specified * Comparator[T]. + * * @param comp the comparator that defines ordering * @return the minimum of the RDD - * */ + */ def min(comp: Comparator[T]): T = { rdd.min()(Ordering.comparatorToOrdering(comp)) } @@ -684,7 +736,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * The asynchronous version of the `foreachPartition` action, which * applies a function f to each partition of this RDD. */ - def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { + def foreachPartitionAsync(f: VoidFunction[JIterator[T]]): JavaFutureAction[Void] = { new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x.asJava)), { x => null.asInstanceOf[Void] }) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index dfd91ae338e89..131f36f5470f0 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -530,6 +530,7 @@ class JavaSparkContext(val sc: SparkContext) * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. */ + @deprecated("use sc().longAccumulator()", "2.0.0") def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] @@ -539,6 +540,7 @@ class JavaSparkContext(val sc: SparkContext) * * This version supports naming the accumulator for display in Spark's web UI. */ + @deprecated("use sc().longAccumulator(String)", "2.0.0") def intAccumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = sc.accumulator(initialValue, name)(IntAccumulatorParam) .asInstanceOf[Accumulator[java.lang.Integer]] @@ -547,6 +549,7 @@ class JavaSparkContext(val sc: SparkContext) * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. */ + @deprecated("use sc().doubleAccumulator()", "2.0.0") def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] @@ -556,6 +559,7 @@ class JavaSparkContext(val sc: SparkContext) * * This version supports naming the accumulator for display in Spark's web UI. */ + @deprecated("use sc().doubleAccumulator(String)", "2.0.0") def doubleAccumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = sc.accumulator(initialValue, name)(DoubleAccumulatorParam) .asInstanceOf[Accumulator[java.lang.Double]] @@ -564,6 +568,7 @@ class JavaSparkContext(val sc: SparkContext) * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. */ + @deprecated("use sc().longAccumulator()", "2.0.0") def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) /** @@ -572,6 +577,7 @@ class JavaSparkContext(val sc: SparkContext) * * This version supports naming the accumulator for display in Spark's web UI. */ + @deprecated("use sc().longAccumulator(String)", "2.0.0") def accumulator(initialValue: Int, name: String): Accumulator[java.lang.Integer] = intAccumulator(initialValue, name) @@ -579,6 +585,7 @@ class JavaSparkContext(val sc: SparkContext) * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values * to using the `add` method. Only the master can access the accumulator's `value`. */ + @deprecated("use sc().doubleAccumulator()", "2.0.0") def accumulator(initialValue: Double): Accumulator[java.lang.Double] = doubleAccumulator(initialValue) @@ -589,6 +596,7 @@ class JavaSparkContext(val sc: SparkContext) * * This version supports naming the accumulator for display in Spark's web UI. */ + @deprecated("use sc().doubleAccumulator(String)", "2.0.0") def accumulator(initialValue: Double, name: String): Accumulator[java.lang.Double] = doubleAccumulator(initialValue, name) @@ -596,6 +604,7 @@ class JavaSparkContext(val sc: SparkContext) * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" * values to using the `add` method. Only the master can access the accumulator's `value`. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) @@ -605,23 +614,26 @@ class JavaSparkContext(val sc: SparkContext) * * This version supports naming the accumulator for display in Spark's web UI. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String, accumulatorParam: AccumulatorParam[T]) : Accumulator[T] = sc.accumulator(initialValue, name)(accumulatorParam) /** * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks - * can "add" values with `add`. Only the master can access the accumuable's `value`. + * can "add" values with `add`. Only the master can access the accumulable's `value`. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = sc.accumulable(initialValue)(param) /** * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks - * can "add" values with `add`. Only the master can access the accumuable's `value`. + * can "add" values with `add`. Only the master can access the accumulable's `value`. * * This version supports naming the accumulator for display in Spark's web UI. */ + @deprecated("use AccumulatorV2", "2.0.0") def accumulable[T, R](initialValue: T, name: String, param: AccumulableParam[T, R]) : Accumulable[T, R] = sc.accumulable(initialValue, name)(param) @@ -712,8 +724,13 @@ class JavaSparkContext(val sc: SparkContext) } /** - * Set a local property that affects jobs submitted from this thread, such as the - * Spark fair scheduler pool. + * Set a local property that affects jobs submitted from this thread, and all child + * threads, such as the Spark fair scheduler pool. + * + * These properties are inherited by child threads spawned from this thread. This + * may have unexpected consequences when working with thread pools. The standard java + * implementation of thread pools have worker threads spawn other worker threads. + * As a result, local properties may propagate unpredictably. */ def setLocalProperty(key: String, value: String): Unit = sc.setLocalProperty(key, value) @@ -780,7 +797,7 @@ class JavaSparkContext(val sc: SparkContext) def cancelAllJobs(): Unit = sc.cancelAllJobs() /** - * Returns an Java map of JavaRDDs that have marked themselves as persistent via cache() call. + * Returns a Java map of JavaRDDs that have marked themselves as persistent via cache() call. * Note that this does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: JMap[java.lang.Integer, JavaRDD[_]] = { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ab5b6c8380e8b..2822eb5d60024 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -919,7 +919,7 @@ private class PythonAccumulatorParam(@transient private val serverHost: String, } /** - * An Wrapper for Python Broadcast, which is written into disk by Python. It also will + * A Wrapper for Python Broadcast, which is written into disk by Python. It also will * write the data into disk after deserialization, then Python can read it from disks. */ // scalastyle:off no.finalize diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 8bcd2903fe768..701097ace8974 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.9.2-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.3-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 3df87f62f2f85..6a5e6f7c5afb1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -235,7 +235,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } private def cleanupIdleWorkers() { - while (idleWorkers.length > 0) { + while (idleWorkers.nonEmpty) { val worker = idleWorkers.dequeue() try { // the worker will exit after closing the socket diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index c416e835a9046..5c45c8493ef61 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -168,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed") + logError(s"$methodName on $objId failed", e) writeInt(dos, -1) // Writing the error message of the cause for the exception. This will be returned // to user in the R process. diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 59c8429c80172..a1a5eb8cf55e8 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.python.PythonRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -140,4 +141,16 @@ private[r] object RRDD { def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) } + + /** + * Create an RRDD given a temporary file name. This is used to create RRDD when parallelize is + * called on large R objects. + * + * @param fileName name of temporary file on driver machine + * @param parallelism number of slices defaults to 4 + */ + def createRDDFromFile(jsc: JavaSparkContext, fileName: String, parallelism: Int): + JavaRDD[Array[Byte]] = { + PythonRDD.readRDDFromFile(jsc, fileName, parallelism) + } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 24ad689f8321c..496fdf851f7db 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -40,7 +40,8 @@ private[spark] class RRunner[U]( broadcastVars: Array[Broadcast[Object]], numPartitions: Int = -1, isDataFrame: Boolean = false, - colNames: Array[String] = null) + colNames: Array[String] = null, + mode: Int = RRunnerModes.RDD) extends Logging { private var bootTime: Double = _ private var dataStream: DataInputStream = _ @@ -148,8 +149,7 @@ private[spark] class RRunner[U]( } dataOut.writeInt(numPartitions) - - dataOut.writeInt(if (isDataFrame) 1 else 0) + dataOut.writeInt(mode) if (isDataFrame) { SerDe.writeObject(dataOut, colNames) @@ -180,6 +180,13 @@ private[spark] class RRunner[U]( for (elem <- iter) { elem match { + case (key, innerIter: Iterator[_]) => + for (innerElem <- innerIter) { + writeElem(innerElem) + } + // Writes key which can be used as a boundary in group-aggregate + dataOut.writeByte('r') + writeElem(key) case (key, value) => writeElem(key) writeElem(value) @@ -187,6 +194,7 @@ private[spark] class RRunner[U]( writeElem(elem) } } + stream.flush() } catch { // TODO: We should propagate this error to the task thread @@ -268,6 +276,12 @@ private object SpecialLengths { val TIMING_DATA = -1 } +private[spark] object RRunnerModes { + val RDD = 0 + val DATAFRAME_DAPPLY = 1 + val DATAFRAME_GAPPLY = 2 +} + private[r] class BufferedStreamThread( in: InputStream, name: String, diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 16157414fd120..77825e75e5136 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -37,6 +37,15 @@ private[spark] object RUtils { ) } + /** + * Check if SparkR is installed before running tests that use SparkR. + */ + def isSparkRInstalled: Boolean = { + localSparkRPackagePath.filter { pkgDir => + new File(Seq(pkgDir, "SparkR").mkString(File.separator)).exists + }.isDefined + } + /** * Get the list of paths for R packages in various deployment modes, of which the first * path is for the SparkR package itself. The second path is for R packages built as diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index e4932a4192d39..550e075a95129 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -125,15 +125,34 @@ private[spark] object SerDe { } def readDate(in: DataInputStream): Date = { - Date.valueOf(readString(in)) + try { + val inStr = readString(in) + if (inStr == "NA") { + null + } else { + Date.valueOf(inStr) + } + } catch { + // TODO: SPARK-18011 with some versions of R deserializing NA from R results in NASE + case _: NegativeArraySizeException => null + } } def readTime(in: DataInputStream): Timestamp = { - val seconds = in.readDouble() - val sec = Math.floor(seconds).toLong - val t = new Timestamp(sec * 1000L) - t.setNanos(((seconds - sec) * 1e9).toInt) - t + try { + val seconds = in.readDouble() + if (java.lang.Double.isNaN(seconds)) { + null + } else { + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t + } + } catch { + // TODO: SPARK-18011 with some versions of R deserializing NA from R results in NASE + case _: NegativeArraySizeException => null + } } def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 632b0ae9c2c37..e8d6d587b4824 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -232,7 +232,11 @@ private object TorrentBroadcast extends Logging { val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos) val ser = serializer.newInstance() val serOut = ser.serializeStream(out) - serOut.writeObject[T](obj).close() + Utils.tryWithSafeFinally { + serOut.writeObject[T](obj) + } { + serOut.close() + } cbbos.toChunkedByteBuffer.getChunks() } @@ -246,8 +250,11 @@ private object TorrentBroadcast extends Logging { val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) - val obj = serIn.readObject[T]() - serIn.close() + val obj = Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } obj } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 34c0696bfc4e5..ac09c6c497f8b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -135,7 +135,7 @@ private[deploy] object DeployMessages { } case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], - exitStatus: Option[Int]) + exitStatus: Option[Int], workerLost: Boolean) case class ApplicationRemoved(message: String) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 2e9e45a1550fd..90c71cc6cfab7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy -import java.io.{ByteArrayInputStream, DataInputStream} +import java.io.{ByteArrayInputStream, DataInputStream, IOException} import java.lang.reflect.Method import java.security.PrivilegedExceptionAction -import java.util.{Arrays, Comparator} +import java.text.DateFormat +import java.util.{Arrays, Comparator, Date} import scala.collection.JavaConverters._ import scala.concurrent.duration._ @@ -34,10 +35,13 @@ import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.token.{Token, TokenIdentifier} +import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -228,6 +232,10 @@ class SparkHadoopUtil extends Logging { recurse(baseStatus) } + def isGlobPath(pattern: Path): Boolean = { + pattern.toString.exists("{}[]*?\\".toSet.contains) + } + def globPath(pattern: Path): Seq[Path] = { val fs = pattern.getFileSystem(conf) Option(fs.globStatus(pattern)).map { statuses => @@ -236,11 +244,7 @@ class SparkHadoopUtil extends Logging { } def globPathIfNecessary(pattern: Path): Seq[Path] = { - if (pattern.toString.exists("{}[]*?\\".toSet.contains)) { - globPath(pattern) - } else { - Seq(pattern) - } + if (isGlobPath(pattern)) globPath(pattern) else Seq(pattern) } /** @@ -285,8 +289,7 @@ class SparkHadoopUtil extends Logging { credentials: Credentials): Long = { val now = System.currentTimeMillis() - val renewalInterval = - sparkConf.getLong("spark.yarn.token.renewal.interval", (24 hours).toMillis) + val renewalInterval = sparkConf.get(TOKEN_RENEWAL_INTERVAL).get credentials.getAllTokens.asScala .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) @@ -357,6 +360,50 @@ class SparkHadoopUtil extends Logging { newConf.setBoolean(confKey, true) newConf } + + /** + * Dump the credentials' tokens to string values. + * + * @param credentials credentials + * @return an iterator over the string values. If no credentials are passed in: an empty list + */ + private[spark] def dumpTokens(credentials: Credentials): Iterable[String] = { + if (credentials != null) { + credentials.getAllTokens.asScala.map(tokenToString) + } else { + Seq() + } + } + + /** + * Convert a token to a string for logging. + * If its an abstract delegation token, attempt to unmarshall it and then + * print more details, including timestamps in human-readable form. + * + * @param token token to convert to a string + * @return a printable string value. + */ + private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = { + val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT) + val buffer = new StringBuilder(128) + buffer.append(token.toString) + try { + val ti = token.decodeIdentifier + buffer.append("; ").append(ti) + ti match { + case dt: AbstractDelegationTokenIdentifier => + // include human times and the renewer, which the HDFS tokens toString omits + buffer.append("; Renewer: ").append(dt.getRenewer) + buffer.append("; Issued: ").append(df.format(new Date(dt.getIssueDate))) + buffer.append("; Max Date: ").append(df.format(new Date(dt.getMaxDate))) + case _ => + } + } catch { + case e: IOException => + logDebug("Failed to decode $token: $e", e) + } + buffer.toString + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 926e1ff7a874d..80611658a1640 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -40,9 +40,11 @@ import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver} -import org.apache.spark.{SPARK_VERSION, SparkException, SparkUserAppException} +import org.apache.spark.{SPARK_REVISION, SPARK_VERSION, SparkException, SparkUserAppException} +import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL} import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ +import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -75,10 +77,6 @@ object SparkSubmit { private val CLUSTER = 2 private val ALL_DEPLOY_MODES = CLIENT | CLUSTER - // A special jar name that indicates the class being run is inside of Spark itself, and therefore - // no user jar is needed. - private val SPARK_INTERNAL = "spark-internal" - // Special primary resource names that represent shells rather than application jars. private val SPARK_SHELL = "spark-shell" private val PYSPARK_SHELL = "pyspark-shell" @@ -106,6 +104,10 @@ object SparkSubmit { /___/ .__/\_,_/_/ /_/\_\ version %s /_/ """.format(SPARK_VERSION)) + printStream.println("Branch %s".format(SPARK_BRANCH)) + printStream.println("Compiled by user %s on %s".format(SPARK_BUILD_USER, SPARK_BUILD_DATE)) + printStream.println("Revision %s".format(SPARK_REVISION)) + printStream.println("Url %s".format(SPARK_REPO_URL)) printStream.println("Type --help for more information.") exitFn(0) } @@ -305,10 +307,11 @@ object SparkSubmit { } // Require all python files to be local, so we can add them to the PYTHONPATH - // In YARN cluster mode, python files are distributed as regular files, which can be non-local - if (args.isPython && !isYarnCluster) { + // In YARN cluster mode, python files are distributed as regular files, which can be non-local. + // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. + if (args.isPython && !isYarnCluster && !isMesosCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local python files are supported: $args.primaryResource") + printErrorAndExit(s"Only local python files are supported: ${args.primaryResource}") } val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",") if (nonLocalPyFiles.nonEmpty) { @@ -319,7 +322,7 @@ object SparkSubmit { // Require all R files to be local if (args.isR && !isYarnCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local R files are supported: $args.primaryResource") + printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") } } @@ -410,12 +413,12 @@ object SparkSubmit { printErrorAndExit("SparkR is not supported for Mesos cluster.") } - // If we're running a R app, set the main class to our specific R runner + // If we're running an R app, set the main class to our specific R runner if (args.isR && deployMode == CLIENT) { if (args.primaryResource == SPARKR_SHELL) { args.mainClass = "org.apache.spark.api.r.RBackend" } else { - // If a R file is provided, add it to the child arguments and list of files to deploy. + // If an R file is provided, add it to the child arguments and list of files to deploy. // Usage: RRunner

[app arguments] args.mainClass = "org.apache.spark.deploy.RRunner" args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs @@ -424,7 +427,7 @@ object SparkSubmit { } if (isYarnCluster && args.isR) { - // In yarn-cluster mode for a R app, add primary resource to files + // In yarn-cluster mode for an R app, add primary resource to files // that can be distributed with the job args.files = mergeFileLists(args.files, args.primaryResource) } @@ -574,7 +577,7 @@ object SparkSubmit { childArgs += ("--primary-r-file", mainFile) childArgs += ("--class", "org.apache.spark.deploy.RRunner") } else { - if (args.primaryResource != SPARK_INTERNAL) { + if (args.primaryResource != SparkLauncher.NO_RESOURCE) { childArgs += ("--jar", args.primaryResource) } childArgs += ("--class", args.mainClass) @@ -630,7 +633,14 @@ object SparkSubmit { // explicitly sets `spark.submit.pyFiles` in his/her default properties file. sysProps.get("spark.submit.pyFiles").foreach { pyFiles => val resolvedPyFiles = Utils.resolveURIs(pyFiles) - val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + val formattedPyFiles = if (!isYarnCluster && !isMesosCluster) { + PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + } else { + // Ignoring formatting python path in yarn and mesos cluster mode, these two modes + // support dealing with remote python files, they could distribute and add python files + // locally. + resolvedPyFiles + } sysProps("spark.submit.pyFiles") = formattedPyFiles } @@ -794,7 +804,7 @@ object SparkSubmit { } private[deploy] def isInternal(res: String): Boolean = { - res == SPARK_INTERNAL + res == SparkLauncher.NO_RESOURCE } /** @@ -894,9 +904,12 @@ private[spark] object SparkSubmitUtils { val localIvyRoot = new File(ivySettings.getDefaultIvyUserDir, "local") localIvy.setLocal(true) localIvy.setRepository(new FileRepository(localIvyRoot)) - val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", - "[artifact](-[classifier]).[ext]").mkString(File.separator) - localIvy.addIvyPattern(localIvyRoot.getAbsolutePath + File.separator + ivyPattern) + val ivyPattern = Seq(localIvyRoot.getAbsolutePath, "[organisation]", "[module]", "[revision]", + "ivys", "ivy.xml").mkString(File.separator) + localIvy.addIvyPattern(ivyPattern) + val artifactPattern = Seq(localIvyRoot.getAbsolutePath, "[organisation]", "[module]", + "[revision]", "[type]s", "[artifact](-[classifier]).[ext]").mkString(File.separator) + localIvy.addArtifactPattern(artifactPattern) localIvy.setName("local-ivy-cache") cr.add(localIvy) @@ -941,7 +954,7 @@ private[spark] object SparkSubmitUtils { artifacts.foreach { mvn => val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) val dd = new DefaultDependencyDescriptor(ri, false, false) - dd.addDependencyConfiguration(ivyConfName, ivyConfName) + dd.addDependencyConfiguration(ivyConfName, ivyConfName + "(runtime)") // scalastyle:off println printStream.println(s"${dd.getDependencyId} added as a dependency") // scalastyle:on println @@ -957,9 +970,9 @@ private[spark] object SparkSubmitUtils { // Add scala exclusion rule md.addExcludeRule(createExclusion("*:scala-library:*", ivySettings, ivyConfName)) - // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and + // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka-0-8 and // other spark-streaming utility components. Underscore is there to differentiate between - // spark-streaming_2.1x and spark-streaming-kafka-assembly_2.1x + // spark-streaming_2.1x and spark-streaming-kafka-0-8-assembly_2.1x val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 78da1b70c54a5..f1761e7c1ec92 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -173,6 +173,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull + files = Option(files).orElse(sparkProperties.get("spark.files")).orNull ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull packages = Option(packages).orElse(sparkProperties.get("spark.jars.packages")).orNull packagesExclusions = Option(packagesExclusions) @@ -549,6 +550,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | (Default: 1). | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). + | If dynamic allocation is enabled, the initial number of + | executors will be at least NUM. | --archives ARCHIVES Comma separated list of archives to be extracted into the | working directory of each executor. | --principal PRINCIPAL Principal to be used to login to KDC, while running on diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala similarity index 88% rename from core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala rename to core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 43b17e5d49bf1..93f58ce63799f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -21,6 +21,8 @@ import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +import scala.concurrent.Future +import scala.util.{Failure, Success} import scala.util.control.NonFatal import org.apache.spark.SparkConf @@ -32,17 +34,18 @@ import org.apache.spark.rpc._ import org.apache.spark.util.{RpcUtils, ThreadUtils} /** - * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, - * an app description, and a listener for cluster events, and calls back the listener when various - * events occur. + * Interface allowing applications to speak with a Spark standalone cluster manager. + * + * Takes a master URL, an app description, and a listener for cluster events, and calls + * back the listener when various events occur. * * @param masterUrls Each url should look like spark://host:port. */ -private[spark] class AppClient( +private[spark] class StandaloneAppClient( rpcEnv: RpcEnv, masterUrls: Array[String], appDescription: ApplicationDescription, - listener: AppClientListener, + listener: StandaloneAppClientListener, conf: SparkConf) extends Logging { @@ -78,11 +81,6 @@ private[spark] class AppClient( private val registrationRetryThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") - // A thread pool to perform receive then reply actions in a thread so as not to block the - // event loop. - private val askAndReplyThreadPool = - ThreadUtils.newDaemonCachedThreadPool("appclient-receive-and-reply-threadpool") - override def onStart(): Unit = { try { registerWithMaster(1) @@ -176,12 +174,12 @@ private[spark] class AppClient( cores)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) - case ExecutorUpdated(id, state, message, exitStatus) => + case ExecutorUpdated(id, state, message, exitStatus, workerLost) => val fullId = appId + "/" + id val messageText = message.map(s => " (" + s + ")").getOrElse("") logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) if (ExecutorState.isFinished(state)) { - listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) + listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerLost) } case MasterChanged(masterRef, masterWebUiUrl) => @@ -219,19 +217,13 @@ private[spark] class AppClient( endpointRef: RpcEndpointRef, context: RpcCallContext, msg: T): Unit = { - // Create a thread to ask a message and reply with the result. Allow thread to be + // Ask a message and create a thread to reply with the result. Allow thread to be // interrupted during shutdown, otherwise context must be notified of NonFatal errors. - askAndReplyThreadPool.execute(new Runnable { - override def run(): Unit = { - try { - context.reply(endpointRef.askWithRetry[Boolean](msg)) - } catch { - case ie: InterruptedException => // Cancelled - case NonFatal(t) => - context.sendFailure(t) - } - } - }) + endpointRef.ask[Boolean](msg).andThen { + case Success(b) => context.reply(b) + case Failure(ie: InterruptedException) => // Cancelled + case Failure(NonFatal(t)) => context.sendFailure(t) + }(ThreadUtils.sameThread) } override def onDisconnected(address: RpcAddress): Unit = { @@ -271,7 +263,6 @@ private[spark] class AppClient( registrationRetryThread.shutdownNow() registerMasterFutures.get.foreach(_.cancel(true)) registerMasterThreadPool.shutdownNow() - askAndReplyThreadPool.shutdownNow() } } @@ -300,12 +291,12 @@ private[spark] class AppClient( * * @return whether the request is acknowledged. */ - def requestTotalExecutors(requestedTotal: Int): Boolean = { + def requestTotalExecutors(requestedTotal: Int): Future[Boolean] = { if (endpoint.get != null && appId.get != null) { - endpoint.get.askWithRetry[Boolean](RequestExecutors(appId.get, requestedTotal)) + endpoint.get.ask[Boolean](RequestExecutors(appId.get, requestedTotal)) } else { logWarning("Attempted to request executors before driver fully initialized.") - false + Future.successful(false) } } @@ -313,12 +304,12 @@ private[spark] class AppClient( * Kill the given list of executors through the Master. * @return whether the kill request is acknowledged. */ - def killExecutors(executorIds: Seq[String]): Boolean = { + def killExecutors(executorIds: Seq[String]): Future[Boolean] = { if (endpoint.get != null && appId.get != null) { - endpoint.get.askWithRetry[Boolean](KillExecutors(appId.get, executorIds)) + endpoint.get.ask[Boolean](KillExecutors(appId.get, executorIds)) } else { logWarning("Attempted to kill executors before driver fully initialized.") - false + Future.successful(false) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala similarity index 90% rename from core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala rename to core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala index 94506a0cbb27c..64255ec92b72a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala @@ -24,7 +24,7 @@ package org.apache.spark.deploy.client * * Users of this API should *not* block inside the callback methods. */ -private[spark] trait AppClientListener { +private[spark] trait StandaloneAppClientListener { def connected(appId: String): Unit /** Disconnection may be a temporary state, as we fail over to a new Master. */ @@ -36,5 +36,6 @@ private[spark] trait AppClientListener { def executorAdded( fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit - def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit + def executorRemoved( + fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 44661edfff90b..ba42b4862aa90 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -109,4 +109,9 @@ private[history] abstract class ApplicationHistoryProvider { @throws(classOf[SparkException]) def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit + /** + * @return the [[ApplicationHistoryInfo]] for the appId if it exists. + */ + def getApplicationInfo(appId: String): Option[ApplicationHistoryInfo] + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 07cbcec8e5f0e..cf4a401b64b38 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -222,6 +222,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getListing(): Iterable[FsApplicationHistoryInfo] = applications.values + override def getApplicationInfo(appId: String): Option[FsApplicationHistoryInfo] = { + applications.get(appId) + } + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => @@ -245,6 +249,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) ui.getSecurityManager.setViewAcls(attempt.sparkUser, appListener.viewAcls.getOrElse("")) + ui.getSecurityManager.setAdminAclsGroups(appListener.adminAclsGroups.getOrElse("")) + ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize)) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 2fad1120cdc8a..a120b6c5fcdff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -44,7 +44,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") if (allAppsSize > 0) { ++ ++ - + ++ + } else if (requestedIncomplete) {

No incomplete applications found!

} else { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d821474bdb590..735aa43cfc994 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,6 +28,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ @@ -55,6 +56,9 @@ class HistoryServer( // How many applications to retain private val retainedApplications = conf.getInt("spark.history.retainedApplications", 50) + // How many applications the summary ui displays + private[history] val maxApplications = conf.get(HISTORY_UI_MAX_APPS); + // application private val appCache = new ApplicationCache(this, retainedApplications, new SystemClock()) @@ -178,6 +182,10 @@ class HistoryServer( getApplicationList().iterator.map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } + def getApplicationInfo(appId: String): Option[ApplicationInfo] = { + provider.getApplicationInfo(appId).map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) + } + override def writeEventLogs( appId: String, attemptId: Option[String], diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 4ffb5283e99a4..53564d0e95152 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -41,7 +41,6 @@ private[spark] class ApplicationInfo( @transient var coresGranted: Int = _ @transient var endTime: Long = _ @transient var appSource: ApplicationSource = _ - @transient @volatile var appUIUrlAtHistoryServer: Option[String] = None // A cap on the number of executors this application can have at any given time. // By default, this is infinite. Only after the first allocation request is issued by the @@ -66,7 +65,6 @@ private[spark] class ApplicationInfo( nextExecutorId = 0 removedExecutors = new ArrayBuffer[ExecutorDesc] executorLimit = desc.initialExecutorLimit.getOrElse(Integer.MAX_VALUE) - appUIUrlAtHistoryServer = None } private def newExecutorId(useID: Option[Int] = None): Int = { @@ -136,11 +134,4 @@ private[spark] class ApplicationInfo( System.currentTimeMillis() - startTime } } - - /** - * Returns the original application UI url unless there is its address at history server - * is defined - */ - def curAppUIUrl: String = appUIUrlAtHistoryServer.getOrElse(desc.appUiUrl) - } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala index 37bfcdfdf4777..097728c821570 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala @@ -22,6 +22,4 @@ private[master] object ApplicationState extends Enumeration { type ApplicationState = Value val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value - - val MAX_NUM_RETRY = 10 } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index edc9be2a8a8cb..dcf41638e7994 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -17,25 +17,17 @@ package org.apache.spark.deploy.master -import java.io.FileNotFoundException -import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date -import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.Duration -import scala.language.postfixOps import scala.util.Random -import org.apache.hadoop.fs.Path - import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI @@ -43,9 +35,7 @@ import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.internal.Logging import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.serializer.{JavaSerializer, Serializer} -import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ThreadUtils, Utils} private[deploy] class Master( @@ -59,10 +49,6 @@ private[deploy] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") - private val rebuildUIThread = - ThreadUtils.newDaemonSingleThreadExecutor("master-rebuild-ui-thread") - private val rebuildUIContext = ExecutionContext.fromExecutor(rebuildUIThread) - private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -72,6 +58,7 @@ private[deploy] class Master( private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) private val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") + private val MAX_EXECUTOR_RETRIES = conf.getInt("spark.deploy.maxExecutorRetries", 10) val workers = new HashSet[WorkerInfo] val idToApp = new HashMap[String, ApplicationInfo] @@ -85,8 +72,6 @@ private[deploy] class Master( private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 - // Using ConcurrentHashMap so that master-rebuild-ui-thread can add a UI after asyncRebuildUI - private val appIdToUI = new ConcurrentHashMap[String, SparkUI] private val drivers = new HashSet[DriverInfo] private val completedDrivers = new ArrayBuffer[DriverInfo] @@ -199,7 +184,6 @@ private[deploy] class Master( checkForWorkerTimeOutTask.cancel(true) } forwardMessageThread.shutdownNow() - rebuildUIThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -268,7 +252,7 @@ private[deploy] class Master( appInfo.resetRetryCount() } - exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus, false)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app @@ -282,19 +266,20 @@ private[deploy] class Master( val normalExit = exitStatus == Some(0) // Only retry certain number of times so we don't go into an infinite loop. - if (!normalExit) { - if (appInfo.incrementRetryCount() < ApplicationState.MAX_NUM_RETRY) { - schedule() - } else { - val execs = appInfo.executors.values - if (!execs.exists(_.state == ExecutorState.RUNNING)) { - logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " + - s"${appInfo.retryCount} times; removing it") - removeApplication(appInfo, ApplicationState.FAILED) - } + // Important note: this code path is not exercised by tests, so be very careful when + // changing this `if` condition. + if (!normalExit + && appInfo.incrementRetryCount() >= MAX_EXECUTOR_RETRIES + && MAX_EXECUTOR_RETRIES >= 0) { // < 0 disables this application-killing path + val execs = appInfo.executors.values + if (!execs.exists(_.state == ExecutorState.RUNNING)) { + logError(s"Application ${appInfo.desc.name} with ID ${appInfo.id} failed " + + s"${appInfo.retryCount} times; removing it") + removeApplication(appInfo, ApplicationState.FAILED) } } } + schedule() case None => logWarning(s"Got status update for unknown executor $appId/$execId") } @@ -391,9 +376,6 @@ private[deploy] class Master( case CheckForWorkerTimeOut => timeOutDeadWorkers() - case AttachCompletedRebuildUI(appId) => - // An asyncRebuildSparkUI has completed, so need to attach to master webUi - Option(appIdToUI.get(appId)).foreach { ui => webUi.attachSparkUI(ui) } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -784,7 +766,7 @@ private[deploy] class Master( for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) exec.application.driver.send(ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None)) + exec.id, ExecutorState.LOST, Some("worker lost"), None, workerLost = true)) exec.state = ExecutorState.LOST exec.application.removeExecutor(exec) } @@ -844,7 +826,6 @@ private[deploy] class Master( if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach { a => - Option(appIdToUI.remove(a.id)).foreach { ui => webUi.detachSparkUI(ui) } applicationMetricsSystem.removeSource(a.appSource) } completedApps.trimStart(toRemove) @@ -852,9 +833,6 @@ private[deploy] class Master( completedApps += app // Remember it in our history waitingApps -= app - // If application events are logged, use them to rebuild the UI - asyncRebuildSparkUI(app) - for (exec <- app.executors.values) { killExecutor(exec) } @@ -953,90 +931,7 @@ private[deploy] class Master( exec.state = ExecutorState.KILLED } - /** - * Rebuild a new SparkUI from the given application's event logs. - * Return the UI if successful, else None - */ - private[master] def rebuildSparkUI(app: ApplicationInfo): Option[SparkUI] = { - val futureUI = asyncRebuildSparkUI(app) - ThreadUtils.awaitResult(futureUI, Duration.Inf) - } - - /** Rebuild a new SparkUI asynchronously to not block RPC event loop */ - private[master] def asyncRebuildSparkUI(app: ApplicationInfo): Future[Option[SparkUI]] = { - val appName = app.desc.name - val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found" - val eventLogDir = app.desc.eventLogDir - .getOrElse { - // Event logging is disabled for this application - app.appUIUrlAtHistoryServer = Some(notFoundBasePath) - return Future.successful(None) - } - val futureUI = Future { - val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, appAttemptId = None, compressionCodecName = app.desc.eventLogCodec) - val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) - val inProgressExists = fs.exists(new Path(eventLogFilePrefix + - EventLoggingListener.IN_PROGRESS)) - - val eventLogFile = if (inProgressExists) { - // Event logging is enabled for this application, but the application is still in progress - logWarning(s"Application $appName is still in progress, it may be terminated abnormally.") - eventLogFilePrefix + EventLoggingListener.IN_PROGRESS - } else { - eventLogFilePrefix - } - - val logInput = EventLoggingListener.openEventLog(new Path(eventLogFile), fs) - val replayBus = new ReplayListenerBus() - val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), - appName, HistoryServer.UI_PATH_PREFIX + s"/${app.id}", app.startTime) - try { - replayBus.replay(logInput, eventLogFile, inProgressExists) - } finally { - logInput.close() - } - - Some(ui) - }(rebuildUIContext) - - futureUI.onSuccess { case Some(ui) => - appIdToUI.put(app.id, ui) - // `self` can be null if we are already in the process of shutting down - // This happens frequently in tests where `local-cluster` is used - if (self != null) { - self.send(AttachCompletedRebuildUI(app.id)) - } - // Application UI is successfully rebuilt, so link the Master UI to it - // NOTE - app.appUIUrlAtHistoryServer is volatile - app.appUIUrlAtHistoryServer = Some(ui.basePath) - }(ThreadUtils.sameThread) - - futureUI.onFailure { - case fnf: FileNotFoundException => - // Event logging is enabled for this application, but no event logs are found - val title = s"Application history not found (${app.id})" - var msg = s"No event logs found for application $appName in ${app.desc.eventLogDir.get}." - logWarning(msg) - msg += " Did you specify the correct logging directory?" - msg = URLEncoder.encode(msg, "UTF-8") - app.appUIUrlAtHistoryServer = Some(notFoundBasePath + s"?msg=$msg&title=$title") - - case e: Exception => - // Relay exception message to application UI page - val title = s"Application history load error (${app.id})" - val exception = URLEncoder.encode(Utils.exceptionString(e), "UTF-8") - var msg = s"Exception in replaying log for application $appName!" - logError(msg, e) - msg = URLEncoder.encode(msg, "UTF-8") - app.appUIUrlAtHistoryServer = - Some(notFoundBasePath + s"?msg=$msg&exception=$exception&title=$title") - }(ThreadUtils.sameThread) - - futureUI - } - - /** Generate a new app ID given a app's submission date */ + /** Generate a new app ID given an app's submission date */ private def newApplicationId(submitDate: Date): String = { val appId = "app-%s-%04d".format(createDateFormat.format(submitDate), nextAppNumber) nextAppNumber += 1 diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 585e0839d0fc1..c63793c16dcef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -20,18 +20,24 @@ package org.apache.spark.deploy.master import scala.annotation.tailrec import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.util.{IntParam, Utils} /** * Command-line parser for the master. */ -private[master] class MasterArguments(args: Array[String], conf: SparkConf) { +private[master] class MasterArguments(args: Array[String], conf: SparkConf) extends Logging { var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 var propertiesFile: String = null // Check for settings in environment variables + if (System.getenv("SPARK_MASTER_IP") != null) { + logWarning("SPARK_MASTER_IP is deprecated, please use SPARK_MASTER_HOST") + host = System.getenv("SPARK_MASTER_IP") + } + if (System.getenv("SPARK_MASTER_HOST") != null) { host = System.getenv("SPARK_MASTER_HOST") } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index a055d097674ce..a952cee36eb44 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -39,6 +39,4 @@ private[master] object MasterMessages { case object BoundPortsRequest case class BoundPortsResponse(rpcEndpointPort: Int, webUIPort: Int, restPort: Option[Int]) - - case class AttachCompletedRebuildUI(appId: String) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 96274958d1422..18c5d0bd01948 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -24,7 +24,7 @@ import scala.xml.Node import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.master.ExecutorDesc -import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} import org.apache.spark.util.Utils private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { @@ -69,13 +69,27 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") } } +
  • + + Executor Limit: + { + if (app.executorLimit == Int.MaxValue) "Unlimited" else app.executorLimit + } + ({app.executors.size} granted) + +
  • Executor Memory: {Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
  • Submit Date: {app.submitDate}
  • State: {app.state}
  • -
  • Application Detail UI
  • + { + if (!app.isFinished) { +
  • Application Detail UI
  • + } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala deleted file mode 100644 index e021f1eef794f..0000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/HistoryNotFoundPage.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.master.ui - -import java.net.URLDecoder -import javax.servlet.http.HttpServletRequest - -import scala.xml.Node - -import org.apache.spark.ui.{UIUtils, WebUIPage} - -private[ui] class HistoryNotFoundPage(parent: MasterWebUI) - extends WebUIPage("history/not-found") { - - /** - * Render a page that conveys failure in loading application history. - * - * This accepts 3 HTTP parameters: - * msg = message to display to the user - * title = title of the page - * exception = detailed description of the exception in loading application history (if any) - * - * Parameters "msg" and "exception" are assumed to be UTF-8 encoded. - */ - def render(request: HttpServletRequest): Seq[Node] = { - val titleParam = request.getParameter("title") - val msgParam = request.getParameter("msg") - val exceptionParam = request.getParameter("exception") - - // If no parameters are specified, assume the user did not enable event logging - val defaultTitle = "Event logging is not enabled" - val defaultContent = -
    -
    - No event logs were found for this application! To - enable event logging, - set spark.eventLog.enabled to true and - spark.eventLog.dir to the directory to which your - event logs are written. -
    -
    - - val title = Option(titleParam).getOrElse(defaultTitle) - val content = Option(msgParam) - .map { msg => URLDecoder.decode(msg, "UTF-8") } - .map { msg => -
    -
    {msg}
    -
    ++ - Option(exceptionParam) - .map { e => URLDecoder.decode(e, "UTF-8") } - .map { e =>
    {e}
    } - .getOrElse(Seq.empty) - }.getOrElse(defaultContent) - - UIUtils.basicSparkPage(content, title) - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 363f4b84f885e..5ed3e39edc484 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -114,8 +114,8 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total, {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
  • Applications: - {state.activeApps.length} Running, - {state.completedApps.length} Completed
  • + {state.activeApps.length} Running, + {state.completedApps.length} Completed
  • Drivers: {state.activeDrivers.length} Running, {state.completedDrivers.length} Completed
  • @@ -133,7 +133,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
    -

    Running Applications

    +

    Running Applications

    {activeAppsTable}
    @@ -152,7 +152,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
    -

    Completed Applications

    +

    Completed Applications

    {completedAppsTable}
    @@ -206,7 +206,13 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {killLink} - {app.desc.name} + { + if (app.isFinished) { + app.desc.name + } else { + {app.desc.name} + } + } {app.coresGranted} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index ae16ce90c84b7..a0727ad83fb66 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -19,8 +19,6 @@ package org.apache.spark.deploy.master.ui import org.apache.spark.deploy.master.Master import org.apache.spark.internal.Logging -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, - UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ @@ -30,60 +28,26 @@ import org.apache.spark.ui.JettyUtils._ private[master] class MasterWebUI( val master: Master, - requestedPort: Int, - customMasterPage: Option[MasterPage] = None) + requestedPort: Int) extends WebUI(master.securityMgr, master.securityMgr.getSSLOptions("standalone"), - requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { + requestedPort, master.conf, name = "MasterUI") with Logging { val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) - val masterPage = customMasterPage.getOrElse(new MasterPage(this)) - initialize() /** Initialize all components of the server. */ def initialize() { val masterPage = new MasterPage(this) attachPage(new ApplicationPage(this)) - attachPage(new HistoryNotFoundPage(this)) attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(ApiRootResource.getServletHandler(this)) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethods = Set("POST"))) } - - /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */ - def attachSparkUI(ui: SparkUI) { - assert(serverInfo.isDefined, "Master UI must be bound to a server before attaching SparkUIs") - ui.getHandlers.foreach(attachHandler) - } - - /** Detach a reconstructed UI from this Master UI. Only valid after bind(). */ - def detachSparkUI(ui: SparkUI) { - assert(serverInfo.isDefined, "Master UI must be bound to a server before detaching SparkUIs") - ui.getHandlers.foreach(detachHandler) - } - - def getApplicationInfoList: Iterator[ApplicationInfo] = { - val state = masterPage.getMasterState - val activeApps = state.activeApps.sortBy(_.startTime).reverse - val completedApps = state.completedApps.sortBy(_.endTime).reverse - activeApps.iterator.map { ApplicationsListResource.convertApplicationInfo(_, false) } ++ - completedApps.iterator.map { ApplicationsListResource.convertApplicationInfo(_, true) } - } - - def getSparkUI(appId: String): Option[SparkUI] = { - val state = masterPage.getMasterState - val activeApps = state.activeApps.sortBy(_.startTime).reverse - val completedApps = state.completedApps.sortBy(_.endTime).reverse - (activeApps ++ completedApps).find { _.id == appId }.flatMap { - master.rebuildSparkUI - } - } } private[master] object MasterWebUI { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 14244ea5714c6..b30c980e95a9a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -17,15 +17,14 @@ package org.apache.spark.deploy.rest -import java.net.InetSocketAddress import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException -import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.{HttpConnectionFactory, Server, ServerConnector} import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} -import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s._ import org.json4s.jackson.JsonMethods._ @@ -80,18 +79,32 @@ private[spark] abstract class RestSubmissionServer( * Return a 2-tuple of the started server and the bound port. */ private def doStart(startPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(host, startPort)) val threadPool = new QueuedThreadPool threadPool.setDaemon(true) - server.setThreadPool(threadPool) + val server = new Server(threadPool) + + val connector = new ServerConnector( + server, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler("RestSubmissionServer-JettyScheduler", true), + null, + -1, + -1, + new HttpConnectionFactory()) + connector.setHost(host) + connector.setPort(startPort) + server.addConnector(connector) + val mainHandler = new ServletContextHandler + mainHandler.setServer(server) mainHandler.setContextPath("/") contextToServlet.foreach { case (prefix, servlet) => mainHandler.addServlet(new ServletHolder(servlet), prefix) } server.setHandler(mainHandler) server.start() - val boundPort = server.getConnectors()(0).getLocalPort + val boundPort = connector.getLocalPort (server, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 3473c41b935fd..465c214362b25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -22,6 +22,8 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} + import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils @@ -138,7 +140,8 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") - val totalLength = files.map { _.length }.sum + val fileLengths: Seq[Long] = files.map(Utils.getFileLength(_, worker.conf)) + val totalLength = fileLengths.sum val offset = offsetOption.getOrElse(totalLength - byteLength) val startIndex = { if (offset < 0) { @@ -151,7 +154,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with } val endIndex = math.min(startIndex + byteLength, totalLength) logDebug(s"Getting log from $startIndex to $endIndex") - val logText = Utils.offsetBytes(files, startIndex, endIndex) + val logText = Utils.offsetBytes(files, fileLengths, startIndex, endIndex) logDebug(s"Got log of length ${logText.length} bytes") (logText, startIndex, endIndex, totalLength) } catch { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index e08729510926b..e30839c49c04f 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.util.{Failure, Success} +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -39,6 +40,7 @@ private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, driverUrl: String, executorId: String, + hostname: String, cores: Int, userClassPath: Seq[URL], env: SparkEnv) @@ -57,14 +59,13 @@ private[spark] class CoarseGrainedExecutorBackend( rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) - ref.ask[Boolean](RegisterExecutor(executorId, self, cores, extractLogUrls)) + ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, extractLogUrls)) }(ThreadUtils.sameThread).onComplete { // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => // Always receive `true`. Just ignore it case Failure(e) => - logError(s"Cannot register with driver: $driverUrl", e) - exitExecutor(1) + exitExecutor(1, s"Cannot register with driver: $driverUrl", e) }(ThreadUtils.sameThread) } @@ -75,18 +76,21 @@ private[spark] class CoarseGrainedExecutorBackend( } override def receive: PartialFunction[Any, Unit] = { - case RegisteredExecutor(hostname) => + case RegisteredExecutor => logInfo("Successfully registered with driver") - executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) + try { + executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) + } catch { + case NonFatal(e) => + exitExecutor(1, "Unable to create executor due to " + e.getMessage, e) + } case RegisterExecutorFailed(message) => - logError("Slave registration failed: " + message) - exitExecutor(1) + exitExecutor(1, "Slave registration failed: " + message) case LaunchTask(data) => if (executor == null) { - logError("Received LaunchTask command but executor was null") - exitExecutor(1) + exitExecutor(1, "Received LaunchTask command but executor was null") } else { val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) @@ -96,8 +100,7 @@ private[spark] class CoarseGrainedExecutorBackend( case KillTask(taskId, _, interruptThread) => if (executor == null) { - logError("Received KillTask command but executor was null") - exitExecutor(1) + exitExecutor(1, "Received KillTask command but executor was null") } else { executor.killTask(taskId, interruptThread) } @@ -126,8 +129,7 @@ private[spark] class CoarseGrainedExecutorBackend( if (stopping.get()) { logInfo(s"Driver from $remoteAddress disconnected during shutdown") } else if (driver.exists(_.address == remoteAddress)) { - logError(s"Driver $remoteAddress disassociated! Shutting down.") - exitExecutor(1) + exitExecutor(1, s"Driver $remoteAddress disassociated! Shutting down.") } else { logWarning(s"An unknown ($remoteAddress) driver disconnected.") } @@ -146,7 +148,14 @@ private[spark] class CoarseGrainedExecutorBackend( * executor exits differently. For e.g. when an executor goes down, * back-end may not want to take the parent process down. */ - protected def exitExecutor(code: Int): Unit = System.exit(code) + protected def exitExecutor(code: Int, reason: String, throwable: Throwable = null) = { + if (throwable != null) { + logError(reason, throwable) + } else { + logError(reason) + } + System.exit(code) + } } private[spark] object CoarseGrainedExecutorBackend extends Logging { @@ -201,7 +210,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverConf, executorId, hostname, port, cores, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( - env.rpcEnv, driverUrl, executorId, cores, userClassPath, env)) + env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) workerUrl.foreach { url => env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 4d61f7e23248b..9a017f29f7d21 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,6 +23,7 @@ import java.net.URL import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -194,6 +195,10 @@ private[spark] class Executor( /** Whether this task has been killed. */ @volatile private var killed = false + /** Whether this task has been finished. */ + @GuardedBy("TaskRunner.this") + private var finished = false + /** How much the JVM process has spent in GC when the task starts to run. */ @volatile var startGCTime: Long = _ @@ -207,10 +212,25 @@ private[spark] class Executor( logInfo(s"Executor is trying to kill $taskName (TID $taskId)") killed = true if (task != null) { - task.kill(interruptThread) + synchronized { + if (!finished) { + task.kill(interruptThread) + } + } } } + /** + * Set the finished flag to true and clear the current thread's interrupt status + */ + private def setTaskFinishedAndClearInterruptStatus(): Unit = synchronized { + this.finished = true + // SPARK-14234 - Reset the interrupted status of the thread to avoid the + // ClosedByInterruptException during execBackend.statusUpdate which causes + // Executor to crash + Thread.interrupted() + } + override def run(): Unit = { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() @@ -261,20 +281,20 @@ private[spark] class Executor( val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId) val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() - if (freedMemory > 0) { + if (freedMemory > 0 && !threwException) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" - if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { throw new SparkException(errMsg) } else { - logError(errMsg) + logWarning(errMsg) } } - if (releasedLocks.nonEmpty) { + if (releasedLocks.nonEmpty && !threwException) { val errMsg = s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" + releasedLocks.mkString("[", ", ", "]") - if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) { + if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) { throw new SparkException(errMsg) } else { logWarning(errMsg) @@ -336,14 +356,17 @@ private[spark] class Executor( } catch { case ffe: FetchFailedException => val reason = ffe.toTaskEndReason + setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) case _: TaskKilledException | _: InterruptedException if task.killed => logInfo(s"Executor killed $taskName (TID $taskId)") + setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskEndReason + setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) case t: Throwable => @@ -353,7 +376,7 @@ private[spark] class Executor( logError(s"Exception in $taskName (TID $taskId)", t) // Collect latest accumulator values to report back to the driver - val accums: Seq[NewAccumulator[_, _]] = + val accums: Seq[AccumulatorV2[_, _]] = if (task != null) { task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) @@ -362,7 +385,7 @@ private[spark] class Executor( Seq.empty } - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.localValue), None)) + val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) val serializedTaskEndReason = { try { @@ -373,6 +396,7 @@ private[spark] class Executor( ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums)) } } + setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) // Don't forcibly exit unless the exception was inherently fatal, to avoid @@ -478,7 +502,7 @@ private[spark] class Executor( /** Reports heartbeat and metrics for active tasks to the driver. */ private def reportHeartBeat(): Unit = { // list of (task id, accumUpdates) to send back to the driver - val accumUpdates = new ArrayBuffer[(Long, Seq[NewAccumulator[_, _]])]() + val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulatorV2[_, _]])]() val curGCTime = computeTotalGcTime() for (taskRunner <- runningTasks.values().asScala) { diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index 6f7160ac0d3a3..3d15f3a0396e1 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -17,8 +17,8 @@ package org.apache.spark.executor -import org.apache.spark.LongAccumulator import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.LongAccumulator /** diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala index db3924cb6937e..dada9697c1cf9 100644 --- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -17,8 +17,8 @@ package org.apache.spark.executor -import org.apache.spark.LongAccumulator import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.LongAccumulator /** diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index fa962108c3064..f7a991770d402 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -17,8 +17,8 @@ package org.apache.spark.executor -import org.apache.spark.LongAccumulator import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.LongAccumulator /** diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index 0e70a4f522849..ada2e1bc08593 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -17,8 +17,8 @@ package org.apache.spark.executor -import org.apache.spark.LongAccumulator import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.LongAccumulator /** diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 0b64917219a7e..47aec44bac019 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,7 @@ package org.apache.spark.executor +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, LinkedHashMap} import org.apache.spark._ @@ -24,6 +25,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.storage.{BlockId, BlockStatus} +import org.apache.spark.util._ /** @@ -50,7 +52,7 @@ class TaskMetrics private[spark] () extends Serializable { private val _memoryBytesSpilled = new LongAccumulator private val _diskBytesSpilled = new LongAccumulator private val _peakExecutionMemory = new LongAccumulator - private val _updatedBlockStatuses = new BlockStatusesAccumulator + private val _updatedBlockStatuses = new CollectionAccumulator[(BlockId, BlockStatus)] /** * Time taken on the executor to deserialize this task. @@ -98,7 +100,11 @@ class TaskMetrics private[spark] () extends Serializable { /** * Storage statuses of any blocks that have been updated as a result of this task. */ - def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.localValue + def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = { + // This is called on driver. All accumulator updates have a fixed value. So it's safe to use + // `asScala` which accesses the internal values using `java.util.Iterator`. + _updatedBlockStatuses.value.asScala + } // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = @@ -113,8 +119,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v) private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit = _updatedBlockStatuses.add(v) - private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + private[spark] def setUpdatedBlockStatuses(v: java.util.List[(BlockId, BlockStatus)]): Unit = _updatedBlockStatuses.setValue(v) + private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = + _updatedBlockStatuses.setValue(v.asJava) /** * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted @@ -201,7 +209,7 @@ class TaskMetrics private[spark] () extends Serializable { output.RECORDS_WRITTEN -> outputMetrics._recordsWritten ) ++ testAccum.map(TEST_ACCUM -> _) - @transient private[spark] lazy val internalAccums: Seq[NewAccumulator[_, _]] = + @transient private[spark] lazy val internalAccums: Seq[AccumulatorV2[_, _]] = nameToAccums.values.toIndexedSeq /* ========================== * @@ -217,13 +225,22 @@ class TaskMetrics private[spark] () extends Serializable { /** * External accumulators registered with this task. */ - @transient private lazy val externalAccums = new ArrayBuffer[NewAccumulator[_, _]] + @transient private[spark] lazy val externalAccums = new ArrayBuffer[AccumulatorV2[_, _]] - private[spark] def registerAccumulator(a: NewAccumulator[_, _]): Unit = { + private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit = { externalAccums += a } - private[spark] def accumulators(): Seq[NewAccumulator[_, _]] = internalAccums ++ externalAccums + private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums + + /** + * Looks for a registered accumulator by accumulator name. + */ + private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { + accumulators.find { acc => + acc.name.isDefined && acc.name.get == name + } + } } @@ -258,7 +275,7 @@ private[spark] object TaskMetrics extends Logging { val name = info.name.get val value = info.update.get if (name == UPDATED_BLOCK_STATUSES) { - tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]]) + tm.setUpdatedBlockStatuses(value.asInstanceOf[java.util.List[(BlockId, BlockStatus)]]) } else { tm.nameToAccums.get(name).foreach( _.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long]) @@ -271,44 +288,18 @@ private[spark] object TaskMetrics extends Logging { /** * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only. */ - def fromAccumulators(accums: Seq[NewAccumulator[_, _]]): TaskMetrics = { + def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = { val tm = new TaskMetrics val (internalAccums, externalAccums) = accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get)) internalAccums.foreach { acc => - val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[NewAccumulator[Any, Any]] + val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]] tmAcc.metadata = acc.metadata - tmAcc.merge(acc.asInstanceOf[NewAccumulator[Any, Any]]) + tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) } tm.externalAccums ++= externalAccums tm } } - - -private[spark] class BlockStatusesAccumulator - extends NewAccumulator[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] { - private[this] var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)] - - override def isZero(): Boolean = _seq.isEmpty - - override def copyAndReset(): BlockStatusesAccumulator = new BlockStatusesAccumulator - - override def add(v: (BlockId, BlockStatus)): Unit = _seq += v - - override def merge(other: NewAccumulator[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]]) - : Unit = other match { - case o: BlockStatusesAccumulator => _seq ++= o.localValue - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - - override def localValue: Seq[(BlockId, BlockStatus)] = _seq - - def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = { - _seq.clear() - _seq ++= newValue - } -} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 2c1e0b71e3613..29f812a2ce116 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.internal +import java.util.concurrent.TimeUnit + import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit @@ -80,6 +82,11 @@ package object config { .doc("Name of the Kerberos principal.") .stringConf.createOptional + private[spark] val TOKEN_RENEWAL_INTERVAL = ConfigBuilder("spark.yarn.token.renewal.interval") + .internal() + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances") .intConf .createOptional @@ -96,4 +103,18 @@ package object config { .stringConf .checkValues(Set("hive", "in-memory")) .createWithDefault("in-memory") + + // To limit memory usage, we only track information for a fixed number of tasks + private[spark] val UI_RETAINED_TASKS = ConfigBuilder("spark.ui.retainedTasks") + .intConf + .createWithDefault(100000) + + // To limit how many applications are shown in the History Server summary ui + private[spark] val HISTORY_UI_MAX_APPS = + ConfigBuilder("spark.history.ui.maxApplications").intConf.createWithDefault(Integer.MAX_VALUE) + + private[spark] val LISTENER_BUS_EVENT_QUEUE_SIZE = + ConfigBuilder("spark.scheduler.listenerbus.eventqueue.size") + .intConf + .createWithDefault(10000) } diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala index f8167074c6dfa..f1915857ea43a 100644 --- a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging /** - * Implements policies and bookkeeping for sharing a adjustable-sized pool of memory between tasks. + * Implements policies and bookkeeping for sharing an adjustable-sized pool of memory between tasks. * * Tries to ensure that each task gets a reasonable share of memory, instead of some task ramping up * to a large amount first and then causing others to spill to disk repeatedly. diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index 0b552cabfc941..4c6b639015a90 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -116,13 +116,13 @@ private[memory] class StorageMemoryPool( } /** - * Try to shrink the size of this storage memory pool by `spaceToFree` bytes. Return the number - * of bytes removed from the pool's capacity. + * Free space to shrink the size of this storage memory pool by `spaceToFree` bytes. + * Note: this method doesn't actually reduce the pool size but relies on the caller to do so. + * + * @return number of bytes to be removed from the pool's capacity. */ - def shrinkPoolToFreeSpace(spaceToFree: Long): Long = lock.synchronized { - // First, shrink the pool by reclaiming free memory: + def freeSpaceToShrinkPool(spaceToFree: Long): Long = lock.synchronized { val spaceFreedByReleasingUnusedMemory = math.min(spaceToFree, memoryFree) - decrementPoolSize(spaceFreedByReleasingUnusedMemory) val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory if (remainingSpaceToFree > 0) { // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: @@ -130,7 +130,6 @@ private[memory] class StorageMemoryPool( memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, memoryMode) // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do // not need to decrement _memoryUsed here. However, we do need to decrement the pool size. - decrementPoolSize(spaceFreedByEviction) spaceFreedByReleasingUnusedMemory + spaceFreedByEviction } else { spaceFreedByReleasingUnusedMemory diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 82023b533d660..c7b36be6027a5 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -25,9 +25,9 @@ import org.apache.spark.storage.BlockId * either side can borrow memory from the other. * * The region shared between execution and storage is a fraction of (the total heap space - 300MB) - * configurable through `spark.memory.fraction` (default 0.75). The position of the boundary + * configurable through `spark.memory.fraction` (default 0.6). The position of the boundary * within this space is further determined by `spark.memory.storageFraction` (default 0.5). - * This means the size of the storage region is 0.75 * 0.5 = 0.375 of the heap space by default. + * This means the size of the storage region is 0.6 * 0.5 = 0.3 of the heap space by default. * * Storage can borrow as much execution memory as is free until execution reclaims its space. * When this happens, cached blocks will be evicted from memory until sufficient borrowed @@ -113,9 +113,10 @@ private[spark] class UnifiedMemoryManager private[memory] ( storagePool.poolSize - storageRegionSize) if (memoryReclaimableFromStorage > 0) { // Only reclaim as much space as is necessary and available: - val spaceReclaimed = storagePool.shrinkPoolToFreeSpace( + val spaceToReclaim = storagePool.freeSpaceToShrinkPool( math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) - executionPool.incrementPoolSize(spaceReclaimed) + storagePool.decrementPoolSize(spaceToReclaim) + executionPool.incrementPoolSize(spaceToReclaim) } } } @@ -186,7 +187,7 @@ object UnifiedMemoryManager { // Set aside a fixed amount of memory for non-storage, non-execution purposes. // This serves a function similar to `spark.memory.fraction`, but guarantees that we reserve // sufficient memory for the system even for small heaps. E.g. if we have a 1GB JVM, then - // the memory used for execution and storage will be (1024 - 300) * 0.75 = 543MB by default. + // the memory used for execution and storage will be (1024 - 300) * 0.6 = 434MB by default. private val RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024 def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { @@ -222,7 +223,7 @@ object UnifiedMemoryManager { } } val usableMemory = systemMemory - reservedMemory - val memoryFraction = conf.getDouble("spark.memory.fraction", 0.75) + val memoryFraction = conf.getDouble("spark.memory.fraction", 0.6) (usableMemory * memoryFraction).toLong } } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 0fed991049dd3..9b16c116ae5ae 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -28,7 +28,7 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.metrics.sink.{MetricsServlet, Sink} -import org.apache.spark.metrics.source.Source +import org.apache.spark.metrics.source.{Source, StaticSources} import org.apache.spark.util.Utils /** @@ -96,6 +96,7 @@ private[spark] class MetricsSystem private ( def start() { require(!running, "Attempting to start a MetricsSystem that is already running") running = true + StaticSources.allSources.foreach(registerSource) registerSources() registerSinks() sinks.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala new file mode 100644 index 0000000000000..6bba259acc391 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.source + +import com.codahale.metrics.MetricRegistry + +import org.apache.spark.annotation.Experimental + +private[spark] object StaticSources { + /** + * The set of all static sources. These sources may be reported to from any class, including + * static classes, without requiring reference to a SparkEnv. + */ + val allSources = Seq(CodegenMetrics) +} + +/** + * :: Experimental :: + * Metrics for code generation. + */ +@Experimental +object CodegenMetrics extends Source { + override val sourceName: String = "CodeGenerator" + override val metricRegistry: MetricRegistry = new MetricRegistry() + + /** + * Histogram of the length of source code text compiled by CodeGenerator (in characters). + */ + val METRIC_SOURCE_CODE_SIZE = metricRegistry.histogram(MetricRegistry.name("sourceCodeSize")) + + /** + * Histogram of the time it took to compile source code text (in milliseconds). + */ + val METRIC_COMPILATION_TIME = metricRegistry.histogram(MetricRegistry.name("compilationTime")) + + /** + * Histogram of the bytecode size of each class generated by CodeGenerator. + */ + val METRIC_GENERATED_CLASS_BYTECODE_SIZE = + metricRegistry.histogram(MetricRegistry.name("generatedClassSize")) + + /** + * Histogram of the bytecode size of each method in classes generated by CodeGenerator. + */ + val METRIC_GENERATED_METHOD_BYTECODE_SIZE = + metricRegistry.histogram(MetricRegistry.name("generatedMethodSize")) +} diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index cc5e7ef3ae008..2610d6f6e45a2 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -41,7 +41,58 @@ package org.apache * level interfaces. These are subject to changes or removal in minor releases. */ +import java.util.Properties + package object spark { - // For package docs only - val SPARK_VERSION = "2.0.0-SNAPSHOT" + + private object SparkBuildInfo { + + val ( + spark_version: String, + spark_branch: String, + spark_revision: String, + spark_build_user: String, + spark_repo_url: String, + spark_build_date: String) = { + + val resourceStream = Thread.currentThread().getContextClassLoader. + getResourceAsStream("spark-version-info.properties") + + try { + val unknownProp = "" + val props = new Properties() + props.load(resourceStream) + ( + props.getProperty("version", unknownProp), + props.getProperty("branch", unknownProp), + props.getProperty("revision", unknownProp), + props.getProperty("user", unknownProp), + props.getProperty("url", unknownProp), + props.getProperty("date", unknownProp) + ) + } catch { + case npe: NullPointerException => + throw new SparkException("Error while locating file spark-version-info.properties", npe) + case e: Exception => + throw new SparkException("Error loading properties from spark-version-info.properties", e) + } finally { + if (resourceStream != null) { + try { + resourceStream.close() + } catch { + case e: Exception => + throw new SparkException("Error closing spark build info resource stream", e) + } + } + } + } + } + + val SPARK_VERSION = SparkBuildInfo.spark_version + val SPARK_BRANCH = SparkBuildInfo.spark_branch + val SPARK_REVISION = SparkBuildInfo.spark_revision + val SPARK_BUILD_USER = SparkBuildInfo.spark_build_user + val SPARK_REPO_URL = SparkBuildInfo.spark_repo_url + val SPARK_BUILD_DATE = SparkBuildInfo.spark_build_date } + diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index be0cb175f5340..41832e8354741 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import org.apache.hadoop.conf.{ Configurable, Configuration } +import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.JobContextImpl diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 63d1d1767a8cb..d47b75544fdba 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -44,7 +44,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo assertValid() val blockManager = SparkEnv.get.blockManager val blockId = split.asInstanceOf[BlockRDDPartition].blockId - blockManager.get(blockId) match { + blockManager.get[T](blockId) match { case Some(block) => block.data.asInstanceOf[Iterator[T]] case None => throw new Exception("Could not compute split, block " + blockId + " not found") diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index e75f1dbf8107a..2ec9846e33f5a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -174,37 +174,37 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host) } - // this class just keeps iterating and rotating infinitely over the partitions of the RDD - // next() returns the next preferred machine that a partition is replicated on - // the rotator first goes through the first replica copy of each partition, then second, third - // the iterators return type is a tuple: (replicaString, partition) - class LocationIterator(prev: RDD[_]) extends Iterator[(String, Partition)] { - - var it: Iterator[(String, Partition)] = resetIterator() - - override val isEmpty = !it.hasNext - - // initializes/resets to start iterating from the beginning - def resetIterator(): Iterator[(String, Partition)] = { - val iterators = (0 to 2).map { x => - prev.partitions.iterator.flatMap { p => - if (currPrefLocs(p, prev).size > x) Some((currPrefLocs(p, prev)(x), p)) else None + class PartitionLocations(prev: RDD[_]) { + + // contains all the partitions from the previous RDD that don't have preferred locations + val partsWithoutLocs = ArrayBuffer[Partition]() + // contains all the partitions from the previous RDD that have preferred locations + val partsWithLocs = ArrayBuffer[(String, Partition)]() + + getAllPrefLocs(prev) + + // gets all the preffered locations of the previous RDD and splits them into partitions + // with preferred locations and ones without + def getAllPrefLocs(prev: RDD[_]) { + val tmpPartsWithLocs = mutable.LinkedHashMap[Partition, Seq[String]]() + // first get the locations for each partition, only do this once since it can be expensive + prev.partitions.foreach(p => { + val locs = currPrefLocs(p, prev) + if (locs.size > 0) { + tmpPartsWithLocs.put(p, locs) + } else { + partsWithoutLocs += p + } } - } - iterators.reduceLeft((x, y) => x ++ y) - } - - // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD - override def hasNext: Boolean = { !isEmpty } - - // return the next preferredLocation of some partition of the RDD - override def next(): (String, Partition) = { - if (it.hasNext) { - it.next() - } else { - it = resetIterator() // ran out of preferred locations, reset and rotate to the beginning - it.next() - } + ) + // convert it into an array of host to partition + (0 to 2).map(x => + tmpPartsWithLocs.foreach(parts => { + val p = parts._1 + val locs = parts._2 + if (locs.size > x) partsWithLocs += ((locs(x), p)) + } ) + ) } } @@ -228,33 +228,32 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) } /** - * Initializes targetLen partition groups and assigns a preferredLocation - * This uses coupon collector to estimate how many preferredLocations it must rotate through - * until it has seen most of the preferred locations (2 * n log(n)) + * Initializes targetLen partition groups. If there are preferred locations, each group + * is assigned a preferredLocation. This uses coupon collector to estimate how many + * preferredLocations it must rotate through until it has seen most of the preferred + * locations (2 * n log(n)) * @param targetLen */ - def setupGroups(targetLen: Int, prev: RDD[_]) { - val rotIt = new LocationIterator(prev) - + def setupGroups(targetLen: Int, partitionLocs: PartitionLocations) { // deal with empty case, just create targetLen partition groups with no preferred location - if (!rotIt.hasNext) { + if (partitionLocs.partsWithLocs.isEmpty) { (1 to targetLen).foreach(x => groupArr += new PartitionGroup()) return } noLocality = false - // number of iterations needed to be certain that we've seen most preferred locations val expectedCoupons2 = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt var numCreated = 0 var tries = 0 // rotate through until either targetLen unique/distinct preferred locations have been created - // OR we've rotated expectedCoupons2, in which case we have likely seen all preferred locations, - // i.e. likely targetLen >> number of preferred locations (more buckets than there are machines) - while (numCreated < targetLen && tries < expectedCoupons2) { + // OR (we have went through either all partitions OR we've rotated expectedCoupons2 - in + // which case we have likely seen all preferred locations) + val numPartsToLookAt = math.min(expectedCoupons2, partitionLocs.partsWithLocs.length) + while (numCreated < targetLen && tries < numPartsToLookAt) { + val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries) tries += 1 - val (nxt_replica, nxt_part) = rotIt.next() if (!groupHash.contains(nxt_replica)) { val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup @@ -263,20 +262,18 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) numCreated += 1 } } - - while (numCreated < targetLen) { // if we don't have enough partition groups, create duplicates - var (nxt_replica, nxt_part) = rotIt.next() + tries = 0 + // if we don't have enough partition groups, create duplicates + while (numCreated < targetLen) { + var (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries) + tries += 1 val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup - var tries = 0 - while (!addPartToPGroup(nxt_part, pgroup) && tries < targetLen) { // ensure at least one part - nxt_part = rotIt.next()._2 - tries += 1 - } + addPartToPGroup(nxt_part, pgroup) numCreated += 1 + if (tries >= partitionLocs.partsWithLocs.length) tries = 0 } - } /** @@ -289,7 +286,11 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) * imbalance in favor of locality * @return partition group (bin to be put in) */ - def pickBin(p: Partition, prev: RDD[_], balanceSlack: Double): PartitionGroup = { + def pickBin( + p: Partition, + prev: RDD[_], + balanceSlack: Double, + partitionLocs: PartitionLocations): PartitionGroup = { val slack = (balanceSlack * prev.partitions.length).toInt // least loaded pref locs val pref = currPrefLocs(p, prev).map(getLeastGroupHash(_)).sortWith(compare) @@ -320,7 +321,10 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) } } - def throwBalls(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) { + def throwBalls( + maxPartitions: Int, + prev: RDD[_], + balanceSlack: Double, partitionLocs: PartitionLocations) { if (noLocality) { // no preferredLocations in parent RDD, no randomization needed if (maxPartitions > groupArr.size) { // just return prev.partitions for ((p, i) <- prev.partitions.zipWithIndex) { @@ -334,8 +338,39 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) } } } else { + // It is possible to have unionRDD where one rdd has preferred locations and another rdd + // that doesn't. To make sure we end up with the requested number of partitions, + // make sure to put a partition in every group. + + // if we don't have a partition assigned to every group first try to fill them + // with the partitions with preferred locations + val partIter = partitionLocs.partsWithLocs.iterator + groupArr.filter(pg => pg.numPartitions == 0).foreach { pg => + while (partIter.hasNext && pg.numPartitions == 0) { + var (nxt_replica, nxt_part) = partIter.next() + if (!initialHash.contains(nxt_part)) { + pg.partitions += nxt_part + initialHash += nxt_part + } + } + } + + // if we didn't get one partitions per group from partitions with preferred locations + // use partitions without preferred locations + val partNoLocIter = partitionLocs.partsWithoutLocs.iterator + groupArr.filter(pg => pg.numPartitions == 0).foreach { pg => + while (partNoLocIter.hasNext && pg.numPartitions == 0) { + var nxt_part = partNoLocIter.next() + if (!initialHash.contains(nxt_part)) { + pg.partitions += nxt_part + initialHash += nxt_part + } + } + } + + // finally pick bin for the rest for (p <- prev.partitions if (!initialHash.contains(p))) { // throw every partition into group - pickBin(p, prev, balanceSlack).partitions += p + pickBin(p, prev, balanceSlack, partitionLocs).partitions += p } } } @@ -349,8 +384,11 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) * @return array of partition groups */ def coalesce(maxPartitions: Int, prev: RDD[_]): Array[PartitionGroup] = { - setupGroups(math.min(prev.partitions.length, maxPartitions), prev) // setup the groups (bins) - throwBalls(maxPartitions, prev, balanceSlack) // assign partitions (balls) to each group (bins) + val partitionLocs = new PartitionLocations(prev) + // setup the groups (bins) + setupGroups(math.min(prev.partitions.length, maxPartitions), partitionLocs) + // assign partitions (balls) to each group (bins) + throwBalls(maxPartitions, prev, balanceSlack, partitionLocs) getPartitions } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index b22134af45b30..297d95a731013 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -22,7 +22,6 @@ import java.text.SimpleDateFormat import java.util.Date import scala.collection.immutable.Map -import scala.collection.mutable.ListBuffer import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} @@ -43,7 +42,6 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.Logging import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} @@ -70,7 +68,7 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) { val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit] // map_input_file is deprecated in favor of mapreduce_map_input_file but set both - // since its not removed yet + // since it's not removed yet Map("map_input_file" -> is.getPath().toString(), "mapreduce_map_input_file" -> is.getPath().toString()) } else { @@ -318,7 +316,7 @@ class HadoopRDD[K, V]( try { val lsplit = c.inputSplitWithLocationInfo.cast(hsplit) val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]] - Some(HadoopRDD.convertSplitLocationInfo(infos)) + HadoopRDD.convertSplitLocationInfo(infos) } catch { case e: Exception => logDebug("Failed to use InputSplitWithLocations.", e) @@ -335,7 +333,7 @@ class HadoopRDD[K, V]( override def persist(storageLevel: StorageLevel): this.type = { if (storageLevel.deserialized) { - logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + + logWarning("Caching HadoopRDDs as deserialized objects usually leads to undesired" + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + " Use a map transformation to make copies of the records.") } @@ -420,21 +418,20 @@ private[spark] object HadoopRDD extends Logging { None } - private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = { - val out = ListBuffer[String]() - infos.foreach { loc => - val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get. - getLocation.invoke(loc).asInstanceOf[String] + private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Option[Seq[String]] = { + Option(infos).map(_.flatMap { loc => + val reflections = HadoopRDD.SPLIT_INFO_REFLECTIONS.get + val locationStr = reflections.getLocation.invoke(loc).asInstanceOf[String] if (locationStr != "localhost") { - if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory. - invoke(loc).asInstanceOf[Boolean]) { - logDebug("Partition " + locationStr + " is cached by Hadoop.") - out += new HDFSCacheTaskLocation(locationStr).toString + if (reflections.isInMemory.invoke(loc).asInstanceOf[Boolean]) { + logDebug(s"Partition $locationStr is cached by Hadoop.") + Some(HDFSCacheTaskLocation(locationStr).toString) } else { - out += new HostTaskLocation(locationStr).toString + Some(HostTaskLocation(locationStr).toString) } + } else { + None } - } - out.seq + }) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala index 108e9d2558190..f40d4c8e0a4d0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala @@ -21,7 +21,7 @@ import org.apache.spark.unsafe.types.UTF8String /** * This holds file names of the current Spark task. This is used in HadoopRDD, - * FileScanRDD and InputFileName function in Spark SQL. + * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL. */ private[spark] object InputFileNameHolder { /** diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 5426bf80bafc5..2f42916439d29 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -34,7 +34,7 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e // TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private /** - * An RDD that executes an SQL query on a JDBC connection and reads results. + * An RDD that executes a SQL query on a JDBC connection and reads results. * For usage example, see test case JdbcRDDSuite. * * @param getConnection a function that returns an open Connection. @@ -138,7 +138,7 @@ object JdbcRDD { } /** - * Create an RDD that executes an SQL query on a JDBC connection and reads results. + * Create an RDD that executes a SQL query on a JDBC connection and reads results. * For usage example, see test case JavaAPISuite.testJavaJdbcRDD. * * @param connectionFactory a factory that returns an open Connection. @@ -178,7 +178,7 @@ object JdbcRDD { } /** - * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is + * Create an RDD that executes a SQL query on a JDBC connection and reads results. Each row is * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD. * * @param connectionFactory a factory that returns an open Connection. diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index ad7c2216a042f..2d9d69dfb8fba 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -32,7 +32,6 @@ import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.Logging import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel @@ -136,6 +135,12 @@ class NewHadoopRDD[K, V]( val inputMetrics = context.taskMetrics().inputMetrics val existingBytesRead = inputMetrics.bytesRead + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) + case _ => InputFileNameHolder.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { @@ -202,6 +207,7 @@ class NewHadoopRDD[K, V]( private def close() { if (reader != null) { + InputFileNameHolder.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic @@ -249,7 +255,7 @@ class NewHadoopRDD[K, V]( case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] - Some(HadoopRDD.convertSplitLocationInfo(infos)) + HadoopRDD.convertSplitLocationInfo(infos) } catch { case e : Exception => logDebug("Failed to use InputSplit#getLocationInfo.", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 7936d8e1d45a2..104e0cb37155f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -40,7 +40,7 @@ import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} +import org.apache.spark.executor.OutputMetrics import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer @@ -375,6 +375,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. + * + * The confidence is the probability that the error bounds of the result will + * contain the true value. That is, if countApprox were called repeatedly + * with confidence 0.9, we would expect 90% of the results to contain the + * true count. The confidence must be in the range [0,1] or an exception will + * be thrown. + * + * @param timeout maximum time to wait for the job, in milliseconds + * @param confidence the desired statistical confidence in the result + * @return a potentially incomplete result, with error bounds */ def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[Map[K, BoundedDouble]] = self.withScope { @@ -1044,7 +1054,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val warningMessage = s"$outputCommitterClass may be an output committer that writes data directly to " + "the final location. Because speculation is enabled, this output committer may " + - "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "cause data loss (see the case in SPARK-10063). If possible, please use an output " + "committer that does not have this behavior (e.g. FileOutputCommitter)." logWarning(warningMessage) } @@ -1132,7 +1142,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val warningMessage = s"$outputCommitterClass may be an output committer that writes data directly to " + "the final location. Because speculation is enabled, this output committer may " + - "cause data loss (see the case in SPARK-10063). If possible, please use a output " + + "cause data loss (see the case in SPARK-10063). If possible, please use an output " + "committer that does not have this behavior (e.g. FileOutputCommitter)." logWarning(warningMessage) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 34a1c112cbcd0..e9092739b298a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -32,8 +32,8 @@ import org.apache.spark.util.Utils private[spark] class ParallelCollectionPartition[T: ClassTag]( var rddId: Long, var slice: Int, - var values: Seq[T]) - extends Partition with Serializable { + var values: Seq[T] + ) extends Partition with Serializable { def iterator: Iterator[T] = values.iterator diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index b6366f3e68df9..d744d67592545 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -60,7 +60,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( sc: SparkContext, var rdds: Seq[RDD[T]] ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { - require(rdds.length > 0) + require(rdds.nonEmpty) require(rdds.forall(_.partitioner.isDefined)) require(rdds.flatMap(_.partitioner).toSet.size == 1, "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index dd8e46ba0f122..02b28b72fb0e7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -17,9 +17,11 @@ package org.apache.spark.rdd +import java.io.BufferedWriter import java.io.File import java.io.FilenameFilter import java.io.IOException +import java.io.OutputStreamWriter import java.io.PrintWriter import java.util.StringTokenizer import java.util.concurrent.atomic.AtomicReference @@ -29,7 +31,6 @@ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.control.NonFatal import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.util.Utils @@ -45,22 +46,11 @@ private[spark] class PipedRDD[T: ClassTag]( envVars: Map[String, String], printPipeContext: (String => Unit) => Unit, printRDDElement: (T, String => Unit) => Unit, - separateWorkingDir: Boolean) + separateWorkingDir: Boolean, + bufferSize: Int, + encoding: String) extends RDD[String](prev) { - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this( - prev: RDD[T], - command: String, - envVars: Map[String, String] = Map(), - printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null, - separateWorkingDir: Boolean = false) = - this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement, - separateWorkingDir) - - override def getPartitions: Array[Partition] = firstParent[T].partitions /** @@ -127,7 +117,7 @@ private[spark] class PipedRDD[T: ClassTag]( override def run(): Unit = { val err = proc.getErrorStream try { - for (line <- Source.fromInputStream(err).getLines) { + for (line <- Source.fromInputStream(err)(encoding).getLines) { // scalastyle:off println System.err.println(line) // scalastyle:on println @@ -144,7 +134,8 @@ private[spark] class PipedRDD[T: ClassTag]( new Thread(s"stdin writer for $command") { override def run(): Unit = { TaskContext.setTaskContext(context) - val out = new PrintWriter(proc.getOutputStream) + val out = new PrintWriter(new BufferedWriter( + new OutputStreamWriter(proc.getOutputStream, encoding), bufferSize)) try { // scalastyle:off println // input the pipe context firstly @@ -168,7 +159,7 @@ private[spark] class PipedRDD[T: ClassTag]( }.start() // Return an iterator that read lines from the process's stdout - val lines = Source.fromInputStream(proc.getInputStream).getLines() + val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines new Iterator[String] { def next(): String = { if (!hasNext()) { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 499a8b9aa1a89..7013396392ae8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -21,6 +21,7 @@ import java.util.Random import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer +import scala.io.Codec import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} @@ -69,7 +70,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details + * [[http://people.csail.mit.edu/matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ abstract class RDD[T: ClassTag]( @@ -437,6 +438,7 @@ abstract class RDD[T: ClassTag]( partitionCoalescer: Option[PartitionCoalescer] = Option.empty) (implicit ord: Ordering[T] = null) : RDD[T] = withScope { + require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ val distributePartition = (index: Int, items: Iterator[T]) => { @@ -472,12 +474,17 @@ abstract class RDD[T: ClassTag]( def sample( withReplacement: Boolean, fraction: Double, - seed: Long = Utils.random.nextLong): RDD[T] = withScope { - require(fraction >= 0.0, "Negative fraction value: " + fraction) - if (withReplacement) { - new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) - } else { - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed) + seed: Long = Utils.random.nextLong): RDD[T] = { + require(fraction >= 0, + s"Fraction must be nonnegative, but got ${fraction}") + + withScope { + require(fraction >= 0.0, "Negative fraction value: " + fraction) + if (withReplacement) { + new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed) + } else { + new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed) + } } } @@ -491,14 +498,22 @@ abstract class RDD[T: ClassTag]( */ def randomSplit( weights: Array[Double], - seed: Long = Utils.random.nextLong): Array[RDD[T]] = withScope { - val sum = weights.sum - val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) - normalizedCumWeights.sliding(2).map { x => - randomSampleWithRange(x(0), x(1), seed) - }.toArray + seed: Long = Utils.random.nextLong): Array[RDD[T]] = { + require(weights.forall(_ >= 0), + s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") + require(weights.sum > 0, + s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}") + + withScope { + val sum = weights.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + normalizedCumWeights.sliding(2).map { x => + randomSampleWithRange(x(0), x(1), seed) + }.toArray + } } + /** * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability * range. @@ -697,18 +712,28 @@ abstract class RDD[T: ClassTag]( * Return an RDD created by piping elements to a forked external process. */ def pipe(command: String): RDD[String] = withScope { - new PipedRDD(this, command) + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + pipe(PipedRDD.tokenize(command)) } /** * Return an RDD created by piping elements to a forked external process. */ def pipe(command: String, env: Map[String, String]): RDD[String] = withScope { - new PipedRDD(this, command, env) + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + pipe(PipedRDD.tokenize(command), env) } /** - * Return an RDD created by piping elements to a forked external process. + * Return an RDD created by piping elements to a forked external process. The resulting RDD + * is computed by executing the given process once per partition. All elements + * of each input partition are written to a process's stdin as lines of input separated + * by a newline. The resulting partition consists of the process's stdout output, with + * each line of stdout resulting in one element of the output partition. A process is invoked + * even for empty partitions. + * * The print behavior can be customized by providing two functions. * * @param command command to run in forked process. @@ -724,6 +749,9 @@ abstract class RDD[T: ClassTag]( * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = * for (e <- record._2) {f(e)} * @param separateWorkingDir Use separate working directories for each task. + * @param bufferSize Buffer size for the stdin writer for the piped process. + * @param encoding Char encoding used for interacting (via stdin, stdout and stderr) with + * the piped process * @return the result RDD */ def pipe( @@ -731,11 +759,15 @@ abstract class RDD[T: ClassTag]( env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, printRDDElement: (T, String => Unit) => Unit = null, - separateWorkingDir: Boolean = false): RDD[String] = withScope { + separateWorkingDir: Boolean = false, + bufferSize: Int = 8192, + encoding: String = Codec.defaultCharsetCodec.name): RDD[String] = withScope { new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, if (printRDDElement ne null) sc.clean(printRDDElement) else null, - separateWorkingDir) + separateWorkingDir, + bufferSize, + encoding) } /** @@ -1104,10 +1136,21 @@ abstract class RDD[T: ClassTag]( /** * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. + * + * The confidence is the probability that the error bounds of the result will + * contain the true value. That is, if countApprox were called repeatedly + * with confidence 0.9, we would expect 90% of the results to contain the + * true count. The confidence must be in the range [0,1] or an exception will + * be thrown. + * + * @param timeout maximum time to wait for the job, in milliseconds + * @param confidence the desired statistical confidence in the result + * @return a potentially incomplete result, with error bounds */ def countApprox( timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = withScope { + require(0.0 <= confidence && confidence <= 1.0, s"confidence ($confidence) must be in [0,1]") val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => var result = 0L while (iter.hasNext) { @@ -1134,10 +1177,15 @@ abstract class RDD[T: ClassTag]( /** * Approximate version of countByValue(). + * + * @param timeout maximum time to wait for the job, in milliseconds + * @param confidence the desired statistical confidence in the result + * @return a potentially incomplete result, with error bounds */ def countByValueApprox(timeout: Long, confidence: Double = 0.95) (implicit ord: Ordering[T] = null) : PartialResult[Map[T, BoundedDouble]] = withScope { + require(0.0 <= confidence && confidence <= 1.0, s"confidence ($confidence) must be in [0,1]") if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } @@ -1230,7 +1278,7 @@ abstract class RDD[T: ClassTag]( def zipWithUniqueId(): RDD[(T, Long)] = withScope { val n = this.partitions.length.toLong this.mapPartitionsWithIndex { case (k, iter) => - iter.zipWithIndex.map { case (item, i) => + Utils.getIteratorZipWithIndex(iter, 0L).map { case (item, i) => (item, i * n + k) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index fddb9353018a8..b73214fe91c4e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -69,10 +69,10 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag]( val inputFiles = fs.listStatus(cpath) .map(_.getPath) .filter(_.getName.startsWith("part-")) - .sortBy(_.toString) + .sortBy(_.getName.stripPrefix("part-").toInt) // Fail fast if input files are invalid inputFiles.zipWithIndex.foreach { case (path, i) => - if (!path.toString.endsWith(ReliableCheckpointRDD.checkpointFileName(i))) { + if (path.getName != ReliableCheckpointRDD.checkpointFileName(i)) { throw new SparkException(s"Invalid checkpoint file: $path") } } diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 66cf4369da2ef..ad1fddbde7b00 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -20,6 +20,8 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport} +import scala.concurrent.forkjoin.ForkJoinPool import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} @@ -56,14 +58,30 @@ private[spark] class UnionPartition[T: ClassTag]( } } +object UnionRDD { + private[spark] lazy val partitionEvalTaskSupport = + new ForkJoinTaskSupport(new ForkJoinPool(8)) +} + @DeveloperApi class UnionRDD[T: ClassTag]( sc: SparkContext, var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil since we implement getDependencies + // visible for testing + private[spark] val isPartitionListingParallel: Boolean = + rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10) + override def getPartitions: Array[Partition] = { - val array = new Array[Partition](rdds.map(_.partitions.length).sum) + val parRDDs = if (isPartitionListingParallel) { + val parArray = rdds.par + parArray.tasksupport = UnionRDD.partitionEvalTaskSupport + parArray + } else { + rdds + } + val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum) var pos = 0 for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) { array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 32931d59acb18..dff67371e93d5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -64,8 +64,7 @@ class ZippedWithIndexRDD[T: ClassTag](prev: RDD[T]) extends RDD[(T, Long)](prev) override def compute(splitIn: Partition, context: TaskContext): Iterator[(T, Long)] = { val split = splitIn.asInstanceOf[ZippedWithIndexRDDPartition] - firstParent[T].iterator(split.prev, context).zipWithIndex.map { x => - (x._1, split.startIndex + x._2) - } + val parentIter = firstParent[T].iterator(split.prev, context) + Utils.getIteratorZipWithIndex(parentIter, split.startIndex) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 4f8fe018b432d..a02cf30a5d831 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,7 +17,7 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ @@ -42,8 +42,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val inbox = new Inbox(ref, endpoint) } - private val endpoints = new ConcurrentHashMap[String, EndpointData] - private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] + private val endpoints: ConcurrentMap[String, EndpointData] = + new ConcurrentHashMap[String, EndpointData] + private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] = + new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] // Track the receivers whose inboxes may contain messages. private val receivers = new LinkedBlockingQueue[EndpointData] @@ -144,25 +146,20 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { endpointName: String, message: InboxMessage, callbackIfStopped: (Exception) => Unit): Unit = { - val shouldCallOnStop = synchronized { + val error = synchronized { val data = endpoints.get(endpointName) - if (stopped || data == null) { - true + if (stopped) { + Some(new RpcEnvStoppedException()) + } else if (data == null) { + Some(new SparkException(s"Could not find $endpointName.")) } else { data.inbox.post(message) receivers.offer(data) - false + None } } - if (shouldCallOnStop) { - // We don't need to call `onStop` in the `synchronized` block - val error = if (stopped) { - new RpcEnvStoppedException() - } else { - new SparkException(s"Could not find $endpointName or it has been stopped.") - } - callbackIfStopped(error) - } + // We don't need to call `onStop` in the `synchronized` block + error.foreach(callbackIfStopped) } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index fffbd5cd44a23..ae4a6003517cc 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -52,7 +52,7 @@ private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteA extends InboxMessage /** - * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. + * An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. */ private[netty] class Inbox( val endpointRef: NettyRpcEndpointRef, diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 7d7b4c82fa392..89d2fb9b47971 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -574,7 +574,7 @@ private[netty] class NettyRpcHandler( private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val clientAddr = RpcAddress(addr.getHostString, addr.getPort) val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. @@ -595,7 +595,7 @@ private[netty] class NettyRpcHandler( override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val clientAddr = RpcAddress(addr.getHostString, addr.getPort) dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) // If the remove RpcEnv listens to some address, we should also fire a // RemoteProcessConnectionError for the remote RpcEnv listening address @@ -614,14 +614,14 @@ private[netty] class NettyRpcHandler( override def channelActive(client: TransportClient): Unit = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val clientAddr = RpcAddress(addr.getHostString, addr.getPort) dispatcher.postToAll(RemoteProcessConnected(clientAddr)) } override def channelInactive(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val clientAddr = RpcAddress(addr.getHostString, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) val remoteEnvAddress = remoteAddresses.remove(clientAddr) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index afcb023a99daa..780fadd5bda8e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -66,14 +66,18 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) } override def addFile(file: File): String = { - require(files.putIfAbsent(file.getName(), file) == null, - s"File ${file.getName()} already registered.") + val existingPath = files.putIfAbsent(file.getName, file) + require(existingPath == null || existingPath == file, + s"File ${file.getName} was already registered with a different path " + + s"(old path = $existingPath, new path = $file") s"${rpcEnv.address.toSparkURL}/files/${Utils.encodeFileNameToURIRawPath(file.getName())}" } override def addJar(file: File): String = { - require(jars.putIfAbsent(file.getName(), file) == null, - s"JAR ${file.getName()} already registered.") + val existingPath = jars.putIfAbsent(file.getName, file) + require(existingPath == null || existingPath == file, + s"File ${file.getName} was already registered with a different path " + + s"(old path = $existingPath, new path = $file") s"${rpcEnv.address.toSparkURL}/jars/${Utils.encodeFileNameToURIRawPath(file.getName())}" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index 9f218c64cac2d..28c45d800ed06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -32,6 +32,8 @@ private[spark] class ApplicationEventListener extends SparkListener { var endTime: Option[Long] = None var viewAcls: Option[String] = None var adminAcls: Option[String] = None + var viewAclsGroups: Option[String] = None + var adminAclsGroups: Option[String] = None override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { appName = Some(applicationStart.appName) @@ -51,6 +53,8 @@ private[spark] class ApplicationEventListener extends SparkListener { val allProperties = environmentDetails("Spark Properties").toMap viewAcls = allProperties.get("spark.ui.view.acls") adminAcls = allProperties.get("spark.admin.acls") + viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") + adminAclsGroups = allProperties.get("spark.admin.acls.groups") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a96d5f6fbf082..e7e2ff1718f22 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -209,7 +209,7 @@ class DAGScheduler( task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Seq[NewAccumulator[_, _]], + accumUpdates: Seq[AccumulatorV2[_, _]], taskInfo: TaskInfo): Unit = { eventProcessLoop.post( CompletionEvent(task, reason, result, accumUpdates, taskInfo)) @@ -233,8 +233,8 @@ class DAGScheduler( /** * Called by TaskScheduler implementation when an executor fails. */ - def executorLost(execId: String): Unit = { - eventProcessLoop.post(ExecutorLost(execId)) + def executorLost(execId: String, reason: ExecutorLossReason): Unit = { + eventProcessLoop.post(ExecutorLost(execId, reason)) } /** @@ -286,7 +286,9 @@ class DAGScheduler( case None => // We are going to register ancestor shuffle dependencies getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep => - shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId) + if (!shuffleToMapStage.contains(dep.shuffleId)) { + shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId) + } } // Then register current shuffleDep val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) @@ -1091,14 +1093,14 @@ class DAGScheduler( event.accumUpdates.foreach { updates => val id = updates.id // Find the corresponding accumulator on the driver and update it - val acc: NewAccumulator[Any, Any] = AccumulatorContext.get(id) match { - case Some(accum) => accum.asInstanceOf[NewAccumulator[Any, Any]] + val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match { + case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]] case None => throw new SparkException(s"attempted to access non-existent accumulator $id") } - acc.merge(updates.asInstanceOf[NewAccumulator[Any, Any]]) + acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]]) // To avoid UI cruft, ignore cases where value wasn't updated - if (acc.name.isDefined && !updates.isZero()) { + if (acc.name.isDefined && !updates.isZero) { stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value)) } @@ -1275,18 +1277,20 @@ class DAGScheduler( s"has failed the maximum allowable number of " + s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + s"Most recent failure reason: ${failureMessage}", None) - } else if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } else { + if (failedStages.isEmpty) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. + // TODO: Cancel running tasks in the stage + logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + failedStages += failedStage + failedStages += mapStage } - failedStages += failedStage - failedStages += mapStage // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { mapStage.removeOutputLoc(mapId, bmAddress) @@ -1295,7 +1299,7 @@ class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) + handleExecutorLost(bmAddress.executorId, filesLost = true, Some(task.epoch)) } } @@ -1321,15 +1325,16 @@ class DAGScheduler( * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * * We will also assume that we've lost all shuffle blocks associated with the executor if the - * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed - * occurred, in which case we presume all shuffle data related to this executor to be lost. + * executor serves its own blocks (i.e., we're not using external shuffle), the entire slave + * is lost (likely including the shuffle service), or a FetchFailed occurred, in which case we + * presume all shuffle data related to this executor to be lost. * * Optionally the epoch during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ private[scheduler] def handleExecutorLost( execId: String, - fetchFailed: Boolean, + filesLost: Boolean, maybeEpoch: Option[Long] = None) { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { @@ -1337,7 +1342,8 @@ class DAGScheduler( logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) blockManagerMaster.removeExecutor(execId) - if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) { + if (filesLost || !env.blockManager.externalShuffleServiceEnabled) { + logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) @@ -1641,8 +1647,12 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case ExecutorAdded(execId, host) => dagScheduler.handleExecutorAdded(execId, host) - case ExecutorLost(execId) => - dagScheduler.handleExecutorLost(execId, fetchFailed = false) + case ExecutorLost(execId, reason) => + val filesLost = reason match { + case SlaveLost(_, true) => true + case _ => false + } + dagScheduler.handleExecutorLost(execId, filesLost) case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index e57a2246d8729..03781a2a2b56c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -23,7 +23,7 @@ import scala.language.existentials import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.CallSite +import org.apache.spark.util.{AccumulatorV2, CallSite} /** * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue @@ -71,13 +71,14 @@ private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Seq[NewAccumulator[_, _]], + accumUpdates: Seq[AccumulatorV2[_, _]], taskInfo: TaskInfo) extends DAGSchedulerEvent private[scheduler] case class ExecutorAdded(execId: String, host: String) extends DAGSchedulerEvent -private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent +private[scheduler] case class ExecutorLost(execId: String, reason: ExecutorLossReason) + extends DAGSchedulerEvent private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 7e1197d742802..46a35b6a2eaf9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import org.apache.spark.executor.ExecutorExitCode /** - * Represents an explanation for a executor or whole slave failing or exiting. + * Represents an explanation for an executor or whole slave failing or exiting. */ private[spark] class ExecutorLossReason(val message: String) extends Serializable { @@ -51,6 +51,10 @@ private[spark] object ExecutorKilled extends ExecutorLossReason("Executor killed */ private [spark] object LossReasonPending extends ExecutorLossReason("Pending loss reason.") +/** + * @param _message human readable loss reason + * @param workerLost whether the worker is confirmed lost too (i.e. including shuffle service) + */ private[spark] -case class SlaveLost(_message: String = "Slave lost") +case class SlaveLost(_message: String = "Slave lost", workerLost: Boolean = false) extends ExecutorLossReason(_message) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 1c21313d1cb17..5533f7b1f2363 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -18,11 +18,12 @@ package org.apache.spark.scheduler import java.util.concurrent._ -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} import scala.util.DynamicVariable -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -32,24 +33,36 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus extends SparkListenerBus { +private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends SparkListenerBus { self => import LiveListenerBus._ - private var sparkContext: SparkContext = null - // Cap the capacity of the event queue so we get an explicit error (rather than // an OOM exception) if it's perpetually being added to more quickly than it's being drained. - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) + private lazy val EVENT_QUEUE_CAPACITY = validateAndGetQueueSize() + private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) + + private def validateAndGetQueueSize(): Int = { + val queueSize = sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_SIZE) + if (queueSize <= 0) { + throw new SparkException("spark.scheduler.listenerbus.eventqueue.size must be > 0!") + } + queueSize + } // Indicate if `start()` is called private val started = new AtomicBoolean(false) // Indicate if `stop()` is called private val stopped = new AtomicBoolean(false) + /** A counter for dropped events. It will be reset every time we log it. */ + private val droppedEventsCounter = new AtomicLong(0L) + + /** When `droppedEventsCounter` was logged last time in milliseconds. */ + @volatile private var lastReportTimestamp = 0L + // Indicate if we are processing some event // Guarded by `self` private var processingEvent = false @@ -96,11 +109,9 @@ private[spark] class LiveListenerBus extends SparkListenerBus { * listens for any additional events asynchronously while the listener bus is still running. * This should only be called once. * - * @param sc Used to stop the SparkContext in case the listener thread dies. */ - def start(sc: SparkContext): Unit = { + def start(): Unit = { if (started.compareAndSet(false, true)) { - sparkContext = sc listenerThread.start() } else { throw new IllegalStateException(s"$name already started!") @@ -118,6 +129,24 @@ private[spark] class LiveListenerBus extends SparkListenerBus { eventLock.release() } else { onDropEvent(event) + droppedEventsCounter.incrementAndGet() + } + + val droppedEvents = droppedEventsCounter.get + if (droppedEvents > 0) { + // Don't log too frequently + if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { + // There may be multiple threads trying to decrease droppedEventsCounter. + // Use "compareAndSet" to make sure only one thread can win. + // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and + // then that thread will update it. + if (droppedEventsCounter.compareAndSet(droppedEvents, 0)) { + val prevLastReportTimestamp = lastReportTimestamp + lastReportTimestamp = System.currentTimeMillis() + logWarning(s"Dropped $droppedEvents SparkListenerEvents since " + + new java.util.Date(prevLastReportTimestamp)) + } + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index a79e71ec7c9bf..266afaf8d73b9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -26,16 +26,14 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** - * An Schedulable entity that represent collection of Pools or TaskSetManagers + * A Schedulable entity that represents collection of Pools or TaskSetManagers */ - private[spark] class Pool( val poolName: String, val schedulingMode: SchedulingMode, initMinShare: Int, initWeight: Int) - extends Schedulable - with Logging { + extends Schedulable with Logging { val schedulableQueue = new ConcurrentLinkedQueue[Schedulable] val schedulableNameToSchedulable = new ConcurrentHashMap[String, Schedulable] @@ -56,7 +54,8 @@ private[spark] class Pool( case SchedulingMode.FIFO => new FIFOSchedulingAlgorithm() case _ => - throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode") + val msg = "Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." + throw new IllegalArgumentException(msg) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 100ed76ecb6d6..96325a0329f89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -112,7 +112,8 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) schedulingMode = SchedulingMode.withName(xmlSchedulingMode) } catch { case e: NoSuchElementException => - logWarning("Error xml schedulingMode, using default schedulingMode") + logWarning(s"Unsupported schedulingMode: $xmlSchedulingMode, " + + s"using the default schedulingMode: $schedulingMode") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 864941d468af9..18ebbbe78a5b4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -36,11 +36,7 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { val stageId2 = s2.stageId res = math.signum(stageId1 - stageId2) } - if (res < 0) { - true - } else { - false - } + res < 0 } } @@ -52,12 +48,12 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val runningTasks2 = s2.runningTasks val s1Needy = runningTasks1 < minShare1 val s2Needy = runningTasks2 < minShare2 - val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble - val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble + val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0) + val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0) val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var compare: Int = 0 + var compare = 0 if (s1Needy && !s2Needy) { return true } else if (!s1Needy && s2Needy) { @@ -67,7 +63,6 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { } else { compare = taskToWeightRatio1.compareTo(taskToWeightRatio2) } - if (compare < 0) { true } else if (compare > 0) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index e7ca6efd84aee..1ed36bf0692f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -21,6 +21,7 @@ import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import java.util.Properties +import scala.collection.mutable import scala.collection.mutable.HashMap import org.apache.spark._ @@ -28,7 +29,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} +import org.apache.spark.util.{AccumulatorV2, ByteBufferInputStream, ByteBufferOutputStream, Utils} /** * A unit of execution. We have two kinds of Task's in Spark: @@ -153,9 +154,16 @@ private[spark] abstract class Task[T]( * Collect the latest values of accumulators used in this task. If the task failed, * filter out the accumulators whose values should not be included on failures. */ - def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[NewAccumulator[_, _]] = { + def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = { if (context != null) { - context.taskMetrics.accumulators().filter { a => !taskFailed || a.countFailedValues } + context.taskMetrics.internalAccums.filter { a => + // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its + // value will be updated at driver side. + // Note: internal accumulators representing task metrics always count failed values + !a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE) + // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter + // them out. + } ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) } else { Seq.empty } @@ -191,8 +199,8 @@ private[spark] object Task { */ def serializeWithDependencies( task: Task[_], - currentFiles: HashMap[String, Long], - currentJars: HashMap[String, Long], + currentFiles: mutable.Map[String, Long], + currentJars: mutable.Map[String, Long], serializer: SerializerInstance) : ByteBuffer = { @@ -222,6 +230,7 @@ private[spark] object Task { dataOut.flush() val taskBytes = serializer.serialize(task) Utils.writeByteBuffer(taskBytes, out) + out.close() out.toByteBuffer } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 1eb6c1614fc0b..06b52935c696c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -64,18 +64,18 @@ private[spark] object TaskLocation { /** * Create a TaskLocation from a string returned by getPreferredLocations. - * These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the - * location is cached. + * These strings have the form executor_[hostname]_[executorid], [hostname], or + * hdfs_cache_[hostname], depending on whether the location is cached. */ def apply(str: String): TaskLocation = { val hstr = str.stripPrefix(inMemoryLocationTag) if (hstr.equals(str)) { if (str.startsWith(executorLocationTag)) { - val splits = str.split("_") - if (splits.length != 3) { - throw new IllegalArgumentException("Illegal executor location format: " + str) - } - new ExecutorCacheTaskLocation(splits(1), splits(2)) + val hostAndExecutorId = str.stripPrefix(executorLocationTag) + val splits = hostAndExecutorId.split("_", 2) + require(splits.length == 2, "Illegal executor location format: " + str) + val Array(host, executorId) = splits + new ExecutorCacheTaskLocation(host, executorId) } else { new HostTaskLocation(str) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index b472c5511b738..80f2bf41224b5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -22,9 +22,9 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{NewAccumulator, SparkEnv} +import org.apache.spark.SparkEnv import org.apache.spark.storage.BlockId -import org.apache.spark.util.Utils +import org.apache.spark.util.{AccumulatorV2, Utils} // Task result. Also contains updates to accumulator variables. private[spark] sealed trait TaskResult[T] @@ -36,7 +36,7 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] class DirectTaskResult[T]( var valueBytes: ByteBuffer, - var accumUpdates: Seq[NewAccumulator[_, _]]) + var accumUpdates: Seq[AccumulatorV2[_, _]]) extends TaskResult[T] with Externalizable { private var valueObjectDeserialized = false @@ -61,9 +61,9 @@ private[spark] class DirectTaskResult[T]( if (numUpdates == 0) { accumUpdates = null } else { - val _accumUpdates = new ArrayBuffer[NewAccumulator[_, _]] + val _accumUpdates = new ArrayBuffer[AccumulatorV2[_, _]] for (i <- 0 until numUpdates) { - _accumUpdates += in.readObject.asInstanceOf[NewAccumulator[_, _]] + _accumUpdates += in.readObject.asInstanceOf[AccumulatorV2[_, _]] } accumUpdates = _accumUpdates } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index b438c285fdf1f..685ef55c66876 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -27,7 +27,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils} /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 75a0c56311977..cd13eebe74a99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -17,9 +17,9 @@ package org.apache.spark.scheduler -import org.apache.spark.NewAccumulator import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.AccumulatorV2 /** * Low-level task scheduler interface, currently implemented exclusively by @@ -67,7 +67,7 @@ private[spark] trait TaskScheduler { */ def executorHeartbeatReceived( execId: String, - accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], + accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 8fa4aa121c123..d22321b88fb8e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -33,21 +33,22 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality +import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. - * It can also work with a local setup by using a LocalBackend and setting isLocal to true. - * It handles common logic, like determining a scheduling order across jobs, waking up to launch - * speculative tasks, etc. + * It can also work with a local setup by using a [[LocalSchedulerBackend]] and setting + * isLocal to true. It handles common logic, like determining a scheduling order across jobs, waking + * up to launch speculative tasks, etc. * * Clients should first call initialize() and start(), then submit task sets through the * runTasks method. * - * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple + * THREADING: [[SchedulerBackend]]s and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some - * SchedulerBackends synchronize on themselves when they want to send events here, and then + * [[SchedulerBackend]]s synchronize on themselves when they want to send events here, and then * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ @@ -331,6 +332,7 @@ private[spark] class TaskSchedulerImpl( def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { var failedExecutor: Option[String] = None + var reason: Option[ExecutorLossReason] = None synchronized { try { if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { @@ -338,8 +340,9 @@ private[spark] class TaskSchedulerImpl( val execId = taskIdToExecutorId(tid) if (executorIdToTaskCount.contains(execId)) { - removeExecutor(execId, + reason = Some( SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) + removeExecutor(execId, reason.get) failedExecutor = Some(execId) } } @@ -372,7 +375,8 @@ private[spark] class TaskSchedulerImpl( } // Update the DAGScheduler without holding a lock on this, since that can deadlock if (failedExecutor.isDefined) { - dagScheduler.executorLost(failedExecutor.get) + assert(reason.isDefined) + dagScheduler.executorLost(failedExecutor.get, reason.get) backend.reviveOffers() } } @@ -384,17 +388,12 @@ private[spark] class TaskSchedulerImpl( */ override def executorHeartbeatReceived( execId: String, - accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], + accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean = { // (taskId, stageId, stageAttemptId, accumUpdates) val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized { accumUpdates.flatMap { case (id, updates) => - // We should call `acc.value` here as we are at driver side now. However, the RPC framework - // optimizes local message delivery so that messages do not need to de serialized and - // deserialized. This brings trouble to the accumulator framework, which depends on - // serialization to set the `atDriverSide` flag. Here we call `acc.localValue` instead to - // be more robust about this issue. - val accInfos = updates.map(acc => acc.toInfo(Some(acc.localValue), None)) + val accInfos = updates.map(acc => acc.toInfo(Some(acc.value), None)) taskIdToTaskSetManager.get(id).map { taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, accInfos) } @@ -503,7 +502,7 @@ private[spark] class TaskSchedulerImpl( } // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock if (failedExecutor.isDefined) { - dagScheduler.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get, reason) backend.reviveOffers() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index b79f643a7481f..08d33f688a160 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -32,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -479,7 +479,7 @@ private[spark] class TaskSetManager( // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" logInfo(s"Starting $taskName (TID $taskId, $host, partition ${task.partitionId}," + - s"$taskLocality, ${serializedTask.limit} bytes)") + s" $taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, @@ -647,7 +647,7 @@ private[spark] class TaskSetManager( info.markFailed() val index = info.index copiesRunning(index) -= 1 - var accumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty + var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + reason.asInstanceOf[TaskFailedReason].toErrorString val failureException: Option[Throwable] = reason match { @@ -716,7 +716,16 @@ private[spark] class TaskSetManager( failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). put(info.executorId, clock.getTimeMillis()) sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) - addPendingTask(index) + + if (successful(index)) { + logInfo( + s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, " + + "but another instance of the task has already succeeded, " + + "so not re-queuing the task to be re-executed.") + } else { + addPendingTask(index) + } + if (!isZombie && state != TaskState.KILLED && reason.isInstanceOf[TaskFailedReason] && reason.asInstanceOf[TaskFailedReason].countTowardsTaskFailures) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 46a829114ec86..edc8aac5d1515 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -40,8 +40,7 @@ private[spark] object CoarseGrainedClusterMessages { sealed trait RegisterExecutorResponse - case class RegisteredExecutor(hostname: String) extends CoarseGrainedClusterMessage - with RegisterExecutorResponse + case object RegisteredExecutor extends CoarseGrainedClusterMessage with RegisterExecutorResponse case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage with RegisterExecutorResponse @@ -50,6 +49,7 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutor( executorId: String, executorRef: RpcEndpointRef, + hostname: String, cores: Int, logUrls: Map[String, String]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 8896391f9775f..2c173dbf77727 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -22,6 +22,8 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.Future +import scala.concurrent.duration.Duration import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} import org.apache.spark.internal.Logging @@ -49,6 +51,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp protected val totalRegisteredExecutors = new AtomicInteger(0) protected val conf = scheduler.sc.conf private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) + private val defaultAskTimeout = RpcUtils.askRpcTimeout(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. private val _minRegisteredRatio = @@ -100,8 +103,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // instance across threads private val ser = SparkEnv.get.closureSerializer.newInstance() - override protected def log = CoarseGrainedSchedulerBackend.this.log - protected val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = @@ -148,7 +149,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterExecutor(executorId, executorRef, cores, logUrls) => + case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls) => if (executorDataMap.contains(executorId)) { executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) context.reply(true) @@ -164,7 +165,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) - val data = new ExecutorData(executorRef, executorRef.address, executorAddress.host, + val data = new ExecutorData(executorRef, executorRef.address, hostname, cores, cores, logUrls) // This must be synchronized because variables mutated // in this block are read when requesting executors @@ -178,7 +179,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") } } - executorRef.send(RegisteredExecutor(executorAddress.host)) + executorRef.send(RegisteredExecutor) // Note: some tests expect the reply to come after we put the executor in the map context.reply(true) listenerBus.post( @@ -274,6 +275,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Remove a disconnected slave from the cluster private def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + logDebug(s"Asked to remove executor $executorId with reason $reason") executorDataMap.get(executorId) match { case Some(executorInfo) => // This must be synchronized because variables mutated @@ -289,7 +291,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.executorLost(executorId, if (killed) ExecutorKilled else reason) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason.toString)) - case None => logInfo(s"Asked to remove non-existent executor $executorId") + case None => + // SPARK-15262: If an executor is still alive even after the scheduler has removed + // its metadata, we may receive a heartbeat from that executor and tell its block + // manager to reregister itself. If that happens, the block manager master will know + // about the executor, but the scheduler will not. Therefore, we should remove the + // executor from the block manager when we hit this case. + scheduler.sc.env.blockManager.master.removeExecutorAsync(executorId) + logInfo(s"Asked to remove non-existent executor $executorId") } } @@ -377,15 +386,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only * be called in the yarn-client mode when AM re-registers after a failure. * */ - protected def reset(): Unit = synchronized { - numPendingExecutors = 0 - executorsPendingToRemove.clear() + protected def reset(): Unit = { + val executors = synchronized { + numPendingExecutors = 0 + executorsPendingToRemove.clear() + Set() ++ executorDataMap.keys + } // Remove all the lingering executors that should be removed but not yet. The reason might be // because (1) disconnected event is not yet received; (2) executors die silently. - executorDataMap.toMap.foreach { case (eid, _) => - driverEndpoint.askWithRetry[Boolean]( - RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered."))) + executors.foreach { eid => + removeExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered.")) } } @@ -401,14 +412,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp conf.getInt("spark.default.parallelism", math.max(totalCoreCount.get(), 2)) } - // Called by subclasses when notified of a lost worker - def removeExecutor(executorId: String, reason: ExecutorLossReason) { - try { - driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) - } catch { - case e: Exception => - throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) - } + /** + * Called by subclasses when notified of a lost worker. It just fires the message and returns + * at once. + */ + protected def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + // Only log the failure since we don't care about the result. + driverEndpoint.ask[Boolean](RemoveExecutor(executorId, reason)).onFailure { case t => + logError(t.getMessage, t) + }(ThreadUtils.sameThread) } def sufficientResourcesRegistered(): Boolean = true @@ -440,19 +452,24 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * Request an additional number of executors from the cluster manager. * @return whether the request is acknowledged. */ - final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized { + final override def requestExecutors(numAdditionalExecutors: Int): Boolean = { if (numAdditionalExecutors < 0) { throw new IllegalArgumentException( "Attempted to request a negative number of additional executor(s) " + s"$numAdditionalExecutors from the cluster manager. Please specify a positive number!") } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") - logDebug(s"Number of pending executors is now $numPendingExecutors") - numPendingExecutors += numAdditionalExecutors - // Account for executors pending to be added or removed - val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size - doRequestTotalExecutors(newTotal) + val response = synchronized { + numPendingExecutors += numAdditionalExecutors + logDebug(s"Number of pending executors is now $numPendingExecutors") + + // Account for executors pending to be added or removed + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + } + + defaultAskTimeout.awaitResult(response) } /** @@ -473,19 +490,24 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp numExecutors: Int, localityAwareTasks: Int, hostToLocalTaskCount: Map[String, Int] - ): Boolean = synchronized { + ): Boolean = { if (numExecutors < 0) { throw new IllegalArgumentException( "Attempted to request a negative number of executor(s) " + s"$numExecutors from the cluster manager. Please specify a positive number!") } - this.localityAwareTasks = localityAwareTasks - this.hostToLocalTaskCount = hostToLocalTaskCount + val response = synchronized { + this.localityAwareTasks = localityAwareTasks + this.hostToLocalTaskCount = hostToLocalTaskCount + + numPendingExecutors = + math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) - numPendingExecutors = - math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) - doRequestTotalExecutors(numExecutors) + doRequestTotalExecutors(numExecutors) + } + + defaultAskTimeout.awaitResult(response) } /** @@ -498,16 +520,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * insufficient resources to satisfy the first request. We make the assumption here that the * cluster manager will eventually fulfill all requests when resources free up. * - * @return whether the request is acknowledged. + * @return a future whose evaluation indicates whether the request is acknowledged. */ - protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false + protected def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = + Future.successful(false) /** * Request that the cluster manager kill the specified executors. * @return whether the kill request is acknowledged. If list to kill is empty, it will return * false. */ - final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { + final override def killExecutors(executorIds: Seq[String]): Boolean = { killExecutors(executorIds, replace = false, force = false) } @@ -527,39 +550,53 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp final def killExecutors( executorIds: Seq[String], replace: Boolean, - force: Boolean): Boolean = synchronized { + force: Boolean): Boolean = { logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") - val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) - unknownExecutors.foreach { id => - logWarning(s"Executor to kill $id does not exist!") - } - // If an executor is already pending to be removed, do not kill it again (SPARK-9795) - // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552) - val executorsToKill = knownExecutors - .filter { id => !executorsPendingToRemove.contains(id) } - .filter { id => force || !scheduler.isExecutorBusy(id) } - executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } - - // If we do not wish to replace the executors we kill, sync the target number of executors - // with the cluster manager to avoid allocating new ones. When computing the new target, - // take into account executors that are pending to be added or removed. - if (!replace) { - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) - } else { - numPendingExecutors += knownExecutors.size + val response = synchronized { + val (knownExecutors, unknownExecutors) = executorIds.partition(executorDataMap.contains) + unknownExecutors.foreach { id => + logWarning(s"Executor to kill $id does not exist!") + } + + // If an executor is already pending to be removed, do not kill it again (SPARK-9795) + // If this executor is busy, do not kill it unless we are told to force kill it (SPARK-9552) + val executorsToKill = knownExecutors + .filter { id => !executorsPendingToRemove.contains(id) } + .filter { id => force || !scheduler.isExecutorBusy(id) } + executorsToKill.foreach { id => executorsPendingToRemove(id) = !replace } + + // If we do not wish to replace the executors we kill, sync the target number of executors + // with the cluster manager to avoid allocating new ones. When computing the new target, + // take into account executors that are pending to be added or removed. + val adjustTotalExecutors = + if (!replace) { + doRequestTotalExecutors( + numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + } else { + numPendingExecutors += knownExecutors.size + Future.successful(true) + } + + val killExecutors: Boolean => Future[Boolean] = + if (!executorsToKill.isEmpty) { + _ => doKillExecutors(executorsToKill) + } else { + _ => Future.successful(false) + } + + adjustTotalExecutors.flatMap(killExecutors)(ThreadUtils.sameThread) } - !executorsToKill.isEmpty && doKillExecutors(executorsToKill) + defaultAskTimeout.awaitResult(response) } /** * Kill the given list of executors through the cluster manager. * @return whether the kill request is acknowledged. */ - protected def doKillExecutors(executorIds: Seq[String]): Boolean = false - + protected def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = + Future.successful(false) } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala similarity index 90% rename from core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala rename to core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 85d002011d64c..04d40e2907cff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -19,30 +19,35 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import scala.concurrent.Future + import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.deploy.client.{AppClient, AppClientListener} +import org.apache.spark.deploy.client.{StandaloneAppClient, StandaloneAppClientListener} import org.apache.spark.internal.Logging import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler._ import org.apache.spark.util.Utils -private[spark] class SparkDeploySchedulerBackend( +/** + * A [[SchedulerBackend]] implementation for Spark's standalone cluster manager. + */ +private[spark] class StandaloneSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) - with AppClientListener + with StandaloneAppClientListener with Logging { - private var client: AppClient = null + private var client: StandaloneAppClient = null private var stopping = false private val launcherBackend = new LauncherBackend() { override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) } - @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _ + @volatile var shutdownCallback: StandaloneSchedulerBackend => Unit = _ @volatile private var appId: String = _ private val registrationBarrier = new Semaphore(0) @@ -100,7 +105,7 @@ private[spark] class SparkDeploySchedulerBackend( } val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) - client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) + client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) waitForRegistration() @@ -145,10 +150,11 @@ private[spark] class SparkDeploySchedulerBackend( fullId, hostPort, cores, Utils.megabytesToString(memory))) } - override def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]) { + override def executorRemoved( + fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean) { val reason: ExecutorLossReason = exitStatus match { case Some(code) => ExecutorExited(code, exitCausedByApp = true, message) - case None => SlaveLost(message) + case None => SlaveLost(message, workerLost = workerLost) } logInfo("Executor %s removed: %s".format(fullId, message)) removeExecutor(fullId.split("/")(1), reason) @@ -170,12 +176,12 @@ private[spark] class SparkDeploySchedulerBackend( * * @return whether the request is acknowledged. */ - protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { Option(client) match { case Some(c) => c.requestTotalExecutors(requestedTotal) case None => logWarning("Attempted to request executors before driver fully initialized.") - false + Future.successful(false) } } @@ -183,12 +189,12 @@ private[spark] class SparkDeploySchedulerBackend( * Kill the given list of executors through the Master. * @return whether the kill request is acknowledged. */ - protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { + protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { Option(client) match { case Some(c) => c.killExecutors(executorIds) case None => logWarning("Attempted to kill executors before driver fully initialized.") - false + Future.successful(false) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala similarity index 84% rename from core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala rename to core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 50b452c72f8aa..473b1be4e20e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -24,14 +24,14 @@ import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.{Buffer, HashMap, HashSet} +import scala.concurrent.Future -import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.{RpcEndpointAddress} +import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -43,16 +43,16 @@ import org.apache.spark.util.Utils * CoarseGrainedSchedulerBackend mechanism. This class is useful for lower and more predictable * latency. * - * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to - * remove this. + * Unfortunately this has a bit of duplication from [[MesosFineGrainedSchedulerBackend]], + * but it seems hard to remove this. */ -private[spark] class CoarseMesosSchedulerBackend( +private[spark] class MesosCoarseGrainedSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String, securityManager: SecurityManager) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) - with MScheduler + with org.apache.mesos.Scheduler with MesosSchedulerUtils { val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures @@ -109,10 +109,14 @@ private[spark] class CoarseMesosSchedulerBackend( private val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) - // reject offers with mismatched constraints in seconds + // Reject offers with mismatched constraints in seconds private val rejectOfferDurationForUnmetConstraints = getRejectOfferDurationForUnmetConstraints(sc) + // Reject offers when we reached the maximum number of cores for this framework + private val rejectOfferDurationForReachedMaxCores = + getRejectOfferDurationForReachedMaxCores(sc) + // A client for talking to the external shuffle service private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { @@ -145,7 +149,7 @@ private[spark] class CoarseMesosSchedulerBackend( super.start() val driver = createSchedulerDriver( master, - CoarseMesosSchedulerBackend.this, + MesosCoarseGrainedSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf, @@ -235,12 +239,12 @@ private[spark] class CoarseMesosSchedulerBackend( } } - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} + override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} - override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { + override def registered( + d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue mesosExternalShuffleClient.foreach(_.init(appId)) - logInfo("Registered as framework ID " + appId) markRegistered() } @@ -248,15 +252,15 @@ private[spark] class CoarseMesosSchedulerBackend( totalCoresAcquired >= maxCores * minRegisteredRatio } - override def disconnected(d: SchedulerDriver) {} + override def disconnected(d: org.apache.mesos.SchedulerDriver) {} - override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) {} /** * Method called by Mesos to offer resources on slaves. We respond by launching an executor, * unless we've already launched more than we wanted to. */ - override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { + override def resourceOffers(d: org.apache.mesos.SchedulerDriver, offers: JList[Offer]) { stateLock.synchronized { if (stopCalled) { logDebug("Ignoring offers during shutdown") @@ -278,19 +282,34 @@ private[spark] class CoarseMesosSchedulerBackend( } } - private def declineUnmatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { - for (offer <- offers) { - val id = offer.getId.getValue - val offerAttributes = toAttributeMap(offer.getAttributesList) - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus") - val filters = Filters.newBuilder() - .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build() - - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" - + s" for $rejectOfferDurationForUnmetConstraints seconds") + private def declineUnmatchedOffers( + d: org.apache.mesos.SchedulerDriver, offers: Buffer[Offer]): Unit = { + offers.foreach { offer => + declineOffer(d, offer, Some("unmet constraints"), + Some(rejectOfferDurationForUnmetConstraints)) + } + } - d.declineOffer(offer.getId, filters) + private def declineOffer( + d: org.apache.mesos.SchedulerDriver, + offer: Offer, + reason: Option[String] = None, + refuseSeconds: Option[Long] = None): Unit = { + + val id = offer.getId.getValue + val offerAttributes = toAttributeMap(offer.getAttributesList) + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus") + + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem" + + s" cpu: $cpus for $refuseSeconds seconds" + + reason.map(r => s" (reason: $r)").getOrElse("")) + + refuseSeconds match { + case Some(seconds) => + val filters = Filters.newBuilder().setRefuseSeconds(seconds).build() + d.declineOffer(offer.getId, filters) + case _ => d.declineOffer(offer.getId) } } @@ -301,7 +320,8 @@ private[spark] class CoarseMesosSchedulerBackend( * @param d SchedulerDriver * @param offers Mesos offers that match attribute constraints */ - private def handleMatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { + private def handleMatchedOffers( + d: org.apache.mesos.SchedulerDriver, offers: Buffer[Offer]): Unit = { val tasks = buildMesosTasks(offers) for (offer <- offers) { val offerAttributes = toAttributeMap(offer.getAttributesList) @@ -326,11 +346,12 @@ private[spark] class CoarseMesosSchedulerBackend( d.launchTasks( Collections.singleton(offer.getId), offerTasks.asJava) - } else { // decline - logDebug(s"Declining offer: $id with attributes: $offerAttributes " + - s"mem: $offerMem cpu: $offerCpus") - - d.declineOffer(offer.getId) + } else if (totalCoresAcquired >= maxCores) { + // Reject an offer for a configurable amount of time to avoid starving other frameworks + declineOffer(d, offer, Some("reached spark.cores.max"), + Some(rejectOfferDurationForReachedMaxCores)) + } else { + declineOffer(d, offer) } } } @@ -421,7 +442,7 @@ private[spark] class CoarseMesosSchedulerBackend( math.min(offerCPUs, maxCores - totalCoresAcquired)) } - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { + override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue val slaveId = status.getSlaveId.getValue val state = TaskState.fromMesos(status.getState) @@ -479,7 +500,7 @@ private[spark] class CoarseMesosSchedulerBackend( } } - override def error(d: SchedulerDriver, message: String) { + override def error(d: org.apache.mesos.SchedulerDriver, message: String) { logError(s"Mesos error: $message") scheduler.error(message) } @@ -519,28 +540,35 @@ private[spark] class CoarseMesosSchedulerBackend( } } - override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} + override def frameworkMessage( + d: org.apache.mesos.SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} /** * Called when a slave is lost or a Mesos task finished. Updates local view on * what tasks are running. It also notifies the driver that an executor was removed. */ private def executorTerminated( - d: SchedulerDriver, + d: org.apache.mesos.SchedulerDriver, slaveId: String, taskId: String, reason: String): Unit = { stateLock.synchronized { - removeExecutor(taskId, SlaveLost(reason)) + // Do not call removeExecutor() after this scheduler backend was stopped because + // removeExecutor() internally will send a message to the driver endpoint but + // the driver endpoint is not available now, otherwise an exception will be thrown. + if (!stopCalled) { + removeExecutor(taskId, SlaveLost(reason)) + } slaves(slaveId).taskIDs.remove(taskId) } } - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID): Unit = { + override def slaveLost(d: org.apache.mesos.SchedulerDriver, slaveId: SlaveID): Unit = { logInfo(s"Mesos slave lost: ${slaveId.getValue}") } - override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { + override def executorLost( + d: org.apache.mesos.SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int): Unit = { logInfo("Mesos executor lost: %s".format(e.getValue)) } @@ -550,7 +578,7 @@ private[spark] class CoarseMesosSchedulerBackend( super.applicationId } - override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future.successful { // We don't truly know if we can fulfill the full amount of executors // since at coarse grain it depends on the amount of slaves available. logInfo("Capping the total amount of executors to " + requestedTotal) @@ -558,7 +586,7 @@ private[spark] class CoarseMesosSchedulerBackend( true } - override def doKillExecutors(executorIds: Seq[String]): Boolean = { + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future.successful { if (mesosDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala similarity index 93% rename from core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala rename to core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 1a94aee2ca30c..e08dc3b5957bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -23,7 +23,6 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString @@ -38,12 +37,12 @@ import org.apache.spark.util.Utils * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks * from multiple apps can run on different cores) and in time (a core can switch ownership). */ -private[spark] class MesosSchedulerBackend( +private[spark] class MesosFineGrainedSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) extends SchedulerBackend - with MScheduler + with org.apache.mesos.Scheduler with MesosSchedulerUtils { // Stores the slave ids that has launched a Mesos executor. @@ -74,7 +73,7 @@ private[spark] class MesosSchedulerBackend( classLoader = Thread.currentThread.getContextClassLoader val driver = createSchedulerDriver( master, - MesosSchedulerBackend.this, + MesosFineGrainedSchedulerBackend.this, sc.sparkUser, sc.appName, sc.conf, @@ -175,9 +174,10 @@ private[spark] class MesosSchedulerBackend( execArgs } - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} + override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} - override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { + override def registered( + d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) @@ -195,9 +195,9 @@ private[spark] class MesosSchedulerBackend( } } - override def disconnected(d: SchedulerDriver) {} + override def disconnected(d: org.apache.mesos.SchedulerDriver) {} - override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) {} private def getTasksSummary(tasks: JArrayList[MesosTaskInfo]): String = { val builder = new StringBuilder @@ -216,7 +216,7 @@ private[spark] class MesosSchedulerBackend( * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that * tasks are balanced across the cluster. */ - override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { + override def resourceOffers(d: org.apache.mesos.SchedulerDriver, offers: JList[Offer]) { inClassLoader() { // Fail first on offers with unmet constraints val (offersMatchingConstraints, offersNotMatchingConstraints) = @@ -355,7 +355,7 @@ private[spark] class MesosSchedulerBackend( (taskInfo, finalResources.asJava) } - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { + override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { inClassLoader() { val tid = status.getTaskId.getValue.toLong val state = TaskState.fromMesos(status.getState) @@ -373,7 +373,7 @@ private[spark] class MesosSchedulerBackend( } } - override def error(d: SchedulerDriver, message: String) { + override def error(d: org.apache.mesos.SchedulerDriver, message: String) { inClassLoader() { logError("Mesos error: " + message) markErr() @@ -391,7 +391,8 @@ private[spark] class MesosSchedulerBackend( mesosDriver.reviveOffers() } - override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} + override def frameworkMessage( + d: org.apache.mesos.SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} /** * Remove executor associated with slaveId in a thread safe manner. @@ -403,7 +404,8 @@ private[spark] class MesosSchedulerBackend( } } - private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { + private def recordSlaveLost( + d: org.apache.mesos.SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) removeExecutor(slaveId.getValue, reason.toString) @@ -411,12 +413,12 @@ private[spark] class MesosSchedulerBackend( } } - override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { + override def slaveLost(d: org.apache.mesos.SchedulerDriver, slaveId: SlaveID) { recordSlaveLost(d, slaveId, SlaveLost()) } - override def executorLost(d: SchedulerDriver, executorId: ExecutorID, - slaveId: SlaveID, status: Int) { + override def executorLost( + d: org.apache.mesos.SchedulerDriver, executorId: ExecutorID, slaveId: SlaveID, status: Int) { logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue, slaveId.getValue)) recordSlaveLost(d, slaveId, ExecutorExited(status, exitCausedByApp = true)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 1b7ac172defb9..05b2b08944098 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -25,7 +25,7 @@ import org.apache.spark.internal.Logging /** * A collection of utility functions which can be used by both the - * MesosSchedulerBackend and the CoarseMesosSchedulerBackend. + * MesosSchedulerBackend and the [[MesosFineGrainedSchedulerBackend]]. */ private[mesos] object MesosSchedulerBackendUtil extends Logging { /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 1e322ac679419..7355ba317d9a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -352,4 +352,8 @@ private[mesos] trait MesosSchedulerUtils extends Logging { sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") } + protected def getRejectOfferDurationForReachedMaxCores(sc: SparkContext): Long = { + sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForReachedMaxCores", "120s") + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala similarity index 87% rename from core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala rename to core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 3473ef21b39a4..e386052814039 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -39,15 +39,15 @@ private case class KillTask(taskId: Long, interruptThread: Boolean) private case class StopExecutor() /** - * Calls to LocalBackend are all serialized through LocalEndpoint. Using an RpcEndpoint makes the - * calls on LocalBackend asynchronous, which is necessary to prevent deadlock between LocalBackend - * and the TaskSchedulerImpl. + * Calls to [[LocalSchedulerBackend]] are all serialized through LocalEndpoint. Using an + * RpcEndpoint makes the calls on [[LocalSchedulerBackend]] asynchronous, which is necessary + * to prevent deadlock between [[LocalSchedulerBackend]] and the [[TaskSchedulerImpl]]. */ private[spark] class LocalEndpoint( override val rpcEnv: RpcEnv, userClassPath: Seq[URL], scheduler: TaskSchedulerImpl, - executorBackend: LocalBackend, + executorBackend: LocalSchedulerBackend, private val totalCores: Int) extends ThreadSafeRpcEndpoint with Logging { @@ -91,11 +91,11 @@ private[spark] class LocalEndpoint( } /** - * LocalBackend is used when running a local version of Spark where the executor, backend, and - * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks - * on a single Executor (created by the LocalBackend) running locally. + * Used when running a local version of Spark where the executor, backend, and master all run in + * the same JVM. It sits behind a [[TaskSchedulerImpl]] and handles launching tasks on a single + * Executor (created by the [[LocalSchedulerBackend]]) running locally. */ -private[spark] class LocalBackend( +private[spark] class LocalSchedulerBackend( conf: SparkConf, scheduler: TaskSchedulerImpl, val totalCores: Int) @@ -124,7 +124,7 @@ private[spark] class LocalBackend( override def start() { val rpcEnv = SparkEnv.get.rpcEnv val executorEndpoint = new LocalEndpoint(rpcEnv, userClassPath, scheduler, this, totalCores) - localEndpoint = rpcEnv.setupEndpoint("LocalBackendEndpoint", executorEndpoint) + localEndpoint = rpcEnv.setupEndpoint("LocalSchedulerBackendEndpoint", executorEndpoint) listenerBus.post(SparkListenerExecutorAdded( System.currentTimeMillis, executorEndpoint.localExecutorId, diff --git a/core/src/main/scala/org/apache/spark/security/GroupMappingServiceProvider.scala b/core/src/main/scala/org/apache/spark/security/GroupMappingServiceProvider.scala new file mode 100644 index 0000000000000..ea047a4f75d55 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/GroupMappingServiceProvider.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.security + +/** + * This Spark trait is used for mapping a given userName to a set of groups which it belongs to. + * This is useful for specifying a common group of admins/developers to provide them admin, modify + * and/or view access rights. Based on whether access control checks are enabled using + * spark.acls.enable, every time a user tries to access or modify the application, the + * SecurityManager gets the corresponding groups a user belongs to from the instance of the groups + * mapping provider specified by the entry spark.user.groups.mapping. + */ + +trait GroupMappingServiceProvider { + + /** + * Get the groups the user belongs to. + * @param userName User's Name + * @return set of groups that the user belongs to. Empty in case of an invalid user. + */ + def getGroups(userName : String) : Set[String] + +} diff --git a/core/src/main/scala/org/apache/spark/security/ShellBasedGroupsMappingProvider.scala b/core/src/main/scala/org/apache/spark/security/ShellBasedGroupsMappingProvider.scala new file mode 100644 index 0000000000000..f71dd08246b2f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/ShellBasedGroupsMappingProvider.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.security + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * This class is responsible for getting the groups for a particular user in Unix based + * environments. This implementation uses the Unix Shell based id command to fetch the user groups + * for the specified user. It does not cache the user groups as the invocations are expected + * to be infrequent. + */ + +private[spark] class ShellBasedGroupsMappingProvider extends GroupMappingServiceProvider + with Logging { + + override def getGroups(username: String): Set[String] = { + val userGroups = getUnixGroups(username) + logDebug("User: " + username + " Groups: " + userGroups.mkString(",")) + userGroups + } + + // shells out a "bash -c id -Gn username" to get user groups + private def getUnixGroups(username: String): Set[String] = { + val cmdSeq = Seq("bash", "-c", "id -Gn " + username) + // we need to get rid of the trailing "\n" from the result of command execution + Utils.executeAndGetOutput(cmdSeq).stripLineEnd.split(" ").toSet + } +} diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index d17a7894fd8a8..f0ed41f6903f4 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -32,6 +32,7 @@ import org.apache.commons.io.IOUtils import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec +import org.apache.spark.util.Utils /** * Custom serializer used for generic Avro records. If the user registers the schemas @@ -72,8 +73,11 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { val bos = new ByteArrayOutputStream() val out = codec.compressedOutputStream(bos) - out.write(schema.toString.getBytes(StandardCharsets.UTF_8)) - out.close() + Utils.tryWithSafeFinally { + out.write(schema.toString.getBytes(StandardCharsets.UTF_8)) + } { + out.close() + } bos.toByteArray }) @@ -86,7 +90,12 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) schemaBytes.array(), schemaBytes.arrayOffset() + schemaBytes.position(), schemaBytes.remaining()) - val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) + val in = codec.compressedInputStream(bis) + val bytes = Utils.tryWithSafeFinally { + IOUtils.toByteArray(in) + } { + in.close() + } new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8)) }) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index c04b483831704..5e7a98c8aa89c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -155,7 +155,7 @@ private[spark] object SerializationDebugger extends Logging { // If the object has been replaced using writeReplace(), // then call visit() on it again to test its type again. - if (!finalObj.eq(o)) { + if (finalObj.getClass != o.getClass) { return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack) } @@ -265,11 +265,10 @@ private[spark] object SerializationDebugger extends Logging { if (!desc.hasWriteReplaceMethod) { (o, desc) } else { - // write place val replaced = desc.invokeWriteReplace(o) - // `writeReplace` may return the same object. - if (replaced eq o) { - (o, desc) + // `writeReplace` recursion stops when the returned object has the same class. + if (replaced.getClass == o.getClass) { + (replaced, desc) } else { findObjectAndDescriptor(replaced) } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 745ef126913f5..59bdc88464a84 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -68,7 +68,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) - private def canUseKryo(ct: ClassTag[_]): Boolean = { + def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag } @@ -128,20 +128,31 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar /** Serializes into a chunked byte buffer. */ def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { + dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) + } + + /** Serializes into a chunked byte buffer. */ + def dataSerializeWithExplicitClassTag( + blockId: BlockId, + values: Iterator[_], + classTag: ClassTag[_]): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) - dataSerializeStream(blockId, bbos, values) + val byteStream = new BufferedOutputStream(bbos) + val ser = getSerializer(classTag).newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } /** - * Deserializes a InputStream into an iterator of values and disposes of it when the end of + * Deserializes an InputStream into an iterator of values and disposes of it when the end of * the iterator is reached. */ - def dataDeserializeStream[T: ClassTag]( + def dataDeserializeStream[T]( blockId: BlockId, - inputStream: InputStream): Iterator[T] = { + inputStream: InputStream) + (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) - getSerializer(implicitly[ClassTag[T]]) + getSerializer(classTag) .newInstance() .deserializeStream(wrapForCompression(blockId, stream)) .asIterator.asInstanceOf[Iterator[T]] diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 94d8c0d0fd3e4..8d6396bededa9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -139,48 +139,54 @@ private[spark] class IndexShuffleBlockResolver( dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) - Utils.tryWithSafeFinally { - // We take in lengths of each block, need to convert it to offsets. - var offset = 0L - out.writeLong(offset) - for (length <- lengths) { - offset += length + try { + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + Utils.tryWithSafeFinally { + // We take in lengths of each block, need to convert it to offsets. + var offset = 0L out.writeLong(offset) + for (length <- lengths) { + offset += length + out.writeLong(offset) + } + } { + out.close() } - } { - out.close() - } - val dataFile = getDataFile(shuffleId, mapId) - // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure - // the following check and rename are atomic. - synchronized { - val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) - if (existingLengths != null) { - // Another attempt for the same task has already written our map outputs successfully, - // so just use the existing partition lengths and delete our temporary map outputs. - System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) - if (dataTmp != null && dataTmp.exists()) { - dataTmp.delete() - } - indexTmp.delete() - } else { - // This is the first successful attempt in writing the map outputs for this task, - // so override any existing index and data files with the ones we wrote. - if (indexFile.exists()) { - indexFile.delete() - } - if (dataFile.exists()) { - dataFile.delete() - } - if (!indexTmp.renameTo(indexFile)) { - throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) - } - if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { - throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + val dataFile = getDataFile(shuffleId, mapId) + // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure + // the following check and rename are atomic. + synchronized { + val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) + if (existingLengths != null) { + // Another attempt for the same task has already written our map outputs successfully, + // so just use the existing partition lengths and delete our temporary map outputs. + System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + if (dataTmp != null && dataTmp.exists()) { + dataTmp.delete() + } + indexTmp.delete() + } else { + // This is the first successful attempt in writing the map outputs for this task, + // so override any existing index and data files with the ones we wrote. + if (indexFile.exists()) { + indexFile.delete() + } + if (dataFile.exists()) { + dataFile.delete() + } + if (!indexTmp.renameTo(indexFile)) { + throw new IOException("fail to rename file " + indexTmp + " to " + indexFile) + } + if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) { + throw new IOException("fail to rename file " + dataTmp + " to " + dataFile) + } } } + } finally { + if (indexTmp.exists() && !indexTmp.delete()) { + logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}") + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 1adacabc86c05..e677270b84605 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -67,10 +67,16 @@ private[spark] class SortShuffleWriter[K, V, C]( // (see SPARK-3570). val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) val tmp = Utils.tempFileWith(output) - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, tmp) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + try { + val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) + val partitionLengths = sorter.writePartitionedFile(blockId, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + } finally { + if (tmp.exists() && !tmp.delete()) { + logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") + } + } } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala index 5783df5d8220c..d0d9ef1165e81 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala @@ -68,7 +68,12 @@ private[v1] object AllJobsResource { listener: JobProgressListener, includeStageDetails: Boolean): JobData = { listener.synchronized { - val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max) + val lastStageInfo = + if (job.stageIds.isEmpty) { + None + } else { + listener.stageIdToInfo.get(job.stageIds.max) + } val lastStageData = lastStageInfo.flatMap { s => listener.stageIdToData.get((s.stageId, s.attemptId)) } @@ -86,7 +91,7 @@ private[v1] object AllJobsResource { numTasks = job.numTasks, numActiveTasks = job.numActiveTasks, numCompletedTasks = job.numCompletedTasks, - numSkippedTasks = job.numCompletedTasks, + numSkippedTasks = job.numSkippedTasks, numFailedTasks = job.numFailedTasks, numActiveStages = job.numActiveStages, numCompletedStages = job.completedStageIndices.size, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index eddc36edc9611..7d63a8f734f0e 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -20,10 +20,10 @@ import java.util.{Arrays, Date, List => JList} import javax.ws.rs.{GET, Produces, QueryParam} import javax.ws.rs.core.MediaType -import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} +import org.apache.spark.ui.jobs.UIData.{InputMetricsUIData => InternalInputMetrics, OutputMetricsUIData => InternalOutputMetrics, ShuffleReadMetricsUIData => InternalShuffleReadMetrics, ShuffleWriteMetricsUIData => InternalShuffleWriteMetrics, TaskMetricsUIData => InternalTaskMetrics} import org.apache.spark.util.Distribution @Produces(Array(MediaType.APPLICATION_JSON)) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index ba9cd711f18e2..c4f1fd29f2970 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -21,10 +21,10 @@ import javax.servlet.ServletContext import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} -import com.sun.jersey.api.core.ResourceConfig -import com.sun.jersey.spi.container.servlet.ServletContainer import org.eclipse.jetty.server.handler.ContextHandler import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} +import org.glassfish.jersey.server.ServerProperties +import org.glassfish.jersey.servlet.ServletContainer import org.apache.spark.SecurityManager import org.apache.spark.ui.SparkUI @@ -191,12 +191,7 @@ private[spark] object ApiRootResource { val jerseyContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS) jerseyContext.setContextPath("/api") val holder: ServletHolder = new ServletHolder(classOf[ServletContainer]) - holder.setInitParameter("com.sun.jersey.config.property.resourceConfigClass", - "com.sun.jersey.api.core.PackagesResourceConfig") - holder.setInitParameter("com.sun.jersey.config.property.packages", - "org.apache.spark.status.api.v1") - holder.setInitParameter(ResourceConfig.PROPERTY_CONTAINER_REQUEST_FILTERS, - classOf[SecurityFilter].getCanonicalName) + holder.setInitParameter(ServerProperties.PROVIDER_PACKAGES, "org.apache.spark.status.api.v1") UIRootFromServletContext.setUiRoot(jerseyContext, uiRoot) jerseyContext.addServlet(holder, "/*") jerseyContext @@ -205,12 +200,13 @@ private[spark] object ApiRootResource { /** * This trait is shared by the all the root containers for application UI information -- - * the HistoryServer, the Master UI, and the application UI. This provides the common + * the HistoryServer and the application UI. This provides the common * interface needed for them all to expose application info as json. */ private[spark] trait UIRoot { def getSparkUI(appKey: String): Option[SparkUI] def getApplicationInfoList: Iterator[ApplicationInfo] + def getApplicationInfo(appId: String): Option[ApplicationInfo] /** * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 0f30183682469..075b9ba37dc84 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -21,7 +21,6 @@ import javax.ws.rs.{DefaultValue, GET, Produces, QueryParam} import javax.ws.rs.core.MediaType import org.apache.spark.deploy.history.ApplicationHistoryInfo -import org.apache.spark.deploy.master.{ApplicationInfo => InternalApplicationInfo} @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class ApplicationListResource(uiRoot: UIRoot) { @@ -30,7 +29,8 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { def appList( @QueryParam("status") status: JList[ApplicationStatus], @DefaultValue("2010-01-01") @QueryParam("minDate") minDate: SimpleDateParam, - @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam) + @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam, + @QueryParam("limit") limit: Integer) : Iterator[ApplicationInfo] = { val allApps = uiRoot.getApplicationInfoList val adjStatus = { @@ -42,7 +42,7 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { } val includeCompleted = adjStatus.contains(ApplicationStatus.COMPLETED) val includeRunning = adjStatus.contains(ApplicationStatus.RUNNING) - allApps.filter { app => + val appList = allApps.filter { app => val anyRunning = app.attempts.exists(!_.completed) // if any attempt is still running, we consider the app to also still be running val statusOk = (!anyRunning && includeCompleted) || @@ -54,6 +54,11 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { } statusOk && dateOk } + if (limit != null) { + appList.take(limit) + } else { + appList + } } } @@ -84,33 +89,4 @@ private[spark] object ApplicationsListResource { } ) } - - def convertApplicationInfo( - internal: InternalApplicationInfo, - completed: Boolean): ApplicationInfo = { - // standalone application info always has just one attempt - new ApplicationInfo( - id = internal.id, - name = internal.desc.name, - coresGranted = Some(internal.coresGranted), - maxCores = internal.desc.maxCores, - coresPerExecutor = internal.desc.coresPerExecutor, - memoryPerExecutorMB = Some(internal.desc.memoryPerExecutorMB), - attempts = Seq(new ApplicationAttemptInfo( - attemptId = None, - startTime = new Date(internal.startTime), - endTime = new Date(internal.endTime), - duration = - if (internal.endTime > 0) { - internal.endTime - internal.startTime - } else { - 0 - }, - lastUpdated = new Date(internal.endTime), - sparkUser = internal.desc.user, - completed = completed - )) - ) - } - } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index d7e6a8b589953..18c3e2f407360 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -24,7 +24,7 @@ private[v1] class OneApplicationResource(uiRoot: UIRoot) { @GET def getApp(@PathParam("appId") appId: String): ApplicationInfo = { - val apps = uiRoot.getApplicationInfoList.find { _.id == appId } + val apps = uiRoot.getApplicationInfo(appId) apps.getOrElse(throw new NotFoundException("unknown app: " + appId)) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala index 95fbd96ade5ab..b4a991eda35f3 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala @@ -16,18 +16,16 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.WebApplicationException +import javax.ws.rs.container.{ContainerRequestContext, ContainerRequestFilter} import javax.ws.rs.core.Response +import javax.ws.rs.ext.Provider -import com.sun.jersey.spi.container.{ContainerRequest, ContainerRequestFilter} - +@Provider private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext { - def filter(req: ContainerRequest): ContainerRequest = { - val user = Option(req.getUserPrincipal).map { _.getName }.orNull - if (uiRoot.securityManager.checkUIViewPermissions(user)) { - req - } else { - throw new WebApplicationException( + override def filter(req: ContainerRequestContext): Unit = { + val user = Option(req.getSecurityContext.getUserPrincipal).map { _.getName }.orNull + if (!uiRoot.securityManager.checkUIViewPermissions(user)) { + req.abortWith( Response .status(Response.Status.FORBIDDEN) .entity(raw"""user "$user"is not authorized""") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index ca53534b61c4a..dd8f5bacb9f6e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -211,9 +211,6 @@ private[storage] class BlockInfoManager extends Logging { * If another task has already locked this block for either reading or writing, then this call * will block until the other locks are released or will return immediately if `blocking = false`. * - * If this is called by a task which already holds the block's exclusive write lock, then this - * method will throw an exception. - * * @param blockId the block to lock. * @param blocking if true (default), this call will block until the lock is acquired. If false, * this call will return immediately if the lock acquisition fails. @@ -228,10 +225,7 @@ private[storage] class BlockInfoManager extends Logging { infos.get(blockId) match { case None => return None case Some(info) => - if (info.writerTask == currentTaskAttemptId) { - throw new IllegalStateException( - s"Task $currentTaskAttemptId has already locked $blockId for writing") - } else if (info.writerTask == BlockInfo.NO_WRITER && info.readerCount == 0) { + if (info.writerTask == BlockInfo.NO_WRITER && info.readerCount == 0) { info.writerTask = currentTaskAttemptId writeLocksByTask.addBinding(currentTaskAttemptId, blockId) logTrace(s"Task $currentTaskAttemptId acquired write lock for $blockId") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index f2d06c7ea8079..37dfbd6818ef7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -65,7 +65,7 @@ private[spark] class BlockManager( memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService, + val blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) extends BlockDataManager with BlockEvictionHandler with Logging { @@ -92,9 +92,9 @@ private[spark] class BlockManager( private[spark] val diskStore = new DiskStore(conf, diskBlockManager) memoryManager.setMemoryStore(memoryStore) - // Note: depending on the memory manager, `maxStorageMemory` may actually vary over time. + // Note: depending on the memory manager, `maxMemory` may actually vary over time. // However, since we use this only for reporting and logging, what we actually want here is - // the absolute maximum value that `maxStorageMemory` can ever possibly reach. We may need + // the absolute maximum value that `maxMemory` can ever possibly reach. We may need // to revisit whether reporting this value as the "max" is intuitive to the user. private val maxMemory = memoryManager.maxOnHeapStorageMemory @@ -216,7 +216,7 @@ private[spark] class BlockManager( logInfo(s"Reporting ${blockInfoManager.size} blocks to the master.") for ((blockId, info) <- blockInfoManager.entries) { val status = getCurrentBlockStatus(blockId, info) - if (!tryToReportBlockStatus(blockId, info, status)) { + if (info.tellMaster && !tryToReportBlockStatus(blockId, status)) { logError(s"Failed to report $blockId to master; giving up.") return } @@ -231,7 +231,7 @@ private[spark] class BlockManager( */ def reregister(): Unit = { // TODO: We might need to rate limit re-registering. - logInfo("BlockManager re-registering with master") + logInfo(s"BlockManager $blockManagerId re-registering with master") master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) reportAllBlocks() } @@ -279,7 +279,12 @@ private[spark] class BlockManager( } else { getLocalBytes(blockId) match { case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) - case None => throw new BlockNotFoundException(blockId.toString) + case None => + // If this block manager receives a request for a block that it doesn't have then it's + // likely that the master has outdated block statuses for this block. Therefore, we send + // an RPC so that this block is marked as being unavailable from this block manager. + reportBlockStatus(blockId, BlockStatus.empty) + throw new BlockNotFoundException(blockId.toString) } } } @@ -297,7 +302,7 @@ private[spark] class BlockManager( /** * Get the BlockStatus for the block identified by the given ID, if it exists. - * NOTE: This is mainly for testing, and it doesn't fetch information from external block store. + * NOTE: This is mainly for testing. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfoManager.get(blockId).map { info => @@ -332,10 +337,9 @@ private[spark] class BlockManager( */ private def reportBlockStatus( blockId: BlockId, - info: BlockInfo, status: BlockStatus, droppedMemorySize: Long = 0L): Unit = { - val needReregister = !tryToReportBlockStatus(blockId, info, status, droppedMemorySize) + val needReregister = !tryToReportBlockStatus(blockId, status, droppedMemorySize) if (needReregister) { logInfo(s"Got told to re-register updating block $blockId") // Re-registering will report our new block for free. @@ -351,17 +355,12 @@ private[spark] class BlockManager( */ private def tryToReportBlockStatus( blockId: BlockId, - info: BlockInfo, status: BlockStatus, droppedMemorySize: Long = 0L): Boolean = { - if (info.tellMaster) { - val storageLevel = status.storageLevel - val inMemSize = Math.max(status.memSize, droppedMemorySize) - val onDiskSize = status.diskSize - master.updateBlockInfo(blockManagerId, blockId, storageLevel, inMemSize, onDiskSize) - } else { - true - } + val storageLevel = status.storageLevel + val inMemSize = Math.max(status.memSize, droppedMemorySize) + val onDiskSize = status.diskSize + master.updateBlockInfo(blockManagerId, blockId, storageLevel, inMemSize, onDiskSize) } /** @@ -373,7 +372,7 @@ private[spark] class BlockManager( info.synchronized { info.level match { case null => - BlockStatus(StorageLevel.NONE, memSize = 0L, diskSize = 0L) + BlockStatus.empty case level => val inMem = level.useMemory && memoryStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) @@ -402,6 +401,17 @@ private[spark] class BlockManager( locations } + /** + * Cleanup code run in response to a failed local read. + * Must be called while holding a read lock on the block. + */ + private def handleLocalReadFailure(blockId: BlockId): Nothing = { + releaseLock(blockId) + // Remove the missing block so that its unavailability is reported to the driver + removeBlock(blockId) + throw new SparkException(s"Block $blockId was not found even though it's read-locked") + } + /** * Get block from local block manager as an iterator of Java objects. */ @@ -441,8 +451,7 @@ private[spark] class BlockManager( val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { - releaseLock(blockId) - throw new SparkException(s"Block $blockId was not found even though it's read-locked") + handleLocalReadFailure(blockId) } } } @@ -487,10 +496,10 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - serializerManager.dataSerialize(blockId, memoryStore.getValues(blockId).get) + serializerManager.dataSerializeWithExplicitClassTag( + blockId, memoryStore.getValues(blockId).get, info.classTag) } else { - releaseLock(blockId) - throw new SparkException(s"Block $blockId was not found even though it's read-locked") + handleLocalReadFailure(blockId) } } else { // storage level is serialized if (level.useMemory && memoryStore.contains(blockId)) { @@ -499,8 +508,7 @@ private[spark] class BlockManager( val diskBytes = diskStore.getBytes(blockId) maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes) } else { - releaseLock(blockId) - throw new SparkException(s"Block $blockId was not found even though it's read-locked") + handleLocalReadFailure(blockId) } } } @@ -510,10 +518,11 @@ private[spark] class BlockManager( * * This does not acquire a lock on this block in this JVM. */ - private def getRemoteValues(blockId: BlockId): Option[BlockResult] = { + private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = { + val ct = implicitly[ClassTag[T]] getRemoteBytes(blockId).map { data => val values = - serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true)) + serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct) new BlockResult(values, DataReadMethod.Network, data.size) } } @@ -554,8 +563,9 @@ private[spark] class BlockManager( // Give up trying anymore locations. Either we've tried all of the original locations, // or we've refreshed the list of locations from the master, and have still // hit failures after trying locations from the refreshed list. - throw new BlockFetchException(s"Failed to fetch block after" + - s" ${totalFailureCount} fetch failures. Most recent failure cause:", e) + logWarning(s"Failed to fetch block after $totalFailureCount fetch failures. " + + s"Most recent failure cause:", e) + return None } logWarning(s"Failed to fetch remote block $blockId " + @@ -592,13 +602,13 @@ private[spark] class BlockManager( * any locks if the block was fetched from a remote block manager. The read lock will * automatically be freed once the result's `data` iterator is fully consumed. */ - def get(blockId: BlockId): Option[BlockResult] = { + def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = { val local = getLocalValues(blockId) if (local.isDefined) { logInfo(s"Found block $blockId locally") return local } - val remote = getRemoteValues(blockId) + val remote = getRemoteValues[T](blockId) if (remote.isDefined) { logInfo(s"Found block $blockId remotely") return remote @@ -650,7 +660,7 @@ private[spark] class BlockManager( makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = { // Attempt to read the block from local or remote storage. If it's present, then we don't need // to go through the local-get-or-put path. - get(blockId) match { + get[T](blockId)(classTag) match { case Some(block) => return Left(block) case _ => @@ -797,12 +807,10 @@ private[spark] class BlockManager( // Now that the block is in either the memory, externalBlockStore, or disk store, // tell the master about it. info.size = size - if (tellMaster) { - reportBlockStatus(blockId, info, putBlockStatus) - } - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus) + if (tellMaster && info.tellMaster) { + reportBlockStatus(blockId, putBlockStatus) } + addUpdatedBlockStatusToTaskMetrics(blockId, putBlockStatus) } logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) if (level.replication > 1) { @@ -853,22 +861,38 @@ private[spark] class BlockManager( } val startTimeMs = System.currentTimeMillis - var blockWasSuccessfullyStored: Boolean = false + var exceptionWasThrown: Boolean = true val result: Option[T] = try { val res = putBody(putBlockInfo) - blockWasSuccessfullyStored = res.isEmpty - res - } finally { - if (blockWasSuccessfullyStored) { + exceptionWasThrown = false + if (res.isEmpty) { + // the block was successfully stored if (keepReadLock) { blockInfoManager.downgradeLock(blockId) } else { blockInfoManager.unlock(blockId) } } else { - blockInfoManager.removeBlock(blockId) + removeBlockInternal(blockId, tellMaster = false) logWarning(s"Putting block $blockId failed") } + res + } finally { + // This cleanup is performed in a finally block rather than a `catch` to avoid having to + // catch and properly re-throw InterruptedException. + if (exceptionWasThrown) { + logWarning(s"Putting block $blockId failed due to an exception") + // If an exception was thrown then it's possible that the code in `putBody` has already + // notified the master about the availability of this block, so we need to send an update + // to remove this block location. + removeBlockInternal(blockId, tellMaster = tellMaster) + // The `putBody` code may have also added a new block status to TaskMetrics, so we need + // to cancel that out by overwriting it with an empty block status. We only do this if + // the finally block was entered via an exception because doing this unconditionally would + // cause us to send empty block statuses for every block that failed to be cached due to + // a memory shortage (which is an expected failure, unlike an uncaught exception). + addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty) + } } if (level.replication > 1) { logDebug("Putting block %s with replication took %s" @@ -951,21 +975,26 @@ private[spark] class BlockManager( val putBlockStatus = getCurrentBlockStatus(blockId, info) val blockWasSuccessfullyStored = putBlockStatus.storageLevel.isValid if (blockWasSuccessfullyStored) { - // Now that the block is in either the memory, externalBlockStore, or disk store, - // tell the master about it. + // Now that the block is in either the memory or disk store, tell the master about it. info.size = size - if (tellMaster) { - reportBlockStatus(blockId, info, putBlockStatus) - } - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus) + if (tellMaster && info.tellMaster) { + reportBlockStatus(blockId, putBlockStatus) } + addUpdatedBlockStatusToTaskMetrics(blockId, putBlockStatus) logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) if (level.replication > 1) { val remoteStartTime = System.currentTimeMillis val bytesToReplicate = doGetLocalBytes(blockId, info) + // [SPARK-16550] Erase the typed classTag when using default serialization, since + // NettyBlockRpcServer crashes when deserializing repl-defined classes. + // TODO(ekl) remove this once the classloader issue on the remote end is fixed. + val remoteClassTag = if (!serializerManager.canUseKryo(classTag)) { + scala.reflect.classTag[Any] + } else { + classTag + } try { - replicate(blockId, bytesToReplicate, level, classTag) + replicate(blockId, bytesToReplicate, level, remoteClassTag) } finally { bytesToReplicate.dispose() } @@ -1162,7 +1191,7 @@ private[spark] class BlockManager( done = true // specified number of peers have been replicated to } } catch { - case e: Exception => + case NonFatal(e) => logWarning(s"Failed to replicate $blockId to $peer, failure #$failures", e) failures += 1 replicationFailed = true @@ -1187,8 +1216,8 @@ private[spark] class BlockManager( /** * Read a block consisting of a single object. */ - def getSingle(blockId: BlockId): Option[Any] = { - get(blockId).map(_.data.next()) + def getSingle[T: ClassTag](blockId: BlockId): Option[T] = { + get[T](blockId).map(_.data.next().asInstanceOf[T]) } /** @@ -1253,12 +1282,10 @@ private[spark] class BlockManager( val status = getCurrentBlockStatus(blockId, info) if (info.tellMaster) { - reportBlockStatus(blockId, info, status, droppedMemorySize) + reportBlockStatus(blockId, status, droppedMemorySize) } if (blockIsUpdated) { - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) - } + addUpdatedBlockStatusToTaskMetrics(blockId, status) } status.storageLevel } @@ -1298,21 +1325,31 @@ private[spark] class BlockManager( // The block has already been removed; do nothing. logWarning(s"Asked to remove block $blockId, which does not exist") case Some(info) => - // Removals are idempotent in disk store and memory store. At worst, we get a warning. - val removedFromMemory = memoryStore.remove(blockId) - val removedFromDisk = diskStore.remove(blockId) - if (!removedFromMemory && !removedFromDisk) { - logWarning(s"Block $blockId could not be removed as it was not found in either " + - "the disk, memory, or external block store") - } - blockInfoManager.removeBlock(blockId) - val removeBlockStatus = getCurrentBlockStatus(blockId, info) - if (tellMaster && info.tellMaster) { - reportBlockStatus(blockId, info, removeBlockStatus) - } - Option(TaskContext.get()).foreach { c => - c.taskMetrics().incUpdatedBlockStatuses(blockId -> removeBlockStatus) - } + removeBlockInternal(blockId, tellMaster = tellMaster && info.tellMaster) + addUpdatedBlockStatusToTaskMetrics(blockId, BlockStatus.empty) + } + } + + /** + * Internal version of [[removeBlock()]] which assumes that the caller already holds a write + * lock on the block. + */ + private def removeBlockInternal(blockId: BlockId, tellMaster: Boolean): Unit = { + // Removals are idempotent in disk store and memory store. At worst, we get a warning. + val removedFromMemory = memoryStore.remove(blockId) + val removedFromDisk = diskStore.remove(blockId) + if (!removedFromMemory && !removedFromDisk) { + logWarning(s"Block $blockId could not be removed as it was not found on disk or in memory") + } + blockInfoManager.removeBlock(blockId) + if (tellMaster) { + reportBlockStatus(blockId, BlockStatus.empty) + } + } + + private def addUpdatedBlockStatusToTaskMetrics(blockId: BlockId, status: BlockStatus): Unit = { + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(blockId -> status) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index c22d2e0fb61fa..8655cf10fc28f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -42,12 +42,20 @@ class BlockManagerMaster( logInfo("Removed " + execId + " successfully in removeExecutor") } + /** Request removal of a dead executor from the driver endpoint. + * This is only called on the driver side. Non-blocking + */ + def removeExecutorAsync(execId: String) { + driverEndpoint.ask[Boolean](RemoveExecutor(execId)) + logInfo("Removal of executor " + execId + " requested") + } + /** Register the BlockManager's id with the driver. */ def registerBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { - logInfo("Trying to register BlockManager") + logInfo(s"Registering BlockManager $blockManagerId") tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) - logInfo("Registered BlockManager") + logInfo(s"Registered BlockManager $blockManagerId") } def updateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index ab97d2e4b8b78..5b493f470b50a 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -203,8 +203,7 @@ private[spark] class DiskBlockObjectWriter( numRecordsWritten += 1 writeMetrics.incRecordsWritten(1) - // TODO: call updateBytesWritten() less frequently. - if (numRecordsWritten % 32 == 0) { + if (numRecordsWritten % 16384 == 0) { updateBytesWritten() } } diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 083d78b59ebee..e5abbf745cc41 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -24,7 +24,7 @@ import org.apache.spark.util.Utils @DeveloperApi class RDDInfo( val id: Int, - val name: String, + var name: String, val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 216ec0793492f..fad0404bebc36 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -120,8 +120,14 @@ class StorageLevel private( private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) override def toString: String = { - s"StorageLevel(disk=$useDisk, memory=$useMemory, offheap=$useOffHeap, " + - s"deserialized=$deserialized, replication=$replication)" + val disk = if (useDisk) "disk" else "" + val memory = if (useMemory) "memory" else "" + val heap = if (useOffHeap) "offheap" else "" + val deserialize = if (deserialized) "deserialized" else "" + + val output = + Seq(disk, memory, heap, deserialize, s"$replication replicas").filter(_.nonEmpty) + s"StorageLevel(${output.mkString(", ")})" } override def hashCode(): Int = toInt * 41 + replication diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 99be4de0658cc..9b87c4236aa52 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} import org.apache.spark.unsafe.Platform -import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} +import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -271,10 +271,11 @@ private[spark] class MemoryStore( blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed))) Right(size) } else { - assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask, + assert(currentUnrollMemoryForThisTask >= unrollMemoryUsedByThisBlock, "released too much unroll memory") Left(new PartiallyUnrolledIterator( this, + MemoryMode.ON_HEAP, unrollMemoryUsedByThisBlock, unrolled = arrayValues.toIterator, rest = Iterator.empty)) @@ -283,7 +284,11 @@ private[spark] class MemoryStore( // We ran out of space while unrolling the values for this block logUnrollFailureMessage(blockId, vector.estimateSize()) Left(new PartiallyUnrolledIterator( - this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values)) + this, + MemoryMode.ON_HEAP, + unrollMemoryUsedByThisBlock, + unrolled = vector.iterator, + rest = values)) } } @@ -377,7 +382,8 @@ private[spark] class MemoryStore( entries.put(blockId, entry) } logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format( - blockId, Utils.bytesToString(entry.size), Utils.bytesToString(blocksMemoryUsed))) + blockId, Utils.bytesToString(entry.size), + Utils.bytesToString(maxMemory - blocksMemoryUsed))) Right(entry.size) } else { // We ran out of space while unrolling the values for this block @@ -391,7 +397,7 @@ private[spark] class MemoryStore( redirectableStream, unrollMemoryUsedByThisBlock, memoryMode, - bbos.toChunkedByteBuffer, + bbos, values, classTag)) } @@ -590,11 +596,11 @@ private[spark] class MemoryStore( val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) if (memoryToRelease > 0) { unrollMemoryMap(taskAttemptId) -= memoryToRelease - if (unrollMemoryMap(taskAttemptId) == 0) { - unrollMemoryMap.remove(taskAttemptId) - } memoryManager.releaseUnrollMemory(memoryToRelease, memoryMode) } + if (unrollMemoryMap(taskAttemptId) == 0) { + unrollMemoryMap.remove(taskAttemptId) + } } } } @@ -652,6 +658,7 @@ private[spark] class MemoryStore( * The result of a failed [[MemoryStore.putIteratorAsValues()]] call. * * @param memoryStore the memoryStore, used for freeing memory. + * @param memoryMode the memory mode (on- or off-heap). * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. * @param unrolled an iterator for the partially-unrolled values. * @param rest the rest of the original iterator passed to @@ -659,39 +666,52 @@ private[spark] class MemoryStore( */ private[storage] class PartiallyUnrolledIterator[T]( memoryStore: MemoryStore, + memoryMode: MemoryMode, unrollMemory: Long, - unrolled: Iterator[T], + private[this] var unrolled: Iterator[T], rest: Iterator[T]) extends Iterator[T] { - private[this] var unrolledIteratorIsConsumed: Boolean = false - private[this] var iter: Iterator[T] = { - val completionIterator = CompletionIterator[T, Iterator[T]](unrolled, { - unrolledIteratorIsConsumed = true - memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory) - }) - completionIterator ++ rest + private def releaseUnrollMemory(): Unit = { + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + // SPARK-17503: Garbage collects the unrolling memory before the life end of + // PartiallyUnrolledIterator. + unrolled = null + } + + override def hasNext: Boolean = { + if (unrolled == null) { + rest.hasNext + } else if (!unrolled.hasNext) { + releaseUnrollMemory() + rest.hasNext + } else { + true + } } - override def hasNext: Boolean = iter.hasNext - override def next(): T = iter.next() + override def next(): T = { + if (unrolled == null) { + rest.next() + } else { + unrolled.next() + } + } /** * Called to dispose of this iterator and free its memory. */ def close(): Unit = { - if (!unrolledIteratorIsConsumed) { - memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory) - unrolledIteratorIsConsumed = true + if (unrolled != null) { + releaseUnrollMemory() } - iter = null } } /** * A wrapper which allows an open [[OutputStream]] to be redirected to a different sink. */ -private class RedirectableOutputStream extends OutputStream { +private[storage] class RedirectableOutputStream extends OutputStream { private[this] var os: OutputStream = _ def setOutputStream(s: OutputStream): Unit = { os = s } override def write(b: Int): Unit = os.write(b) @@ -711,7 +731,8 @@ private class RedirectableOutputStream extends OutputStream { * @param redirectableOutputStream an OutputStream which can be redirected to a different sink. * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. * @param memoryMode whether the unroll memory is on- or off-heap - * @param unrolled a byte buffer containing the partially-serialized values. + * @param bbos byte buffer output stream containing the partially-serialized values. + * [[redirectableOutputStream]] initially points to this output stream. * @param rest the rest of the original iterator passed to * [[MemoryStore.putIteratorAsValues()]]. * @param classTag the [[ClassTag]] for the block. @@ -720,14 +741,19 @@ private[storage] class PartiallySerializedBlock[T]( memoryStore: MemoryStore, serializerManager: SerializerManager, blockId: BlockId, - serializationStream: SerializationStream, - redirectableOutputStream: RedirectableOutputStream, - unrollMemory: Long, + private val serializationStream: SerializationStream, + private val redirectableOutputStream: RedirectableOutputStream, + val unrollMemory: Long, memoryMode: MemoryMode, - unrolled: ChunkedByteBuffer, + bbos: ChunkedByteBufferOutputStream, rest: Iterator[T], classTag: ClassTag[T]) { + private lazy val unrolledBuffer: ChunkedByteBuffer = { + bbos.close() + bbos.toChunkedByteBuffer + } + // If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of // this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task // completion listener here in order to ensure that `unrolled.dispose()` is called at least once. @@ -736,7 +762,23 @@ private[storage] class PartiallySerializedBlock[T]( taskContext.addTaskCompletionListener { _ => // When a task completes, its unroll memory will automatically be freed. Thus we do not call // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing. - unrolled.dispose() + unrolledBuffer.dispose() + } + } + + // Exposed for testing + private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer + + private[this] var discarded = false + private[this] var consumed = false + + private def verifyNotConsumedAndNotDiscarded(): Unit = { + if (consumed) { + throw new IllegalStateException( + "Can only call one of finishWritingToStream() or valuesIterator() and can only call once.") + } + if (discarded) { + throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock") } } @@ -744,15 +786,18 @@ private[storage] class PartiallySerializedBlock[T]( * Called to dispose of this block and free its memory. */ def discard(): Unit = { - try { - // We want to close the output stream in order to free any resources associated with the - // serializer itself (such as Kryo's internal buffers). close() might cause data to be - // written, so redirect the output stream to discard that data. - redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream()) - serializationStream.close() - } finally { - unrolled.dispose() - memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + if (!discarded) { + try { + // We want to close the output stream in order to free any resources associated with the + // serializer itself (such as Kryo's internal buffers). close() might cause data to be + // written, so redirect the output stream to discard that data. + redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream()) + serializationStream.close() + } finally { + discarded = true + unrolledBuffer.dispose() + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) + } } } @@ -761,8 +806,10 @@ private[storage] class PartiallySerializedBlock[T]( * and then serializing the values from the original input iterator. */ def finishWritingToStream(os: OutputStream): Unit = { + verifyNotConsumedAndNotDiscarded() + consumed = true // `unrolled`'s underlying buffers will be freed once this input stream is fully read: - ByteStreams.copy(unrolled.toInputStream(dispose = true), os) + ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os) memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) redirectableOutputStream.setOutputStream(os) while (rest.hasNext) { @@ -779,13 +826,22 @@ private[storage] class PartiallySerializedBlock[T]( * `close()` on it to free its resources. */ def valuesIterator: PartiallyUnrolledIterator[T] = { + verifyNotConsumedAndNotDiscarded() + consumed = true + // Close the serialization stream so that the serializer's internal buffers are freed and any + // "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream. + serializationStream.close() // `unrolled`'s underlying buffers will be freed once this input stream is fully read: val unrolledIter = serializerManager.dataDeserializeStream( - blockId, unrolled.toInputStream(dispose = true))(classTag) + blockId, unrolledBuffer.toInputStream(dispose = true))(classTag) + // The unroll memory will be freed once `unrolledIter` is fully consumed in + // PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any + // extra unroll memory will automatically be freed by a `finally` block in `Task`. new PartiallyUnrolledIterator( memoryStore, + memoryMode, unrollMemory, - unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()), + unrolled = unrolledIter, rest = rest) } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index db24f0319ba05..d980094a3be45 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -25,13 +25,12 @@ import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.xml.Node -import org.eclipse.jetty.server.{Connector, Request, Server} +import org.eclipse.jetty.server.{HttpConnectionFactory, Request, Server, ServerConnector} import org.eclipse.jetty.server.handler._ -import org.eclipse.jetty.server.nio.SelectChannelConnector -import org.eclipse.jetty.server.ssl.SslSelectChannelConnector import org.eclipse.jetty.servlet._ +import org.eclipse.jetty.servlets.gzip.GzipHandler import org.eclipse.jetty.util.component.LifeCycle -import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} @@ -243,10 +242,24 @@ private[spark] object JettyUtils extends Logging { // Bind to the given port, or throw a java.net.BindException if the port is occupied def connect(currentPort: Int): (Server, Int) = { - val server = new Server - val connectors = new ArrayBuffer[Connector] + val pool = new QueuedThreadPool + if (serverName.nonEmpty) { + pool.setName(serverName) + } + pool.setDaemon(true) + + val server = new Server(pool) + val connectors = new ArrayBuffer[ServerConnector] // Create a connector on port currentPort to listen for HTTP requests - val httpConnector = new SelectChannelConnector() + val httpConnector = new ServerConnector( + server, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true), + null, + -1, + -1, + new HttpConnectionFactory()) httpConnector.setPort(currentPort) connectors += httpConnector @@ -260,8 +273,9 @@ private[spark] object JettyUtils extends Logging { } val scheme = "https" // Create a connector on port securePort to listen for HTTPS requests - val connector = new SslSelectChannelConnector(factory) + val connector = new ServerConnector(server, factory) connector.setPort(securePort) + connectors += connector // redirect the HTTP requests to HTTPS port @@ -269,34 +283,28 @@ private[spark] object JettyUtils extends Logging { } gzipHandlers.foreach(collection.addHandler) - connectors.foreach(_.setHost(hostName)) // As each acceptor and each selector will use one thread, the number of threads should at // least be the number of acceptors and selectors plus 1. (See SPARK-13776) var minThreads = 1 - connectors.foreach { c => + connectors.foreach { connector => // Currently we only use "SelectChannelConnector" - val connector = c.asInstanceOf[SelectChannelConnector] // Limit the max acceptor number to 8 so that we don't waste a lot of threads - connector.setAcceptors(math.min(connector.getAcceptors, 8)) + connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8)) + connector.setHost(hostName) // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 } server.setConnectors(connectors.toArray) - - val pool = new QueuedThreadPool - if (serverName.nonEmpty) { - pool.setName(serverName) - } pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) - pool.setDaemon(true) - server.setThreadPool(pool) + val errorHandler = new ErrorHandler() errorHandler.setShowStacks(true) + errorHandler.setServer(server) server.addBean(errorHandler) server.setHandler(collection) try { server.start() - (server, server.getConnectors.head.getLocalPort) + (server, httpConnector.getLocalPort) } catch { case e: Exception => server.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 39155ff2649ec..ef71db89798f1 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -126,6 +126,10 @@ private[spark] class SparkUI private ( )) )) } + + def getApplicationInfo(appId: String): Option[ApplicationInfo] = { + getApplicationInfoList.find(_.id == appId) + } } private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 2d2d80be4aabe..3cc5353f475f4 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -90,4 +90,10 @@ private[spark] object ToolTips { val TASK_TIME = "Shaded red when garbage collection (GC) time is over 10% of task time" + + val APPLICATION_EXECUTOR_LIMIT = + """Maximum number of executors that this application will use. This limit is finite only when + dynamic allocation is enabled. The number of granted executors may exceed the limit + ephemerally when executors are being killed. + """ } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 6241593bba32f..1aa85d60ea815 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui import java.net.URLDecoder import java.text.SimpleDateFormat -import java.util.{Date, Locale} +import java.util.{Date, Locale, TimeZone} import scala.util.control.NonFatal import scala.xml._ @@ -502,4 +502,7 @@ private[spark] object UIUtils extends Logging { } param } + + def getTimeZoneOffset() : Int = + TimeZone.getDefault().getOffset(System.currentTimeMillis()) / 1000 / 60 } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 2c40e726992d3..adc4a4faccad1 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -146,7 +146,10 @@ private[spark] abstract class WebUI( } /** Return the url of web interface. Only valid after bind(). */ - def webUrl: String = s"http://$publicHostName:$boundPort" + def webUrl: String = { + val protocol = if (sslOptions.enabled) "https" else "http" + s"$protocol://$publicHostName:$boundPort" + } /** Return the actual port to which this server is bound. Only valid after bind(). */ def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index f0a1174a71d34..22136a6f10743 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -26,11 +26,15 @@ import org.apache.spark.ui.{UIUtils, WebUIPage} private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") { private val listener = parent.listener + private def removePass(kv: (String, String)): (String, String) = { + if (kv._1.toLowerCase.contains("password")) (kv._1, "******") else kv + } + def render(request: HttpServletRequest): Seq[Node] = { val runtimeInformationTable = UIUtils.listingTable( propertyHeader, jvmRow, listener.jvmInformation, fixedWidth = true) val sparkPropertiesTable = UIUtils.listingTable( - propertyHeader, propertyRow, listener.sparkProperties, fixedWidth = true) + propertyHeader, propertyRow, listener.sparkProperties.map(removePass), fixedWidth = true) val systemPropertiesTable = UIUtils.listingTable( propertyHeader, propertyRow, listener.systemProperties, fixedWidth = true) val classpathEntriesTable = UIUtils.listingTable( diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 791dbe5c272b5..67deb7b14bcb9 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui.exec import java.net.URLEncoder import javax.servlet.http.HttpServletRequest +import scala.util.Try import scala.xml.Node import org.apache.spark.status.api.v1.ExecutorSummary @@ -53,6 +54,9 @@ private[ui] class ExecutorsPage( // When GCTimePercent is edited change ToolTips.TASK_TIME to match private val GCTimePercent = 0.1 + // a safe String to Int for sorting ids (converts non-numeric Strings to -1) + private def idStrToInt(str: String) : Int = Try(str.toInt).getOrElse(-1) + def render(request: HttpServletRequest): Seq[Node] = { val (activeExecutorInfo, deadExecutorInfo) = listener.synchronized { // The follow codes should be protected by `listener` to make sure no executors will be @@ -69,13 +73,14 @@ private[ui] class ExecutorsPage( } val execInfo = activeExecutorInfo ++ deadExecutorInfo + implicit val idOrder = Ordering[Int].on((s: String) => idStrToInt(s)).reverse val execInfoSorted = execInfo.sortBy(_.id) val logsExist = execInfo.filter(_.executorLogs.nonEmpty).nonEmpty val execTable = { - + @@ -136,7 +141,7 @@ private[ui] class ExecutorsPage( } - + + + your default properties file. @@ -244,7 +245,7 @@ Apart from these, the following properties are also available, and may be useful
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-library-path command line option or in - your default properties file. + your default properties file. @@ -292,11 +293,19 @@ Apart from these, the following properties are also available, and may be useful Older log files will be deleted. Disabled by default. + + + + + @@ -308,7 +317,7 @@ Apart from these, the following properties are also available, and may be useful Set the strategy of rolling of executor logs. By default it is disabled. It can be set to "time" (time-based rolling) or "size" (size-based rolling). For "time", use spark.executor.logs.rolling.time.interval to set the rolling interval. - For "size", use spark.executor.logs.rolling.size.maxBytes to set + For "size", use spark.executor.logs.rolling.maxSize to set the maximum file size for rolling. @@ -380,6 +389,53 @@ Apart from these, the following properties are also available, and may be useful from JVM to Python worker for every task. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Executor IDExecutor ID Address Status RDD Blocks
    {info.id}{info.id} {info.hostPort} {executorStatus} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 07484c9550134..373c26be4c5f3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -23,6 +23,8 @@ import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, ListBuffer} import scala.xml._ +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.JobExecutionStatus import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} import org.apache.spark.ui.jobs.UIData.{ExecutorUIData, JobUIData} @@ -87,9 +89,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { case JobExecutionStatus.UNKNOWN => "unknown" } - // The timeline library treats contents as HTML, so we have to escape them; for the - // data-title attribute string we have to escape them twice since that's in a string. + // The timeline library treats contents as HTML, so we have to escape them. We need to add + // extra layers of escaping in order to embed this in a Javascript string literal. val escapedDesc = Utility.escape(displayJobDescription) + val jsEscapedDesc = StringEscapeUtils.escapeEcmaScript(escapedDesc) val jobEventJsonAsStr = s""" |{ @@ -99,7 +102,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { | 'end': new Date(${completionTime}), | 'content': '
    ' + | 'Status: ${status}
    ' + | 'Submitted: ${UIUtils.formatDate(new Date(submissionTime))}' + | '${ @@ -109,7 +112,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { "" } }">' + - | '${escapedDesc} (Job ${jobId})
    ' + | '${jsEscapedDesc} (Job ${jobId})' |} """.stripMargin jobEventJsonAsStr @@ -203,7 +206,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { ++ } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 645e2d2e360bb..99f2bd8bc1f22 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -23,6 +23,8 @@ import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, HashMap, ListBuffer} import scala.xml.{Node, NodeSeq, Unparsed, Utility} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} @@ -63,9 +65,10 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { val submissionTime = stage.submissionTime.get val completionTime = stage.completionTime.getOrElse(System.currentTimeMillis()) - // The timeline library treats contents as HTML, so we have to escape them; for the - // data-title attribute string we have to escape them twice since that's in a string. + // The timeline library treats contents as HTML, so we have to escape them. We need to add + // extra layers of escaping in order to embed this in a Javascript string literal. val escapedName = Utility.escape(name) + val jsEscapedName = StringEscapeUtils.escapeEcmaScript(escapedName) s""" |{ | 'className': 'stage job-timeline-object ${status}', @@ -74,7 +77,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'end': new Date(${completionTime}), | 'content': '
    ' + | 'Status: ${status.toUpperCase}
    ' + | 'Submitted: ${UIUtils.formatDate(new Date(submissionTime))}' + | '${ @@ -84,7 +87,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { "" } }">' + - | '${escapedName} (Stage ${stageId}.${attemptId})
    ', + | '${jsEscapedName} (Stage ${stageId}.${attemptId})', |} """.stripMargin } @@ -176,7 +179,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { ++ } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 945830c8bf242..38ad6e985c4b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -19,12 +19,13 @@ package org.apache.spark.ui.jobs import java.util.concurrent.TimeoutException -import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap, ListBuffer} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId @@ -93,6 +94,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val retainedStages = conf.getInt("spark.ui.retainedStages", SparkUI.DEFAULT_RETAINED_STAGES) val retainedJobs = conf.getInt("spark.ui.retainedJobs", SparkUI.DEFAULT_RETAINED_JOBS) + val retainedTasks = conf.get(UI_RETAINED_TASKS) // We can test for memory leaks by ensuring that collections that track non-active jobs and // stages do not grow without bound and that collections for active jobs/stages eventually become @@ -332,7 +334,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { new StageUIData }) stageData.numActiveTasks += 1 - stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo, Some(metrics))) + stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo, Some(metrics))) } for ( activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); @@ -395,11 +397,16 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) } - val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) - taskData.taskInfo = info - taskData.metrics = taskMetrics + val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info, None)) + taskData.updateTaskInfo(info) + taskData.updateTaskMetrics(taskMetrics) taskData.errorMessage = errorMessage + // If Tasks is too large, remove and garbage collect old tasks + if (stageData.taskData.size > retainedTasks) { + stageData.taskData = stageData.taskData.drop(stageData.taskData.size - retainedTasks) + } + for ( activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId); jobId <- activeJobsDependentOnStage; @@ -425,7 +432,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData: StageUIData, execId: String, taskMetrics: TaskMetrics, - oldMetrics: Option[TaskMetrics]) { + oldMetrics: Option[TaskMetricsUIData]) { val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) val shuffleWriteDelta = @@ -503,7 +510,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { if (!t.taskInfo.finished) { updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics) // Overwrite task metrics - t.metrics = Some(metrics) + t.updateTaskMetrics(Some(metrics)) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 5d1928ac6b2ca..d93a660d85555 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -26,7 +26,7 @@ import scala.xml.{Elem, Node, Unparsed} import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.{InternalAccumulator, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} import org.apache.spark.ui._ @@ -131,7 +131,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val stageData = stageDataOption.get val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime) - val numCompleted = tasks.count(_.taskInfo.finished) + val numCompleted = stageData.numCompleteTasks + val totalTasks = stageData.numActiveTasks + + stageData.numCompleteTasks + stageData.numFailedTasks + val totalTasksNumStr = if (totalTasks == tasks.size) { + s"$totalTasks" + } else { + s"$totalTasks, showing ${tasks.size}" + } val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } @@ -576,7 +583,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ maybeAccumulableTable ++ -

    Tasks

    ++ taskTableHTML ++ jsForScrollingDownToTaskTable +

    Tasks ({totalTasksNumStr})

    ++ + taskTableHTML ++ jsForScrollingDownToTaskTable UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } } @@ -628,9 +636,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val executorComputingTime = executorRunTime - shuffleReadTime - shuffleWriteTime val executorComputingTimeProportion = - (100 - schedulerDelayProportion - shuffleReadTimeProportion - + math.max(100 - schedulerDelayProportion - shuffleReadTimeProportion - shuffleWriteTimeProportion - serializationTimeProportion - - deserializationTimeProportion - gettingResultTimeProportion) + deserializationTimeProportion - gettingResultTimeProportion, 0) val schedulerDelayProportionPos = 0 val deserializationTimeProportionPos = @@ -746,7 +754,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { ++ } @@ -767,7 +776,7 @@ private[ui] object StagePage { } private[ui] def getSchedulerDelay( - info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = { + info: TaskInfo, metrics: TaskMetricsUIData, currentTime: Long): Long = { if (info.finished) { val totalExecutionTime = info.finishTime - info.launchTime val executorOverhead = (metrics.executorDeserializeTime + diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index b454ef1b204b2..818605003eaf4 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -18,11 +18,12 @@ package org.apache.spark.ui.jobs import scala.collection.mutable -import scala.collection.mutable.HashMap +import scala.collection.mutable.{HashMap, LinkedHashMap} import org.apache.spark.JobExecutionStatus -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.executor.{ShuffleReadMetrics, ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} +import org.apache.spark.util.AccumulatorContext import org.apache.spark.util.collection.OpenHashSet private[spark] object UIData { @@ -92,7 +93,7 @@ private[spark] object UIData { var description: Option[String] = None var accumulables = new HashMap[Long, AccumulableInfo] - var taskData = new HashMap[Long, TaskUIData] + var taskData = new LinkedHashMap[Long, TaskUIData] var executorSummary = new HashMap[String, ExecutorSummary] def hasInput: Boolean = inputBytes > 0 @@ -105,13 +106,135 @@ private[spark] object UIData { /** * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. */ - class TaskUIData( - var taskInfo: TaskInfo, - var metrics: Option[TaskMetrics] = None, - var errorMessage: Option[String] = None) + class TaskUIData private( + private var _taskInfo: TaskInfo, + private var _metrics: Option[TaskMetricsUIData]) { + + var errorMessage: Option[String] = None + + def taskInfo: TaskInfo = _taskInfo + + def metrics: Option[TaskMetricsUIData] = _metrics + + def updateTaskInfo(taskInfo: TaskInfo): Unit = { + _taskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo) + } + + def updateTaskMetrics(metrics: Option[TaskMetrics]): Unit = { + _metrics = TaskUIData.toTaskMetricsUIData(metrics) + } + } + + object TaskUIData { + def apply(taskInfo: TaskInfo, metrics: Option[TaskMetrics]): TaskUIData = { + new TaskUIData(dropInternalAndSQLAccumulables(taskInfo), toTaskMetricsUIData(metrics)) + } + + private def toTaskMetricsUIData(metrics: Option[TaskMetrics]): Option[TaskMetricsUIData] = { + metrics.map { m => + TaskMetricsUIData( + executorDeserializeTime = m.executorDeserializeTime, + executorRunTime = m.executorRunTime, + resultSize = m.resultSize, + jvmGCTime = m.jvmGCTime, + resultSerializationTime = m.resultSerializationTime, + memoryBytesSpilled = m.memoryBytesSpilled, + diskBytesSpilled = m.diskBytesSpilled, + peakExecutionMemory = m.peakExecutionMemory, + inputMetrics = InputMetricsUIData(m.inputMetrics.bytesRead, m.inputMetrics.recordsRead), + outputMetrics = + OutputMetricsUIData(m.outputMetrics.bytesWritten, m.outputMetrics.recordsWritten), + shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), + shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) + } + } + + /** + * We don't need to store internal or SQL accumulables as their values will be shown in other + * places, so drop them to reduce the memory usage. + */ + private[spark] def dropInternalAndSQLAccumulables(taskInfo: TaskInfo): TaskInfo = { + val newTaskInfo = new TaskInfo( + taskId = taskInfo.taskId, + index = taskInfo.index, + attemptNumber = taskInfo.attemptNumber, + launchTime = taskInfo.launchTime, + executorId = taskInfo.executorId, + host = taskInfo.host, + taskLocality = taskInfo.taskLocality, + speculative = taskInfo.speculative + ) + newTaskInfo.gettingResultTime = taskInfo.gettingResultTime + newTaskInfo.accumulables ++= taskInfo.accumulables.filter { + accum => !accum.internal && accum.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER) + } + newTaskInfo.finishTime = taskInfo.finishTime + newTaskInfo.failed = taskInfo.failed + newTaskInfo + } + } class ExecutorUIData( val startTime: Long, var finishTime: Option[Long] = None, var finishReason: Option[String] = None) + + case class TaskMetricsUIData( + executorDeserializeTime: Long, + executorRunTime: Long, + resultSize: Long, + jvmGCTime: Long, + resultSerializationTime: Long, + memoryBytesSpilled: Long, + diskBytesSpilled: Long, + peakExecutionMemory: Long, + inputMetrics: InputMetricsUIData, + outputMetrics: OutputMetricsUIData, + shuffleReadMetrics: ShuffleReadMetricsUIData, + shuffleWriteMetrics: ShuffleWriteMetricsUIData) + + case class InputMetricsUIData(bytesRead: Long, recordsRead: Long) + + case class OutputMetricsUIData(bytesWritten: Long, recordsWritten: Long) + + case class ShuffleReadMetricsUIData( + remoteBlocksFetched: Long, + localBlocksFetched: Long, + remoteBytesRead: Long, + localBytesRead: Long, + fetchWaitTime: Long, + recordsRead: Long, + totalBytesRead: Long, + totalBlocksFetched: Long) + + object ShuffleReadMetricsUIData { + def apply(metrics: ShuffleReadMetrics): ShuffleReadMetricsUIData = { + new ShuffleReadMetricsUIData( + remoteBlocksFetched = metrics.remoteBlocksFetched, + localBlocksFetched = metrics.localBlocksFetched, + remoteBytesRead = metrics.remoteBytesRead, + localBytesRead = metrics.localBytesRead, + fetchWaitTime = metrics.fetchWaitTime, + recordsRead = metrics.recordsRead, + totalBytesRead = metrics.totalBytesRead, + totalBlocksFetched = metrics.totalBlocksFetched + ) + } + } + + case class ShuffleWriteMetricsUIData( + bytesWritten: Long, + recordsWritten: Long, + writeTime: Long) + + object ShuffleWriteMetricsUIData { + def apply(metrics: ShuffleWriteMetrics): ShuffleWriteMetricsUIData = { + new ShuffleWriteMetricsUIData( + bytesWritten = metrics.bytesWritten, + recordsWritten = metrics.recordsWritten, + writeTime = metrics.writeTime + ) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 50095831b4a53..c212362557be6 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -59,7 +59,7 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { val rddInfos = stageSubmitted.stageInfo.rddInfos - rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) } + rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info).name = info.name } } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized { diff --git a/core/src/main/scala/org/apache/spark/NewAccumulator.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala similarity index 52% rename from core/src/main/scala/org/apache/spark/NewAccumulator.scala rename to core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 1571e15b76ac2..d3ddd39131326 100644 --- a/core/src/main/scala/org/apache/spark/NewAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.util import java.{lang => jl} import java.io.ObjectInputStream +import java.util.{ArrayList, Collections} import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong -import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ +import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} import org.apache.spark.scheduler.AccumulableInfo -import org.apache.spark.util.Utils private[spark] case class AccumulatorMetadata( @@ -38,8 +38,11 @@ private[spark] case class AccumulatorMetadata( /** * The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of * type `OUT`. + * + * `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely + * (e.g., synchronized collections) because it will be read from other threads. */ -abstract class NewAccumulator[IN, OUT] extends Serializable { +abstract class AccumulatorV2[IN, OUT] extends Serializable { private[spark] var metadata: AccumulatorMetadata = _ private[this] var atDriverSide = true @@ -57,7 +60,7 @@ abstract class NewAccumulator[IN, OUT] extends Serializable { /** * Returns true if this accumulator has been registered. Note that all accumulators must be - * registered before ues, or it will throw exception. + * registered before use, or it will throw exception. */ final def isRegistered: Boolean = metadata != null && AccumulatorContext.get(metadata.id).isDefined @@ -95,7 +98,7 @@ abstract class NewAccumulator[IN, OUT] extends Serializable { } /** - * Creates an [[AccumulableInfo]] representation of this [[NewAccumulator]] with the provided + * Creates an [[AccumulableInfo]] representation of this [[AccumulatorV2]] with the provided * values. */ private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { @@ -106,46 +109,47 @@ abstract class NewAccumulator[IN, OUT] extends Serializable { final private[spark] def isAtDriverSide: Boolean = atDriverSide /** - * Tells if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero + * Returns if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero * value; for a list accumulator, Nil is zero value. */ - def isZero(): Boolean + def isZero: Boolean /** * Creates a new copy of this accumulator, which is zero value. i.e. call `isZero` on the copy * must return true. */ - def copyAndReset(): NewAccumulator[IN, OUT] + def copyAndReset(): AccumulatorV2[IN, OUT] = { + val copyAcc = copy() + copyAcc.reset() + copyAcc + } /** - * Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator. + * Creates a new copy of this accumulator. */ - def add(v: IN): Unit + def copy(): AccumulatorV2[IN, OUT] /** - * Merges another same-type accumulator into this one and update its state, i.e. this should be - * merge-in-place. + * Resets this accumulator, which is zero value. i.e. call `isZero` must + * return true. */ - def merge(other: NewAccumulator[IN, OUT]): Unit + def reset(): Unit /** - * Access this accumulator's current value; only allowed on driver. + * Takes the inputs and accumulates. */ - final def value: OUT = { - if (atDriverSide) { - localValue - } else { - throw new UnsupportedOperationException("Can't read accumulator value in task") - } - } + def add(v: IN): Unit /** - * Defines the current value of this accumulator. - * - * This is NOT the global value of the accumulator. To get the global value after a - * completed operation on the dataset, call `value`. + * Merges another same-type accumulator into this one and update its state, i.e. this should be + * merge-in-place. */ - def localValue: OUT + def merge(other: AccumulatorV2[IN, OUT]): Unit + + /** + * Defines the current value of this accumulator + */ + def value: OUT // Called by Java when serializing an object final protected def writeReplace(): Any = { @@ -154,10 +158,10 @@ abstract class NewAccumulator[IN, OUT] extends Serializable { throw new UnsupportedOperationException( "Accumulator must be registered before send to executor") } - val copy = copyAndReset() - assert(copy.isZero(), "copyAndReset must return a zero value copy") - copy.metadata = metadata - copy + val copyAcc = copyAndReset() + assert(copyAcc.isZero, "copyAndReset must return a zero value copy") + copyAcc.metadata = metadata + copyAcc } else { this } @@ -185,12 +189,15 @@ abstract class NewAccumulator[IN, OUT] extends Serializable { if (metadata == null) { "Un-registered Accumulator: " + getClass.getSimpleName } else { - getClass.getSimpleName + s"(id: $id, name: $name, value: $localValue)" + getClass.getSimpleName + s"(id: $id, name: $name, value: $value)" } } } +/** + * An internal class used to track accumulators by Spark itself. + */ private[spark] object AccumulatorContext { /** @@ -199,44 +206,45 @@ private[spark] object AccumulatorContext { * once the RDDs and user-code that reference them are cleaned up. * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051). */ - private val originals = new ConcurrentHashMap[Long, jl.ref.WeakReference[NewAccumulator[_, _]]] + private val originals = new ConcurrentHashMap[Long, jl.ref.WeakReference[AccumulatorV2[_, _]]] private[this] val nextId = new AtomicLong(0L) /** - * Return a globally unique ID for a new [[Accumulator]]. - * Note: Once you copy the [[Accumulator]] the ID is no longer unique. + * Returns a globally unique ID for a new [[AccumulatorV2]]. + * Note: Once you copy the [[AccumulatorV2]] the ID is no longer unique. */ def newId(): Long = nextId.getAndIncrement + /** Returns the number of accumulators registered. Used in testing. */ def numAccums: Int = originals.size /** - * Register an [[Accumulator]] created on the driver such that it can be used on the executors. + * Registers an [[AccumulatorV2]] created on the driver such that it can be used on the executors. * * All accumulators registered here can later be used as a container for accumulating partial * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does. * Note: if an accumulator is registered here, it should also be registered with the active * context cleaner for cleanup so as to avoid memory leaks. * - * If an [[Accumulator]] with the same ID was already registered, this does nothing instead + * If an [[AccumulatorV2]] with the same ID was already registered, this does nothing instead * of overwriting it. We will never register same accumulator twice, this is just a sanity check. */ - def register(a: NewAccumulator[_, _]): Unit = { - originals.putIfAbsent(a.id, new jl.ref.WeakReference[NewAccumulator[_, _]](a)) + def register(a: AccumulatorV2[_, _]): Unit = { + originals.putIfAbsent(a.id, new jl.ref.WeakReference[AccumulatorV2[_, _]](a)) } /** - * Unregister the [[Accumulator]] with the given ID, if any. + * Unregisters the [[AccumulatorV2]] with the given ID, if any. */ def remove(id: Long): Unit = { originals.remove(id) } /** - * Return the [[Accumulator]] registered with the given ID, if any. + * Returns the [[AccumulatorV2]] registered with the given ID, if any. */ - def get(id: Long): Option[NewAccumulator[_, _]] = { + def get(id: Long): Option[AccumulatorV2[_, _]] = { Option(originals.get(id)).map { ref => // Since we are storing weak references, we must check whether the underlying data is valid. val acc = ref.get @@ -248,118 +256,213 @@ private[spark] object AccumulatorContext { } /** - * Clear all registered [[Accumulator]]s. For testing only. + * Clears all registered [[AccumulatorV2]]s. For testing only. */ def clear(): Unit = { originals.clear() } -} - -class LongAccumulator extends NewAccumulator[jl.Long, jl.Long] { - private[this] var _sum = 0L - - override def isZero(): Boolean = _sum == 0 + /** + * Looks for a registered accumulator by accumulator name. + */ + private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { + originals.values().asScala.find { ref => + val acc = ref.get + acc != null && acc.name.isDefined && acc.name.get == name + }.map(_.get) + } - override def copyAndReset(): LongAccumulator = new LongAccumulator + // Identifier for distinguishing SQL metrics from other accumulators + private[spark] val SQL_ACCUM_IDENTIFIER = "sql" +} - override def add(v: jl.Long): Unit = _sum += v - def add(v: Long): Unit = _sum += v +/** + * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers. + * + * @since 2.0.0 + */ +class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { + private var _sum = 0L + private var _count = 0L - def sum: Long = _sum + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + override def isZero: Boolean = _sum == 0L && _count == 0 - override def merge(other: NewAccumulator[jl.Long, jl.Long]): Unit = other match { - case o: LongAccumulator => _sum += o.sum - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + override def copy(): LongAccumulator = { + val newAcc = new LongAccumulator + newAcc._count = this._count + newAcc._sum = this._sum + newAcc } - private[spark] def setValue(newValue: Long): Unit = _sum = newValue - - override def localValue: jl.Long = _sum -} - - -class DoubleAccumulator extends NewAccumulator[jl.Double, jl.Double] { - private[this] var _sum = 0.0 + override def reset(): Unit = { + _sum = 0L + _count = 0L + } - override def isZero(): Boolean = _sum == 0.0 + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + override def add(v: jl.Long): Unit = { + _sum += v + _count += 1 + } - override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + def add(v: Long): Unit = { + _sum += v + _count += 1 + } - override def add(v: jl.Double): Unit = _sum += v + /** + * Returns the number of elements added to the accumulator. + * @since 2.0.0 + */ + def count: Long = _count - def add(v: Double): Unit = _sum += v + /** + * Returns the sum of elements added to the accumulator. + * @since 2.0.0 + */ + def sum: Long = _sum - def sum: Double = _sum + /** + * Returns the average of elements added to the accumulator. + * @since 2.0.0 + */ + def avg: Double = _sum.toDouble / _count - override def merge(other: NewAccumulator[jl.Double, jl.Double]): Unit = other match { - case o: DoubleAccumulator => _sum += o.sum - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match { + case o: LongAccumulator => + _sum += o.sum + _count += o.count + case _ => + throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - private[spark] def setValue(newValue: Double): Unit = _sum = newValue + private[spark] def setValue(newValue: Long): Unit = _sum = newValue - override def localValue: jl.Double = _sum + override def value: jl.Long = _sum } -class AverageAccumulator extends NewAccumulator[jl.Double, jl.Double] { - private[this] var _sum = 0.0 - private[this] var _count = 0L +/** + * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision + * floating numbers. + * + * @since 2.0.0 + */ +class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { + private var _sum = 0.0 + private var _count = 0L - override def isZero(): Boolean = _sum == 0.0 && _count == 0 + override def isZero: Boolean = _sum == 0.0 && _count == 0 - override def copyAndReset(): AverageAccumulator = new AverageAccumulator + override def copy(): DoubleAccumulator = { + val newAcc = new DoubleAccumulator + newAcc._count = this._count + newAcc._sum = this._sum + newAcc + } + override def reset(): Unit = { + _sum = 0.0 + _count = 0L + } + + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ override def add(v: jl.Double): Unit = { _sum += v _count += 1 } - def add(d: Double): Unit = { - _sum += d + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + def add(v: Double): Unit = { + _sum += v _count += 1 } - override def merge(other: NewAccumulator[jl.Double, jl.Double]): Unit = other match { - case o: AverageAccumulator => + /** + * Returns the number of elements added to the accumulator. + * @since 2.0.0 + */ + def count: Long = _count + + /** + * Returns the sum of elements added to the accumulator. + * @since 2.0.0 + */ + def sum: Double = _sum + + /** + * Returns the average of elements added to the accumulator. + * @since 2.0.0 + */ + def avg: Double = _sum / _count + + override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match { + case o: DoubleAccumulator => _sum += o.sum _count += o.count - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - - override def localValue: jl.Double = if (_count == 0) { - Double.NaN - } else { - _sum / _count + case _ => + throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - def sum: Double = _sum + private[spark] def setValue(newValue: Double): Unit = _sum = newValue - def count: Long = _count + override def value: jl.Double = _sum } -class ListAccumulator[T] extends NewAccumulator[T, java.util.List[T]] { - private[this] val _list: java.util.List[T] = new java.util.ArrayList[T] +/** + * An [[AccumulatorV2 accumulator]] for collecting a list of elements. + * + * @since 2.0.0 + */ +class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] { + private val _list: java.util.List[T] = Collections.synchronizedList(new ArrayList[T]()) - override def isZero(): Boolean = _list.isEmpty + override def isZero: Boolean = _list.isEmpty - override def copyAndReset(): ListAccumulator[T] = new ListAccumulator + override def copyAndReset(): CollectionAccumulator[T] = new CollectionAccumulator + + override def copy(): CollectionAccumulator[T] = { + val newAcc = new CollectionAccumulator[T] + _list.synchronized { + newAcc._list.addAll(_list) + } + newAcc + } + + override def reset(): Unit = _list.clear() override def add(v: T): Unit = _list.add(v) - override def merge(other: NewAccumulator[T, java.util.List[T]]): Unit = other match { - case o: ListAccumulator[T] => _list.addAll(o.localValue) + override def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit = other match { + case o: CollectionAccumulator[T] => _list.addAll(o.value) case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def localValue: java.util.List[T] = java.util.Collections.unmodifiableList(_list) + override def value: java.util.List[T] = _list.synchronized { + java.util.Collections.unmodifiableList(new ArrayList[T](_list)) + } private[spark] def setValue(newValue: java.util.List[T]): Unit = { _list.clear() @@ -370,24 +473,28 @@ class ListAccumulator[T] extends NewAccumulator[T, java.util.List[T]] { class LegacyAccumulatorWrapper[R, T]( initialValue: R, - param: org.apache.spark.AccumulableParam[R, T]) extends NewAccumulator[T, R] { + param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { private[spark] var _value = initialValue // Current value on driver - override def isZero(): Boolean = _value == param.zero(initialValue) + override def isZero: Boolean = _value == param.zero(initialValue) - override def copyAndReset(): LegacyAccumulatorWrapper[R, T] = { + override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) - acc._value = param.zero(initialValue) + acc._value = _value acc } + override def reset(): Unit = { + _value = param.zero(initialValue) + } + override def add(v: T): Unit = _value = param.addAccumulator(_value, v) - override def merge(other: NewAccumulator[T, R]): Unit = other match { - case o: LegacyAccumulatorWrapper[R, T] => _value = param.addInPlace(_value, o.localValue) + override def merge(other: AccumulatorV2[T, R]): Unit = other match { + case o: LegacyAccumulatorWrapper[R, T] => _value = param.addInPlace(_value, o.value) case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def localValue: R = _value + override def value: R = _value } diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index 1fc0ad7a4d6d3..7def44bd2a2b1 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -17,10 +17,14 @@ package org.apache.spark.util +import java.io.{OutputStream, PrintStream} + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ import scala.util.Try +import org.apache.commons.io.output.TeeOutputStream import org.apache.commons.lang3.SystemUtils /** @@ -33,21 +37,40 @@ import org.apache.commons.lang3.SystemUtils * * The benchmark function takes one argument that is the iteration that's being run. * - * If outputPerIteration is true, the timing for each run will be printed to stdout. + * @param name name of this benchmark. + * @param valuesPerIteration number of values used in the test case, used to compute rows/s. + * @param minNumIters the min number of iterations that will be run per case, not counting warm-up. + * @param warmupTime amount of time to spend running dummy case iterations for JIT warm-up. + * @param minTime further iterations will be run for each case until this time is used up. + * @param outputPerIteration if true, the timing for each run will be printed to stdout. + * @param output optional output stream to write benchmark results to */ private[spark] class Benchmark( name: String, valuesPerIteration: Long, - iters: Int = 5, - outputPerIteration: Boolean = false) { + minNumIters: Int = 2, + warmupTime: FiniteDuration = 2.seconds, + minTime: FiniteDuration = 2.seconds, + outputPerIteration: Boolean = false, + output: Option[OutputStream] = None) { + import Benchmark._ val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case] + val out = if (output.isDefined) { + new PrintStream(new TeeOutputStream(System.out, output.get)) + } else { + System.out + } + /** * Adds a case to run when run() is called. The given function will be run for several * iterations to collect timing statistics. + * + * @param name of the benchmark case + * @param numIters if non-zero, forces exactly this many iterations to be run */ - def addCase(name: String)(f: Int => Unit): Unit = { - addTimerCase(name) { timer => + def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = { + addTimerCase(name, numIters) { timer => timer.startTiming() f(timer.iteration) timer.stopTiming() @@ -58,9 +81,12 @@ private[spark] class Benchmark( * Adds a case with manual timing control. When the function is run, timing does not start * until timer.startTiming() is called within the given function. The corresponding * timer.stopTiming() method must be called before the function returns. + * + * @param name of the benchmark case + * @param numIters if non-zero, forces exactly this many iterations to be run */ - def addTimerCase(name: String)(f: Benchmark.Timer => Unit): Unit = { - benchmarks += Benchmark.Case(name, f) + def addTimerCase(name: String, numIters: Int = 0)(f: Benchmark.Timer => Unit): Unit = { + benchmarks += Benchmark.Case(name, f, numIters) } /** @@ -75,29 +101,63 @@ private[spark] class Benchmark( val results = benchmarks.map { c => println(" Running case: " + c.name) - Benchmark.measure(valuesPerIteration, iters, outputPerIteration)(c.fn) + measure(valuesPerIteration, c.numIters)(c.fn) } println val firstBest = results.head.bestMs // The results are going to be processor specific so it is useful to include that. - println(Benchmark.getJVMOSInfo()) - println(Benchmark.getProcessorName()) - printf("%-35s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)", + out.println(Benchmark.getJVMOSInfo()) + out.println(Benchmark.getProcessorName()) + out.printf("%-40s %16s %12s %13s %10s\n", name + ":", "Best/Avg Time(ms)", "Rate(M/s)", "Per Row(ns)", "Relative") - println("-----------------------------------------------------------------------------------" + - "--------") + out.println("-" * 96) results.zip(benchmarks).foreach { case (result, benchmark) => - printf("%-35s %16s %12s %13s %10s\n", + out.printf("%-40s %16s %12s %13s %10s\n", benchmark.name, "%5.0f / %4.0f" format (result.bestMs, result.avgMs), "%10.1f" format result.bestRate, "%6.1f" format (1000 / result.bestRate), "%3.1fX" format (firstBest / result.bestMs)) } - println + out.println // scalastyle:on } + + /** + * Runs a single function `f` for iters, returning the average time the function took and + * the rate of the function. + */ + def measure(num: Long, overrideNumIters: Int)(f: Timer => Unit): Result = { + System.gc() // ensures garbage from previous cases don't impact this one + val warmupDeadline = warmupTime.fromNow + while (!warmupDeadline.isOverdue) { + f(new Benchmark.Timer(-1)) + } + val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters + val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos + val runTimes = ArrayBuffer[Long]() + var i = 0 + while (i < minIters || runTimes.sum < minDuration) { + val timer = new Benchmark.Timer(i) + f(timer) + val runTime = timer.totalTime() + runTimes += runTime + + if (outputPerIteration) { + // scalastyle:off + println(s"Iteration $i took ${runTime / 1000} microseconds") + // scalastyle:on + } + i += 1 + } + // scalastyle:off + println(s" Stopped after $i iterations, ${runTimes.sum / 1000000} ms") + // scalastyle:on + val best = runTimes.min + val avg = runTimes.sum / runTimes.size + Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0) + } } private[spark] object Benchmark { @@ -128,7 +188,7 @@ private[spark] object Benchmark { } } - case class Case(name: String, fn: Timer => Unit) + case class Case(name: String, fn: Timer => Unit, numIters: Int) case class Result(avgMs: Double, bestRate: Double, bestMs: Double) /** @@ -162,30 +222,4 @@ private[spark] object Benchmark { val osVersion = System.getProperty("os.version") s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}" } - - /** - * Runs a single function `f` for iters, returning the average time the function took and - * the rate of the function. - */ - def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Timer => Unit): Result = { - val runTimes = ArrayBuffer[Long]() - for (i <- 0 until iters + 1) { - val timer = new Benchmark.Timer(i) - f(timer) - val runTime = timer.totalTime() - if (i > 0) { - runTimes += runTime - } - - if (outputPerIteration) { - // scalastyle:off - println(s"Iteration $i took ${runTime / 1000} microseconds") - // scalastyle:on - } - } - val best = runTimes.min - val avg = runTimes.sum / iters - Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0) - } } - diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala index 09e7579ae9606..9077b86f9ba1d 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferOutputStream.scala @@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp def getCount(): Int = count + private[this] var closed: Boolean = false + + override def write(b: Int): Unit = { + require(!closed, "cannot write to a closed ByteBufferOutputStream") + super.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + require(!closed, "cannot write to a closed ByteBufferOutputStream") + super.write(b, off, len) + } + + override def reset(): Unit = { + require(!closed, "cannot reset a closed ByteBufferOutputStream") + super.reset() + } + + override def close(): Unit = { + if (!closed) { + super.close() + closed = true + } + } + def toByteBuffer: ByteBuffer = { - return ByteBuffer.wrap(buf, 0, count) + require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed") + ByteBuffer.wrap(buf, 0, count) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index aeab71d9df603..8861696cd97d7 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -280,7 +280,7 @@ private[spark] object JsonProtocol { ("Getting Result Time" -> taskInfo.gettingResultTime) ~ ("Finish Time" -> taskInfo.finishTime) ~ ("Failed" -> taskInfo.failed) ~ - ("Accumulables" -> JArray(taskInfo.accumulables.map(accumulableInfoToJson).toList)) + ("Accumulables" -> JArray(taskInfo.accumulables.toList.map(accumulableInfoToJson))) } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { @@ -309,11 +309,12 @@ private[spark] object JsonProtocol { case v: Int => JInt(v) case v: Long => JInt(v) // We only have 3 kind of internal accumulator types, so if it's not int or long, it must be - // the blocks accumulator, whose type is `Seq[(BlockId, BlockStatus)]` + // the blocks accumulator, whose type is `java.util.List[(BlockId, BlockStatus)]` case v => - JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) => - ("Block ID" -> id.toString) ~ - ("Status" -> blockStatusToJson(status)) + JArray(v.asInstanceOf[java.util.List[(BlockId, BlockStatus)]].asScala.toList.map { + case (id, status) => + ("Block ID" -> id.toString) ~ + ("Status" -> blockStatusToJson(status)) }) } } else { @@ -740,7 +741,7 @@ private[spark] object JsonProtocol { val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") (id, status) - } + }.asJava case _ => throw new IllegalArgumentException(s"unexpected json value $value for " + "accumulator " + name.get) } @@ -841,7 +842,7 @@ private[spark] object JsonProtocol { val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => { - acc.toInfo(Some(acc.localValue), None) + acc.toInfo(Some(acc.value), None) })) ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 436c1951dee2f..79fc2e94599c7 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -54,7 +54,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { */ final def postToAll(event: E): Unit = { // JavaConverters can create a JIterableWrapper if we use asScala. - // However, this method will be called frequently. To avoid the wrapper cost, here ewe use + // However, this method will be called frequently. To avoid the wrapper cost, here we use // Java Iterator directly. val iter = listeners.iterator while (iter.hasNext) { @@ -70,7 +70,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { /** * Post an event to the specified listener. `onPostEvent` is guaranteed to be called in the same - * thread. + * thread for all listeners. */ protected def doPostEvent(listener: L, event: E): Unit diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 0a3180da87987..034826c57ef1d 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -19,7 +19,6 @@ package org.apache.spark.util import java.net.{URL, URLClassLoader} import java.util.Enumeration -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ @@ -48,32 +47,12 @@ private[spark] class ChildFirstURLClassLoader(urls: Array[URL], parent: ClassLoa private val parentClassLoader = new ParentClassLoader(parent) - /** - * Used to implement fine-grained class loading locks similar to what is done by Java 7. This - * prevents deadlock issues when using non-hierarchical class loaders. - * - * Note that due to some issues with implementing class loaders in - * Scala, Java 7's `ClassLoader.registerAsParallelCapable` method is not called. - */ - private val locks = new ConcurrentHashMap[String, Object]() - override def loadClass(name: String, resolve: Boolean): Class[_] = { - var lock = locks.get(name) - if (lock == null) { - val newLock = new Object() - lock = locks.putIfAbsent(name, newLock) - if (lock == null) { - lock = newLock - } - } - - lock.synchronized { - try { - super.loadClass(name, resolve) - } catch { - case e: ClassNotFoundException => - parentClassLoader.loadClass(name, resolve) - } + try { + super.loadClass(name, resolve) + } catch { + case e: ClassNotFoundException => + parentClassLoader.loadClass(name, resolve) } } diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index bd26bfd848ff1..93ac67e5db0d7 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -170,9 +170,7 @@ private [util] class SparkShutdownHookManager { @volatile private var shuttingDown = false /** - * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not - * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for - * the best. + * Install a hook to run at shutdown and run all registered hooks in order. */ def install(): Unit = { val hookTask = new Runnable() { diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 5a6dbc830448a..d093e7bfc3dac 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -194,4 +194,25 @@ private[spark] object ThreadUtils { throw new SparkException("Exception thrown in awaitResult: ", t) } } + + /** + * Calls [[Awaitable.result]] directly to avoid using `ForkJoinPool`'s `BlockingContext`, wraps + * and re-throws any exceptions with nice stack track. + * + * Codes running in the user's thread may be in a thread of Scala ForkJoinPool. As concurrent + * executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this method + * basically prevents ForkJoinPool from running other tasks in the current waiting thread. + */ + @throws(classOf[SparkException]) + def awaitResultInForkJoinSafely[T](awaitable: Awaitable[T], atMost: Duration): T = { + try { + // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. + // See SPARK-13747. + val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] + awaitable.result(Duration.Inf)(awaitPermission) + } catch { + case NonFatal(t) => + throw new SparkException("Exception thrown in awaitResult: ", t) + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index 4dcf95177aa78..f0b68f0cb7e29 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -89,13 +89,6 @@ private[spark] class UninterruptibleThread(name: String) extends Thread(name) { } } - /** - * Tests whether `interrupt()` has been called. - */ - override def isInterrupted: Boolean = { - super.isInterrupted || uninterruptibleLock.synchronized { shouldInterruptThread } - } - /** * Interrupt `this` thread if possible. If `this` is in the uninterruptible status, it won't be * interrupted until it enters into the interruptible status. diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ea49991493fd7..b9cf721d9f54c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -26,6 +26,8 @@ import java.nio.charset.StandardCharsets import java.nio.file.Files import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean +import java.util.zip.GZIPInputStream import javax.net.ssl.HttpsURLConnection import scala.annotation.tailrec @@ -37,6 +39,7 @@ import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils @@ -51,8 +54,10 @@ import org.slf4j.Logger import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{DYN_ALLOCATION_INITIAL_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS, EXECUTOR_INSTANCES} import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} +import org.apache.spark.util.logging.RollingFileAppender /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -78,6 +83,52 @@ private[spark] object Utils extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null + /** + * The performance overhead of creating and logging strings for wide schemas can be large. To + * limit the impact, we bound the number of fields to include by default. This can be overriden + * by setting the 'spark.debug.maxToStringFields' conf in SparkEnv. + */ + val DEFAULT_MAX_TO_STRING_FIELDS = 25 + + private def maxNumToStringFields = { + if (SparkEnv.get != null) { + SparkEnv.get.conf.getInt("spark.debug.maxToStringFields", DEFAULT_MAX_TO_STRING_FIELDS) + } else { + DEFAULT_MAX_TO_STRING_FIELDS + } + } + + /** Whether we have warned about plan string truncation yet. */ + private val truncationWarningPrinted = new AtomicBoolean(false) + + /** + * Format a sequence with semantics similar to calling .mkString(). Any elements beyond + * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. + * + * @return the trimmed and formatted string. + */ + def truncatedString[T]( + seq: Seq[T], + start: String, + sep: String, + end: String, + maxNumFields: Int = maxNumToStringFields): String = { + if (seq.length > maxNumFields) { + if (truncationWarningPrinted.compareAndSet(false, true)) { + logWarning( + "Truncated the string representation of a plan since it was too large. This " + + "behavior can be adjusted by setting 'spark.debug.maxToStringFields' in SparkEnv.conf.") + } + val numFields = math.max(0, maxNumFields - 1) + seq.take(numFields).mkString( + start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) + } else { + seq.mkString(start, sep, end) + } + } + + /** Shorthand for calling truncatedString() without start or end strings. */ + def truncatedString[T](seq: Seq[T], sep: String): String = truncatedString(seq, "", sep, "") /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { @@ -649,6 +700,26 @@ private[spark] object Utils extends Logging { } } + /** + * Validate that a given URI is actually a valid URL as well. + * @param uri The URI to validate + */ + @throws[MalformedURLException]("when the URI is an invalid URL") + def validateURL(uri: URI): Unit = { + Option(uri.getScheme).getOrElse("file") match { + case "http" | "https" | "ftp" => + try { + uri.toURL + } catch { + case e: MalformedURLException => + val ex = new MalformedURLException(s"URI (${uri.toString}) is not a valid URL.") + ex.initCause(e) + throw ex + } + case _ => // will not be turned into a URL anyway + } + } + /** * Get the path of a temporary directory. Spark's local directories can be configured through * multiple settings, which are used with the following precedence: @@ -776,7 +847,7 @@ private[spark] object Utils extends Logging { */ def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { for (i <- (arr.length - 1) to 1 by -1) { - val j = rand.nextInt(i) + val j = rand.nextInt(i + 1) val tmp = arr(j) arr(j) = arr(i) arr(i) = tmp @@ -1169,7 +1240,7 @@ private[spark] object Utils extends Logging { } /** - * Execute a block of code that evaluates to Unit, stop SparkContext is there is any uncaught + * Execute a block of code that evaluates to Unit, stop SparkContext if there is any uncaught * exception * * NOTE: This method is to be called by the driver-side components to avoid stopping the @@ -1380,14 +1451,72 @@ private[spark] object Utils extends Logging { CallSite(shortForm, longForm) } + private val UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF = + "spark.worker.ui.compressedLogFileLengthCacheSize" + private val DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE = 100 + private var compressedLogFileLengthCache: LoadingCache[String, java.lang.Long] = null + private def getCompressedLogFileLengthCache( + sparkConf: SparkConf): LoadingCache[String, java.lang.Long] = this.synchronized { + if (compressedLogFileLengthCache == null) { + val compressedLogFileLengthCacheSize = sparkConf.getInt( + UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE_CONF, + DEFAULT_UNCOMPRESSED_LOG_FILE_LENGTH_CACHE_SIZE) + compressedLogFileLengthCache = CacheBuilder.newBuilder() + .maximumSize(compressedLogFileLengthCacheSize) + .build[String, java.lang.Long](new CacheLoader[String, java.lang.Long]() { + override def load(path: String): java.lang.Long = { + Utils.getCompressedFileLength(new File(path)) + } + }) + } + compressedLogFileLengthCache + } + + /** + * Return the file length, if the file is compressed it returns the uncompressed file length. + * It also caches the uncompressed file size to avoid repeated decompression. The cache size is + * read from workerConf. + */ + def getFileLength(file: File, workConf: SparkConf): Long = { + if (file.getName.endsWith(".gz")) { + getCompressedLogFileLengthCache(workConf).get(file.getAbsolutePath) + } else { + file.length + } + } + + /** Return uncompressed file length of a compressed file. */ + private def getCompressedFileLength(file: File): Long = { + try { + // Uncompress .gz file to determine file size. + var fileSize = 0L + val gzInputStream = new GZIPInputStream(new FileInputStream(file)) + val bufSize = 1024 + val buf = new Array[Byte](bufSize) + var numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize) + while (numBytes > 0) { + fileSize += numBytes + numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize) + } + fileSize + } catch { + case e: Throwable => + logError(s"Cannot get file length of ${file}", e) + throw e + } + } + /** Return a string containing part of a file from byte 'start' to 'end'. */ - def offsetBytes(path: String, start: Long, end: Long): String = { + def offsetBytes(path: String, length: Long, start: Long, end: Long): String = { val file = new File(path) - val length = file.length() val effectiveEnd = math.min(length, end) val effectiveStart = math.max(0, start) val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt) - val stream = new FileInputStream(file) + val stream = if (path.endsWith(".gz")) { + new GZIPInputStream(new FileInputStream(file)) + } else { + new FileInputStream(file) + } try { ByteStreams.skipFully(stream, effectiveStart) @@ -1403,8 +1532,8 @@ private[spark] object Utils extends Logging { * and `endIndex` is based on the cumulative size of all the files take in * the given order. See figure below for more details. */ - def offsetBytes(files: Seq[File], start: Long, end: Long): String = { - val fileLengths = files.map { _.length } + def offsetBytes(files: Seq[File], fileLengths: Seq[Long], start: Long, end: Long): String = { + assert(files.length == fileLengths.length) val startIndex = math.max(start, 0) val endIndex = math.min(end, fileLengths.sum) val fileToLength = files.zip(fileLengths).toMap @@ -1412,7 +1541,7 @@ private[spark] object Utils extends Logging { val stringBuffer = new StringBuffer((endIndex - startIndex).toInt) var sum = 0L - for (file <- files) { + files.zip(fileLengths).foreach { case (file, fileLength) => val startIndexOfFile = sum val endIndexOfFile = sum + fileToLength(file) logDebug(s"Processing file $file, " + @@ -1431,19 +1560,19 @@ private[spark] object Utils extends Logging { if (startIndex <= startIndexOfFile && endIndex >= endIndexOfFile) { // Case C: read the whole file - stringBuffer.append(offsetBytes(file.getAbsolutePath, 0, fileToLength(file))) + stringBuffer.append(offsetBytes(file.getAbsolutePath, fileLength, 0, fileToLength(file))) } else if (startIndex > startIndexOfFile && startIndex < endIndexOfFile) { // Case A and B: read from [start of required range] to [end of file / end of range] val effectiveStartIndex = startIndex - startIndexOfFile val effectiveEndIndex = math.min(endIndex - startIndexOfFile, fileToLength(file)) stringBuffer.append(Utils.offsetBytes( - file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) + file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex)) } else if (endIndex > startIndexOfFile && endIndex < endIndexOfFile) { // Case D: read from [start of file] to [end of require range] val effectiveStartIndex = math.max(startIndex - startIndexOfFile, 0) val effectiveEndIndex = endIndex - startIndexOfFile stringBuffer.append(Utils.offsetBytes( - file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex)) + file.getAbsolutePath, fileLength, effectiveStartIndex, effectiveEndIndex)) } sum += fileToLength(file) logDebug(s"After processing file $file, string built is ${stringBuffer.toString}") @@ -1638,6 +1767,21 @@ private[spark] object Utils extends Logging { count } + /** + * Generate a zipWithIndex iterator, avoid index value overflowing problem + * in scala's zipWithIndex + */ + def getIteratorZipWithIndex[T](iterator: Iterator[T], startIndex: Long): Iterator[(T, Long)] = { + new Iterator[(T, Long)] { + var index: Long = startIndex - 1L + def hasNext: Boolean = iterator.hasNext + def next(): (T, Long) = { + index += 1L + (iterator.next(), index) + } + } + } + /** * Creates a symlink. * @@ -1724,50 +1868,66 @@ private[spark] object Utils extends Logging { } /** - * Terminates a process waiting for at most the specified duration. Returns whether - * the process terminated. + * Terminates a process waiting for at most the specified duration. + * + * @return the process exit value if it was successfully terminated, else None */ def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = { - try { - // Java8 added a new API which will more forcibly kill the process. Use that if available. - val destroyMethod = process.getClass().getMethod("destroyForcibly"); - destroyMethod.setAccessible(true) - destroyMethod.invoke(process) - } catch { - case NonFatal(e) => - if (!e.isInstanceOf[NoSuchMethodException]) { - logWarning("Exception when attempting to kill process", e) - } - process.destroy() - } + // Politely destroy first + process.destroy() + if (waitForProcess(process, timeoutMs)) { + // Successful exit Option(process.exitValue()) } else { - None + // Java 8 added a new API which will more forcibly kill the process. Use that if available. + try { + classOf[Process].getMethod("destroyForcibly").invoke(process) + } catch { + case _: NoSuchMethodException => return None // Not available; give up + case NonFatal(e) => logWarning("Exception when attempting to kill process", e) + } + // Wait, again, although this really should return almost immediately + if (waitForProcess(process, timeoutMs)) { + Option(process.exitValue()) + } else { + logWarning("Timed out waiting to forcibly kill process") + None + } } } /** * Wait for a process to terminate for at most the specified duration. - * Return whether the process actually terminated after the given timeout. + * + * @return whether the process actually terminated before the given timeout. */ def waitForProcess(process: Process, timeoutMs: Long): Boolean = { - var terminated = false - val startTime = System.currentTimeMillis - while (!terminated) { - try { - process.exitValue() - terminated = true - } catch { - case e: IllegalThreadStateException => - // Process not terminated yet - if (System.currentTimeMillis - startTime > timeoutMs) { - return false + try { + // Use Java 8 method if available + classOf[Process].getMethod("waitFor", java.lang.Long.TYPE, classOf[TimeUnit]) + .invoke(process, timeoutMs.asInstanceOf[java.lang.Long], TimeUnit.MILLISECONDS) + .asInstanceOf[Boolean] + } catch { + case _: NoSuchMethodException => + // Otherwise implement it manually + var terminated = false + val startTime = System.currentTimeMillis + while (!terminated) { + try { + process.exitValue() + terminated = true + } catch { + case e: IllegalThreadStateException => + // Process not terminated yet + if (System.currentTimeMillis - startTime > timeoutMs) { + return false + } + Thread.sleep(100) } - Thread.sleep(100) - } + } + true } - true } /** @@ -1817,7 +1977,11 @@ private[spark] object Utils extends Logging { /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */ def isFatalError(e: Throwable): Boolean = { e match { - case NonFatal(_) | _: InterruptedException | _: NotImplementedError | _: ControlThrowable => + case NonFatal(_) | + _: InterruptedException | + _: NotImplementedError | + _: ControlThrowable | + _: LinkageError => false case _ => true @@ -2181,6 +2345,25 @@ private[spark] object Utils extends Logging { .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) } + val EMPTY_USER_GROUPS = Set[String]() + + // Returns the groups to which the current user belongs. + def getCurrentUserGroups(sparkConf: SparkConf, username: String): Set[String] = { + val groupProviderClassName = sparkConf.get("spark.user.groups.mapping", + "org.apache.spark.security.ShellBasedGroupsMappingProvider") + if (groupProviderClassName != "") { + try { + val groupMappingServiceProvider = classForName(groupProviderClassName).newInstance. + asInstanceOf[org.apache.spark.security.GroupMappingServiceProvider] + val currentUserGroups = groupMappingServiceProvider.getGroups(username) + return currentUserGroups + } catch { + case e: Exception => logError(s"Error getting groups for user=$username", e) + } + } + EMPTY_USER_GROUPS + } + /** * Split the comma delimited string of master URLs into a list. * For instance, "spark://abc,def" becomes [spark://abc, spark://def]. @@ -2243,21 +2426,41 @@ private[spark] object Utils extends Logging { } /** - * Return whether dynamic allocation is enabled in the given conf - * Dynamic allocation and explicitly setting the number of executors are inherently - * incompatible. In environments where dynamic allocation is turned on by default, - * the latter should override the former (SPARK-9092). + * Return whether dynamic allocation is enabled in the given conf. */ def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { - val numExecutor = conf.getInt("spark.executor.instances", 0) val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) - if (numExecutor != 0 && dynamicAllocationEnabled) { - logWarning("Dynamic Allocation and num executors both set, thus dynamic allocation disabled.") - } - numExecutor == 0 && dynamicAllocationEnabled && + dynamicAllocationEnabled && (!isLocalMaster(conf) || conf.getBoolean("spark.dynamicAllocation.testing", false)) } + /** + * Return the initial number of executors for dynamic allocation. + */ + def getDynamicAllocationInitialExecutors(conf: SparkConf): Int = { + if (conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS) < conf.get(DYN_ALLOCATION_MIN_EXECUTORS)) { + logWarning(s"${DYN_ALLOCATION_INITIAL_EXECUTORS.key} less than " + + s"${DYN_ALLOCATION_MIN_EXECUTORS.key} is invalid, ignoring its setting, " + + "please update your configs.") + } + + if (conf.get(EXECUTOR_INSTANCES).getOrElse(0) < conf.get(DYN_ALLOCATION_MIN_EXECUTORS)) { + logWarning(s"${EXECUTOR_INSTANCES.key} less than " + + s"${DYN_ALLOCATION_MIN_EXECUTORS.key} is invalid, ignoring its setting, " + + "please update your configs.") + } + + val initialExecutors = Seq( + conf.get(DYN_ALLOCATION_MIN_EXECUTORS), + conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS), + conf.get(EXECUTOR_INSTANCES).getOrElse(0)).max + + logInfo(s"Using initial executors = $initialExecutors, max of " + + s"${DYN_ALLOCATION_INITIAL_EXECUTORS.key}, ${DYN_ALLOCATION_MIN_EXECUTORS.key} and " + + s"${EXECUTOR_INSTANCES.key}") + initialExecutors + } + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { val resource = createResource try f.apply(resource) finally resource.close() @@ -2286,6 +2489,31 @@ private[spark] object Utils extends Logging { log.info(s"Started daemon with process name: ${Utils.getProcessName()}") SignalUtils.registerLogger(log) } + + /** + * Unions two comma-separated lists of files and filters out empty strings. + */ + def unionFileLists(leftList: Option[String], rightList: Option[String]): Set[String] = { + var allFiles = Set[String]() + leftList.foreach { value => allFiles ++= value.split(",") } + rightList.foreach { value => allFiles ++= value.split(",") } + allFiles.filter { _.nonEmpty } + } + + /** + * In YARN mode this method returns a union of the jar files pointed by "spark.jars" and the + * "spark.yarn.dist.jars" properties, while in other modes it returns the jar files pointed by + * only the "spark.jars" property. + */ + def getUserJars(conf: SparkConf, isShell: Boolean = false): Seq[String] = { + val sparkJars = conf.getOption("spark.jars") + if (conf.get("spark.master") == "yarn" && isShell) { + val yarnJars = conf.getOption("spark.yarn.dist.jars") + unionFileLists(sparkJars, yarnJars).toSeq + } else { + sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten + } + } } /** @@ -2325,29 +2553,24 @@ private[spark] class RedirectThread( * the toString method. */ private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream { - var pos: Int = 0 - var buffer = new Array[Int](sizeInBytes) + private var pos: Int = 0 + private var isBufferFull = false + private val buffer = new Array[Byte](sizeInBytes) - def write(i: Int): Unit = { - buffer(pos) = i + def write(input: Int): Unit = { + buffer(pos) = input.toByte pos = (pos + 1) % buffer.length + isBufferFull = isBufferFull || (pos == 0) } override def toString: String = { - val (end, start) = buffer.splitAt(pos) - val input = new java.io.InputStream { - val iterator = (start ++ end).iterator - - def read(): Int = if (iterator.hasNext) iterator.next() else -1 - } - val reader = new BufferedReader(new InputStreamReader(input, StandardCharsets.UTF_8)) - val stringBuilder = new StringBuilder - var line = reader.readLine() - while (line != null) { - stringBuilder.append(line) - stringBuilder.append("\n") - line = reader.readLine() + if (!isBufferFull) { + return new String(buffer, 0, pos, StandardCharsets.UTF_8) } - stringBuilder.toString() + + val nonCircularBuffer = new Array[Byte](sizeInBytes) + System.arraycopy(buffer, pos, nonCircularBuffer, 0, buffer.length - pos) + System.arraycopy(buffer, 0, nonCircularBuffer, buffer.length - pos, pos) + new String(nonCircularBuffer, StandardCharsets.UTF_8) } } diff --git a/core/src/main/scala/org/apache/spark/util/VersionUtils.scala b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala new file mode 100644 index 0000000000000..828153b868420 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/VersionUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Utilities for working with Spark version strings + */ +private[spark] object VersionUtils { + + private val majorMinorRegex = """^(\d+)\.(\d+)(\..*)?$""".r + + /** + * Given a Spark version string, return the major version number. + * E.g., for 2.0.1-SNAPSHOT, return 2. + */ + def majorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._1 + + /** + * Given a Spark version string, return the minor version number. + * E.g., for 2.0.1-SNAPSHOT, return 0. + */ + def minorVersion(sparkVersion: String): Int = majorMinorVersion(sparkVersion)._2 + + /** + * Given a Spark version string, return the (major version number, minor version number). + * E.g., for 2.0.1-SNAPSHOT, return (2, 0). + */ + def majorMinorVersion(sparkVersion: String): (Int, Int) = { + majorMinorRegex.findFirstMatchIn(sparkVersion) match { + case Some(m) => + (m.group(1).toInt, m.group(2).toInt) + case None => + throw new IllegalArgumentException(s"Spark tried to parse '$sparkVersion' as a Spark" + + s" version string, but it could not find the major and minor version numbers.") + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index fc71f8365cd18..6ddc72afde270 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -375,14 +375,14 @@ class ExternalAppendOnlyMap[K, V, C]( /** * Return true if there exists an input stream that still has unvisited pairs. */ - override def hasNext: Boolean = mergeHeap.length > 0 + override def hasNext: Boolean = mergeHeap.nonEmpty /** * Select a key with the minimum hash, then combine all values with the same key from all * input streams. */ override def next(): (K, C) = { - if (mergeHeap.length == 0) { + if (mergeHeap.isEmpty) { throw new NoSuchElementException } // Select a key from the StreamBuffer that holds the lowest key hash @@ -397,7 +397,7 @@ class ExternalAppendOnlyMap[K, V, C]( // For all other streams that may have this key (i.e. have the same minimum key hash), // merge in the corresponding value (if any) from that stream val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer) - while (mergeHeap.length > 0 && mergeHeap.head.minKeyHash == minHash) { + while (mergeHeap.nonEmpty && mergeHeap.head.minKeyHash == minHash) { val newBuffer = mergeHeap.dequeue() minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer) mergedBuffers += newBuffer diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 4067acee738ed..6ea7307c3c6e7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -622,7 +622,9 @@ private[spark] class ExternalSorter[K, V, C]( val ds = deserializeStream deserializeStream = null fileStream = null - ds.close() + if (ds != null) { + ds.close() + } // NOTE: We don't do file.delete() here because that is done in ExternalSorter.stop(). // This should also be fixed in ExternalAppendOnlyMap. } diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index aee6399eb0c8c..8183f825592c0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -41,7 +41,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) protected def forceSpill(): Boolean // Number of elements read from input since last spill - @volatile protected def elementsRead: Long = _elementsRead + protected def elementsRead: Long = _elementsRead // Called by subclasses every time a record is read // It's used for checking spilling frequency @@ -83,7 +83,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = acquireOnHeapMemory(amountToRequest) + val granted = acquireMemory(amountToRequest) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -112,7 +112,6 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) if (!isSpilled) { 0L } else { - _elementsRead = 0 val freeMemory = myMemoryThreshold - initialMemoryThreshold _memoryBytesSpilled += freeMemory releaseMemory() @@ -132,7 +131,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) * Release our memory back to the execution pool so that other tasks can grab it. */ def releaseMemory(): Unit = { - freeOnHeapMemory(myMemoryThreshold - initialMemoryThreshold) + freeMemory(myMemoryThreshold - initialMemoryThreshold) myMemoryThreshold = initialMemoryThreshold } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index fb4706e78d38f..89b0874e3865a 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -31,14 +31,13 @@ import org.apache.spark.storage.StorageUtils * Read-only byte buffer which is physically stored as multiple chunks rather than a single * contiguous array. * - * @param chunks an array of [[ByteBuffer]]s. Each buffer in this array must be non-empty and have - * position == 0. Ownership of these buffers is transferred to the ChunkedByteBuffer, - * so if these buffers may also be used elsewhere then the caller is responsible for - * copying them as needed. + * @param chunks an array of [[ByteBuffer]]s. Each buffer in this array must have position == 0. + * Ownership of these buffers is transferred to the ChunkedByteBuffer, so if these + * buffers may also be used elsewhere then the caller is responsible for copying + * them as needed. */ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { require(chunks != null, "chunks must not be null") - require(chunks.forall(_.limit() > 0), "chunks must be non-empty") require(chunks.forall(_.position() == 0), "chunks' positions must be 0") private[this] var disposed: Boolean = false diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala index 67b50d1e70437..a625b3289538a 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala @@ -49,10 +49,19 @@ private[spark] class ChunkedByteBufferOutputStream( */ private[this] var position = chunkSize private[this] var _size = 0 + private[this] var closed: Boolean = false def size: Long = _size + override def close(): Unit = { + if (!closed) { + super.close() + closed = true + } + } + override def write(b: Int): Unit = { + require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream") allocateNewChunkIfNeeded() chunks(lastChunkIndex).put(b.toByte) position += 1 @@ -60,6 +69,7 @@ private[spark] class ChunkedByteBufferOutputStream( } override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream") var written = 0 while (written < len) { allocateNewChunkIfNeeded() @@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream( @inline private def allocateNewChunkIfNeeded(): Unit = { - require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called") if (position == chunkSize) { chunks += allocator(chunkSize) lastChunkIndex += 1 @@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream( } def toChunkedByteBuffer: ChunkedByteBuffer = { + require(closed, "cannot call toChunkedByteBuffer() unless close() has been called") require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once") toChunkedByteBufferWasCalled = true if (lastChunkIndex == -1) { diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index a0eb05c7c0e82..5d8cec8447b53 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -17,9 +17,11 @@ package org.apache.spark.util.logging -import java.io.{File, FileFilter, InputStream} +import java.io._ +import java.util.zip.GZIPOutputStream import com.google.common.io.Files +import org.apache.commons.io.IOUtils import org.apache.spark.SparkConf @@ -45,6 +47,7 @@ private[spark] class RollingFileAppender( import RollingFileAppender._ private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1) + private val enableCompression = conf.getBoolean(ENABLE_COMPRESSION, false) /** Stop the appender */ override def stop() { @@ -76,6 +79,33 @@ private[spark] class RollingFileAppender( } } + // Roll the log file and compress if enableCompression is true. + private def rotateFile(activeFile: File, rolloverFile: File): Unit = { + if (enableCompression) { + val gzFile = new File(rolloverFile.getAbsolutePath + GZIP_LOG_SUFFIX) + var gzOutputStream: GZIPOutputStream = null + var inputStream: InputStream = null + try { + inputStream = new FileInputStream(activeFile) + gzOutputStream = new GZIPOutputStream(new FileOutputStream(gzFile)) + IOUtils.copy(inputStream, gzOutputStream) + inputStream.close() + gzOutputStream.close() + activeFile.delete() + } finally { + IOUtils.closeQuietly(inputStream) + IOUtils.closeQuietly(gzOutputStream) + } + } else { + Files.move(activeFile, rolloverFile) + } + } + + // Check if the rollover file already exists. + private def rolloverFileExist(file: File): Boolean = { + file.exists || new File(file.getAbsolutePath + GZIP_LOG_SUFFIX).exists + } + /** Move the active log file to a new rollover file */ private def moveFile() { val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix() @@ -83,8 +113,8 @@ private[spark] class RollingFileAppender( activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile") if (activeFile.exists) { - if (!rolloverFile.exists) { - Files.move(activeFile, rolloverFile) + if (!rolloverFileExist(rolloverFile)) { + rotateFile(activeFile, rolloverFile) logInfo(s"Rolled over $activeFile to $rolloverFile") } else { // In case the rollover file name clashes, make a unique file name. @@ -97,11 +127,11 @@ private[spark] class RollingFileAppender( altRolloverFile = new File(activeFile.getParent, s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile i += 1 - } while (i < 10000 && altRolloverFile.exists) + } while (i < 10000 && rolloverFileExist(altRolloverFile)) logWarning(s"Rollover file $rolloverFile already exists, " + s"rolled over $activeFile to file $altRolloverFile") - Files.move(activeFile, altRolloverFile) + rotateFile(activeFile, altRolloverFile) } } else { logWarning(s"File $activeFile does not exist") @@ -142,6 +172,9 @@ private[spark] object RollingFileAppender { val SIZE_DEFAULT = (1024 * 1024).toString val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" val DEFAULT_BUFFER_SIZE = 8192 + val ENABLE_COMPRESSION = "spark.executor.logs.rolling.enableCompression" + + val GZIP_LOG_SUFFIX = ".gz" /** * Get the sorted list of rolled over files. This assumes that the all the rolled @@ -158,6 +191,6 @@ private[spark] object RollingFileAppender { val file = new File(directory, activeFileName).getAbsoluteFile if (file.exists) Some(file) else None } - rolledOverFiles ++ activeFile + rolledOverFiles.sortBy(_.getName.stripSuffix(GZIP_LOG_SUFFIX)) ++ activeFile } } diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index 6e80db2f51f9c..5c4238c0381a1 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -113,7 +113,7 @@ private[spark] class SizeBasedRollingPolicy( /** Should rollover if the next set of bytes is going to exceed the size limit */ def shouldRollover(bytesToBeWritten: Long): Boolean = { - logInfo(s"$bytesToBeWritten + $bytesWrittenSinceRollover > $rolloverSizeBytes") + logDebug(s"$bytesToBeWritten + $bytesWrittenSinceRollover > $rolloverSizeBytes") bytesToBeWritten + bytesWrittenSinceRollover > rolloverSizeBytes } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 04f92d60167d8..7bac0683212b3 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -70,6 +70,7 @@ import org.apache.spark.rdd.RDD; import org.apache.spark.serializer.KryoSerializer; import org.apache.spark.storage.StorageLevel; +import org.apache.spark.util.LongAccumulator; import org.apache.spark.util.StatCounter; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -287,7 +288,7 @@ public Integer call(Tuple2 t) { @Test public void foreach() { - final Accumulator accum = sc.accumulator(0); + final LongAccumulator accum = sc.sc().longAccumulator(); JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreach(new VoidFunction() { @Override @@ -300,7 +301,7 @@ public void call(String s) { @Test public void foreachPartition() { - final Accumulator accum = sc.accumulator(0); + final LongAccumulator accum = sc.sc().longAccumulator(); JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); rdd.foreachPartition(new VoidFunction>() { @Override @@ -1377,6 +1378,7 @@ public Iterator call(Iterator i, Iterator s) { assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); } + @SuppressWarnings("deprecation") @Test public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 3e47bfc274cb1..8ca54b24d82ef 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -94,7 +94,7 @@ public void testChildProcLauncher() throws Exception { SparkLauncher launcher = new SparkLauncher(env) .setSparkHome(System.getProperty("spark.test.home")) .setMaster("local") - .setAppResource("spark-internal") + .setAppResource(SparkLauncher.NO_RESOURCE) .addSparkArg(opts.CONF, String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)) .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 127789b632b44..ad755529dec64 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -34,7 +34,8 @@ public void leakedPageMemoryIsDetected() { Long.MAX_VALUE, 1), 0); - manager.allocatePage(4096, null); // leak memory + final MemoryConsumer c = new TestMemoryConsumer(manager); + manager.allocatePage(4096, c); // leak memory Assert.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } @@ -45,7 +46,8 @@ public void encodePageNumberAndOffsetOffHeap() { .set("spark.memory.offHeap.enabled", "true") .set("spark.memory.offHeap.size", "1000"); final TaskMemoryManager manager = new TaskMemoryManager(new TestMemoryManager(conf), 0); - final MemoryBlock dataPage = manager.allocatePage(256, null); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.OFF_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); // In off-heap mode, an offset is an absolute address that may require more than 51 bits to // encode. This test exercises that corner-case: final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); @@ -58,7 +60,8 @@ public void encodePageNumberAndOffsetOffHeap() { public void encodePageNumberAndOffsetOnHeap() { final TaskMemoryManager manager = new TaskMemoryManager( new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); - final MemoryBlock dataPage = manager.allocatePage(256, null); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + final MemoryBlock dataPage = manager.allocatePage(256, c); final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); @@ -106,6 +109,25 @@ public void cooperativeSpilling() { Assert.assertEquals(0, manager.cleanUpAllAllocatedMemory()); } + @Test + public void shouldNotForceSpillingInDifferentModes() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(100); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0); + + TestMemoryConsumer c1 = new TestMemoryConsumer(manager, MemoryMode.ON_HEAP); + TestMemoryConsumer c2 = new TestMemoryConsumer(manager, MemoryMode.OFF_HEAP); + c1.use(80); + Assert.assertEquals(80, c1.getUsed()); + c2.use(80); + Assert.assertEquals(20, c2.getUsed()); // not enough memory + Assert.assertEquals(80, c1.getUsed()); // not spilled + + c2.use(10); + Assert.assertEquals(10, c2.getUsed()); // spilled + Assert.assertEquals(80, c1.getUsed()); // not spilled + } + @Test public void offHeapConfigurationBackwardsCompatibility() { // Tests backwards-compatibility with the old `spark.unsafe.offHeap` configuration, which diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index e6e16fff80401..db91329c94cb6 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -20,8 +20,11 @@ import java.io.IOException; public class TestMemoryConsumer extends MemoryConsumer { + public TestMemoryConsumer(TaskMemoryManager memoryManager, MemoryMode mode) { + super(memoryManager, 1024L, mode); + } public TestMemoryConsumer(TaskMemoryManager memoryManager) { - super(memoryManager); + this(memoryManager, MemoryMode.ON_HEAP); } @Override @@ -32,19 +35,13 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { } void use(long size) { - long got = taskMemoryManager.acquireExecutionMemory( - size, - taskMemoryManager.tungstenMemoryMode, - this); + long got = taskMemoryManager.acquireExecutionMemory(size, this); used += got; } void free(long size) { used -= size; - taskMemoryManager.releaseExecutionMemory( - size, - taskMemoryManager.tungstenMemoryMode, - this); + taskMemoryManager.releaseExecutionMemory(size, this); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index fe5abc5c23049..354efe18dbde7 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -22,8 +22,7 @@ import org.junit.Test; import org.apache.spark.SparkConf; -import org.apache.spark.memory.TestMemoryManager; -import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.*; import org.apache.spark.unsafe.memory.MemoryBlock; import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; @@ -38,8 +37,9 @@ public void heap() throws IOException { final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); - final MemoryBlock page0 = memoryManager.allocatePage(128, null); - final MemoryBlock page1 = memoryManager.allocatePage(128, null); + final MemoryConsumer c = new TestMemoryConsumer(memoryManager, MemoryMode.ON_HEAP); + final MemoryBlock page0 = memoryManager.allocatePage(128, c); + final MemoryBlock page1 = memoryManager.allocatePage(128, c); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); @@ -59,8 +59,9 @@ public void offHeap() throws IOException { .set("spark.memory.offHeap.size", "10000"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); - final MemoryBlock page0 = memoryManager.allocatePage(128, null); - final MemoryBlock page1 = memoryManager.allocatePage(128, null); + final MemoryConsumer c = new TestMemoryConsumer(memoryManager, MemoryMode.OFF_HEAP); + final MemoryBlock page0 = memoryManager.allocatePage(128, c); + final MemoryBlock page1 = memoryManager.allocatePage(128, c); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 278a827644db7..694352ee2af44 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -26,6 +26,7 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; @@ -71,7 +72,8 @@ public void testBasicSorting() throws Exception { final SparkConf conf = new SparkConf().set("spark.memory.offHeap.enabled", "false"); final TaskMemoryManager memoryManager = new TaskMemoryManager(new TestMemoryManager(conf), 0); - final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); + final MemoryConsumer c = new TestMemoryConsumer(memoryManager); + final MemoryBlock dataPage = memoryManager.allocatePage(2048, c); final Object baseObject = dataPage.getBaseObject(); final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter( consumer, 4, shouldUseRadixSort()); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index f9dc20d8b751b..7dd61f85abefd 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -21,12 +21,15 @@ import java.nio.ByteBuffer; import java.util.*; -import scala.*; +import scala.Option; +import scala.Product2; +import scala.Tuple2; +import scala.Tuple2$; import scala.collection.Iterator; import scala.runtime.AbstractFunction1; -import com.google.common.collect.Iterators; import com.google.common.collect.HashMultiset; +import com.google.common.collect.Iterators; import com.google.common.io.ByteStreams; import org.junit.After; import org.junit.Before; @@ -35,29 +38,33 @@ import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.lessThan; -import static org.junit.Assert.*; -import static org.mockito.Answers.RETURNS_SMART_NULLS; -import static org.mockito.Mockito.*; -import org.apache.spark.*; +import org.apache.spark.HashPartitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.executor.TaskMetrics; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.LZ4CompressionCodec; import org.apache.spark.io.LZFCompressionCodec; import org.apache.spark.io.SnappyCompressionCodec; -import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.LimitedInputStream; -import org.apache.spark.serializer.*; import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.storage.*; -import org.apache.spark.memory.TestMemoryManager; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.junit.Assert.*; +import static org.mockito.Answers.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.*; + public class UnsafeShuffleWriterSuite { static final int NUM_PARTITITONS = 4; diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 84b82f5a4742c..fc127f07c8d69 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -589,7 +589,7 @@ public void spillInIterator() throws IOException { @Test public void multipleValuesForSameKey() { BytesToBytesMap map = - new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false); + new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.5, 1024, false); try { int i; for (i = 0; i < 1024; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 60a40cc172114..960698f4ebac6 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -49,6 +49,7 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.*; import static org.mockito.Answers.RETURNS_SMART_NULLS; @@ -175,6 +176,7 @@ private UnsafeExternalSorter newSorter() throws IOException { prefixComparator, /* initialSize */ 1024, pageSizeBytes, + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, shouldUseRadixSort()); } @@ -225,6 +227,25 @@ public void testSortingEmptyArrays() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void testSortTimeMetric() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + long prevSortTime = sorter.getSortTimeNanos(); + assertEquals(prevSortTime, 0); + + sorter.insertRecord(null, 0, 0, 0); + sorter.spill(); + assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime)); + prevSortTime = sorter.getSortTimeNanos(); + + sorter.spill(); // no sort needed + assertEquals(sorter.getSortTimeNanos(), prevSortTime); + + sorter.insertRecord(null, 0, 0, 0); + UnsafeSorterIterator iter = sorter.getSortedIterator(); + assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime)); + } + @Test public void spillingOccursInResponseToMemoryPressure() throws Exception { final UnsafeExternalSorter sorter = newSorter(); @@ -379,6 +400,7 @@ public void forcedSpillingWithoutComparator() throws Exception { null, /* initialSize */ 1024, pageSizeBytes, + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, shouldUseRadixSort()); long[] record = new long[100]; int recordSize = record.length * 8; @@ -415,6 +437,7 @@ public void testPeakMemoryUsed() throws Exception { prefixComparator, 1024, pageSizeBytes, + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD, shouldUseRadixSort()); // Peak memory should be monotonically increasing. More specifically, every time diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 4a2f65a0ed2b2..383c5b3b0884a 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -78,7 +78,7 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { final TaskMemoryManager memoryManager = new TaskMemoryManager( new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); - final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); + final MemoryBlock dataPage = memoryManager.allocatePage(2048, consumer); final Object baseObject = dataPage.getBaseObject(); // Write the records into the data page: long position = dataPage.getBaseOffset(); diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json index bb6bf434be90b..c108fa61a4318 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json @@ -6,7 +6,7 @@ "numTasks" : 8, "numActiveTasks" : 0, "numCompletedTasks" : 8, - "numSkippedTasks" : 8, + "numSkippedTasks" : 0, "numFailedTasks" : 0, "numActiveStages" : 0, "numCompletedStages" : 1, diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json index bb6bf434be90b..c108fa61a4318 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json @@ -6,7 +6,7 @@ "numTasks" : 8, "numActiveTasks" : 0, "numCompletedTasks" : 8, - "numSkippedTasks" : 8, + "numSkippedTasks" : 0, "numFailedTasks" : 0, "numActiveStages" : 0, "numCompletedStages" : 1, diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json index 1583e5ddef565..3d7407004d262 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json @@ -6,7 +6,7 @@ "numTasks" : 8, "numActiveTasks" : 0, "numCompletedTasks" : 8, - "numSkippedTasks" : 8, + "numSkippedTasks" : 0, "numFailedTasks" : 0, "numActiveStages" : 0, "numCompletedStages" : 1, @@ -20,7 +20,7 @@ "numTasks" : 16, "numActiveTasks" : 0, "numCompletedTasks" : 15, - "numSkippedTasks" : 15, + "numSkippedTasks" : 0, "numFailedTasks" : 1, "numActiveStages" : 0, "numCompletedStages" : 1, @@ -34,7 +34,7 @@ "numTasks" : 8, "numActiveTasks" : 0, "numCompletedTasks" : 8, - "numSkippedTasks" : 8, + "numSkippedTasks" : 0, "numFailedTasks" : 0, "numActiveStages" : 0, "numCompletedStages" : 1, diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json new file mode 100644 index 0000000000000..9165f549d7d25 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -0,0 +1,67 @@ +[ { + "id" : "local-1430917381534", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0 + } ] +}, { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0 + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0 + } ] +}, { + "id" : "local-1426533911241", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-03-17T23:11:50.242GMT", + "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0 + }, { + "attemptId" : "1", + "startTime" : "2015-03-16T19:25:10.242GMT", + "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0 + } ] +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json index f1f0ec885587b..10c7e1c0b36fd 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json @@ -6,7 +6,7 @@ "numTasks" : 8, "numActiveTasks" : 0, "numCompletedTasks" : 8, - "numSkippedTasks" : 8, + "numSkippedTasks" : 0, "numFailedTasks" : 0, "numActiveStages" : 0, "numCompletedStages" : 1, diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 11eec0b49c40b..96d86b7278ff1 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -39,21 +39,21 @@ } } }, { - "taskId" : 5, - "index" : 5, + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.505GMT", + "launchTime" : "2015-05-06T13:03:06.502GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 30, + "executorDeserializeTime" : 31, "executorRunTime" : 350, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -74,26 +74,26 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3675510, + "writeTime" : 3934399, "recordsWritten" : 10 } } }, { - "taskId" : 1, - "index" : 1, + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.502GMT", + "launchTime" : "2015-05-06T13:03:06.505GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 30, "executorRunTime" : 350, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -114,22 +114,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3934399, + "writeTime" : 3675510, "recordsWritten" : 10 } } }, { - "taskId" : 4, - "index" : 4, + "taskId" : 0, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.504GMT", + "launchTime" : "2015-05-06T13:03:06.494GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 32, "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, @@ -137,7 +137,7 @@ "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 60488, + "bytesRead" : 49294, "recordsRead" : 10000 }, "outputMetrics" : { @@ -154,15 +154,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 83022, + "writeTime" : 3842811, "recordsWritten" : 10 } } }, { - "taskId" : 7, - "index" : 7, + "taskId" : 3, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.506GMT", + "launchTime" : "2015-05-06T13:03:06.504GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -173,7 +173,7 @@ "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -194,13 +194,13 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 2579051, + "writeTime" : 1311694, "recordsWritten" : 10 } } }, { - "taskId" : 3, - "index" : 3, + "taskId" : 4, + "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", "executorId" : "driver", @@ -213,7 +213,7 @@ "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -234,30 +234,30 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 1311694, + "writeTime" : 83022, "recordsWritten" : 10 } } }, { - "taskId" : 0, - "index" : 0, + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.494GMT", + "launchTime" : "2015-05-06T13:03:06.506GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 32, + "executorDeserializeTime" : 31, "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 49294, + "bytesRead" : 60488, "recordsRead" : 10000 }, "outputMetrics" : { @@ -274,7 +274,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3842811, + "writeTime" : 2579051, "recordsWritten" : 10 } } @@ -479,25 +479,25 @@ } } }, { - "taskId" : 16, - "index" : 16, + "taskId" : 9, + "index" : 9, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.001GMT", + "launchTime" : "2015-05-06T13:03:06.915GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 10, + "executorDeserializeTime" : 9, "executorRunTime" : 84, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 60489, "recordsRead" : 10000 }, "outputMetrics" : { @@ -514,22 +514,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 108320, + "writeTime" : 101664, "recordsWritten" : 10 } } }, { - "taskId" : 19, - "index" : 19, + "taskId" : 16, + "index" : 16, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.012GMT", + "launchTime" : "2015-05-06T13:03:07.001GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 5, + "executorDeserializeTime" : 10, "executorRunTime" : 84, "resultSize" : 2010, "jvmGcTime" : 5, @@ -554,30 +554,30 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95788, + "writeTime" : 108320, "recordsWritten" : 10 } } }, { - "taskId" : 9, - "index" : 9, + "taskId" : 19, + "index" : 19, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.915GMT", + "launchTime" : "2015-05-06T13:03:07.012GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 5, "executorRunTime" : 84, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 60489, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -594,25 +594,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101664, + "writeTime" : 95788, "recordsWritten" : 10 } } }, { - "taskId" : 20, - "index" : 20, + "taskId" : 14, + "index" : 14, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.014GMT", + "launchTime" : "2015-05-06T13:03:06.925GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 6, "executorRunTime" : 83, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -634,25 +634,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 97716, + "writeTime" : 95646, "recordsWritten" : 10 } } }, { - "taskId" : 14, - "index" : 14, + "taskId" : 20, + "index" : 20, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.925GMT", + "launchTime" : "2015-05-06T13:03:07.014GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 3, "executorRunTime" : 83, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -674,7 +674,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95646, + "writeTime" : 97716, "recordsWritten" : 10 } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 11eec0b49c40b..96d86b7278ff1 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -39,21 +39,21 @@ } } }, { - "taskId" : 5, - "index" : 5, + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.505GMT", + "launchTime" : "2015-05-06T13:03:06.502GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 30, + "executorDeserializeTime" : 31, "executorRunTime" : 350, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -74,26 +74,26 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3675510, + "writeTime" : 3934399, "recordsWritten" : 10 } } }, { - "taskId" : 1, - "index" : 1, + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.502GMT", + "launchTime" : "2015-05-06T13:03:06.505GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 30, "executorRunTime" : 350, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -114,22 +114,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3934399, + "writeTime" : 3675510, "recordsWritten" : 10 } } }, { - "taskId" : 4, - "index" : 4, + "taskId" : 0, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.504GMT", + "launchTime" : "2015-05-06T13:03:06.494GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 31, + "executorDeserializeTime" : 32, "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, @@ -137,7 +137,7 @@ "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 60488, + "bytesRead" : 49294, "recordsRead" : 10000 }, "outputMetrics" : { @@ -154,15 +154,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 83022, + "writeTime" : 3842811, "recordsWritten" : 10 } } }, { - "taskId" : 7, - "index" : 7, + "taskId" : 3, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.506GMT", + "launchTime" : "2015-05-06T13:03:06.504GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -173,7 +173,7 @@ "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 0, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -194,13 +194,13 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 2579051, + "writeTime" : 1311694, "recordsWritten" : 10 } } }, { - "taskId" : 3, - "index" : 3, + "taskId" : 4, + "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", "executorId" : "driver", @@ -213,7 +213,7 @@ "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -234,30 +234,30 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 1311694, + "writeTime" : 83022, "recordsWritten" : 10 } } }, { - "taskId" : 0, - "index" : 0, + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.494GMT", + "launchTime" : "2015-05-06T13:03:06.506GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 32, + "executorDeserializeTime" : 31, "executorRunTime" : 349, "resultSize" : 2010, "jvmGcTime" : 7, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 49294, + "bytesRead" : 60488, "recordsRead" : 10000 }, "outputMetrics" : { @@ -274,7 +274,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 3842811, + "writeTime" : 2579051, "recordsWritten" : 10 } } @@ -479,25 +479,25 @@ } } }, { - "taskId" : 16, - "index" : 16, + "taskId" : 9, + "index" : 9, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.001GMT", + "launchTime" : "2015-05-06T13:03:06.915GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 10, + "executorDeserializeTime" : 9, "executorRunTime" : 84, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 60489, "recordsRead" : 10000 }, "outputMetrics" : { @@ -514,22 +514,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 108320, + "writeTime" : 101664, "recordsWritten" : 10 } } }, { - "taskId" : 19, - "index" : 19, + "taskId" : 16, + "index" : 16, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.012GMT", + "launchTime" : "2015-05-06T13:03:07.001GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 5, + "executorDeserializeTime" : 10, "executorRunTime" : 84, "resultSize" : 2010, "jvmGcTime" : 5, @@ -554,30 +554,30 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95788, + "writeTime" : 108320, "recordsWritten" : 10 } } }, { - "taskId" : 9, - "index" : 9, + "taskId" : 19, + "index" : 19, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.915GMT", + "launchTime" : "2015-05-06T13:03:07.012GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 9, + "executorDeserializeTime" : 5, "executorRunTime" : 84, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 60489, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -594,25 +594,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101664, + "writeTime" : 95788, "recordsWritten" : 10 } } }, { - "taskId" : 20, - "index" : 20, + "taskId" : 14, + "index" : 14, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.014GMT", + "launchTime" : "2015-05-06T13:03:06.925GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 6, "executorRunTime" : 83, "resultSize" : 2010, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -634,25 +634,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 97716, + "writeTime" : 95646, "recordsWritten" : 10 } } }, { - "taskId" : 14, - "index" : 14, + "taskId" : 20, + "index" : 20, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:06.925GMT", + "launchTime" : "2015-05-06T13:03:07.014GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 3, "executorRunTime" : 83, "resultSize" : 2010, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -674,7 +674,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95646, + "writeTime" : 97716, "recordsWritten" : 10 } } diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index 9528d872ef731..e0e9e8140c717 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -39,21 +39,21 @@ } } }, { - "taskId" : 86, - "index" : 86, + "taskId" : 41, + "index" : 41, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.374GMT", + "launchTime" : "2015-05-06T13:03:07.200GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 2, "executorRunTime" : 16, "resultSize" : 2065, "jvmGcTime" : 0, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -74,15 +74,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 95848, + "writeTime" : 90765, "recordsWritten" : 10 } } }, { - "taskId" : 41, - "index" : 41, + "taskId" : 43, + "index" : 43, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.200GMT", + "launchTime" : "2015-05-06T13:03:07.204GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -114,22 +114,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 90765, + "writeTime" : 171516, "recordsWritten" : 10 } } }, { - "taskId" : 68, - "index" : 68, + "taskId" : 57, + "index" : 57, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.306GMT", + "launchTime" : "2015-05-06T13:03:07.257GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 3, "executorRunTime" : 16, "resultSize" : 2065, "jvmGcTime" : 0, @@ -154,7 +154,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 101750, + "writeTime" : 96849, "recordsWritten" : 10 } } @@ -199,10 +199,10 @@ } } }, { - "taskId" : 43, - "index" : 43, + "taskId" : 68, + "index" : 68, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.204GMT", + "launchTime" : "2015-05-06T13:03:07.306GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -234,15 +234,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 171516, + "writeTime" : 101750, "recordsWritten" : 10 } } }, { - "taskId" : 57, - "index" : 57, + "taskId" : 86, + "index" : 86, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.257GMT", + "launchTime" : "2015-05-06T13:03:07.374GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -253,7 +253,7 @@ "executorRunTime" : 16, "resultSize" : 2065, "jvmGcTime" : 0, - "resultSerializationTime" : 0, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -274,15 +274,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 96849, + "writeTime" : 95848, "recordsWritten" : 10 } } }, { - "taskId" : 59, - "index" : 59, + "taskId" : 32, + "index" : 32, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.265GMT", + "launchTime" : "2015-05-06T13:03:07.148GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -314,22 +314,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 100753, + "writeTime" : 89603, "recordsWritten" : 10 } } }, { - "taskId" : 32, - "index" : 32, + "taskId" : 39, + "index" : 39, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.148GMT", + "launchTime" : "2015-05-06T13:03:07.180GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 2, "executorRunTime" : 17, "resultSize" : 2065, "jvmGcTime" : 0, @@ -354,22 +354,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 89603, + "writeTime" : 98748, "recordsWritten" : 10 } } }, { - "taskId" : 87, - "index" : 87, + "taskId" : 42, + "index" : 42, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.374GMT", + "launchTime" : "2015-05-06T13:03:07.203GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 10, "executorRunTime" : 17, "resultSize" : 2065, "jvmGcTime" : 0, @@ -394,15 +394,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 102159, + "writeTime" : 103713, "recordsWritten" : 10 } } }, { - "taskId" : 99, - "index" : 99, + "taskId" : 51, + "index" : 51, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.426GMT", + "launchTime" : "2015-05-06T13:03:07.242GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -417,7 +417,7 @@ "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 70565, + "bytesRead" : 70564, "recordsRead" : 10000 }, "outputMetrics" : { @@ -434,25 +434,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 133964, + "writeTime" : 96013, "recordsWritten" : 10 } } }, { - "taskId" : 63, - "index" : 63, + "taskId" : 59, + "index" : 59, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.276GMT", + "launchTime" : "2015-05-06T13:03:07.265GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 20, + "executorDeserializeTime" : 3, "executorRunTime" : 17, "resultSize" : 2065, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -474,25 +474,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 102779, + "writeTime" : 100753, "recordsWritten" : 10 } } }, { - "taskId" : 90, - "index" : 90, + "taskId" : 63, + "index" : 63, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.385GMT", + "launchTime" : "2015-05-06T13:03:07.276GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 20, "executorRunTime" : 17, "resultSize" : 2065, - "jvmGcTime" : 0, + "jvmGcTime" : 5, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -514,22 +514,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 98472, + "writeTime" : 102779, "recordsWritten" : 10 } } }, { - "taskId" : 39, - "index" : 39, + "taskId" : 87, + "index" : 87, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.180GMT", + "launchTime" : "2015-05-06T13:03:07.374GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 12, "executorRunTime" : 17, "resultSize" : 2065, "jvmGcTime" : 0, @@ -554,22 +554,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 98748, + "writeTime" : 102159, "recordsWritten" : 10 } } }, { - "taskId" : 42, - "index" : 42, + "taskId" : 90, + "index" : 90, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.203GMT", + "launchTime" : "2015-05-06T13:03:07.385GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 10, + "executorDeserializeTime" : 2, "executorRunTime" : 17, "resultSize" : 2065, "jvmGcTime" : 0, @@ -594,15 +594,15 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 103713, + "writeTime" : 98472, "recordsWritten" : 10 } } }, { - "taskId" : 51, - "index" : 51, + "taskId" : 99, + "index" : 99, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.242GMT", + "launchTime" : "2015-05-06T13:03:07.426GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", @@ -617,7 +617,7 @@ "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { - "bytesRead" : 70564, + "bytesRead" : 70565, "recordsRead" : 10000 }, "outputMetrics" : { @@ -634,22 +634,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 96013, + "writeTime" : 133964, "recordsWritten" : 10 } } }, { - "taskId" : 50, - "index" : 50, + "taskId" : 44, + "index" : 44, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.240GMT", + "launchTime" : "2015-05-06T13:03:07.205GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 4, + "executorDeserializeTime" : 3, "executorRunTime" : 18, "resultSize" : 2065, "jvmGcTime" : 0, @@ -674,22 +674,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 90836, + "writeTime" : 98293, "recordsWritten" : 10 } } }, { - "taskId" : 53, - "index" : 53, + "taskId" : 47, + "index" : 47, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.244GMT", + "launchTime" : "2015-05-06T13:03:07.212GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 6, + "executorDeserializeTime" : 2, "executorRunTime" : 18, "resultSize" : 2065, "jvmGcTime" : 0, @@ -714,22 +714,22 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 92835, + "writeTime" : 103015, "recordsWritten" : 10 } } }, { - "taskId" : 44, - "index" : 44, + "taskId" : 50, + "index" : 50, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.205GMT", + "launchTime" : "2015-05-06T13:03:07.240GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 3, + "executorDeserializeTime" : 4, "executorRunTime" : 18, "resultSize" : 2065, "jvmGcTime" : 0, @@ -754,25 +754,25 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 98293, + "writeTime" : 90836, "recordsWritten" : 10 } } }, { - "taskId" : 80, - "index" : 80, + "taskId" : 52, + "index" : 52, "attempt" : 0, - "launchTime" : "2015-05-06T13:03:07.341GMT", + "launchTime" : "2015-05-06T13:03:07.243GMT", "executorId" : "driver", "host" : "localhost", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 5, "executorRunTime" : 18, "resultSize" : 2065, - "jvmGcTime" : 5, + "jvmGcTime" : 0, "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, @@ -794,7 +794,7 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, - "writeTime" : 98069, + "writeTime" : 89664, "recordsWritten" : 10 } } diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json index 1583e5ddef565..3d7407004d262 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json @@ -6,7 +6,7 @@ "numTasks" : 8, "numActiveTasks" : 0, "numCompletedTasks" : 8, - "numSkippedTasks" : 8, + "numSkippedTasks" : 0, "numFailedTasks" : 0, "numActiveStages" : 0, "numCompletedStages" : 1, @@ -20,7 +20,7 @@ "numTasks" : 16, "numActiveTasks" : 0, "numCompletedTasks" : 15, - "numSkippedTasks" : 15, + "numSkippedTasks" : 0, "numFailedTasks" : 1, "numActiveStages" : 0, "numCompletedStages" : 1, @@ -34,7 +34,7 @@ "numTasks" : 8, "numActiveTasks" : 0, "numCompletedTasks" : 8, - "numSkippedTasks" : 8, + "numSkippedTasks" : 0, "numFailedTasks" : 0, "numActiveStages" : 0, "numCompletedStages" : 1, diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json index c232c98323755..6a9bafd6b2191 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json @@ -6,7 +6,7 @@ "numTasks" : 8, "numActiveTasks" : 0, "numCompletedTasks" : 8, - "numSkippedTasks" : 8, + "numSkippedTasks" : 0, "numFailedTasks" : 0, "numActiveStages" : 0, "numCompletedStages" : 1, @@ -20,7 +20,7 @@ "numTasks" : 8, "numActiveTasks" : 0, "numCompletedTasks" : 8, - "numSkippedTasks" : 8, + "numSkippedTasks" : 0, "numFailedTasks" : 0, "numActiveStages" : 0, "numCompletedStages" : 1, diff --git a/core/src/test/resources/TestUDTF.jar b/core/src/test/resources/TestUDTF.jar new file mode 100644 index 0000000000000..514f2d5d26fd3 Binary files /dev/null and b/core/src/test/resources/TestUDTF.jar differ diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 9c90049715ddc..6cbd5ae5d428a 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -28,9 +28,10 @@ import scala.util.control.NonFatal import org.scalatest.Matchers import org.scalatest.exceptions.TestFailedException -import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, StringAccumulatorParam} +import org.apache.spark.AccumulatorParam.StringAccumulatorParam import org.apache.spark.scheduler._ import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.{AccumulatorContext, AccumulatorMetadata, AccumulatorV2, LongAccumulator} class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext { @@ -69,7 +70,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex // serialize and de-serialize it, to simulate sending accumulator to executor. val acc2 = ser.deserialize[LongAccumulator](ser.serialize(acc)) // value is reset on the executors - assert(acc2.localValue == 0) + assert(acc2.value == 0) assert(!acc2.isAtDriverSide) acc2.add(10) @@ -234,21 +235,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex acc.merge("kindness") assert(acc.value === "kindness") } - - test("list accumulator param") { - val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers")) - assert(acc.value === Seq.empty[Int]) - acc.add(Seq(1, 2)) - assert(acc.value === Seq(1, 2)) - acc += Seq(3, 4) - assert(acc.value === Seq(1, 2, 3, 4)) - acc ++= Seq(5, 6) - assert(acc.value === Seq(1, 2, 3, 4, 5, 6)) - acc.merge(Seq(7, 8)) - assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8)) - acc.setValue(Seq(9, 10)) - assert(acc.value === Seq(9, 10)) - } } private[spark] object AccumulatorSuite { @@ -273,7 +259,7 @@ private[spark] object AccumulatorSuite { * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the * info as an accumulator update. */ - def makeInfo(a: NewAccumulator[_, _]): AccumulableInfo = a.toInfo(Some(a.localValue), None) + def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) /** * Run one or more Spark jobs and verify that in at least one job the peak execution memory diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 69ff6c7c28ee9..6724af952505f 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage._ +import org.apache.spark.util.Utils /** * An abstract base class for context cleaner tests, which sets up a context with a config @@ -206,8 +207,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { } test("automatically cleanup normal checkpoint") { - val checkpointDir = java.io.File.createTempFile("temp", "") - checkpointDir.deleteOnExit() + val checkpointDir = Utils.createTempDir() checkpointDir.delete() var rdd = newPairRDD() sc.setCheckpointDir(checkpointDir.toString) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index a0086e18435b6..4e36adc8baf3f 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -92,8 +92,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("accumulators") { sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) - sc.parallelize(1 to 10, 10).foreach(x => accum += x) + val accum = sc.longAccumulator + sc.parallelize(1 to 10, 10).foreach(x => accum.add(x)) assert(accum.value === 55) } @@ -109,7 +109,6 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("repeatedly failing task") { sc = new SparkContext(clusterUrl, "test") - val accum = sc.accumulator(0) val thrown = intercept[SparkException] { // scalastyle:off println sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) @@ -135,76 +134,62 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } } - test("caching") { + test("repeatedly failing task that crashes JVM with a zero exit code (SPARK-16925)") { + // Ensures that if a task which causes the JVM to exit with a zero exit code will cause the + // Spark job to eventually fail. sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).cache() - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching on disk") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory, serialized, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching on disk, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) - } - - test("caching in memory and disk, replicated") { - sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2) - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) + failAfter(Span(100000, Millis)) { + val thrown = intercept[SparkException] { + sc.parallelize(1 to 1, 1).foreachPartition { _ => System.exit(0) } + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("failed 4 times")) + } + // Check that the cluster is still usable: + sc.parallelize(1 to 10).count() } - test("caching in memory and disk, serialized, replicated") { + private def testCaching(storageLevel: StorageLevel): Unit = { sc = new SparkContext(clusterUrl, "test") - val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2) - - assert(data.count() === 1000) - assert(data.count() === 1000) - assert(data.count() === 1000) + sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) + val data = sc.parallelize(1 to 1000, 10) + val cachedData = data.persist(storageLevel) + assert(cachedData.count === 1000) + assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum === + storageLevel.replication * data.getNumPartitions) + assert(cachedData.count === 1000) + assert(cachedData.count === 1000) // Get all the locations of the first partition and try to fetch the partitions // from those locations. val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray val blockId = blockIds(0) val blockManager = SparkEnv.get.blockManager - val blockTransfer = SparkEnv.get.blockTransferService + val blockTransfer = blockManager.blockTransferService val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString) - val deserialized = serializerManager.dataDeserializeStream[Int](blockId, - new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList + val deserialized = serializerManager.dataDeserializeStream(blockId, + new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) } + // This will exercise the getRemoteBytes / getRemoteValues code paths: + assert(blockIds.flatMap(id => blockManager.get[Int](id).get.data).toSet === (1 to 1000).toSet) + } + + Seq( + "caching" -> StorageLevel.MEMORY_ONLY, + "caching on disk" -> StorageLevel.DISK_ONLY, + "caching in memory, replicated" -> StorageLevel.MEMORY_ONLY_2, + "caching in memory, serialized, replicated" -> StorageLevel.MEMORY_ONLY_SER_2, + "caching on disk, replicated" -> StorageLevel.DISK_ONLY_2, + "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, + "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 + ).foreach { case (testName, storageLevel) => + test(testName) { + testCaching(storageLevel) + } } test("compute without caching when no partitions fit in memory") { @@ -224,7 +209,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("compute when only some partitions fit in memory") { val size = 10000 - val numPartitions = 10 + val numPartitions = 20 val conf = new SparkConf() .set("spark.storage.unrollMemoryThreshold", "1024") .set("spark.testing.memory", size.toString) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 3def8b0b1850e..d805c67714ff8 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark import java.io.{IOException, NotSerializableException, ObjectInputStream} +import org.apache.spark.memory.TestMemoryConsumer +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.NonSerializable // Common state shared by FailureSuite-launched tasks. We use a global object @@ -149,7 +151,8 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // cause is preserved val thrownDueToTaskFailure = intercept[SparkException] { sc.parallelize(Seq(0)).mapPartitions { iter => - TaskContext.get().taskMemoryManager().allocatePage(128, null) + val c = new TestMemoryConsumer(TaskContext.get().taskMemoryManager()) + TaskContext.get().taskMemoryManager().allocatePage(128, c) throw new Exception("intentional task failure") iter }.count() @@ -159,7 +162,8 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // If the task succeeded but memory was leaked, then the task should fail due to that leak val thrownDueToMemoryLeak = intercept[SparkException] { sc.parallelize(Seq(0)).mapPartitions { iter => - TaskContext.get().taskMemoryManager().allocatePage(128, null) + val c = new TestMemoryConsumer(TaskContext.get().taskMemoryManager()) + TaskContext.get().taskMemoryManager().allocatePage(128, c) iter }.count() } @@ -238,6 +242,26 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { FailureSuiteState.clear() } + test("failure because cached RDD partitions are missing from DiskStore (SPARK-15736)") { + sc = new SparkContext("local[1,2]", "test") + val rdd = sc.parallelize(1 to 2, 2).persist(StorageLevel.DISK_ONLY) + rdd.count() + // Directly delete all files from the disk store, triggering failures when reading cached data: + SparkEnv.get.blockManager.diskBlockManager.getAllFiles().foreach(_.delete()) + // Each task should fail once due to missing cached data, but then should succeed on its second + // attempt because the missing cache locations will be purged and the blocks will be recomputed. + rdd.count() + } + + test("SPARK-16304: Link error should not crash executor") { + sc = new SparkContext("local[1,2]", "test") + intercept[SparkException] { + sc.parallelize(1 to 2).foreach { i => + throw new LinkageError() + } + } + } + // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 1adc90ab1e9dd..e30349570b7ed 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -21,6 +21,7 @@ import java.util.concurrent.{ExecutorService, TimeUnit} import scala.collection.Map import scala.collection.mutable +import scala.concurrent.Future import scala.concurrent.duration._ import scala.language.postfixOps @@ -174,9 +175,9 @@ class HeartbeatReceiverSuite val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean]( - RegisterExecutor(executorId1, dummyExecutorEndpointRef1, 0, Map.empty)) + RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "1.2.3.4", 0, Map.empty)) fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean]( - RegisterExecutor(executorId2, dummyExecutorEndpointRef2, 0, Map.empty)) + RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "1.2.3.5", 0, Map.empty)) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) addExecutorAndVerify(executorId1) addExecutorAndVerify(executorId2) @@ -242,8 +243,8 @@ class HeartbeatReceiverSuite } private def getTrackedExecutors: Map[String, Long] = { - // We may receive undesired SparkListenerExecutorAdded from LocalBackend, so exclude it from - // the map. See SPARK-10800. + // We may receive undesired SparkListenerExecutorAdded from LocalSchedulerBackend, + // so exclude it from the map. See SPARK-10800. heartbeatReceiver.invokePrivate(_executorLastSeen()). filterKeys(_ != SparkContext.DRIVER_IDENTIFIER) } @@ -270,13 +271,13 @@ private class FakeSchedulerBackend( clusterManagerEndpoint: RpcEndpointRef) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - clusterManagerEndpoint.askWithRetry[Boolean]( + protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { + clusterManagerEndpoint.ask[Boolean]( RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } - protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { - clusterManagerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) + protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { + clusterManagerEndpoint.ask[Boolean](KillExecutors(executorIds)) } } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 688eb6bde9041..840f55ce2f6e5 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { @@ -213,7 +214,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) { private val accumsRegistered = new ArrayBuffer[Long] - override def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = { + override def registerAccumulatorForCleanup(a: AccumulatorV2[_, _]): Unit = { accumsRegistered += a.id super.registerAccumulatorForCleanup(a) } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ddf48765ec30a..c6aebc19fd12d 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.Matchers.{any, isA} import org.mockito.Mockito._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException @@ -30,6 +31,12 @@ import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf + private def newTrackerMaster(sparkConf: SparkConf = conf) = { + val broadcastManager = new BroadcastManager(true, sparkConf, + new SecurityManager(sparkConf)) + new MapOutputTrackerMaster(sparkConf, broadcastManager, true) + } + def createRpcEnv(name: String, host: String = "localhost", port: Int = 0, securityManager: SecurityManager = new SecurityManager(conf)): RpcEnv = { RpcEnv.create(name, host, port, conf, securityManager) @@ -37,7 +44,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { test("master start and stop") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.stop() @@ -46,7 +53,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { test("master register shuffle and fetch") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) @@ -62,13 +69,14 @@ class MapOutputTrackerSuite extends SparkFunSuite { Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) + assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() rpcEnv.shutdown() } test("master register and unregister shuffle") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) @@ -80,6 +88,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) + assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) @@ -90,7 +99,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { test("master register shuffle and unregister map output and fetch") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) @@ -101,6 +110,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) + assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) @@ -118,7 +128,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val hostname = "localhost" val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) - val masterTracker = new MapOutputTrackerMaster(conf) + val masterTracker = newTrackerMaster() masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) @@ -139,6 +149,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) masterTracker.incrementEpoch() @@ -147,6 +158,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // failure should be cached intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.stop() slaveTracker.stop() @@ -158,8 +170,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "1048576") - val masterTracker = new MapOutputTrackerMaster(conf) + val masterTracker = newTrackerMaster(newConf) val rpcEnv = createRpcEnv("spark") val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) @@ -172,45 +185,27 @@ class MapOutputTrackerSuite extends SparkFunSuite { val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) - verify(rpcCallContext).reply(any()) - verify(rpcCallContext, never()).sendFailure(any()) + // Default size for broadcast in this testsuite is set to -1 so should not cause broadcast + // to be used. + verify(rpcCallContext, timeout(30000)).reply(any()) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) // masterTracker.stop() // this throws an exception rpcEnv.shutdown() } - test("remote fetch exceeds max RPC message size") { + test("min broadcast size exceeds max RPC message size") { val newConf = new SparkConf newConf.set("spark.rpc.message.maxSize", "1") newConf.set("spark.rpc.askTimeout", "1") // Fail fast + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", Int.MaxValue.toString) - val masterTracker = new MapOutputTrackerMaster(conf) - val rpcEnv = createRpcEnv("test") - val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) - rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) - - // Message size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception. - // Note that the size is hand-selected here because map output statuses are compressed before - // being sent. - masterTracker.registerShuffle(20, 100) - (0 until 100).foreach { i => - masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) - } - val senderAddress = RpcAddress("localhost", 12345) - val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.senderAddress).thenReturn(senderAddress) - masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) - verify(rpcCallContext, never()).reply(any()) - verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) - -// masterTracker.stop() // this throws an exception - rpcEnv.shutdown() + intercept[IllegalArgumentException] { newTrackerMaster(newConf) } } test("getLocationsWithLargestOutputs with multiple outputs in same machine") { val rpcEnv = createRpcEnv("test") - val tracker = new MapOutputTrackerMaster(conf) + val tracker = newTrackerMaster() tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) // Setup 3 map tasks @@ -242,4 +237,44 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.stop() rpcEnv.shutdown() } + + test("remote fetch using broadcast") { + val newConf = new SparkConf + newConf.set("spark.rpc.message.maxSize", "1") + newConf.set("spark.rpc.askTimeout", "1") // Fail fast + newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize + + // needs TorrentBroadcast so need a SparkContext + val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf) + try { + val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + val rpcEnv = sc.env.rpcEnv + val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf) + rpcEnv.stop(masterTracker.trackerEndpoint) + rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint) + + // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. + // Note that the size is hand-selected here because map output statuses are compressed before + // being sent. + masterTracker.registerShuffle(20, 100) + (0 until 100).foreach { i => + masterTracker.registerMapOutput(20, i, new CompressedMapStatus( + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + } + val senderAddress = RpcAddress("localhost", 12345) + val rpcCallContext = mock(classOf[RpcCallContext]) + when(rpcCallContext.senderAddress).thenReturn(senderAddress) + masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) + // should succeed since majority of data is broadcast and actual serialized + // message size is small + verify(rpcCallContext, timeout(30000)).reply(any()) + assert(1 == masterTracker.getNumCachedSerializedBroadcast) + masterTracker.unregisterShuffle(20) + assert(0 == masterTracker.getNumCachedSerializedBroadcast) + + } finally { + LocalSparkContext.stop(sc) + } + } + } diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 159b448e05b02..2b8b1805bc83f 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -79,7 +79,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.protocol", "SSLv3") val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === true) assert(opts.trustStore.isDefined === true) @@ -102,20 +102,20 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConf conf.set("spark.ssl.enabled", "true") - conf.set("spark.ui.ssl.enabled", "false") + conf.set("spark.ssl.ui.enabled", "false") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") - conf.set("spark.ui.ssl.keyStorePassword", "12345") + conf.set("spark.ssl.ui.keyStorePassword", "12345") conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") conf.set("spark.ssl.enabledAlgorithms", "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") - conf.set("spark.ui.ssl.enabledAlgorithms", "ABC, DEF") + conf.set("spark.ssl.ui.enabledAlgorithms", "ABC, DEF") conf.set("spark.ssl.protocol", "SSLv3") val defaultOpts = SSLOptions.parse(conf, "spark.ssl", defaults = None) - val opts = SSLOptions.parse(conf, "spark.ui.ssl", defaults = Some(defaultOpts)) + val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) assert(opts.trustStore.isDefined === true) diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 8bdb237c28f66..9801b2638cc15 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,8 +19,18 @@ package org.apache.spark import java.io.File +import org.apache.spark.security.GroupMappingServiceProvider import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils} +class DummyGroupMappingServiceProvider extends GroupMappingServiceProvider { + + val userGroups: Set[String] = Set[String]("group1", "group2", "group3") + + override def getGroups(username: String): Set[String] = { + userGroups + } +} + class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { test("set security with conf") { @@ -37,6 +47,45 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkUIViewPermissions("user3") === false) } + test("set security with conf for groups") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + conf.set("spark.ui.acls.enable", "true") + conf.set("spark.ui.view.acls.groups", "group1,group2") + // default ShellBasedGroupsMappingProvider is used to resolve user groups + val securityManager = new SecurityManager(conf); + // assuming executing user does not belong to group1,group2 + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user2") === false) + + val conf2 = new SparkConf + conf2.set("spark.authenticate", "true") + conf2.set("spark.authenticate.secret", "good") + conf2.set("spark.ui.acls.enable", "true") + conf2.set("spark.ui.view.acls.groups", "group1,group2") + // explicitly specify a custom GroupsMappingServiceProvider + conf2.set("spark.user.groups.mapping", "org.apache.spark.DummyGroupMappingServiceProvider") + + val securityManager2 = new SecurityManager(conf2); + // group4,group5 do not match + assert(securityManager2.checkUIViewPermissions("user1") === true) + assert(securityManager2.checkUIViewPermissions("user2") === true) + + val conf3 = new SparkConf + conf3.set("spark.authenticate", "true") + conf3.set("spark.authenticate.secret", "good") + conf3.set("spark.ui.acls.enable", "true") + conf3.set("spark.ui.view.acls.groups", "group4,group5") + // explicitly specify a bogus GroupsMappingServiceProvider + conf3.set("spark.user.groups.mapping", "BogusServiceProvider") + + val securityManager3 = new SecurityManager(conf3); + // BogusServiceProvider cannot be loaded and an error is logged returning an empty group set + assert(securityManager3.checkUIViewPermissions("user1") === false) + assert(securityManager3.checkUIViewPermissions("user2") === false) + } + test("set security with api") { val conf = new SparkConf conf.set("spark.ui.view.acls", "user1,user2") @@ -60,6 +109,40 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkUIViewPermissions(null) === true) } + test("set security with api for groups") { + val conf = new SparkConf + conf.set("spark.user.groups.mapping", "org.apache.spark.DummyGroupMappingServiceProvider") + + val securityManager = new SecurityManager(conf); + securityManager.setAcls(true) + securityManager.setViewAclsGroups("group1,group2") + + // group1,group2 match + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user2") === true) + + // change groups so they do not match + securityManager.setViewAclsGroups("group4,group5") + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user2") === false) + + val conf2 = new SparkConf + conf.set("spark.user.groups.mapping", "BogusServiceProvider") + + val securityManager2 = new SecurityManager(conf2) + securityManager2.setAcls(true) + securityManager2.setViewAclsGroups("group1,group2") + + // group1,group2 do not match because of BogusServiceProvider + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user2") === false) + + // setting viewAclsGroups to empty should still not match because of BogusServiceProvider + securityManager2.setViewAclsGroups("") + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user2") === false) + } + test("set security modify acls") { val conf = new SparkConf conf.set("spark.modify.acls", "user1,user2") @@ -84,6 +167,29 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkModifyPermissions(null) === true) } + test("set security modify acls for groups") { + val conf = new SparkConf + conf.set("spark.user.groups.mapping", "org.apache.spark.DummyGroupMappingServiceProvider") + + val securityManager = new SecurityManager(conf); + securityManager.setAcls(true) + securityManager.setModifyAclsGroups("group1,group2") + + // group1,group2 match + assert(securityManager.checkModifyPermissions("user1") === true) + assert(securityManager.checkModifyPermissions("user2") === true) + + // change groups so they do not match + securityManager.setModifyAclsGroups("group4,group5") + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user2") === false) + + // change so they match again + securityManager.setModifyAclsGroups("group2,group3") + assert(securityManager.checkModifyPermissions("user1") === true) + assert(securityManager.checkModifyPermissions("user2") === true) + } + test("set security admin acls") { val conf = new SparkConf conf.set("spark.admin.acls", "user1,user2") @@ -122,7 +228,48 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkUIViewPermissions("user1") === false) assert(securityManager.checkUIViewPermissions("user3") === false) assert(securityManager.checkUIViewPermissions(null) === true) + } + + test("set security admin acls for groups") { + val conf = new SparkConf + conf.set("spark.admin.acls.groups", "group1") + conf.set("spark.ui.view.acls.groups", "group2") + conf.set("spark.modify.acls.groups", "group3") + conf.set("spark.user.groups.mapping", "org.apache.spark.DummyGroupMappingServiceProvider") + + val securityManager = new SecurityManager(conf); + securityManager.setAcls(true) + assert(securityManager.aclsEnabled() === true) + + // group1,group2,group3 match + assert(securityManager.checkModifyPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user1") === true) + // change admin groups so they do not match. view and modify groups are set to admin groups + securityManager.setAdminAclsGroups("group4,group5") + // invoke the set ui and modify to propagate the changes + securityManager.setViewAclsGroups("") + securityManager.setModifyAclsGroups("") + + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user1") === false) + + // change modify groups so they match + securityManager.setModifyAclsGroups("group3") + assert(securityManager.checkModifyPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user1") === false) + + // change view groups so they match + securityManager.setViewAclsGroups("group2") + securityManager.setModifyAclsGroups("group4") + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user1") === true) + + // change modify and view groups so they do not match + securityManager.setViewAclsGroups("group7") + securityManager.setModifyAclsGroups("group8") + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user1") === false) } test("set security with * in acls") { @@ -166,6 +313,57 @@ class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { assert(securityManager.checkModifyPermissions("user8") === true) } + test("set security with * in acls for groups") { + val conf = new SparkConf + conf.set("spark.ui.acls.enable", "true") + conf.set("spark.admin.acls.groups", "group4,group5") + conf.set("spark.ui.view.acls.groups", "*") + conf.set("spark.modify.acls.groups", "group6") + + val securityManager = new SecurityManager(conf) + assert(securityManager.aclsEnabled() === true) + + // check for viewAclsGroups with * + assert(securityManager.checkUIViewPermissions("user1") === true) + assert(securityManager.checkUIViewPermissions("user2") === true) + assert(securityManager.checkModifyPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user2") === false) + + // check for modifyAcls with * + securityManager.setModifyAclsGroups("*") + securityManager.setViewAclsGroups("group6") + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkUIViewPermissions("user2") === false) + assert(securityManager.checkModifyPermissions("user1") === true) + assert(securityManager.checkModifyPermissions("user2") === true) + + // check for adminAcls with * + securityManager.setAdminAclsGroups("group9,*") + securityManager.setModifyAclsGroups("group4,group5") + securityManager.setViewAclsGroups("group6,group7") + assert(securityManager.checkUIViewPermissions("user5") === true) + assert(securityManager.checkUIViewPermissions("user6") === true) + assert(securityManager.checkModifyPermissions("user7") === true) + assert(securityManager.checkModifyPermissions("user8") === true) + } + + test("security for groups default behavior") { + // no groups or userToGroupsMapper provided + // this will default to the ShellBasedGroupsMappingProvider + val conf = new SparkConf + + val securityManager = new SecurityManager(conf) + securityManager.setAcls(true) + + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user1") === false) + + // set groups only + securityManager.setAdminAclsGroups("group1,group2") + assert(securityManager.checkUIViewPermissions("user1") === false) + assert(securityManager.checkModifyPermissions("user1") === false) + } + test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() val expectedAlgorithms = Set( diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 213d70f4e5840..7d75a93ff6839 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.SparkDeploySchedulerBackend -import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.scheduler.local.LocalBackend -import org.apache.spark.util.Utils +import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend +import org.apache.spark.scheduler.cluster.mesos.{MesosCoarseGrainedSchedulerBackend, MesosFineGrainedSchedulerBackend} +import org.apache.spark.scheduler.local.LocalSchedulerBackend + class SparkContextSchedulerCreationSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester with Logging { @@ -58,7 +58,7 @@ class SparkContextSchedulerCreationSuite test("local") { val sched = createTaskScheduler("local") sched.backend match { - case s: LocalBackend => assert(s.totalCores === 1) + case s: LocalSchedulerBackend => assert(s.totalCores === 1) case _ => fail() } } @@ -66,7 +66,8 @@ class SparkContextSchedulerCreationSuite test("local-*") { val sched = createTaskScheduler("local[*]") sched.backend match { - case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors()) + case s: LocalSchedulerBackend => + assert(s.totalCores === Runtime.getRuntime.availableProcessors()) case _ => fail() } } @@ -75,7 +76,7 @@ class SparkContextSchedulerCreationSuite val sched = createTaskScheduler("local[5]") assert(sched.maxTaskFailures === 1) sched.backend match { - case s: LocalBackend => assert(s.totalCores === 5) + case s: LocalSchedulerBackend => assert(s.totalCores === 5) case _ => fail() } } @@ -84,7 +85,8 @@ class SparkContextSchedulerCreationSuite val sched = createTaskScheduler("local[* ,2]") assert(sched.maxTaskFailures === 2) sched.backend match { - case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors()) + case s: LocalSchedulerBackend => + assert(s.totalCores === Runtime.getRuntime.availableProcessors()) case _ => fail() } } @@ -93,7 +95,7 @@ class SparkContextSchedulerCreationSuite val sched = createTaskScheduler("local[4, 2]") assert(sched.maxTaskFailures === 2) sched.backend match { - case s: LocalBackend => assert(s.totalCores === 4) + case s: LocalSchedulerBackend => assert(s.totalCores === 4) case _ => fail() } } @@ -117,14 +119,14 @@ class SparkContextSchedulerCreationSuite val sched = createTaskScheduler("local", "client", conf) sched.backend match { - case s: LocalBackend => assert(s.defaultParallelism() === 16) + case s: LocalSchedulerBackend => assert(s.defaultParallelism() === 16) case _ => fail() } } test("local-cluster") { createTaskScheduler("local-cluster[3, 14, 1024]").backend match { - case s: SparkDeploySchedulerBackend => // OK + case s: StandaloneSchedulerBackend => // OK case _ => fail() } } @@ -143,19 +145,16 @@ class SparkContextSchedulerCreationSuite } test("mesos fine-grained") { - testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend], coarse = false) + testMesos("mesos://localhost:1234", classOf[MesosFineGrainedSchedulerBackend], coarse = false) } test("mesos coarse-grained") { - testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend], coarse = true) + testMesos("mesos://localhost:1234", classOf[MesosCoarseGrainedSchedulerBackend], coarse = true) } test("mesos with zookeeper") { testMesos("mesos://zk://localhost:1234,localhost:2345", - classOf[MesosSchedulerBackend], coarse = false) + classOf[MesosFineGrainedSchedulerBackend], coarse = false) } - test("mesos with zookeeper and Master URL starting with zk://") { - testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend], coarse = false) - } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index a759f364fe059..c451c596b069a 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.net.MalformedURLException import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -108,7 +109,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { assert(byteArray2.length === 0) } - test("addFile works") { + test("basic case for addFile and listFiles") { val dir = Utils.createTempDir() val file1 = File.createTempFile("someprefix1", "somesuffix1", dir) @@ -156,6 +157,39 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } x }).count() + assert(sc.listFiles().filter(_.contains("somesuffix1")).size == 1) + } finally { + sc.stop() + } + } + + test("add and list jar files") { + val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar") + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addJar(jarPath.toString) + assert(sc.listJars().filter(_.contains("TestUDTF.jar")).size == 1) + } finally { + sc.stop() + } + } + + test("SPARK-17650: malformed url's throw exceptions before bricking Executors") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + Seq("http", "https", "ftp").foreach { scheme => + val badURL = s"$scheme://user:pwd/path" + val e1 = intercept[MalformedURLException] { + sc.addFile(badURL) + } + assert(e1.getMessage.contains(badURL)) + val e2 = intercept[MalformedURLException] { + sc.addJar(badURL) + } + assert(e2.getMessage.contains(badURL)) + assert(sc.addedFiles.isEmpty) + assert(sc.addedJars.isEmpty) + } } finally { sc.stop() } @@ -204,6 +238,57 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("cannot call addFile with different paths that have the same filename") { + val dir = Utils.createTempDir() + try { + val subdir1 = new File(dir, "subdir1") + val subdir2 = new File(dir, "subdir2") + assert(subdir1.mkdir()) + assert(subdir2.mkdir()) + val file1 = new File(subdir1, "file") + val file2 = new File(subdir2, "file") + Files.write("old", file1, StandardCharsets.UTF_8) + Files.write("new", file2, StandardCharsets.UTF_8) + sc = new SparkContext("local-cluster[1,1,1024]", "test") + sc.addFile(file1.getAbsolutePath) + def getAddedFileContents(): String = { + sc.parallelize(Seq(0)).map { _ => + scala.io.Source.fromFile(SparkFiles.get("file")).mkString + }.first() + } + assert(getAddedFileContents() === "old") + intercept[IllegalArgumentException] { + sc.addFile(file2.getAbsolutePath) + } + assert(getAddedFileContents() === "old") + } finally { + Utils.deleteRecursively(dir) + } + } + + // Regression tests for SPARK-16787 + for ( + schedulingMode <- Seq("local-mode", "non-local-mode"); + method <- Seq("addJar", "addFile") + ) { + val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar").toString + val master = schedulingMode match { + case "local-mode" => "local" + case "non-local-mode" => "local-cluster[1,1,1024]" + } + test(s"$method can be called twice with same file in $schedulingMode (SPARK-16787)") { + sc = new SparkContext(master, "test") + method match { + case "addJar" => + sc.addJar(jarPath) + sc.addJar(jarPath) + case "addFile" => + sc.addFile(jarPath) + sc.addFile(jarPath) + } + } + } + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) @@ -323,4 +408,47 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { assert(sc.getConf.getInt("spark.executor.instances", 0) === 6) } } + + + test("localProperties are inherited by spawned threads.") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.setLocalProperty("testProperty", "testValue") + var result = "unset"; + val thread = new Thread() { override def run() = {result = sc.getLocalProperty("testProperty")}} + thread.start() + thread.join() + sc.stop() + assert(result == "testValue") + } + + test("localProperties do not cross-talk between threads.") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + var result = "unset"; + val thread1 = new Thread() { + override def run() = {sc.setLocalProperty("testProperty", "testValue")}} + // testProperty should be unset and thus return null + val thread2 = new Thread() { + override def run() = {result = sc.getLocalProperty("testProperty")}} + thread1.start() + thread1.join() + thread2.start() + thread2.join() + sc.stop() + assert(result == null) + } + + test("log level case-insensitive and reset log level") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val originalLevel = org.apache.log4j.Logger.getRootLogger().getLevel + try { + sc.setLogLevel("debug") + assert(org.apache.log4j.Logger.getRootLogger().getLevel === org.apache.log4j.Level.DEBUG) + sc.setLogLevel("INfo") + assert(org.apache.log4j.Logger.getRootLogger().getLevel === org.apache.log4j.Level.INFO) + } finally { + sc.setLogLevel(originalLevel.toString) + assert(org.apache.log4j.Logger.getRootLogger().getLevel === originalLevel) + sc.stop() + } + } } diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 4aae2c9b4a8e4..cd876807f890e 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -21,11 +21,12 @@ package org.apache.spark import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} import org.apache.spark.internal.Logging +import org.apache.spark.util.AccumulatorContext /** * Base abstract class for all unit tests in Spark for handling common functionality. */ -private[spark] abstract class SparkFunSuite +abstract class SparkFunSuite extends FunSuite with BeforeAndAfterAll with Logging { diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 6657104823e71..973676398ae54 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -138,7 +138,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { } /** - * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. + * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * * This test creates a broadcast variable, uses it on all executors, and then unpersists it. * In between each step, this test verifies that the broadcast blocks are present only on the diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 271897699201b..54693c1bf81ee 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.Logging +import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch @@ -417,6 +418,8 @@ class SparkSubmitSuite // See https://gist.github.com/shivaram/3a2fecce60768a603dac for a error log ignore("correctly builds R packages included in a jar with --packages") { assume(RUtils.isRInstalled, "R isn't installed on this machine.") + // Check if the SparkR package is installed + assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val rScriptDir = @@ -435,6 +438,41 @@ class SparkSubmitSuite } } + test("include an external JAR in SparkR") { + assume(RUtils.isRInstalled, "R isn't installed on this machine.") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + // Check if the SparkR package is installed + assume(RUtils.isSparkRInstalled, "SparkR is not installed in this build.") + val rScriptDir = + Seq(sparkHome, "R", "pkg", "inst", "tests", "testthat", "jarTest.R").mkString(File.separator) + assert(new File(rScriptDir).exists) + + // compile a small jar containing a class that will be called from R code. + val tempDir = Utils.createTempDir() + val srcDir = new File(tempDir, "sparkrtest") + srcDir.mkdirs() + val excSource = new JavaSourceFromString(new File(srcDir, "DummyClass").getAbsolutePath, + """package sparkrtest; + | + |public class DummyClass implements java.io.Serializable { + | public static String helloWorld(String arg) { return "Hello " + arg; } + | public static int addStuff(int arg1, int arg2) { return arg1 + arg2; } + |} + """.stripMargin) + val excFile = TestUtils.createCompiledClass("DummyClass", srcDir, excSource, Seq.empty) + val jarFile = new File(tempDir, "sparkRTestJar-%s.jar".format(System.currentTimeMillis())) + val jarURL = TestUtils.createJar(Seq(excFile), jarFile, directoryPrefix = Some("sparkrtest")) + + val args = Seq( + "--name", "testApp", + "--master", "local", + "--jars", jarURL.toString, + "--verbose", + "--conf", "spark.ui.enabled=false", + rScriptDir) + runSparkSubmit(args) + } + test("resolves command line argument paths correctly") { val jars = "/jar1,/jar2" // --jars val files = "hdfs:/file1,file2" // --files @@ -539,6 +577,25 @@ class SparkSubmitSuite val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 sysProps3("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) + + // Test remote python files + val f4 = File.createTempFile("test-submit-remote-python-files", "", tmpDir) + val writer4 = new PrintWriter(f4) + val remotePyFiles = "hdfs:///tmp/file1.py,hdfs:///tmp/file2.py" + writer4.println("spark.submit.pyFiles " + remotePyFiles) + writer4.close() + val clArgs4 = Seq( + "--master", "yarn", + "--deploy-mode", "cluster", + "--properties-file", f4.getPath, + "hdfs:///tmp/mister.py" + ) + val appArgs4 = new SparkSubmitArguments(clArgs4) + val sysProps4 = SparkSubmit.prepareSubmitEnvironment(appArgs4)._3 + // Should not format python path for yarn cluster mode + sysProps4("spark.submit.pyFiles") should be( + Utils.resolveURIs(remotePyFiles) + ) } test("user classpath first in driver") { @@ -570,6 +627,18 @@ class SparkSubmitSuite appArgs.executorMemory should be ("2.3g") } } + + test("comma separated list of files are unioned correctly") { + val left = Option("/tmp/a.jar,/tmp/b.jar") + val right = Option("/tmp/c.jar,/tmp/a.jar") + val emptyString = Option("") + Utils.unionFileLists(left, right) should be (Set("/tmp/a.jar", "/tmp/b.jar", "/tmp/c.jar")) + Utils.unionFileLists(emptyString, emptyString) should be (Set.empty) + Utils.unionFileLists(Option("/tmp/a.jar"), emptyString) should be (Set("/tmp/a.jar")) + Utils.unionFileLists(emptyString, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) + Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) + Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) + } // scalastyle:on println // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 3d39bd4a748cf..814027076d6fe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -559,7 +559,7 @@ class StandaloneDynamicAllocationSuite val endpointRef = mock(classOf[RpcEndpointRef]) val mockAddress = mock(classOf[RpcAddress]) when(endpointRef.address).thenReturn(mockAddress) - val message = RegisterExecutor(id, endpointRef, 10, Map.empty) + val message = RegisterExecutor(id, endpointRef, "localhost", 10, Map.empty) val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend] backend.driverEndpoint.askWithRetry[Boolean](message) } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 7b46f9101d89b..bc58fb2a362a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.concurrent.duration._ import org.scalatest.BeforeAndAfterAll -import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.{Eventually, ScalaFutures} import org.apache.spark._ import org.apache.spark.deploy.{ApplicationDescription, Command} @@ -36,7 +36,12 @@ import org.apache.spark.util.Utils /** * End-to-end tests for application client in standalone mode. */ -class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterAll { +class AppClientSuite + extends SparkFunSuite + with LocalSparkContext + with BeforeAndAfterAll + with Eventually + with ScalaFutures { private val numWorkers = 2 private val conf = new SparkConf() private val securityManager = new SecurityManager(conf) @@ -93,7 +98,12 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd // Send message to Master to request Executors, verify request by change in executor limit val numExecutorsRequested = 1 - assert(ci.client.requestTotalExecutors(numExecutorsRequested)) + whenReady( + ci.client.requestTotalExecutors(numExecutorsRequested), + timeout(10.seconds), + interval(10.millis)) { acknowledged => + assert(acknowledged) + } eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() @@ -101,10 +111,12 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd } // Send request to kill executor, verify request was made - assert { - val apps = getApplications() - val executorId: String = apps.head.executors.head._2.fullId - ci.client.killExecutors(Seq(executorId)) + val executorId: String = getApplications().head.executors.head._2.fullId + whenReady( + ci.client.killExecutors(Seq(executorId)), + timeout(10.seconds), + interval(10.millis)) { acknowledged => + assert(acknowledged) } // Issue stop command for Client to disconnect from Master @@ -122,7 +134,9 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd val ci = new AppClientInst(masterRpcEnv.address.toSparkURL) // requests to master should fail immediately - assert(ci.client.requestTotalExecutors(3) === false) + whenReady(ci.client.requestTotalExecutors(3), timeout(1.seconds)) { success => + assert(success === false) + } } // =============================== @@ -166,7 +180,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd } /** Application Listener to collect events */ - private class AppClientCollector extends AppClientListener with Logging { + private class AppClientCollector extends StandaloneAppClientListener with Logging { val connectedIdList = new ConcurrentLinkedQueue[String]() @volatile var disconnectedCount: Int = 0 val deadReasonList = new ConcurrentLinkedQueue[String]() @@ -196,7 +210,8 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd execAddedList.add(id) } - def executorRemoved(id: String, message: String, exitStatus: Option[Int]): Unit = { + def executorRemoved( + id: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit = { execRemovedList.add(id) } } @@ -208,7 +223,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd List(), Map(), Seq(), Seq(), Seq()) private val desc = new ApplicationDescription("AppClientSuite", Some(1), 512, cmd, "ignored") val listener = new AppClientCollector - val client = new AppClient(rpcEnv, Array(masterUrl), desc, listener, new SparkConf) + val client = new StandaloneAppClient(rpcEnv, Array(masterUrl), desc, listener, new SparkConf) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 631a7cd9d5d7a..ae3f5d9c012ea 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -100,6 +100,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "minDate app list json" -> "applications?minDate=2015-02-10", "maxDate app list json" -> "applications?maxDate=2015-02-10", "maxDate2 app list json" -> "applications?maxDate=2015-02-03T16:42:40.000GMT", + "limit app list json" -> "applications?limit=3", "one app json" -> "applications/local-1422981780767", "one app multi-attempt json" -> "applications/local-1426533911241", "job list json" -> "applications/local-1422981780767/jobs", diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala index 0c9382a92bcaf..69a460fbc7dba 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -17,74 +17,96 @@ package org.apache.spark.deploy.master.ui +import java.io.DataOutputStream +import java.net.{HttpURLConnection, URL} +import java.nio.charset.StandardCharsets import java.util.Date -import scala.io.Source -import scala.language.postfixOps +import scala.collection.mutable.HashMap -import org.json4s.jackson.JsonMethods._ -import org.json4s.JsonAST.{JInt, JNothing, JString} -import org.mockito.Mockito.{mock, when} -import org.scalatest.BeforeAndAfter +import org.mockito.Mockito.{mock, times, verify, when} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.deploy.DeployMessages.MasterStateResponse +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver} import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ -import org.apache.spark.rpc.RpcEnv +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} -class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter { +class MasterWebUISuite extends SparkFunSuite with BeforeAndAfterAll { - val masterPage = mock(classOf[MasterPage]) - val master = { - val conf = new SparkConf - val securityMgr = new SecurityManager(conf) - val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) - val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) - master - } - val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage)) + val conf = new SparkConf + val securityMgr = new SecurityManager(conf) + val rpcEnv = mock(classOf[RpcEnv]) + val master = mock(classOf[Master]) + val masterEndpointRef = mock(classOf[RpcEndpointRef]) + when(master.securityMgr).thenReturn(securityMgr) + when(master.conf).thenReturn(conf) + when(master.rpcEnv).thenReturn(rpcEnv) + when(master.self).thenReturn(masterEndpointRef) + val masterWebUI = new MasterWebUI(master, 0) - before { + override def beforeAll() { + super.beforeAll() masterWebUI.bind() } - after { + override def afterAll() { masterWebUI.stop() + super.afterAll() } - test("list applications") { - val worker = createWorkerInfo() + test("kill application") { val appDesc = createAppDesc() // use new start date so it isn't filtered by UI val activeApp = new ApplicationInfo( - new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue) - activeApp.addExecutor(worker, 2) - - val workers = Array[WorkerInfo](worker) - val activeApps = Array(activeApp) - val completedApps = Array[ApplicationInfo]() - val activeDrivers = Array[DriverInfo]() - val completedDrivers = Array[DriverInfo]() - val stateResponse = new MasterStateResponse( - "host", 8080, None, workers, activeApps, completedApps, - activeDrivers, completedDrivers, RecoveryState.ALIVE) - - when(masterPage.getMasterState).thenReturn(stateResponse) - - val resultJson = Source.fromURL( - s"http://localhost:${masterWebUI.boundPort}/api/v1/applications") - .mkString - val parsedJson = parse(resultJson) - val firstApp = parsedJson(0) - - assert(firstApp \ "id" === JString(activeApp.id)) - assert(firstApp \ "name" === JString(activeApp.desc.name)) - assert(firstApp \ "coresGranted" === JInt(2)) - assert(firstApp \ "maxCores" === JInt(4)) - assert(firstApp \ "memoryPerExecutorMB" === JInt(1234)) - assert(firstApp \ "coresPerExecutor" === JNothing) + new Date().getTime, "app-0", appDesc, new Date(), null, Int.MaxValue) + + when(master.idToApp).thenReturn(HashMap[String, ApplicationInfo]((activeApp.id, activeApp))) + + val url = s"http://localhost:${masterWebUI.boundPort}/app/kill/" + val body = convPostDataToString(Map(("id", activeApp.id), ("terminate", "true"))) + val conn = sendHttpRequest(url, "POST", body) + conn.getResponseCode + + // Verify the master was called to remove the active app + verify(master, times(1)).removeApplication(activeApp, ApplicationState.KILLED) + } + + test("kill driver") { + val activeDriverId = "driver-0" + val url = s"http://localhost:${masterWebUI.boundPort}/driver/kill/" + val body = convPostDataToString(Map(("id", activeDriverId), ("terminate", "true"))) + val conn = sendHttpRequest(url, "POST", body) + conn.getResponseCode + + // Verify that master was asked to kill driver with the correct id + verify(masterEndpointRef, times(1)).ask[KillDriverResponse](RequestKillDriver(activeDriverId)) } + private def convPostDataToString(data: Map[String, String]): String = { + (for ((name, value) <- data) yield s"$name=$value").mkString("&") + } + + /** + * Send an HTTP request to the given URL using the method and the body specified. + * Return the connection object. + */ + private def sendHttpRequest( + url: String, + method: String, + body: String = ""): HttpURLConnection = { + val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod(method) + if (body.nonEmpty) { + conn.setDoOutput(true) + conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded") + conn.setRequestProperty("Content-Length", Integer.toString(body.length)) + val out = new DataOutputStream(conn.getOutputStream) + out.write(body.getBytes(StandardCharsets.UTF_8)) + out.close() + } + conn + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index a7bb9aa4686eb..dd50e33da30ac 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -408,7 +408,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { /** * Start a [[StandaloneRestServer]] that communicates with the given endpoint. - * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead. + * If `faulty` is true, start a [[FaultyStandaloneRestServer]] instead. * Return the master URL that corresponds to the address of this server. */ private def startServer( diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala index 72eaffb416981..4c3e96777940d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ui/LogPageSuite.scala @@ -22,16 +22,20 @@ import java.io.{File, FileWriter} import org.mockito.Mockito.{mock, when} import org.scalatest.PrivateMethodTester -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.worker.Worker class LogPageSuite extends SparkFunSuite with PrivateMethodTester { test("get logs simple") { val webui = mock(classOf[WorkerWebUI]) + val worker = mock(classOf[Worker]) val tmpDir = new File(sys.props("java.io.tmpdir")) val workDir = new File(tmpDir, "work-dir") workDir.mkdir() when(webui.workDir).thenReturn(workDir) + when(webui.worker).thenReturn(worker) + when(worker.conf).thenReturn(new SparkConf()) val logPage = new LogPage(webui) // Prepare some fake log files to read later diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 94f6e1a3a77c1..eae26fa742a23 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.storage.{BlockStatus, StorageLevel, TestBlockId} +import org.apache.spark.util.AccumulatorV2 class TaskMetricsSuite extends SparkFunSuite { @@ -203,7 +204,7 @@ class TaskMetricsSuite extends SparkFunSuite { acc1.add(1) acc2.add(2) val newUpdates = tm.accumulators() - .map(a => (a.id, a.asInstanceOf[NewAccumulator[Any, Any]])).toMap + .map(a => (a.id, a.asInstanceOf[AccumulatorV2[Any, Any]])).toMap assert(newUpdates.contains(acc1.id)) assert(newUpdates.contains(acc2.id)) assert(newUpdates.contains(acc3.id)) @@ -230,8 +231,8 @@ private[spark] object TaskMetricsSuite extends Assertions { * Note: this does NOT check accumulator ID equality. */ def assertUpdatesEquals( - updates1: Seq[NewAccumulator[_, _]], - updates2: Seq[NewAccumulator[_, _]]): Unit = { + updates1: Seq[AccumulatorV2[_, _]], + updates2: Seq[AccumulatorV2[_, _]]): Unit = { assert(updates1.size === updates2.size) updates1.zip(updates2).foreach { case (acc1, acc2) => // do not assert ID equals here diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index f205d4f0d60b5..38b48a4c9e654 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -38,12 +38,6 @@ class ChunkedByteBufferSuite extends SparkFunSuite { emptyChunkedByteBuffer.toInputStream(dispose = true).close() } - test("chunks must be non-empty") { - intercept[IllegalArgumentException] { - new ChunkedByteBuffer(Array(ByteBuffer.allocate(0))) - } - } - test("getChunks() duplicates chunks") { val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8))) chunkedByteBuffer.getChunks().head.position(4) @@ -63,8 +57,9 @@ class ChunkedByteBufferSuite extends SparkFunSuite { } test("toArray()") { + val empty = ByteBuffer.wrap(Array[Byte]()) val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte)) - val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes, bytes)) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes, bytes, empty)) assert(chunkedByteBuffer.toArray === bytes.array() ++ bytes.array()) } @@ -79,9 +74,10 @@ class ChunkedByteBufferSuite extends SparkFunSuite { } test("toInputStream()") { + val empty = ByteBuffer.wrap(Array[Byte]()) val bytes1 = ByteBuffer.wrap(Array.tabulate(256)(_.toByte)) val bytes2 = ByteBuffer.wrap(Array.tabulate(128)(_.toByte)) - val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes1, bytes2)) + val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, bytes2)) assert(chunkedByteBuffer.size === bytes1.limit() + bytes2.limit()) val inputStream = chunkedByteBuffer.toInputStream(dispose = false) diff --git a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala index 713560d3ddfa1..cac15a1dc4414 100644 --- a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala @@ -48,7 +48,7 @@ class LauncherBackendSuite extends SparkFunSuite with Matchers { .setConf("spark.ui.enabled", "false") .setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, s"-Dtest.appender=console") .setMaster(master) - .setAppResource("spark-internal") + .setAppResource(SparkLauncher.NO_RESOURCE) .setMainClass(TestApp.getClass.getName().stripSuffix("$")) .startApplication() diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index a1286523a235d..38bf7e5e5aec3 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -78,6 +78,21 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft ms } + /** + * Make a mocked [[MemoryStore]] whose [[MemoryStore.evictBlocksToFreeSpace]] method is + * stubbed to always throw [[RuntimeException]]. + */ + protected def makeBadMemoryStore(mm: MemoryManager): MemoryStore = { + val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS) + when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())).thenAnswer(new Answer[Long] { + override def answer(invocation: InvocationOnMock): Long = { + throw new RuntimeException("bad memory store!") + } + }) + mm.setMemoryStore(ms) + ms + } + /** * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. * @@ -147,39 +162,42 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft test("single task requesting on-heap execution memory") { val manager = createMemoryManager(1000L) val taskMemoryManager = new TaskMemoryManager(manager, 0) + val c = new TestMemoryConsumer(taskMemoryManager) - assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 100L) - assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) - assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) - assert(taskMemoryManager.acquireExecutionMemory(200L, MemoryMode.ON_HEAP, null) === 100L) - assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) - assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L, c) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(400L, c) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(400L, c) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(200L, c) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(100L, c) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L, c) === 0L) - taskMemoryManager.releaseExecutionMemory(500L, MemoryMode.ON_HEAP, null) - assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 300L) - assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 200L) + taskMemoryManager.releaseExecutionMemory(500L, c) + assert(taskMemoryManager.acquireExecutionMemory(300L, c) === 300L) + assert(taskMemoryManager.acquireExecutionMemory(300L, c) === 200L) taskMemoryManager.cleanUpAllAllocatedMemory() - assert(taskMemoryManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) === 1000L) - assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(1000L, c) === 1000L) + assert(taskMemoryManager.acquireExecutionMemory(100L, c) === 0L) } test("two tasks requesting full on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val c1 = new TestMemoryConsumer(t1MemManager) + val c2 = new TestMemoryConsumer(t2MemManager) val futureTimeout: Duration = 20.seconds // Have both tasks request 500 bytes, then wait until both requests have been granted: - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, c1) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, c2) } assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 500L) assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 500L) // Have both tasks each request 500 bytes more; both should immediately return 0 as they are // both now at 1 / N - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, c1) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, c2) } assert(ThreadUtils.awaitResult(t1Result2, 200.millis) === 0L) assert(ThreadUtils.awaitResult(t2Result2, 200.millis) === 0L) } @@ -188,18 +206,20 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val c1 = new TestMemoryConsumer(t1MemManager) + val c2 = new TestMemoryConsumer(t2MemManager) val futureTimeout: Duration = 20.seconds // Have both tasks request 250 bytes, then wait until both requests have been granted: - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, c1) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, c2) } assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 250L) assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 250L) // Have both tasks each request 500 bytes more. // We should only grant 250 bytes to each of them on this second request - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, c1) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, c2) } assert(ThreadUtils.awaitResult(t1Result2, futureTimeout) === 250L) assert(ThreadUtils.awaitResult(t2Result2, futureTimeout) === 250L) } @@ -208,20 +228,22 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val c1 = new TestMemoryConsumer(t1MemManager) + val c2 = new TestMemoryConsumer(t2MemManager) val futureTimeout: Duration = 20.seconds // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, c1) } assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 1000L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, c2) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) - t1MemManager.releaseExecutionMemory(250L, MemoryMode.ON_HEAP, null) + t1MemManager.releaseExecutionMemory(250L, c1) // The memory freed from t1 should now be granted to t2. assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 250L) // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory. - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, c2) } assert(ThreadUtils.awaitResult(t2Result2, 200.millis) === 0L) } @@ -229,21 +251,23 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val c1 = new TestMemoryConsumer(t1MemManager) + val c2 = new TestMemoryConsumer(t2MemManager) val futureTimeout: Duration = 20.seconds // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, c1) } assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 1000L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, c2) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) // t1 releases all of its memory, so t2 should be able to grab all of the memory t1MemManager.cleanUpAllAllocatedMemory() assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 500L) - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, c2) } assert(ThreadUtils.awaitResult(t2Result2, futureTimeout) === 500L) - val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, c2) } assert(ThreadUtils.awaitResult(t2Result3, 200.millis) === 0L) } @@ -252,15 +276,17 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val c1 = new TestMemoryConsumer(t1MemManager) + val c2 = new TestMemoryConsumer(t2MemManager) val futureTimeout: Duration = 20.seconds - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, MemoryMode.ON_HEAP, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, c1) } assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 700L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, c2) } assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 300L) - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, c1) } assert(ThreadUtils.awaitResult(t1Result2, 200.millis) === 0L) } @@ -270,17 +296,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft maxOffHeapExecutionMemory = 1000L) val tMemManager = new TaskMemoryManager(memoryManager, 1) - val result1 = Future { tMemManager.acquireExecutionMemory(1000L, MemoryMode.OFF_HEAP, null) } + val c = new TestMemoryConsumer(tMemManager, MemoryMode.OFF_HEAP) + val result1 = Future { tMemManager.acquireExecutionMemory(1000L, c) } assert(ThreadUtils.awaitResult(result1, 200.millis) === 1000L) assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) - val result2 = Future { tMemManager.acquireExecutionMemory(300L, MemoryMode.OFF_HEAP, null) } + val result2 = Future { tMemManager.acquireExecutionMemory(300L, c) } assert(ThreadUtils.awaitResult(result2, 200.millis) === 0L) assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) - tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + tMemManager.releaseExecutionMemory(500L, c) assert(tMemManager.getMemoryConsumptionForThisTask === 500L) - tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + tMemManager.releaseExecutionMemory(500L, c) assert(tMemManager.getMemoryConsumptionForThisTask === 0L) } } diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 14255818c7b5e..c821054412d7d 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -280,4 +280,27 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(evictedBlocks.nonEmpty) } + test("SPARK-15260: atomically resize memory pools") { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.memory.storageFraction", "0") + .set("spark.testing.memory", "1000") + val mm = UnifiedMemoryManager(conf, numCores = 2) + makeBadMemoryStore(mm) + val memoryMode = MemoryMode.ON_HEAP + // Acquire 1000 then release 600 bytes of storage memory, leaving the + // storage memory pool at 1000 bytes but only 400 bytes of which are used. + assert(mm.acquireStorageMemory(dummyBlock, 1000L, memoryMode)) + mm.releaseStorageMemory(600L, memoryMode) + // Before the fix for SPARK-15260, we would first shrink the storage pool by the amount of + // unused storage memory (600 bytes), try to evict blocks, then enlarge the execution pool + // by the same amount. If the eviction threw an exception, then we would shrink one pool + // without enlarging the other, resulting in an assertion failure. + intercept[RuntimeException] { + mm.acquireExecutionMemory(1000L, 0, memoryMode) + } + val assertInvariants = PrivateMethod[Unit]('assertInvariants) + mm.invokePrivate[Unit](assertInvariants()) + } + } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 5d8554229dbe1..2400832f6eea7 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource -import org.apache.spark.metrics.source.Source +import org.apache.spark.metrics.source.{Source, StaticSources} class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ @@ -43,7 +43,7 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM val sources = PrivateMethod[ArrayBuffer[Source]]('sources) val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) - assert(metricsSystem.invokePrivate(sources()).length === 0) + assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length) assert(metricsSystem.invokePrivate(sinks()).length === 0) assert(metricsSystem.getServletHandlers.nonEmpty) } @@ -54,13 +54,13 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM val sources = PrivateMethod[ArrayBuffer[Source]]('sources) val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) - assert(metricsSystem.invokePrivate(sources()).length === 0) + assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length) assert(metricsSystem.invokePrivate(sinks()).length === 1) assert(metricsSystem.getServletHandlers.nonEmpty) val source = new MasterSource(null) metricsSystem.registerSource(source) - assert(metricsSystem.invokePrivate(sources()).length === 1) + assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length + 1) } test("MetricsSystem with Driver instance") { diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 8cb0a295b0773..58664e77d24a5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -65,9 +65,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim test("foreachAsync") { zeroPartRdd.foreachAsync(i => Unit).get() - val accum = sc.accumulator(0) + val accum = sc.longAccumulator sc.parallelize(1 to 1000, 3).foreachAsync { i => - accum += 1 + accum.add(1) }.get() assert(accum.value === 1000) } @@ -75,9 +75,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim test("foreachPartitionAsync") { zeroPartRdd.foreachPartitionAsync(iter => Unit).get() - val accum = sc.accumulator(0) + val accum = sc.longAccumulator sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter => - accum += 1 + accum.add(1) }.get() assert(accum.value === 9) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index e9cc8195240f0..f8d523fa2c6ae 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import java.io.File import scala.collection.Map +import scala.io.Codec import scala.language.postfixOps import scala.sys.process._ import scala.util.Try @@ -50,6 +51,22 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("basic pipe with tokenization") { + if (testCommandAvailable("wc")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + + // verify that both RDD.pipe(command: String) and RDD.pipe(command: String, env) work good + for (piped <- Seq(nums.pipe("wc -l"), nums.pipe("wc -l", Map[String, String]()))) { + val c = piped.collect() + assert(c.size === 2) + assert(c(0).trim === "2") + assert(c(1).trim === "2") + } + } else { + assert(true) + } + } + test("failure in iterating over pipe input") { if (testCommandAvailable("cat")) { val nums = @@ -121,6 +138,14 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("pipe with empty partition") { + val data = sc.parallelize(Seq("foo", "bing"), 8) + val piped = data.pipe("wc -c") + assert(piped.count == 8) + val charCounts = piped.map(_.trim.toInt).collect().toSet + assert(Set(0, 4, 5) == charCounts) + } + test("pipe with env variable") { if (testCommandAvailable("printenv")) { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) @@ -171,7 +196,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true) val collectPwd = pipedPwd.collect() assert(collectPwd(0).contains("tasks/")) - val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true).collect() + val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true, bufferSize = 16384).collect() // make sure symlinks were created assert(pipedLs.length > 0) // clean up top level tasks directory @@ -207,7 +232,16 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } } val hadoopPart1 = generateFakeHadoopPartition() - val pipedRdd = new PipedRDD(nums, "printenv " + varName) + val pipedRdd = + new PipedRDD( + nums, + PipedRDD.tokenize("printenv " + varName), + Map(), + null, + null, + false, + 4092, + Codec.defaultCharsetCodec.name) val tContext = TaskContext.empty() val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 8dc463d56d182..ad56715656c85 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -116,6 +116,23 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) } + test("SparkContext.union parallel partition listing") { + val nums1 = sc.makeRDD(Array(1, 2, 3, 4), 2) + val nums2 = sc.makeRDD(Array(5, 6, 7, 8), 2) + val serialUnion = sc.union(nums1, nums2) + val expected = serialUnion.collect().toList + + assert(serialUnion.asInstanceOf[UnionRDD[Int]].isPartitionListingParallel === false) + + sc.conf.set("spark.rdd.parallelListingThreshold", "1") + val parallelUnion = sc.union(nums1, nums2) + val actual = parallelUnion.collect().toList + sc.conf.remove("spark.rdd.parallelListingThreshold") + + assert(parallelUnion.asInstanceOf[UnionRDD[Int]].isPartitionListingParallel === true) + assert(expected === actual) + } + test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1)) val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true)) @@ -259,6 +276,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("repartitioned RDDs") { val data = sc.parallelize(1 to 1000, 10) + intercept[IllegalArgumentException] { + data.repartition(0) + } + // Coalesce partitions val repartitioned1 = data.repartition(2) assert(repartitioned1.partitions.size == 2) @@ -312,6 +333,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("coalesced RDDs") { val data = sc.parallelize(1 to 10, 10) + intercept[IllegalArgumentException] { + data.coalesce(0) + } + val coalesced1 = data.coalesce(2) assert(coalesced1.collect().toList === (1 to 10).toList) assert(coalesced1.glom().collect().map(_.toList).toList === @@ -377,6 +402,33 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") } + test("coalesced RDDs with partial locality") { + // Make an RDD that has some locality preferences and some without. This can happen + // with UnionRDD + val data = sc.makeRDD((1 to 9).map(i => { + if (i > 4) { + (i, (i to (i + 2)).map { j => "m" + (j % 6) }) + } else { + (i, Vector()) + } + })) + val coalesced1 = data.coalesce(3) + assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing") + + val splits = coalesced1.glom().collect().map(_.toList).toList + assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length) + + assert(splits.forall(_.length >= 1) === true, "Some partitions were empty") + + // If we try to coalesce into more partitions than the original RDD, it should just + // keep the original number of partitions. + val coalesced4 = data.coalesce(20) + val listOfLists = coalesced4.glom().collect().map(_.toList).toList + val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) } + assert(sortedList === (1 to 9). + map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back") + } + test("coalesced RDDs with locality, large scale (10K partitions)") { // large scale experiment import collection.mutable @@ -418,6 +470,48 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("coalesced RDDs with partial locality, large scale (10K partitions)") { + // large scale experiment + import collection.mutable + val halfpartitions = 5000 + val partitions = 10000 + val numMachines = 50 + val machines = mutable.ListBuffer[String]() + (1 to numMachines).foreach(machines += "m" + _) + val rnd = scala.util.Random + for (seed <- 1 to 5) { + rnd.setSeed(seed) + + val firstBlocks = (1 to halfpartitions).map { i => + (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) + } + val blocksNoLocality = (halfpartitions + 1 to partitions).map { i => + (i, List()) + } + val blocks = firstBlocks ++ blocksNoLocality + + val data2 = sc.makeRDD(blocks) + + // first try going to same number of partitions + val coalesced2 = data2.coalesce(partitions) + + // test that we have 10000 partitions + assert(coalesced2.partitions.size == 10000, "Expected 10000 partitions, but got " + + coalesced2.partitions.size) + + // test that we have 100 partitions + val coalesced3 = data2.coalesce(numMachines * 2) + assert(coalesced3.partitions.size == 100, "Expected 100 partitions, but got " + + coalesced3.partitions.size) + + // test that the groups are load balanced with 100 +/- 20 elements in each + val maxImbalance3 = coalesced3.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) + .foldLeft(0)((dev, curr) => math.max(math.abs(100 - curr), dev)) + assert(maxImbalance3 <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance3) + } + } + // Test for SPARK-2412 -- ensure that the second pass of the algorithm does not throw an exception test("coalesced RDDs with locality, fail first pass") { val initialPartitions = 1000 @@ -592,27 +686,26 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } { val sample = data.takeSample(withReplacement = true, num = 20) - assert(sample.size === 20) // Got exactly 100 elements - assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") + assert(sample.size === 20) // Got exactly 20 elements assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement = true, num = n) - assert(sample.size === n) // Got exactly 100 elements - // Chance of getting all distinct elements is astronomically low, so test we got < 100 + assert(sample.size === n) // Got exactly n elements + // Chance of getting all distinct elements is astronomically low, so test we got < n assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement = true, n, seed) - assert(sample.size === n) // Got exactly 100 elements - // Chance of getting all distinct elements is astronomically low, so test we got < 100 + assert(sample.size === n) // Got exactly n elements + // Chance of getting all distinct elements is astronomically low, so test we got < n assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement = true, 2 * n, seed) - assert(sample.size === 2 * n) // Got exactly 200 elements - // Chance of getting all distinct elements is still quite low, so test we got < 100 + assert(sample.size === 2 * n) // Got exactly 2 * n elements + // Chance of getting all distinct elements is still quite low, so test we got < n assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 505cd476ff881..acdf21df9a161 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -489,7 +489,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { /** * Setup an [[RpcEndpoint]] to collect all network events. * - * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events. + * @return the [[RpcEndpointRef]] and a `ConcurrentLinkedQueue` that contains network events. */ private def setupNetworkEndpoint( _env: RpcEnv, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9912d1f3bc5a7..5c353021677ef 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.util.Properties +import java.util.concurrent.atomic.AtomicBoolean import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} @@ -28,11 +29,12 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.executor.TaskMetrics +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} -import org.apache.spark.util.{CallSite, Utils} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { @@ -112,7 +114,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def stop() = {} override def executorHeartbeatReceived( execId: String, - accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], + accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean = true override def submitTasks(taskSet: TaskSet) = { // normally done by TaskSetManager @@ -157,6 +159,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } var mapOutputTracker: MapOutputTrackerMaster = null + var broadcastManager: BroadcastManager = null + var securityMgr: SecurityManager = null var scheduler: DAGScheduler = null var dagEventProcessLoopTester: DAGSchedulerEventProcessLoop = null @@ -197,7 +201,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def beforeEach(): Unit = { super.beforeEach() - sc = new SparkContext("local", "DAGSchedulerSuite") + init(new SparkConf()) + } + + private def init(testConf: SparkConf): Unit = { + sc = new SparkContext("local", "DAGSchedulerSuite", testConf) sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() @@ -208,7 +216,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou cancelledStages.clear() cacheLocations.clear() results.clear() - mapOutputTracker = new MapOutputTrackerMaster(conf) + securityMgr = new SecurityManager(conf) + broadcastManager = new BroadcastManager(true, conf, securityMgr) + mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) scheduler = new DAGScheduler( sc, taskScheduler, @@ -321,6 +331,53 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(sparkListener.stageByOrderOfExecution(0) < sparkListener.stageByOrderOfExecution(1)) } + /** + * This test ensures that DAGScheduler build stage graph correctly. + * + * Suppose you have the following DAG: + * + * [A] <--(s_A)-- [B] <--(s_B)-- [C] <--(s_C)-- [D] + * \ / + * <------------- + * + * Here, RDD B has a shuffle dependency on RDD A, and RDD C has shuffle dependency on both + * B and A. The shuffle dependency IDs are numbers in the DAGScheduler, but to make the example + * easier to understand, let's call the shuffled data from A shuffle dependency ID s_A and the + * shuffled data from B shuffle dependency ID s_B. + * + * Note: [] means an RDD, () means a shuffle dependency. + */ + test("[SPARK-13902] Ensure no duplicate stages are created") { + val rddA = new MyRDD(sc, 1, Nil) + val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1)) + val s_A = shuffleDepA.shuffleId + + val rddB = new MyRDD(sc, 1, List(shuffleDepA), tracker = mapOutputTracker) + val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1)) + val s_B = shuffleDepB.shuffleId + + val rddC = new MyRDD(sc, 1, List(shuffleDepA, shuffleDepB), tracker = mapOutputTracker) + val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(1)) + val s_C = shuffleDepC.shuffleId + + val rddD = new MyRDD(sc, 1, List(shuffleDepC), tracker = mapOutputTracker) + + submit(rddD, Array(0)) + + assert(scheduler.shuffleToMapStage.size === 3) + assert(scheduler.activeJobs.size === 1) + + val mapStageA = scheduler.shuffleToMapStage(s_A) + val mapStageB = scheduler.shuffleToMapStage(s_B) + val mapStageC = scheduler.shuffleToMapStage(s_C) + val finalStage = scheduler.activeJobs.head.finalStage + + assert(mapStageA.parents.isEmpty) + assert(mapStageB.parents === List(mapStageA)) + assert(mapStageC.parents === List(mapStageA, mapStageB)) + assert(finalStage.parents === List(mapStageC)) + } + test("zero split job") { var numResults = 0 var failureReason: Option[Exception] = None @@ -483,7 +540,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( execId: String, - accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], + accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean = true override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def applicationAttemptId(): Option[String] = None @@ -554,6 +611,46 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + private val shuffleFileLossTests = Seq( + ("slave lost with shuffle service", SlaveLost("", false), true, false), + ("worker lost with shuffle service", SlaveLost("", true), true, true), + ("worker lost without shuffle service", SlaveLost("", true), false, true), + ("executor failure with shuffle service", ExecutorKilled, true, false), + ("executor failure without shuffle service", ExecutorKilled, false, true)) + + for ((eventDescription, event, shuffleServiceOn, expectFileLoss) <- shuffleFileLossTests) { + val maybeLost = if (expectFileLoss) { + "lost" + } else { + "not lost" + } + test(s"shuffle files $maybeLost when $eventDescription") { + // reset the test context with the right shuffle service config + afterEach() + val conf = new SparkConf() + conf.set("spark.shuffle.service.enabled", shuffleServiceOn.toString) + init(conf) + assert(sc.env.blockManager.externalShuffleServiceEnabled == shuffleServiceOn) + + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)), + (Success, makeMapStatus("hostB", 1)))) + runEvent(ExecutorLost("exec-hostA", event)) + if (expectFileLoss) { + intercept[MetadataFetchFailedException] { + mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0) + } + } else { + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + } + } + } // Helper function to validate state when creating tests for task failures private def checkStageId(stageId: Int, attempt: Int, stageAttempt: TaskSet) { @@ -561,7 +658,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(stageAttempt.stageAttemptId == attempt) } - // Helper functions to extract commonly used code in Fetch Failure test cases private def setupStageAbortTest(sc: SparkContext) { sc.listenerBus.addListener(new EndListener()) @@ -1043,7 +1139,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // pretend we were told hostA went away val oldEpoch = mapOutputTracker.getEpoch - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) @@ -1174,7 +1270,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou )) // then one executor dies, and a task fails in stage 1 - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) runEvent(makeCompletionEvent( taskSets(1).tasks(0), FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), @@ -1272,7 +1368,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou makeMapStatus("hostA", reduceRdd.partitions.length))) // now that host goes down - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) // so we resubmit those tasks runEvent(makeCompletionEvent(taskSets(0).tasks(0), Resubmitted, null)) @@ -1465,7 +1561,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou submit(reduceRdd, Array(0)) // blockManagerMaster.removeExecutor("exec-hostA") // pretend we were told hostA went away - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks // rather than marking it is as failed and waiting. complete(taskSets(0), Seq( @@ -1542,13 +1638,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { - val acc = new Accumulator[Int](0, new AccumulatorParam[Int] { - override def addAccumulator(t1: Int, t2: Int): Int = t1 + t2 - override def zero(initialValue: Int): Int = 0 - override def addInPlace(r1: Int, r2: Int): Int = { - throw new DAGSchedulerSuiteDummyException - } - }) + val acc = new LongAccumulator { + override def add(v: java.lang.Long): Unit = throw new DAGSchedulerSuiteDummyException + override def add(v: Long): Unit = throw new DAGSchedulerSuiteDummyException + } + sc.register(acc) // Run this on executors sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } @@ -1652,7 +1746,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } test("reduce tasks should be placed locally with map output") { - // Create an shuffleMapRdd with 1 partition + // Create a shuffleMapRdd with 1 partition val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId @@ -1673,7 +1767,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou test("reduce task locality preferences should only include machines with largest map outputs") { val numMapTasks = 4 - // Create an shuffleMapRdd with more partitions + // Create a shuffleMapRdd with more partitions val shuffleMapRdd = new MyRDD(sc, numMapTasks, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleId = shuffleDep.shuffleId @@ -1934,7 +2028,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // Pretend host A was lost val oldEpoch = mapOutputTracker.getEpoch - runEvent(ExecutorLost("exec-hostA")) + runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) val newEpoch = mapOutputTracker.getEpoch assert(newEpoch > oldEpoch) @@ -1965,6 +2059,61 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + test("SPARK-17644: After one stage is aborted for too many failed attempts, subsequent stages" + + "still behave correctly on fetch failures") { + // Runs a job that always encounters a fetch failure, so should eventually be aborted + def runJobWithPersistentFetchFailure: Unit = { + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val shuffleHandle = + rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + rdd1.map { + case (x, _) if (x == 1) => + throw new FetchFailedException( + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + case (x, _) => x + }.count() + } + + // Runs a job that encounters a single fetch failure but succeeds on the second attempt + def runJobWithTemporaryFetchFailure: Unit = { + object FailThisAttempt { + val _fail = new AtomicBoolean(true) + } + val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).map(x => (x, 1)).groupByKey() + val shuffleHandle = + rdd1.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle + rdd1.map { + case (x, _) if (x == 1) && FailThisAttempt._fail.getAndSet(false) => + throw new FetchFailedException( + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + } + } + + failAfter(10.seconds) { + val e = intercept[SparkException] { + runJobWithPersistentFetchFailure + } + assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException")) + } + + // Run a second job that will fail due to a fetch failure. + // This job will hang without the fix for SPARK-17644. + failAfter(10.seconds) { + val e = intercept[SparkException] { + runJobWithPersistentFetchFailure + } + assert(e.getMessage.contains("org.apache.spark.shuffle.FetchFailedException")) + } + + failAfter(10.seconds) { + try { + runJobWithTemporaryFetchFailure + } catch { + case e: Throwable => fail("A job with one fetch failure should eventually succeed") + } + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. @@ -2012,7 +2161,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou task: Task[_], reason: TaskEndReason, result: Any, - extraAccumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty, + extraAccumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty, taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = { val accumUpdates = reason match { case Success => task.metrics.accumulators() diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 176d8930aad19..7f4859206e257 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -142,14 +142,14 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit extraConf.foreach { case (k, v) => conf.set(k, v) } val logName = compressionCodec.map("test-" + _).getOrElse("test") val eventLogger = new EventLoggingListener(logName, None, testDirPath.toUri(), conf) - val listenerBus = new LiveListenerBus + val listenerBus = new LiveListenerBus(sc) val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, 125L, "Mickey", None) val applicationEnd = SparkListenerApplicationEnd(1000L) // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite eventLogger.start() - listenerBus.start(sc) + listenerBus.start() listenerBus.addListener(eventLogger) listenerBus.postToAll(applicationStart) listenerBus.postToAll(applicationEnd) @@ -181,7 +181,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") - val sc = new SparkContext("local-cluster[2,2,1024]", "test", conf) + sc = new SparkContext("local-cluster[2,2,1024]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val eventLogPath = eventLogger.logPath diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 16027d944fdfd..59c1b359a780a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.scheduler -import org.apache.spark.{LocalSparkContext, NewAccumulator, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.AccumulatorV2 class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext { test("launch of backend and scheduler") { @@ -67,6 +68,6 @@ private class DummyTaskScheduler extends TaskScheduler { override def applicationAttemptId(): Option[String] = None def executorHeartbeatReceived( execId: String, - accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], + accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId): Boolean = true } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index 601f1c378c41f..32cdf16dd3318 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -37,7 +37,6 @@ class OutputCommitCoordinatorIntegrationSuite override def beforeAll(): Unit = { super.beforeAll() val conf = new SparkConf() - .set("master", "local[2,4]") .set("spark.hadoop.outputCommitCoordination.enabled", "true") .set("spark.hadoop.mapred.output.committer.class", classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 35215c15ea805..1732aca9417ea 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -23,7 +23,7 @@ import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, JsonProtocolSuite, Utils} @@ -31,7 +31,7 @@ import org.apache.spark.util.{JsonProtocol, JsonProtocolSuite, Utils} /** * Test whether ReplayListenerBus replays events from logs correctly. */ -class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { +class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { private val fileSystem = Utils.getHadoopFileSystem("/", SparkHadoopUtil.get.newConfiguration(new SparkConf())) private var testDir: File = _ @@ -101,7 +101,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { fileSystem.mkdirs(logDirPath) val conf = EventLoggingListenerSuite.getLoggingConf(logDirPath, codecName) - val sc = new SparkContext("local-cluster[2,1,1024]", "Test replay", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "Test replay", conf) // Run a few jobs sc.parallelize(1 to 100, 1).count() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 5ba67afc0cd62..e8a88d4909a83 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -37,13 +37,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val jobCompletionTime = 1421191296660L test("don't call sc.stop in listener") { - sc = new SparkContext("local", "SparkListenerSuite") + sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) val listener = new SparkContextStoppingListener(sc) - val bus = new LiveListenerBus + val bus = new LiveListenerBus(sc) bus.addListener(listener) // Starting listener bus should flush all buffered events - bus.start(sc) + bus.start() bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) @@ -52,8 +52,9 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("basic creation and shutdown of LiveListenerBus") { + sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) val counter = new BasicJobCounter - val bus = new LiveListenerBus + val bus = new LiveListenerBus(sc) bus.addListener(counter) // Listener bus hasn't started yet, so posting events should not increment counter @@ -61,7 +62,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(counter.count === 0) // Starting listener bus should flush all buffered events - bus.start(sc) + bus.start() bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) @@ -72,14 +73,14 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Listener bus must not be started twice intercept[IllegalStateException] { - val bus = new LiveListenerBus - bus.start(sc) - bus.start(sc) + val bus = new LiveListenerBus(sc) + bus.start() + bus.start() } // ... or stopped before starting intercept[IllegalStateException] { - val bus = new LiveListenerBus + val bus = new LiveListenerBus(sc) bus.stop() } } @@ -106,12 +107,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match drained = true } } - - val bus = new LiveListenerBus + sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) + val bus = new LiveListenerBus(sc) val blockingListener = new BlockingListener bus.addListener(blockingListener) - bus.start(sc) + bus.start() bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() @@ -353,13 +354,14 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val badListener = new BadListener val jobCounter1 = new BasicJobCounter val jobCounter2 = new BasicJobCounter - val bus = new LiveListenerBus + sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) + val bus = new LiveListenerBus(sc) // Propagate events to bad listener first bus.addListener(badListener) bus.addListener(jobCounter1) bus.addListener(jobCounter2) - bus.start(sc) + bus.start() // Post events to all listeners, and wait until the queue is drained (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9aca4dbc23644..9eda79ace18d0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -146,14 +146,13 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") - val param = AccumulatorParam.LongAccumulatorParam // Create 2 accumulators, one that counts failed values and another that doesn't - val acc1 = new Accumulator(0L, param, Some("x"), countFailedValues = true) - val acc2 = new Accumulator(0L, param, Some("y"), countFailedValues = false) + val acc1 = AccumulatorSuite.createLongAccum("x", true) + val acc2 = AccumulatorSuite.createLongAccum("y", false) // Fail first 3 attempts of every task. This means each task should be run 4 times. sc.parallelize(1 to 10, 10).map { i => - acc1 += 1 - acc2 += 1 + acc1.add(1) + acc2.add(1) if (TaskContext.get.attemptNumber() <= 2) { throw new Exception("you did something wrong") } else { @@ -168,8 +167,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("failed tasks collect only accumulators whose values count during failures") { sc = new SparkContext("local", "test") - val acc1 = AccumulatorSuite.createLongAccum("x", true) - val acc2 = AccumulatorSuite.createLongAccum("y", false) + val acc1 = AccumulatorSuite.createLongAccum("x", false) + val acc2 = AccumulatorSuite.createLongAccum("y", true) + acc1.add(1) + acc2.add(1) // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. val taskMetrics = TaskMetrics.empty @@ -185,12 +186,33 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark } // First, simulate task success. This should give us all the accumulators. val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false) - val accumUpdates2 = taskMetrics.internalAccums ++ Seq(acc1, acc2) - TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2) + TaskMetricsSuite.assertUpdatesEquals(accumUpdates1.takeRight(2), Seq(acc1, acc2)) // Now, simulate task failures. This should give us only the accums that count failed values. - val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true) - val accumUpdates4 = taskMetrics.internalAccums ++ Seq(acc1) - TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4) + val accumUpdates2 = task.collectAccumulatorUpdates(taskFailed = true) + TaskMetricsSuite.assertUpdatesEquals(accumUpdates2.takeRight(1), Seq(acc2)) + } + + test("only updated internal accumulators will be sent back to driver") { + sc = new SparkContext("local", "test") + // Create a dummy task. We won't end up running this; we just want to collect + // accumulator updates from it. + val taskMetrics = TaskMetrics.empty + val task = new Task[Int](0, 0, 0) { + context = new TaskContextImpl(0, 0, 0L, 0, + new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), + new Properties, + SparkEnv.get.metricsSystem, + taskMetrics) + taskMetrics.incMemoryBytesSpilled(10) + override def runTask(tc: TaskContext): Int = 0 + } + val updatedAccums = task.collectAccumulatorUpdates() + assert(updatedAccums.length == 2) + // the RESULT_SIZE accumulator will be sent back anyway. + assert(updatedAccums(0).name == Some(InternalAccumulator.RESULT_SIZE)) + assert(updatedAccums(0).value == 0) + assert(updatedAccums(1).name == Some(InternalAccumulator.MEMORY_BYTES_SPILLED)) + assert(updatedAccums(1).value == 10) } test("localProperties are propagated to executors correctly") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 339fc4254d53a..b98f945bac253 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.util.ManualClock +import org.apache.spark.util.{AccumulatorV2, ManualClock} class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) extends DAGScheduler(sc) { @@ -37,14 +37,14 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Seq[NewAccumulator[_, _]], + accumUpdates: Seq[AccumulatorV2[_, _]], taskInfo: TaskInfo) { taskScheduler.endedTasks(taskInfo.index) = reason } override def executorAdded(execId: String, host: String) {} - override def executorLost(execId: String) {} + override def executorLost(execId: String, reason: ExecutorLossReason) {} override def taskSetFailed( taskSet: TaskSet, @@ -184,7 +184,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) - val accumUpdatesByTask: Array[Seq[NewAccumulator[_, _]]] = taskSet.tasks.map { task => + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => task.metrics.internalAccums } @@ -787,11 +787,13 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(TaskLocation("host1") === HostTaskLocation("host1")) assert(TaskLocation("hdfs_cache_host1") === HDFSCacheTaskLocation("host1")) assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3")) + assert(TaskLocation("executor_some.host1_executor_task_3") === + ExecutorCacheTaskLocation("some.host1", "executor_task_3")) } private def createTaskResult( id: Int, - accumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty): DirectTaskResult[Int] = { + accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala similarity index 82% rename from core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala rename to core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index b18f0eb162b1d..3ffbe70e76bbb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark.scheduler.cluster.mesos import java.util.Collections +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.concurrent.Promise +import scala.reflect.ClassTag import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ @@ -28,25 +32,33 @@ import org.apache.mesos.Protos.Value.Scalar import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Matchers._ import org.mockito.Mockito._ +import org.scalatest.concurrent.ScalaFutures import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor import org.apache.spark.scheduler.TaskSchedulerImpl -class CoarseMesosSchedulerBackendSuite extends SparkFunSuite +class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar - with BeforeAndAfter { + with BeforeAndAfter + with ScalaFutures { private var sparkConf: SparkConf = _ private var driver: SchedulerDriver = _ private var taskScheduler: TaskSchedulerImpl = _ - private var backend: CoarseMesosSchedulerBackend = _ + private var backend: MesosCoarseGrainedSchedulerBackend = _ private var externalShuffleClient: MesosExternalShuffleClient = _ private var driverEndpoint: RpcEndpointRef = _ + @volatile private var stopCalled = false + + // All 'requests' to the scheduler run immediately on the same thread, so + // demand that all futures have their value available immediately. + implicit override val patienceConfig = PatienceConfig(timeout = Duration(0, TimeUnit.SECONDS)) test("mesos supports killing and limiting executors") { setBackend() @@ -62,8 +74,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite verifyTaskLaunched("o1") // kills executors - backend.doRequestTotalExecutors(0) - assert(backend.doKillExecutors(Seq("0"))) + assert(backend.doRequestTotalExecutors(0).futureValue) + assert(backend.doKillExecutors(Seq("0")).futureValue) val taskID0 = createTaskId("0") verify(driver, times(1)).killTask(taskID0) @@ -147,6 +159,19 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite verifyDeclinedOffer(driver, createOfferId("o1"), true) } + test("mesos declines offers with a filter when reached spark.cores.max") { + val maxCores = 3 + setBackend(Map("spark.cores.max" -> maxCores.toString)) + + val executorMemory = backend.executorMemory(sc) + offerResources(List( + (executorMemory, maxCores + 1), + (executorMemory, maxCores + 1))) + + verifyTaskLaunched("o1") + verifyDeclinedOffer(driver, createOfferId("o2"), true) + } + test("mesos assigns tasks round-robin on offers") { val executorCores = 4 val maxCores = executorCores * 2 @@ -217,7 +242,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) val securityManager = mock[SecurityManager] - val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { + val backend = new MesosCoarseGrainedSchedulerBackend( + taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( masterUrl: String, scheduler: Scheduler, @@ -238,6 +264,32 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite backend.start() } + test("Do not call removeExecutor() after backend is stopped") { + setBackend() + + // launches a task on a valid offer + val offers = List((backend.executorMemory(sc), 1)) + offerResources(offers) + verifyTaskLaunched("o1") + + // launches a thread simulating status update + val statusUpdateThread = new Thread { + override def run(): Unit = { + while (!stopCalled) { + Thread.sleep(100) + } + + val status = createTaskStatus("0", "s1", TaskState.TASK_FINISHED) + backend.statusUpdate(driver, status) + } + }.start + + backend.stop() + // Any method of the backend involving sending messages to the driver endpoint should not + // be called after the backend is stopped. + verify(driverEndpoint, never()).askWithRetry(isA(classOf[RemoveExecutor]))(any[ClassTag[_]]) + } + private def verifyDeclinedOffer(driver: SchedulerDriver, offerId: OfferID, filter: Boolean = false): Unit = { @@ -310,10 +362,11 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite taskScheduler: TaskSchedulerImpl, driver: SchedulerDriver, shuffleClient: MesosExternalShuffleClient, - endpoint: RpcEndpointRef): CoarseMesosSchedulerBackend = { + endpoint: RpcEndpointRef): MesosCoarseGrainedSchedulerBackend = { val securityManager = mock[SecurityManager] - val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) { + val backend = new MesosCoarseGrainedSchedulerBackend( + taskScheduler, sc, "master", securityManager) { override protected def createSchedulerDriver( masterUrl: String, scheduler: Scheduler, @@ -335,6 +388,10 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite mesosDriver = newDriver } + override def stopExecutors(): Unit = { + stopCalled = true + } + markRegistered() } backend.start() @@ -362,6 +419,7 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite when(taskScheduler.sc).thenReturn(sc) externalShuffleClient = mock[MesosExternalShuffleClient] driverEndpoint = mock[RpcEndpointRef] + when(driverEndpoint.ask(any())(any())).thenReturn(Promise().future) backend = createSchedulerBackend(taskScheduler, driver, externalShuffleClient, driverEndpoint) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala similarity index 95% rename from core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala rename to core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 7d6b7bde68253..41693b1191a3c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -40,7 +40,8 @@ import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.ExecutorInfo -class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { +class MesosFineGrainedSchedulerBackendSuite + extends SparkFunSuite with LocalSparkContext with MockitoSugar { test("weburi is set in created scheduler driver") { val conf = new SparkConf @@ -56,7 +57,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val driver = mock[SchedulerDriver] when(driver.start()).thenReturn(Protos.Status.DRIVER_RUNNING) - val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") { + val backend = new MesosFineGrainedSchedulerBackend(taskScheduler, sc, "master") { override protected def createSchedulerDriver( masterUrl: String, scheduler: Scheduler, @@ -96,7 +97,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val mesosSchedulerBackend = new MesosFineGrainedSchedulerBackend(taskScheduler, sc, "master") val resources = Arrays.asList( mesosSchedulerBackend.createResource("cpus", 4), @@ -127,7 +128,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val taskScheduler = mock[TaskSchedulerImpl] when(taskScheduler.CPUS_PER_TASK).thenReturn(2) - val mesosSchedulerBackend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val mesosSchedulerBackend = new MesosFineGrainedSchedulerBackend(taskScheduler, sc, "master") val resources = Arrays.asList( mesosSchedulerBackend.createResource("cpus", 4), @@ -163,7 +164,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(sc.conf).thenReturn(conf) when(sc.listenerBus).thenReturn(listenerBus) - val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val backend = new MesosFineGrainedSchedulerBackend(taskScheduler, sc, "master") val (execInfo, _) = backend.createExecutorInfo( Arrays.asList(backend.createResource("cpus", 4)), "mockExecutor") @@ -222,7 +223,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi when(sc.conf).thenReturn(new SparkConf) when(sc.listenerBus).thenReturn(listenerBus) - val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val backend = new MesosFineGrainedSchedulerBackend(taskScheduler, sc, "master") val minMem = backend.executorMemory(sc) val minCpu = 4 @@ -333,7 +334,7 @@ class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext wi val mesosOffers = new java.util.ArrayList[Offer] mesosOffers.add(offer) - val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val backend = new MesosFineGrainedSchedulerBackend(taskScheduler, sc, "master") val expectedWorkerOffers = new ArrayBuffer[WorkerOffer](1) expectedWorkerOffers.append(new WorkerOffer( diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index f019b1e25900b..912a516dff0f4 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -126,7 +126,11 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { assert(find(new SerializableClassWithWriteReplace(new SerializableClass1)).isEmpty) } - test("object containing writeObject() and not serializable field") { + test("no infinite loop with writeReplace() which returns class of its own type") { + assert(find(new SerializableClassWithRecursiveWriteReplace).isEmpty) + } + + test("object containing writeObject() and not serializable field") { val s = find(new SerializableClassWithWriteObject(new NotSerializable)) assert(s.size === 3) assert(s(0).contains("NotSerializable")) @@ -229,6 +233,13 @@ class SerializableClassWithWriteReplace(@(transient @param) replacementFieldObje } +class SerializableClassWithRecursiveWriteReplace extends Serializable { + private def writeReplace(): Object = { + new SerializableClassWithRecursiveWriteReplace + } +} + + class ExternalizableClass(objectField: Object) extends java.io.Externalizable { val serializableObjectField = new SerializableClass1 diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala index d223af1496a4b..1bfb0c1547ec4 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.status.api.v1 import java.util.Date -import scala.collection.mutable.HashMap +import scala.collection.mutable.LinkedHashMap import org.apache.spark.SparkFunSuite import org.apache.spark.scheduler.{StageInfo, TaskInfo, TaskLocality} @@ -28,10 +28,10 @@ import org.apache.spark.ui.jobs.UIData.{StageUIData, TaskUIData} class AllStagesResourceSuite extends SparkFunSuite { def getFirstTaskLaunchTime(taskLaunchTimes: Seq[Long]): Option[Date] = { - val tasks = new HashMap[Long, TaskUIData] + val tasks = new LinkedHashMap[Long, TaskUIData] taskLaunchTimes.zipWithIndex.foreach { case (time, idx) => - tasks(idx.toLong) = new TaskUIData( - new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None, None) + tasks(idx.toLong) = TaskUIData( + new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None) } val stageUiData = new StageUIData() diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala index 63b0e77629dde..18baeb1cb9c71 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/SimpleDateParamSuite.scala @@ -26,7 +26,8 @@ class SimpleDateParamSuite extends SparkFunSuite with Matchers { test("date parsing") { new SimpleDateParam("2015-02-20T23:21:17.190GMT").timestamp should be (1424474477190L) - new SimpleDateParam("2015-02-20T17:21:17.190EST").timestamp should be (1424470877190L) + // don't use EST, it is ambiguous, use -0500 instead, see SPARK-15723 + new SimpleDateParam("2015-02-20T17:21:17.190-0500").timestamp should be (1424470877190L) new SimpleDateParam("2015-02-20").timestamp should be (1424390400000L) // GMT intercept[WebApplicationException] { new SimpleDateParam("invalid date") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 9ee83b76e71dc..1b325801e27fc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -208,16 +208,14 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("cannot call lockForWriting while already holding a write lock") { + test("cannot grab a writer lock while already holding a write lock") { withTaskId(0) { assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo())) blockInfoManager.unlock("block") } withTaskId(1) { assert(blockInfoManager.lockForWriting("block").isDefined) - intercept[IllegalStateException] { - blockInfoManager.lockForWriting("block") - } + assert(blockInfoManager.lockForWriting("block", false).isEmpty) blockInfoManager.assertBlockIsLockedForWriting("block") } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d14728cb50555..b9e3a364ee221 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -37,13 +38,17 @@ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ -class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with BeforeAndAfter { +class BlockManagerReplicationSuite extends SparkFunSuite + with Matchers + with BeforeAndAfter + with LocalSparkContext { private val conf = new SparkConf(false).set("spark.app.id", "test") private var rpcEnv: RpcEnv = null private var master: BlockManagerMaster = null private val securityMgr = new SecurityManager(conf) - private val mapOutputTracker = new MapOutputTrackerMaster(conf) + private val bcastManager = new BroadcastManager(true, conf, securityMgr) + private val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) private val shuffleManager = new SortShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped @@ -89,8 +94,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo // to make cached peers refresh frequently conf.set("spark.storage.cachedPeersTtl", "10") + sc = new SparkContext("local", "test", conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", - new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) + new BlockManagerMasterEndpoint(rpcEnv, true, conf, + new LiveListenerBus(sc))), conf, true) allStores.clear() } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index db1efaf2a20b8..1b3197a6921bd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} @@ -48,7 +49,7 @@ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with PrivateMethodTester with ResetSystemProperties { + with PrivateMethodTester with LocalSparkContext with ResetSystemProperties { import BlockManagerSuite._ @@ -59,7 +60,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(new SparkConf(false)) - val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false)) + val bcastManager = new BroadcastManager(true, new SparkConf(false), securityMgr) + val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false), bcastManager, true) val shuffleManager = new SortShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test @@ -105,8 +107,13 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) + // Mock SparkContext to reduce the memory usage of tests. It's fine since the only reason we + // need to create a SparkContext is to initialize LiveListenerBus. + sc = mock(classOf[SparkContext]) + when(sc.conf).thenReturn(conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", - new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) + new BlockManagerMasterEndpoint(rpcEnv, true, conf, + new LiveListenerBus(sc))), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -509,10 +516,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getRemoteBytes("list1").isDefined, "list1Get expected to be fetched") store3.stop() store3 = null - // exception throw because there is no locations - intercept[BlockFetchException] { - store.getRemoteBytes("list1") - } + // Should return None instead of throwing an exception: + assert(store.getRemoteBytes("list1").isEmpty) } test("SPARK-14252: getOrElseUpdate should still read from remote storage") { @@ -859,6 +864,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memoryManager.setMemoryStore(store.memoryStore) + store.initialize("app-id") // The put should fail since a1 is not serializable. class UnserializableClass @@ -1137,14 +1143,52 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingle("a3").isDefined, "a3 was not in store") } + private def testReadWithLossOfOnDiskFiles( + storageLevel: StorageLevel, + readMethod: BlockManager => Option[_]): Unit = { + store = makeBlockManager(12000) + assert(store.putSingle("blockId", new Array[Byte](4000), storageLevel)) + assert(store.getStatus("blockId").isDefined) + // Directly delete all files from the disk store, triggering failures when reading blocks: + store.diskBlockManager.getAllFiles().foreach(_.delete()) + // The BlockManager still thinks that these blocks exist: + assert(store.getStatus("blockId").isDefined) + // Because the BlockManager's metadata claims that the block exists (i.e. that it's present + // in at least one store), the read attempts to read it and fails when the on-disk file is + // missing. + intercept[SparkException] { + readMethod(store) + } + // Subsequent read attempts will succeed; the block isn't present but we return an expected + // "block not found" response rather than a fatal error: + assert(readMethod(store).isEmpty) + // The reason why this second read succeeded is because the metadata entry for the missing + // block was removed as a result of the read failure: + assert(store.getStatus("blockId").isEmpty) + } + + test("remove block if a read fails due to missing DiskStore files (SPARK-15736)") { + val storageLevels = Seq( + StorageLevel(useDisk = true, useMemory = false, deserialized = false), + StorageLevel(useDisk = true, useMemory = false, deserialized = true)) + val readMethods = Map[String, BlockManager => Option[_]]( + "getLocalBytes" -> ((m: BlockManager) => m.getLocalBytes("blockId")), + "getLocalValues" -> ((m: BlockManager) => m.getLocalValues("blockId")) + ) + testReadWithLossOfOnDiskFiles(StorageLevel.DISK_ONLY, _.getLocalBytes("blockId")) + for ((readMethodName, readMethod) <- readMethods; storageLevel <- storageLevels) { + withClue(s"$readMethodName $storageLevel") { + testReadWithLossOfOnDiskFiles(storageLevel, readMethod) + } + } + } + test("SPARK-13328: refresh block locations (fetch should fail after hitting a threshold)") { val mockBlockTransferService = new MockBlockTransferService(conf.getInt("spark.block.failures.beforeLocationRefresh", 5)) store = makeBlockManager(8000, "executor1", transferService = Option(mockBlockTransferService)) store.putSingle("item", 999L, StorageLevel.MEMORY_ONLY, tellMaster = true) - intercept[BlockFetchException] { - store.getRemoteBytes("item") - } + assert(store.getRemoteBytes("item").isEmpty) } test("SPARK-13328: refresh block locations (fetch should succeed after location refresh)") { @@ -1166,6 +1210,39 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE verify(mockBlockManagerMaster, times(2)).getLocations("item") } + test("SPARK-17484: block status is properly updated following an exception in put()") { + val mockBlockTransferService = new MockBlockTransferService(maxFailures = 10) { + override def uploadBlock( + hostname: String, + port: Int, execId: String, + blockId: BlockId, + blockData: ManagedBuffer, + level: StorageLevel, + classTag: ClassTag[_]): Future[Unit] = { + throw new InterruptedException("Intentional interrupt") + } + } + store = makeBlockManager(8000, "executor1", transferService = Option(mockBlockTransferService)) + store2 = makeBlockManager(8000, "executor2", transferService = Option(mockBlockTransferService)) + intercept[InterruptedException] { + store.putSingle("item", "value", StorageLevel.MEMORY_ONLY_2, tellMaster = true) + } + assert(store.getLocalBytes("item").isEmpty) + assert(master.getLocations("item").isEmpty) + assert(store2.getRemoteBytes("item").isEmpty) + } + + test("SPARK-17484: master block locations are updated following an invalid remote block fetch") { + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + store.putSingle("item", "value", StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(master.getLocations("item").nonEmpty) + store.removeBlock("item", tellMaster = false) + assert(master.getLocations("item").nonEmpty) + assert(store2.getRemoteBytes("item").isEmpty) + assert(master.getLocations("item").isEmpty) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 8eff3c297035d..ec4ef4b2fcbf0 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -53,13 +53,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { assert(writeMetrics.recordsWritten === 1) // Metrics don't update on every write assert(writeMetrics.bytesWritten == 0) - // After 32 writes, metrics should update - for (i <- 0 until 32) { + // After 16384 writes, metrics should update + for (i <- 0 until 16384) { writer.flush() writer.write(Long.box(i), Long.box(i)) } assert(writeMetrics.bytesWritten > 0) - assert(writeMetrics.recordsWritten === 33) + assert(writeMetrics.recordsWritten === 16385) writer.commitAndClose() assert(file.length() == writeMetrics.bytesWritten) } @@ -75,13 +75,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { assert(writeMetrics.recordsWritten === 1) // Metrics don't update on every write assert(writeMetrics.bytesWritten == 0) - // After 32 writes, metrics should update - for (i <- 0 until 32) { + // After 16384 writes, metrics should update + for (i <- 0 until 16384) { writer.flush() writer.write(Long.box(i), Long.box(i)) } assert(writeMetrics.bytesWritten > 0) - assert(writeMetrics.recordsWritten === 33) + assert(writeMetrics.recordsWritten === 16385) writer.revertPartialWritesAndClose() assert(writeMetrics.bytesWritten == 0) assert(writeMetrics.recordsWritten == 0) diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index 145d432afe85e..9e10ee5601480 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -80,6 +80,13 @@ class MemoryStoreSuite (memoryStore, blockInfoManager) } + private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = { + assert(actual.length === expected.length, s"wrong number of values returned in $hint") + expected.iterator.zip(actual.iterator).foreach { case (e, a) => + assert(e === a, s"$hint did not return original values!") + } + } + test("reserve/release unroll memory") { val (memoryStore, _) = makeMemoryStore(12000) assert(memoryStore.currentUnrollMemory === 0) @@ -138,9 +145,7 @@ class MemoryStoreSuite var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any) assert(putResult.isRight) assert(memoryStore.currentUnrollMemoryForThisTask === 0) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } + assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues") blockInfoManager.lockForWriting("unroll") assert(memoryStore.remove("unroll")) blockInfoManager.removeBlock("unroll") @@ -153,9 +158,7 @@ class MemoryStoreSuite assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(memoryStore.contains("someBlock2")) assert(!memoryStore.contains("someBlock1")) - smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) => - assert(e === a, "getValues() did not return original values!") - } + assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues") blockInfoManager.lockForWriting("unroll") assert(memoryStore.remove("unroll")) blockInfoManager.removeBlock("unroll") @@ -168,9 +171,7 @@ class MemoryStoreSuite assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(!memoryStore.contains("someBlock2")) assert(putResult.isLeft) - bigList.iterator.zip(putResult.left.get).foreach { case (e, a) => - assert(e === a, "putIterator() did not return original values!") - } + assertSameContents(bigList, putResult.left.get.toSeq, "putIterator") // The unroll memory was freed once the iterator returned by putIterator() was fully traversed. assert(memoryStore.currentUnrollMemoryForThisTask === 0) } @@ -317,9 +318,8 @@ class MemoryStoreSuite assert(res.isLeft) assert(memoryStore.currentUnrollMemoryForThisTask > 0) val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization - valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) => - assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!") - } + assertSameContents( + bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()") // The unroll memory was freed once the iterator was fully traversed. assert(memoryStore.currentUnrollMemoryForThisTask === 0) } @@ -341,12 +341,10 @@ class MemoryStoreSuite res.left.get.finishWritingToStream(bos) // The unroll memory was freed once the block was fully written. assert(memoryStore.currentUnrollMemoryForThisTask === 0) - val deserializationStream = serializerManager.dataDeserializeStream[Any]( - "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any) - deserializationStream.zip(bigList.iterator).foreach { case (e, a) => - assert(e === a, - "PartiallySerializedBlock.finishWritingtoStream() did not write original values!") - } + val deserializedValues = serializerManager.dataDeserializeStream[Any]( + "b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq + assertSameContents( + bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()") } test("multiple unrolls by the same thread") { diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala new file mode 100644 index 0000000000000..ec4f2637fadd0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + +import org.mockito.Mockito +import org.mockito.Mockito.atLeastOnce +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} + +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} +import org.apache.spark.memory.MemoryMode +import org.apache.spark.serializer.{JavaSerializer, SerializationStream, SerializerManager} +import org.apache.spark.storage.memory.{MemoryStore, PartiallySerializedBlock, RedirectableOutputStream} +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream} +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} + +class PartiallySerializedBlockSuite + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester { + + private val blockId = new TestBlockId("test") + private val conf = new SparkConf() + private val memoryStore = Mockito.mock(classOf[MemoryStore], Mockito.RETURNS_SMART_NULLS) + private val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) + + private val getSerializationStream = PrivateMethod[SerializationStream]('serializationStream) + private val getRedirectableOutputStream = + PrivateMethod[RedirectableOutputStream]('redirectableOutputStream) + + override protected def beforeEach(): Unit = { + super.beforeEach() + Mockito.reset(memoryStore) + } + + private def partiallyUnroll[T: ClassTag]( + iter: Iterator[T], + numItemsToBuffer: Int): PartiallySerializedBlock[T] = { + + val bbos: ChunkedByteBufferOutputStream = { + val spy = Mockito.spy(new ChunkedByteBufferOutputStream(128, ByteBuffer.allocate)) + Mockito.doAnswer(new Answer[ChunkedByteBuffer] { + override def answer(invocationOnMock: InvocationOnMock): ChunkedByteBuffer = { + Mockito.spy(invocationOnMock.callRealMethod().asInstanceOf[ChunkedByteBuffer]) + } + }).when(spy).toChunkedByteBuffer + spy + } + + val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() + val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream) + redirectableOutputStream.setOutputStream(bbos) + val serializationStream = Mockito.spy(serializer.serializeStream(redirectableOutputStream)) + + (1 to numItemsToBuffer).foreach { _ => + assert(iter.hasNext) + serializationStream.writeObject[T](iter.next()) + } + + val unrollMemory = bbos.size + new PartiallySerializedBlock[T]( + memoryStore, + serializerManager, + blockId, + serializationStream = serializationStream, + redirectableOutputStream, + unrollMemory = unrollMemory, + memoryMode = MemoryMode.ON_HEAP, + bbos, + rest = iter, + classTag = implicitly[ClassTag[T]]) + } + + test("valuesIterator() and finishWritingToStream() cannot be called after discard() is called") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.discard() + intercept[IllegalStateException] { + partiallySerializedBlock.finishWritingToStream(null) + } + intercept[IllegalStateException] { + partiallySerializedBlock.valuesIterator + } + } + + test("discard() can be called more than once") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.discard() + partiallySerializedBlock.discard() + } + + test("cannot call valuesIterator() more than once") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.valuesIterator + intercept[IllegalStateException] { + partiallySerializedBlock.valuesIterator + } + } + + test("cannot call finishWritingToStream() more than once") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + intercept[IllegalStateException] { + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + } + } + + test("cannot call finishWritingToStream() after valuesIterator()") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.valuesIterator + intercept[IllegalStateException] { + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + } + } + + test("cannot call valuesIterator() after finishWritingToStream()") { + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + partiallySerializedBlock.finishWritingToStream(new ByteBufferOutputStream()) + intercept[IllegalStateException] { + partiallySerializedBlock.valuesIterator + } + } + + test("buffers are deallocated in a TaskCompletionListener") { + try { + TaskContext.setTaskContext(TaskContext.empty()) + val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted() + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose() + Mockito.verifyNoMoreInteractions(memoryStore) + } finally { + TaskContext.unset() + } + } + + private def testUnroll[T: ClassTag]( + testCaseName: String, + items: Seq[T], + numItemsToBuffer: Int): Unit = { + + test(s"$testCaseName with discard() and numBuffered = $numItemsToBuffer") { + val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer) + partiallySerializedBlock.discard() + + Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask( + MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory) + Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close() + Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close() + Mockito.verifyNoMoreInteractions(memoryStore) + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() + } + + test(s"$testCaseName with finishWritingToStream() and numBuffered = $numItemsToBuffer") { + val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer) + val bbos = Mockito.spy(new ByteBufferOutputStream()) + partiallySerializedBlock.finishWritingToStream(bbos) + + Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask( + MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory) + Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close() + Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close() + Mockito.verify(bbos).close() + Mockito.verifyNoMoreInteractions(memoryStore) + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() + + val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() + val deserialized = + serializer.deserializeStream(new ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq + assert(deserialized === items) + } + + test(s"$testCaseName with valuesIterator() and numBuffered = $numItemsToBuffer") { + val partiallySerializedBlock = partiallyUnroll(items.iterator, numItemsToBuffer) + val valuesIterator = partiallySerializedBlock.valuesIterator + Mockito.verify(partiallySerializedBlock.invokePrivate(getSerializationStream())).close() + Mockito.verify(partiallySerializedBlock.invokePrivate(getRedirectableOutputStream())).close() + + val deserializedItems = valuesIterator.toArray.toSeq + Mockito.verify(memoryStore).releaseUnrollMemoryForThisTask( + MemoryMode.ON_HEAP, partiallySerializedBlock.unrollMemory) + Mockito.verifyNoMoreInteractions(memoryStore) + Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() + assert(deserializedItems === items) + } + } + + testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 50) + testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 0) + testUnroll("basic numbers", 1 to 1000, numItemsToBuffer = 1000) + testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 50) + testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 0) + testUnroll("case classes", (1 to 1000).map(x => MyCaseClass(x.toString)), numItemsToBuffer = 1000) + testUnroll("empty iterator", Seq.empty[String], numItemsToBuffer = 0) +} + +private case class MyCaseClass(str: String) diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala new file mode 100644 index 0000000000000..4253cc8ca4cd1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.memory.MemoryMode.ON_HEAP +import org.apache.spark.storage.memory.{MemoryStore, PartiallyUnrolledIterator} + +class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar { + test("join two iterators") { + val unrollSize = 1000 + val unroll = (0 until unrollSize).iterator + val restSize = 500 + val rest = (unrollSize until restSize + unrollSize).iterator + + val memoryStore = mock[MemoryStore] + val joinIterator = new PartiallyUnrolledIterator(memoryStore, ON_HEAP, unrollSize, unroll, rest) + + // Firstly iterate over unrolling memory iterator + (0 until unrollSize).foreach { value => + assert(joinIterator.hasNext) + assert(joinIterator.hasNext) + assert(joinIterator.next() == value) + } + + joinIterator.hasNext + joinIterator.hasNext + verify(memoryStore, times(1)) + .releaseUnrollMemoryForThisTask(Matchers.eq(ON_HEAP), Matchers.eq(unrollSize.toLong)) + + // Secondly, iterate over rest iterator + (unrollSize until unrollSize + restSize).foreach { value => + assert(joinIterator.hasNext) + assert(joinIterator.hasNext) + assert(joinIterator.next() == value) + } + + joinIterator.close() + // MemoryMode.releaseUnrollMemoryForThisTask is called only once + verifyNoMoreInteractions(memoryStore) + } +} diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index ce7d51d1c371b..1fa9b28edf4be 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -25,7 +25,8 @@ import org.apache.spark._ import org.apache.spark.{LocalSparkContext, SparkConf, Success} import org.apache.spark.executor._ import org.apache.spark.scheduler._ -import org.apache.spark.util.Utils +import org.apache.spark.ui.jobs.UIData.TaskUIData +import org.apache.spark.util.{AccumulatorContext, Utils} class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers { @@ -359,4 +360,30 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert( stage1Data.taskData.get(1237L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 402) } + + test("drop internal and sql accumulators") { + val taskInfo = new TaskInfo(0, 0, 0, 0, "", "", TaskLocality.ANY, false) + val internalAccum = + AccumulableInfo(id = 1, name = Some("internal"), None, None, true, false, None) + val sqlAccum = AccumulableInfo( + id = 2, + name = Some("sql"), + update = None, + value = None, + internal = false, + countFailedValues = false, + metadata = Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) + val userAccum = AccumulableInfo( + id = 3, + name = Some("user"), + update = None, + value = None, + internal = false, + countFailedValues = false, + metadata = None) + taskInfo.accumulables ++= Seq(internalAccum, sqlAccum, userAccum) + + val newTaskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo) + assert(newTaskInfo.accumulables === Seq(userAccum)) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 7d77deeb60618..f6c8418ba3ac4 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.ui.storage import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkFunSuite, Success} -import org.apache.spark.executor.TaskMetrics +import org.apache.spark._ import org.apache.spark.scheduler._ import org.apache.spark.storage._ /** * Test various functionality in the StorageListener that supports the StorageTab. */ -class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { +class StorageTabSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfter { private var bus: LiveListenerBus = _ private var storageStatusListener: StorageStatusListener = _ private var storageListener: StorageListener = _ @@ -43,8 +42,10 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { private val bm1 = BlockManagerId("big", "dog", 1) before { - bus = new LiveListenerBus - storageStatusListener = new StorageStatusListener(new SparkConf()) + val conf = new SparkConf() + sc = new SparkContext("local", "test", conf) + bus = new LiveListenerBus(sc) + storageStatusListener = new StorageStatusListener(conf) storageListener = new StorageListener(storageStatusListener) bus.addListener(storageStatusListener) bus.addListener(storageListener) @@ -179,6 +180,23 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { assert(storageListener.rddInfoList.size === 2) } + test("verify StorageTab still contains a renamed RDD") { + val rddInfo = new RDDInfo(0, "original_name", 1, memOnly, Seq(4)) + val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo), Seq.empty, "details") + bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) + bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) + val blockUpdateInfos1 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(0, 1), memOnly, 100L, 0L)) + postUpdateBlocks(bus, blockUpdateInfos1) + assert(storageListener.rddInfoList.size == 1) + + val newName = "new_name" + val rddInfoRenamed = new RDDInfo(0, newName, 1, memOnly, Seq(4)) + val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfoRenamed), Seq.empty, "details") + bus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + assert(storageListener.rddInfoList.size == 1) + assert(storageListener.rddInfoList.head.name == newName) + } + private def postUpdateBlocks( bus: SparkListenerBus, blockUpdateInfos: Seq[BlockUpdatedInfo]): Unit = { blockUpdateInfos.foreach { blockUpdateInfo => diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala new file mode 100644 index 0000000000000..a04644d57ed88 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark._ + +class AccumulatorV2Suite extends SparkFunSuite { + + test("LongAccumulator add/avg/sum/count/isZero") { + val acc = new LongAccumulator + assert(acc.isZero) + assert(acc.count == 0) + assert(acc.sum == 0) + assert(acc.avg.isNaN) + + acc.add(0) + assert(!acc.isZero) + assert(acc.count == 1) + assert(acc.sum == 0) + assert(acc.avg == 0.0) + + acc.add(1) + assert(acc.count == 2) + assert(acc.sum == 1) + assert(acc.avg == 0.5) + + // Also test add using non-specialized add function + acc.add(new java.lang.Long(2)) + assert(acc.count == 3) + assert(acc.sum == 3) + assert(acc.avg == 1.0) + + // Test merging + val acc2 = new LongAccumulator + acc2.add(2) + acc.merge(acc2) + assert(acc.count == 4) + assert(acc.sum == 5) + assert(acc.avg == 1.25) + } + + test("DoubleAccumulator add/avg/sum/count/isZero") { + val acc = new DoubleAccumulator + assert(acc.isZero) + assert(acc.count == 0) + assert(acc.sum == 0.0) + assert(acc.avg.isNaN) + + acc.add(0.0) + assert(!acc.isZero) + assert(acc.count == 1) + assert(acc.sum == 0.0) + assert(acc.avg == 0.0) + + acc.add(1.0) + assert(acc.count == 2) + assert(acc.sum == 1.0) + assert(acc.avg == 0.5) + + // Also test add using non-specialized add function + acc.add(new java.lang.Double(2.0)) + assert(acc.count == 3) + assert(acc.sum == 3.0) + assert(acc.avg == 1.0) + + // Test merging + val acc2 = new DoubleAccumulator + acc2.add(2.0) + acc.merge(acc2) + assert(acc.count == 4) + assert(acc.sum == 5.0) + assert(acc.avg == 1.25) + } + + test("ListAccumulator") { + val acc = new CollectionAccumulator[Double] + assert(acc.value.isEmpty) + assert(acc.isZero) + + acc.add(0.0) + assert(acc.value.contains(0.0)) + assert(!acc.isZero) + + acc.add(new java.lang.Double(1.0)) + + val acc2 = acc.copyAndReset() + assert(acc2.value.isEmpty) + assert(acc2.isZero) + + assert(acc.value.contains(1.0)) + assert(!acc.isZero) + assert(acc.value.size() === 2) + + acc2.add(2.0) + assert(acc2.value.contains(2.0)) + assert(!acc2.isZero) + assert(acc2.value.size() === 1) + + // Test merging + acc.merge(acc2) + assert(acc.value.contains(2.0)) + assert(!acc.isZero) + assert(acc.value.size() === 3) + + val acc3 = acc.copy() + assert(acc3.value.contains(2.0)) + assert(!acc3.isZero) + assert(acc3.value.size() === 3) + + acc3.reset() + assert(acc3.isZero) + assert(acc3.value.isEmpty) + } + + test("LegacyAccumulatorWrapper") { + val acc = new LegacyAccumulatorWrapper("default", AccumulatorParam.StringAccumulatorParam) + assert(acc.value === "default") + assert(!acc.isZero) + + acc.add("foo") + assert(acc.value === "foo") + assert(!acc.isZero) + + acc.add(new java.lang.String("bar")) + + val acc2 = acc.copyAndReset() + assert(acc2.value === "") + assert(acc2.isZero) + + assert(acc.value === "bar") + assert(!acc.isZero) + + acc2.add("baz") + assert(acc2.value === "baz") + assert(!acc2.isZero) + + // Test merging + acc.merge(acc2) + assert(acc.value === "baz") + assert(!acc.isZero) + + val acc3 = acc.copy() + assert(acc3.value === "baz") + assert(!acc3.isZero) + + acc3.reset() + assert(acc3.isZero) + assert(acc3.value === "") + } +} diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 4fa9f9a8f590f..7e2da8e141532 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.util import java.io._ import java.nio.charset.StandardCharsets import java.util.concurrent.CountDownLatch +import java.util.zip.GZIPInputStream import scala.collection.mutable.HashSet import scala.reflect._ import com.google.common.io.Files +import org.apache.commons.io.IOUtils import org.apache.log4j.{Appender, Level, Logger} import org.apache.log4j.spi.LoggingEvent import org.mockito.ArgumentCaptor @@ -72,6 +74,25 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { testRolling(appender, testOutputStream, textToAppend, rolloverIntervalMillis) } + test("rolling file appender - time-based rolling (compressed)") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val rolloverIntervalMillis = 100 + val durationMillis = 1000 + val numRollovers = durationMillis / rolloverIntervalMillis + val textToAppend = (1 to numRollovers).map( _.toString * 10 ) + + val sparkConf = new SparkConf() + sparkConf.set("spark.executor.logs.rolling.enableCompression", "true") + val appender = new RollingFileAppender(testInputStream, testFile, + new TimeBasedRollingPolicy(rolloverIntervalMillis, s"--HH-mm-ss-SSSS", false), + sparkConf, 10) + + testRolling( + appender, testOutputStream, textToAppend, rolloverIntervalMillis, isCompressed = true) + } + test("rolling file appender - size-based rolling") { // setup input stream and appender val testOutputStream = new PipedOutputStream() @@ -89,6 +110,25 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { } } + test("rolling file appender - size-based rolling (compressed)") { + // setup input stream and appender + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000) + val rolloverSize = 1000 + val textToAppend = (1 to 3).map( _.toString * 1000 ) + + val sparkConf = new SparkConf() + sparkConf.set("spark.executor.logs.rolling.enableCompression", "true") + val appender = new RollingFileAppender(testInputStream, testFile, + new SizeBasedRollingPolicy(rolloverSize, false), sparkConf, 99) + + val files = testRolling(appender, testOutputStream, textToAppend, 0, isCompressed = true) + files.foreach { file => + logInfo(file.toString + ": " + file.length + " bytes") + assert(file.length < rolloverSize) + } + } + test("rolling file appender - cleaning") { // setup input stream and appender val testOutputStream = new PipedOutputStream() @@ -273,7 +313,8 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { appender: FileAppender, outputStream: OutputStream, textToAppend: Seq[String], - sleepTimeBetweenTexts: Long + sleepTimeBetweenTexts: Long, + isCompressed: Boolean = false ): Seq[File] = { // send data to appender through the input stream, and wait for the data to be written val expectedText = textToAppend.mkString("") @@ -290,10 +331,23 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { // verify whether all the data written to rolled over files is same as expected val generatedFiles = RollingFileAppender.getSortedRolledOverFiles( testFile.getParentFile.toString, testFile.getName) - logInfo("Filtered files: \n" + generatedFiles.mkString("\n")) + logInfo("Generate files: \n" + generatedFiles.mkString("\n")) assert(generatedFiles.size > 1) + if (isCompressed) { + assert( + generatedFiles.filter(_.getName.endsWith(RollingFileAppender.GZIP_LOG_SUFFIX)).size > 0) + } val allText = generatedFiles.map { file => - Files.toString(file, StandardCharsets.UTF_8) + if (file.getName.endsWith(RollingFileAppender.GZIP_LOG_SUFFIX)) { + val inputStream = new GZIPInputStream(new FileInputStream(file)) + try { + IOUtils.toString(inputStream, StandardCharsets.UTF_8) + } finally { + IOUtils.closeQuietly(inputStream) + } + } else { + Files.toString(file, StandardCharsets.UTF_8) + } }.mkString("") assert(allText === expectedText) generatedFiles diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 6fda7378e6cef..cda3457523937 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.Properties +import scala.collection.JavaConverters._ import scala.collection.Map import org.json4s.jackson.JsonMethods._ @@ -258,7 +259,7 @@ class JsonProtocolSuite extends SparkFunSuite { } test("FetchFailed backwards compatibility") { - // FetchFailed in Spark 1.1.0 does not have an "Message" property. + // FetchFailed in Spark 1.1.0 does not have a "Message" property. val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, "ignored") val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) @@ -415,7 +416,7 @@ class JsonProtocolSuite extends SparkFunSuite { }) testAccumValue(Some(RESULT_SIZE), 3L, JInt(3)) testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2)) - testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson) + testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks.asJava, blocksJson) // For anything else, we just cast the value to a string testAccumValue(Some("anything"), blocks, JString(blocks.toString)) testAccumValue(Some("anything"), 123, JString("123")) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4aa4854c36f3a..b67482aeba334 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, PrintStream} import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} @@ -25,12 +25,15 @@ import java.nio.charset.StandardCharsets import java.text.DecimalFormatSymbols import java.util.Locale import java.util.concurrent.TimeUnit +import java.util.zip.GZIPOutputStream import scala.collection.mutable.ListBuffer import scala.util.Random import com.google.common.io.Files +import org.apache.commons.io.IOUtils import org.apache.commons.lang3.SystemUtils +import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -40,6 +43,14 @@ import org.apache.spark.network.util.ByteUnit class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { + test("truncatedString") { + assert(Utils.truncatedString(Nil, "[", ", ", "]", 2) == "[]") + assert(Utils.truncatedString(Seq(1, 2), "[", ", ", "]", 2) == "[1, 2]") + assert(Utils.truncatedString(Seq(1, 2, 3), "[", ", ", "]", 2) == "[1, ... 2 more fields]") + assert(Utils.truncatedString(Seq(1, 2, 3), "[", ", ", "]", -5) == "[, ... 3 more fields]") + assert(Utils.truncatedString(Seq(1, 2, 3), ", ") == "1, 2, 3") + } + test("timeConversion") { // Test -1 assert(Utils.timeStringAsSeconds("-1") === -1) @@ -265,65 +276,109 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(str(10 * hour + 59 * minute + 59 * second + 999) === "11" + sep + "00 h") } - test("reading offset bytes of a file") { + def getSuffix(isCompressed: Boolean): String = { + if (isCompressed) { + ".gz" + } else { + "" + } + } + + def writeLogFile(path: String, content: Array[Byte]): Unit = { + val outputStream = if (path.endsWith(".gz")) { + new GZIPOutputStream(new FileOutputStream(path)) + } else { + new FileOutputStream(path) + } + IOUtils.write(content, outputStream) + outputStream.close() + content.size + } + + private val workerConf = new SparkConf() + + def testOffsetBytes(isCompressed: Boolean): Unit = { val tmpDir2 = Utils.createTempDir() - val f1Path = tmpDir2 + "/f1" - val f1 = new FileOutputStream(f1Path) - f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(StandardCharsets.UTF_8)) - f1.close() + val suffix = getSuffix(isCompressed) + val f1Path = tmpDir2 + "/f1" + suffix + writeLogFile(f1Path, "1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(StandardCharsets.UTF_8)) + val f1Length = Utils.getFileLength(new File(f1Path), workerConf) // Read first few bytes - assert(Utils.offsetBytes(f1Path, 0, 5) === "1\n2\n3") + assert(Utils.offsetBytes(f1Path, f1Length, 0, 5) === "1\n2\n3") // Read some middle bytes - assert(Utils.offsetBytes(f1Path, 4, 11) === "3\n4\n5\n6") + assert(Utils.offsetBytes(f1Path, f1Length, 4, 11) === "3\n4\n5\n6") // Read last few bytes - assert(Utils.offsetBytes(f1Path, 12, 18) === "7\n8\n9\n") + assert(Utils.offsetBytes(f1Path, f1Length, 12, 18) === "7\n8\n9\n") // Read some nonexistent bytes in the beginning - assert(Utils.offsetBytes(f1Path, -5, 5) === "1\n2\n3") + assert(Utils.offsetBytes(f1Path, f1Length, -5, 5) === "1\n2\n3") // Read some nonexistent bytes at the end - assert(Utils.offsetBytes(f1Path, 12, 22) === "7\n8\n9\n") + assert(Utils.offsetBytes(f1Path, f1Length, 12, 22) === "7\n8\n9\n") // Read some nonexistent bytes on both ends - assert(Utils.offsetBytes(f1Path, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") + assert(Utils.offsetBytes(f1Path, f1Length, -3, 25) === "1\n2\n3\n4\n5\n6\n7\n8\n9\n") Utils.deleteRecursively(tmpDir2) } - test("reading offset bytes across multiple files") { + test("reading offset bytes of a file") { + testOffsetBytes(isCompressed = false) + } + + test("reading offset bytes of a file (compressed)") { + testOffsetBytes(isCompressed = true) + } + + def testOffsetBytesMultipleFiles(isCompressed: Boolean): Unit = { val tmpDir = Utils.createTempDir() - val files = (1 to 3).map(i => new File(tmpDir, i.toString)) - Files.write("0123456789", files(0), StandardCharsets.UTF_8) - Files.write("abcdefghij", files(1), StandardCharsets.UTF_8) - Files.write("ABCDEFGHIJ", files(2), StandardCharsets.UTF_8) + val suffix = getSuffix(isCompressed) + val files = (1 to 3).map(i => new File(tmpDir, i.toString + suffix)) :+ new File(tmpDir, "4") + writeLogFile(files(0).getAbsolutePath, "0123456789".getBytes(StandardCharsets.UTF_8)) + writeLogFile(files(1).getAbsolutePath, "abcdefghij".getBytes(StandardCharsets.UTF_8)) + writeLogFile(files(2).getAbsolutePath, "ABCDEFGHIJ".getBytes(StandardCharsets.UTF_8)) + writeLogFile(files(3).getAbsolutePath, "9876543210".getBytes(StandardCharsets.UTF_8)) + val fileLengths = files.map(Utils.getFileLength(_, workerConf)) // Read first few bytes in the 1st file - assert(Utils.offsetBytes(files, 0, 5) === "01234") + assert(Utils.offsetBytes(files, fileLengths, 0, 5) === "01234") // Read bytes within the 1st file - assert(Utils.offsetBytes(files, 5, 8) === "567") + assert(Utils.offsetBytes(files, fileLengths, 5, 8) === "567") // Read bytes across 1st and 2nd file - assert(Utils.offsetBytes(files, 8, 18) === "89abcdefgh") + assert(Utils.offsetBytes(files, fileLengths, 8, 18) === "89abcdefgh") // Read bytes across 1st, 2nd and 3rd file - assert(Utils.offsetBytes(files, 5, 24) === "56789abcdefghijABCD") + assert(Utils.offsetBytes(files, fileLengths, 5, 24) === "56789abcdefghijABCD") + + // Read bytes across 3rd and 4th file + assert(Utils.offsetBytes(files, fileLengths, 25, 35) === "FGHIJ98765") // Read some nonexistent bytes in the beginning - assert(Utils.offsetBytes(files, -5, 18) === "0123456789abcdefgh") + assert(Utils.offsetBytes(files, fileLengths, -5, 18) === "0123456789abcdefgh") // Read some nonexistent bytes at the end - assert(Utils.offsetBytes(files, 18, 35) === "ijABCDEFGHIJ") + assert(Utils.offsetBytes(files, fileLengths, 18, 45) === "ijABCDEFGHIJ9876543210") // Read some nonexistent bytes on both ends - assert(Utils.offsetBytes(files, -5, 35) === "0123456789abcdefghijABCDEFGHIJ") + assert(Utils.offsetBytes(files, fileLengths, -5, 45) === + "0123456789abcdefghijABCDEFGHIJ9876543210") Utils.deleteRecursively(tmpDir) } + test("reading offset bytes across multiple files") { + testOffsetBytesMultipleFiles(isCompressed = false) + } + + test("reading offset bytes across multiple files (compressed)") { + testOffsetBytesMultipleFiles(isCompressed = true) + } + test("deserialize long value") { val testval : Long = 9730889947L val bbuf = ByteBuffer.allocate(8) @@ -341,6 +396,13 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(Utils.getIteratorSize(iterator) === 5L) } + test("getIteratorZipWithIndex") { + val iterator = Utils.getIteratorZipWithIndex(Iterator(0, 1, 2), -1L + Int.MaxValue) + assert(iterator.toArray === Array( + (0, -1L + Int.MaxValue), (1, 0L + Int.MaxValue), (2, 1L + Int.MaxValue) + )) + } + test("doesDirectoryContainFilesNewerThan") { // create some temporary directories and files val parent: File = Utils.createTempDir() @@ -681,14 +743,37 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(!Utils.isInDirectory(nullFile, childFile3)) } - test("circular buffer") { + test("circular buffer: if nothing was written to the buffer, display nothing") { + val buffer = new CircularBuffer(4) + assert(buffer.toString === "") + } + + test("circular buffer: if the buffer isn't full, print only the contents written") { + val buffer = new CircularBuffer(10) + val stream = new PrintStream(buffer, true, "UTF-8") + stream.print("test") + assert(buffer.toString === "test") + } + + test("circular buffer: data written == size of the buffer") { + val buffer = new CircularBuffer(4) + val stream = new PrintStream(buffer, true, "UTF-8") + + // fill the buffer to its exact size so that it just hits overflow + stream.print("test") + assert(buffer.toString === "test") + + // add more data to the buffer + stream.print("12") + assert(buffer.toString === "st12") + } + + test("circular buffer: multiple overflow") { val buffer = new CircularBuffer(25) - val stream = new java.io.PrintStream(buffer, true, "UTF-8") + val stream = new PrintStream(buffer, true, "UTF-8") - // scalastyle:off println - stream.println("test circular test circular test circular test circular test circular") - // scalastyle:on println - assert(buffer.toString === "t circular test circular\n") + stream.print("test circular test circular test circular test circular test circular") + assert(buffer.toString === "st circular test circular") } test("nanSafeCompareDoubles") { @@ -723,20 +808,40 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { test("isDynamicAllocationEnabled") { val conf = new SparkConf() - conf.set("spark.master", "yarn-client") + conf.set("spark.master", "yarn") + conf.set("spark.submit.deployMode", "client") assert(Utils.isDynamicAllocationEnabled(conf) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.dynamicAllocation.enabled", "false")) === false) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.dynamicAllocation.enabled", "true")) === true) assert(Utils.isDynamicAllocationEnabled( - conf.set("spark.executor.instances", "1")) === false) + conf.set("spark.executor.instances", "1")) === true) assert(Utils.isDynamicAllocationEnabled( conf.set("spark.executor.instances", "0")) === true) assert(Utils.isDynamicAllocationEnabled(conf.set("spark.master", "local")) === false) assert(Utils.isDynamicAllocationEnabled(conf.set("spark.dynamicAllocation.testing", "true"))) } + test("getDynamicAllocationInitialExecutors") { + val conf = new SparkConf() + assert(Utils.getDynamicAllocationInitialExecutors(conf) === 0) + assert(Utils.getDynamicAllocationInitialExecutors( + conf.set("spark.dynamicAllocation.minExecutors", "3")) === 3) + assert(Utils.getDynamicAllocationInitialExecutors( // should use minExecutors + conf.set("spark.executor.instances", "2")) === 3) + assert(Utils.getDynamicAllocationInitialExecutors( // should use executor.instances + conf.set("spark.executor.instances", "4")) === 4) + assert(Utils.getDynamicAllocationInitialExecutors( // should use executor.instances + conf.set("spark.dynamicAllocation.initialExecutors", "3")) === 4) + assert(Utils.getDynamicAllocationInitialExecutors( // should use initialExecutors + conf.set("spark.dynamicAllocation.initialExecutors", "5")) === 5) + assert(Utils.getDynamicAllocationInitialExecutors( // should use minExecutors + conf.set("spark.dynamicAllocation.initialExecutors", "2") + .set("spark.executor.instances", "1")) === 3) + } + + test("encodeFileNameToURIRawPath") { assert(Utils.encodeFileNameToURIRawPath("abc") === "abc") assert(Utils.encodeFileNameToURIRawPath("abc xyz") === "abc%20xyz") @@ -815,7 +920,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(terminated.isDefined) Utils.waitForProcess(process, 5000) val duration = System.currentTimeMillis() - start - assert(duration < 5000) + assert(duration < 6000) // add a little extra time to allow a force kill to finish assert(!pidExists(pid)) } finally { signal(pid, "SIGKILL") @@ -823,4 +928,38 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { } } } + + test("chi square test of randomizeInPlace") { + // Parameters + val arraySize = 10 + val numTrials = 1000 + val threshold = 0.05 + val seed = 1L + + // results(i)(j): how many times Utils.randomize moves an element from position j to position i + val results = Array.ofDim[Long](arraySize, arraySize) + + // This must be seeded because even a fair random process will fail this test with + // probability equal to the value of `threshold`, which is inconvenient for a unit test. + val rand = new java.util.Random(seed) + val range = 0 until arraySize + + for { + _ <- 0 until numTrials + trial = Utils.randomizeInPlace(range.toArray, rand) + i <- range + } results(i)(trial(i)) += 1L + + val chi = new ChiSquareTest() + + // We expect an even distribution; this array will be rescaled by `chiSquareTest` + val expected = Array.fill(arraySize * arraySize)(1.0) + val observed = results.flatten + + // Performs Pearson's chi-squared test. Using the sum-of-squares as the test statistic, gives + // the probability of a uniform distribution producing results as extreme as `observed` + val pValue = chi.chiSquareTest(expected, observed) + + assert(pValue > threshold) + } } diff --git a/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala new file mode 100644 index 0000000000000..aaf79ebd4f9fc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/VersionUtilsSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class VersionUtilsSuite extends SparkFunSuite { + + import org.apache.spark.util.VersionUtils._ + + test("Parse Spark major version") { + assert(majorVersion("2.0") === 2) + assert(majorVersion("12.10.11") === 12) + assert(majorVersion("2.0.1-SNAPSHOT") === 2) + assert(majorVersion("2.0.x") === 2) + withClue("majorVersion parsing should fail for invalid major version number") { + intercept[IllegalArgumentException] { + majorVersion("2z.0") + } + } + withClue("majorVersion parsing should fail for invalid minor version number") { + intercept[IllegalArgumentException] { + majorVersion("2.0z") + } + } + } + + test("Parse Spark minor version") { + assert(minorVersion("2.0") === 0) + assert(minorVersion("12.10.11") === 10) + assert(minorVersion("2.0.1-SNAPSHOT") === 0) + assert(minorVersion("2.0.x") === 0) + withClue("minorVersion parsing should fail for invalid major version number") { + intercept[IllegalArgumentException] { + minorVersion("2z.0") + } + } + withClue("minorVersion parsing should fail for invalid minor version number") { + intercept[IllegalArgumentException] { + minorVersion("2.0z") + } + } + } + + test("Parse Spark major and minor versions") { + assert(majorMinorVersion("2.0") === (2, 0)) + assert(majorMinorVersion("12.10.11") === (12, 10)) + assert(majorMinorVersion("2.0.1-SNAPSHOT") === (2, 0)) + assert(majorMinorVersion("2.0.x") === (2, 0)) + withClue("majorMinorVersion parsing should fail for invalid major version number") { + intercept[IllegalArgumentException] { + majorMinorVersion("2z.0") + } + } + withClue("majorMinorVersion parsing should fail for invalid minor version number") { + intercept[IllegalArgumentException] { + majorMinorVersion("2.0z") + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 4dd8e31c27351..6bcc601e13ecc 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -17,12 +17,17 @@ package org.apache.spark.util.collection +import java.util.Comparator + import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.unsafe.array.LongArray +import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordPointerAndKeyPrefix, UnsafeSortDataFormat} class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { import TestUtils.{assertNotSpilled, assertSpilled} @@ -93,6 +98,27 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { testWithMultipleSer("sort without breaking sorting contracts", loadDefaults = true)( sortWithoutBreakingSortingContracts) + // This test is ignored by default as it requires a fairly large heap size (16GB) + ignore("sort without breaking timsort contracts for large arrays") { + val size = 300000000 + // To manifest the bug observed in SPARK-8428 and SPARK-13850, we explicitly use an array of + // the form [150000000, 150000001, 150000002, ...., 300000000, 0, 1, 2, ..., 149999999] + // that can trigger copyRange() in TimSort.mergeLo() or TimSort.mergeHi() + val ref = Array.tabulate[Long](size) { i => if (i < size / 2) size / 2 + i else i } + val buf = new LongArray(MemoryBlock.fromLongArray(ref)) + val tmp = new Array[Long](size/2) + val tmpBuf = new LongArray(MemoryBlock.fromLongArray(tmp)) + + new Sorter(new UnsafeSortDataFormat(tmpBuf)).sort( + buf, 0, size, new Comparator[RecordPointerAndKeyPrefix] { + override def compare( + r1: RecordPointerAndKeyPrefix, + r2: RecordPointerAndKeyPrefix): Int = { + PrefixComparators.LONG.compare(r1.keyPrefix, r2.keyPrefix) + } + }) + } + test("spilling with hash collisions") { val size = 1000 val conf = createSparkConf(loadDefaults = true, kryo = false) diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index 52428634e5205..1d26d4a8307cf 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray import org.apache.spark.unsafe.memory.MemoryBlock -import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom @@ -94,7 +93,8 @@ class RadixSortSuite extends SparkFunSuite with Logging { } private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { - new Sorter(UnsafeSortDataFormat.INSTANCE).sort( + val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { override def compare( r1: RecordPointerAndKeyPrefix, @@ -184,81 +184,4 @@ class RadixSortSuite extends SparkFunSuite with Logging { assert(res1.view == res2.view) } } - - ignore("microbenchmarks") { - val size = 25000000 - val rand = new XORShiftRandom(123) - val benchmark = new Benchmark("radix sort " + size, size) - benchmark.addTimerCase("reference TimSort key prefix array") { timer => - val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) - timer.startTiming() - referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) - timer.stopTiming() - } - benchmark.addTimerCase("reference Arrays.sort") { timer => - val ref = Array.tabulate[Long](size) { i => rand.nextLong } - timer.startTiming() - Arrays.sort(ref) - timer.stopTiming() - } - benchmark.addTimerCase("radix sort one byte") { timer => - val array = new Array[Long](size * 2) - var i = 0 - while (i < size) { - array(i) = rand.nextLong & 0xff - i += 1 - } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) - timer.startTiming() - RadixSort.sort(buf, size, 0, 7, false, false) - timer.stopTiming() - } - benchmark.addTimerCase("radix sort two bytes") { timer => - val array = new Array[Long](size * 2) - var i = 0 - while (i < size) { - array(i) = rand.nextLong & 0xffff - i += 1 - } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) - timer.startTiming() - RadixSort.sort(buf, size, 0, 7, false, false) - timer.stopTiming() - } - benchmark.addTimerCase("radix sort eight bytes") { timer => - val array = new Array[Long](size * 2) - var i = 0 - while (i < size) { - array(i) = rand.nextLong - i += 1 - } - val buf = new LongArray(MemoryBlock.fromLongArray(array)) - timer.startTiming() - RadixSort.sort(buf, size, 0, 7, false, false) - timer.stopTiming() - } - benchmark.addTimerCase("radix sort key prefix array") { timer => - val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong) - timer.startTiming() - RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false) - timer.stopTiming() - } - benchmark.run - - /** - Running benchmark: radix sort 25000000 - Java HotSpot(TM) 64-Bit Server VM 1.8.0_66-b17 on Linux 3.13.0-44-generic - Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz - - radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - reference TimSort key prefix array 15546 / 15859 1.6 621.9 1.0X - reference Arrays.sort 2416 / 2446 10.3 96.6 6.4X - radix sort one byte 133 / 137 188.4 5.3 117.2X - radix sort two bytes 255 / 258 98.2 10.2 61.1X - radix sort eight bytes 991 / 997 25.2 39.6 15.7X - radix sort key prefix array 1540 / 1563 16.2 61.6 10.1X - */ - } } diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala index 226622075a6cc..86961745673c6 100644 --- a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala @@ -28,12 +28,14 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { test("empty output") { val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) + o.close() assert(o.toChunkedByteBuffer.size === 0) } test("write a single byte") { val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) o.write(10) + o.close() val chunkedByteBuffer = o.toChunkedByteBuffer assert(chunkedByteBuffer.getChunks().length === 1) assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte)) @@ -43,6 +45,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(new Array[Byte](9)) o.write(99) + o.close() val chunkedByteBuffer = o.toChunkedByteBuffer assert(chunkedByteBuffer.getChunks().length === 1) assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte) @@ -52,6 +55,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(new Array[Byte](10)) o.write(99) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 2) assert(arrays(1).length === 1) @@ -63,6 +67,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 1) assert(arrays.head.length === ref.length) @@ -74,6 +79,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 1) assert(arrays.head.length === ref.length) @@ -85,6 +91,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 3) assert(arrays(0).length === 10) @@ -101,6 +108,7 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { Random.nextBytes(ref) val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) o.write(ref) + o.close() val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) assert(arrays.length === 3) assert(arrays(0).length === 10) diff --git a/graphx/data/followers.txt b/data/graphx/followers.txt similarity index 100% rename from graphx/data/followers.txt rename to data/graphx/followers.txt diff --git a/graphx/data/users.txt b/data/graphx/users.txt similarity index 100% rename from graphx/data/users.txt rename to data/graphx/users.txt diff --git a/data/mllib/sample_isotonic_regression_data.txt b/data/mllib/sample_isotonic_regression_data.txt deleted file mode 100644 index d257b509d4d37..0000000000000 --- a/data/mllib/sample_isotonic_regression_data.txt +++ /dev/null @@ -1,100 +0,0 @@ -0.24579296,0.01 -0.28505864,0.02 -0.31208567,0.03 -0.35900051,0.04 -0.35747068,0.05 -0.16675166,0.06 -0.17491076,0.07 -0.04181540,0.08 -0.04793473,0.09 -0.03926568,0.10 -0.12952575,0.11 -0.00000000,0.12 -0.01376849,0.13 -0.13105558,0.14 -0.08873024,0.15 -0.12595614,0.16 -0.15247323,0.17 -0.25956145,0.18 -0.20040796,0.19 -0.19581846,0.20 -0.15757267,0.21 -0.13717491,0.22 -0.19020908,0.23 -0.19581846,0.24 -0.20091790,0.25 -0.16879143,0.26 -0.18510964,0.27 -0.20040796,0.28 -0.29576747,0.29 -0.43396226,0.30 -0.53391127,0.31 -0.52116267,0.32 -0.48546660,0.33 -0.49209587,0.34 -0.54156043,0.35 -0.59765426,0.36 -0.56144824,0.37 -0.58592555,0.38 -0.52983172,0.39 -0.50178480,0.40 -0.52626211,0.41 -0.58286588,0.42 -0.64660887,0.43 -0.68077511,0.44 -0.74298827,0.45 -0.64864865,0.46 -0.67261601,0.47 -0.65782764,0.48 -0.69811321,0.49 -0.63029067,0.50 -0.61601224,0.51 -0.63233044,0.52 -0.65323814,0.53 -0.65323814,0.54 -0.67363590,0.55 -0.67006629,0.56 -0.51555329,0.57 -0.50892402,0.58 -0.33299337,0.59 -0.36206017,0.60 -0.43090260,0.61 -0.45996940,0.62 -0.56348802,0.63 -0.54920959,0.64 -0.48393677,0.65 -0.48495665,0.66 -0.46965834,0.67 -0.45181030,0.68 -0.45843957,0.69 -0.47118817,0.70 -0.51555329,0.71 -0.58031617,0.72 -0.55481897,0.73 -0.56297807,0.74 -0.56603774,0.75 -0.57929628,0.76 -0.64762876,0.77 -0.66241713,0.78 -0.69301377,0.79 -0.65119837,0.80 -0.68332483,0.81 -0.66598674,0.82 -0.73890872,0.83 -0.73992861,0.84 -0.84242733,0.85 -0.91330954,0.86 -0.88016318,0.87 -0.90719021,0.88 -0.93115757,0.89 -0.93115757,0.90 -0.91942886,0.91 -0.92911780,0.92 -0.95665477,0.93 -0.95002550,0.94 -0.96940337,0.95 -1.00000000,0.96 -0.89801122,0.97 -0.90311066,0.98 -0.90362060,0.99 -0.83477817,1.0 \ No newline at end of file diff --git a/data/mllib/sample_isotonic_regression_libsvm_data.txt b/data/mllib/sample_isotonic_regression_libsvm_data.txt new file mode 100644 index 0000000000000..f39fe0269c2fe --- /dev/null +++ b/data/mllib/sample_isotonic_regression_libsvm_data.txt @@ -0,0 +1,100 @@ +0.24579296 1:0.01 +0.28505864 1:0.02 +0.31208567 1:0.03 +0.35900051 1:0.04 +0.35747068 1:0.05 +0.16675166 1:0.06 +0.17491076 1:0.07 +0.04181540 1:0.08 +0.04793473 1:0.09 +0.03926568 1:0.10 +0.12952575 1:0.11 +0.00000000 1:0.12 +0.01376849 1:0.13 +0.13105558 1:0.14 +0.08873024 1:0.15 +0.12595614 1:0.16 +0.15247323 1:0.17 +0.25956145 1:0.18 +0.20040796 1:0.19 +0.19581846 1:0.20 +0.15757267 1:0.21 +0.13717491 1:0.22 +0.19020908 1:0.23 +0.19581846 1:0.24 +0.20091790 1:0.25 +0.16879143 1:0.26 +0.18510964 1:0.27 +0.20040796 1:0.28 +0.29576747 1:0.29 +0.43396226 1:0.30 +0.53391127 1:0.31 +0.52116267 1:0.32 +0.48546660 1:0.33 +0.49209587 1:0.34 +0.54156043 1:0.35 +0.59765426 1:0.36 +0.56144824 1:0.37 +0.58592555 1:0.38 +0.52983172 1:0.39 +0.50178480 1:0.40 +0.52626211 1:0.41 +0.58286588 1:0.42 +0.64660887 1:0.43 +0.68077511 1:0.44 +0.74298827 1:0.45 +0.64864865 1:0.46 +0.67261601 1:0.47 +0.65782764 1:0.48 +0.69811321 1:0.49 +0.63029067 1:0.50 +0.61601224 1:0.51 +0.63233044 1:0.52 +0.65323814 1:0.53 +0.65323814 1:0.54 +0.67363590 1:0.55 +0.67006629 1:0.56 +0.51555329 1:0.57 +0.50892402 1:0.58 +0.33299337 1:0.59 +0.36206017 1:0.60 +0.43090260 1:0.61 +0.45996940 1:0.62 +0.56348802 1:0.63 +0.54920959 1:0.64 +0.48393677 1:0.65 +0.48495665 1:0.66 +0.46965834 1:0.67 +0.45181030 1:0.68 +0.45843957 1:0.69 +0.47118817 1:0.70 +0.51555329 1:0.71 +0.58031617 1:0.72 +0.55481897 1:0.73 +0.56297807 1:0.74 +0.56603774 1:0.75 +0.57929628 1:0.76 +0.64762876 1:0.77 +0.66241713 1:0.78 +0.69301377 1:0.79 +0.65119837 1:0.80 +0.68332483 1:0.81 +0.66598674 1:0.82 +0.73890872 1:0.83 +0.73992861 1:0.84 +0.84242733 1:0.85 +0.91330954 1:0.86 +0.88016318 1:0.87 +0.90719021 1:0.88 +0.93115757 1:0.89 +0.93115757 1:0.90 +0.91942886 1:0.91 +0.92911780 1:0.92 +0.95665477 1:0.93 +0.95002550 1:0.94 +0.96940337 1:0.95 +1.00000000 1:0.96 +0.89801122 1:0.97 +0.90311066 1:0.98 +0.90362060 1:0.99 +0.83477817 1:1.0 \ No newline at end of file diff --git a/data/mllib/sample_kmeans_data.txt b/data/mllib/sample_kmeans_data.txt new file mode 100644 index 0000000000000..50013776b182a --- /dev/null +++ b/data/mllib/sample_kmeans_data.txt @@ -0,0 +1,6 @@ +0 1:0.0 2:0.0 3:0.0 +1 1:0.1 2:0.1 3:0.1 +2 1:0.2 2:0.2 3:0.2 +3 1:9.0 2:9.0 3:9.0 +4 1:9.1 2:9.1 3:9.1 +5 1:9.2 2:9.2 3:9.2 diff --git a/data/mllib/sample_lda_libsvm_data.txt b/data/mllib/sample_lda_libsvm_data.txt new file mode 100644 index 0000000000000..bf118d7d5b20f --- /dev/null +++ b/data/mllib/sample_lda_libsvm_data.txt @@ -0,0 +1,12 @@ +0 1:1 2:2 3:6 4:0 5:2 6:3 7:1 8:1 9:0 10:0 11:3 +1 1:1 2:3 3:0 4:1 5:3 6:0 7:0 8:2 9:0 10:0 11:1 +2 1:1 2:4 3:1 4:0 5:0 6:4 7:9 8:0 9:1 10:2 11:0 +3 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:3 11:9 +4 1:3 2:1 3:1 4:9 5:3 6:0 7:2 8:0 9:0 10:1 11:3 +5 1:4 2:2 3:0 4:3 5:4 6:5 7:1 8:1 9:1 10:4 11:0 +6 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:2 11:9 +7 1:1 2:1 3:1 4:9 5:2 6:1 7:2 8:0 9:0 10:1 11:3 +8 1:4 2:4 3:0 4:3 5:4 6:2 7:1 8:3 9:0 10:0 11:0 +9 1:2 2:8 3:2 4:0 5:3 6:0 7:2 8:0 9:2 10:7 11:2 +10 1:1 2:1 3:1 4:9 5:0 6:2 7:2 8:0 9:0 10:3 11:3 +11 1:4 2:1 3:0 4:0 5:4 6:5 7:1 8:3 9:0 10:1 11:0 diff --git a/data/mllib/sample_naive_bayes_data.txt b/data/mllib/sample_naive_bayes_data.txt deleted file mode 100644 index bd22bea3a59d6..0000000000000 --- a/data/mllib/sample_naive_bayes_data.txt +++ /dev/null @@ -1,12 +0,0 @@ -0,1 0 0 -0,2 0 0 -0,3 0 0 -0,4 0 0 -1,0 1 0 -1,0 2 0 -1,0 3 0 -1,0 4 0 -2,0 0 1 -2,0 0 2 -2,0 0 3 -2,0 0 4 \ No newline at end of file diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 67a241806d7cb..0c866717a3f43 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -98,3 +98,5 @@ spark-deps-.* .*csv .*tsv org.apache.spark.scheduler.ExternalClusterManager +.*\.sql +.Rbuildignore diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index ee72da4df0652..b28e7a427b8f6 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -116,7 +116,7 @@ def ensure_path_not_present(path): # dependencies within those projects. modules = [ "spark-core", "spark-mllib", "spark-streaming", "spark-repl", - "spark-graphx", "spark-streaming-flume", "spark-streaming-kafka", + "spark-graphx", "spark-streaming-flume", "spark-streaming-kafka-0-8", "spark-catalyst", "spark-sql", "spark-hive", "spark-streaming-kinesis-asl" ] modules = map(lambda m: "%s_%s" % (m, SCALA_BINARY_VERSION), modules) diff --git a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala index f69d46cd17d0b..8cbfb9cd41b38 100644 --- a/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala +++ b/dev/audit-release/sbt_app_hive/src/main/scala/HiveApp.scala @@ -33,7 +33,9 @@ object SparkSqlExample { case None => new SparkConf().setAppName("Simple Sql App") } val sc = new SparkContext(conf) - val sparkSession = SparkSession.withHiveSupport(sc) + val sparkSession = SparkSession.builder + .enableHiveSupport() + .getOrCreate() import sparkSession._ sql("DROP TABLE IF EXISTS src") @@ -41,14 +43,14 @@ object SparkSqlExample { sql("LOAD DATA LOCAL INPATH 'data.txt' INTO TABLE src") val results = sql("FROM src SELECT key, value WHERE key >= 0 AND KEY < 5").collect() results.foreach(println) - + def test(f: => Boolean, failureMsg: String) = { if (!f) { println(failureMsg) System.exit(-1) } } - + test(results.size == 5, "Unexpected number of selected elements: " + results) println("Test succeeded") sc.stop() diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index 69c1154dc0955..10026314ef7ac 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -41,7 +41,7 @@ object SparkSqlExample { import sqlContext._ val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)).toDF() - people.registerTempTable("people") + people.createOrReplaceTempView("people") val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") val teenagerNames = teenagers.map(t => "Name: " + t(0)).collect() teenagerNames.foreach(println) @@ -52,7 +52,7 @@ object SparkSqlExample { System.exit(-1) } } - + test(teenagerNames.size == 7, "Unexpected number of selected elements: " + teenagerNames) println("Test succeeded") sc.stop() diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index a1a88ac8cdac5..31656ca0e5a60 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -36,4 +36,12 @@ files="src/test/java/org/apache/spark/sql/hive/test/Complex.java"/> + + + + diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 65e80fc76056a..2833dc7651117 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -80,7 +80,7 @@ NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) MVN="build/mvn --force" -PUBLISH_PROFILES="-Pyarn -Phive -Phadoop-2.2" +PUBLISH_PROFILES="-Pyarn -Phive -Phive-thriftserver -Phadoop-2.2" PUBLISH_PROFILES="$PUBLISH_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" rm -rf spark @@ -254,8 +254,7 @@ if [[ "$1" == "publish-snapshot" ]]; then # Generate random point for Zinc export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") - $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \ - -Phive-thriftserver deploy + $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES deploy ./dev/change-scala-version.sh 2.10 $MVN -DzincPort=$ZINC_PORT -Dscala-2.10 --settings $tmp_settings \ -DskipTests $PUBLISH_PROFILES clean deploy @@ -291,8 +290,7 @@ if [[ "$1" == "publish-release" ]]; then # Generate random point for Zinc export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)") - $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \ - -Phive-thriftserver clean install + $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES clean install ./dev/change-scala-version.sh 2.10 diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index d404939d1caee..b7e5100ca7408 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -60,12 +60,27 @@ git config user.email $GIT_EMAIL # Create release version $MVN versions:set -DnewVersion=$RELEASE_VERSION | grep -v "no value" # silence logs +# Set the release version in R/pkg/DESCRIPTION +sed -i".tmp1" 's/Version.*$/Version: '"$RELEASE_VERSION"'/g' R/pkg/DESCRIPTION +# Set the release version in docs +sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml +sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml + git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" git tag $RELEASE_TAG # Create next version $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs +# Remove -SNAPSHOT before setting the R version as R expects version strings to only have numbers +R_NEXT_VERSION=`echo $NEXT_VERSION | sed 's/-SNAPSHOT//g'` +sed -i".tmp2" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION + +# Update docs with next version +sed -i".tmp3" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml +# Use R version for short version +sed -i".tmp4" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml + git commit -a -m "Preparing development version $NEXT_VERSION" # Push changes diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index a4ef9a9af2939..34cd4e6b15f84 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -1,14 +1,13 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar -activation-1.1.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.5.2-1.jar +antlr4-runtime-4.5.3.jar aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar -asm-3.1.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -31,7 +30,7 @@ commons-configuration-1.6.jar commons-dbcp-1.4.jar commons-digester-1.8.jar commons-httpclient-3.1.jar -commons-io-2.1.jar +commons-io-2.4.jar commons-lang-2.6.jar commons-lang3-3.3.2.jar commons-logging-1.1.3.jar @@ -47,14 +46,8 @@ curator-recipes-2.4.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.10.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar -gmbal-api-only-3.0.0-b023.jar -grizzly-framework-2.1.2.jar -grizzly-http-2.1.2.jar -grizzly-http-server-2.1.2.jar -grizzly-http-servlet-2.1.2.jar -grizzly-rcm-2.1.2.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar @@ -73,46 +66,46 @@ hadoop-yarn-client-2.2.0.jar hadoop-yarn-common-2.2.0.jar hadoop-yarn-server-common-2.2.0.jar hadoop-yarn-server-web-proxy-2.2.0.jar -httpclient-4.3.2.jar -httpcore-4.3.2.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +httpclient-4.5.2.jar +httpcore-4.4.4.jar ivy-2.4.0.jar -jackson-annotations-2.5.3.jar -jackson-core-2.5.3.jar +jackson-annotations-2.6.5.jar +jackson-core-2.6.5.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.5.3.jar -jackson-jaxrs-1.9.13.jar +jackson-databind-2.6.5.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.11-2.5.3.jar -jackson-xc-1.9.13.jar +jackson-module-paranamer-2.6.5.jar +jackson-module-scala_2.11-2.6.5.jar janino-2.7.8.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar javax.inject-1.jar -javax.servlet-3.0.0.v201112011016.jar -javax.servlet-3.1.jar -javax.servlet-api-3.0.1.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar javolution-5.5.1.jar -jaxb-api-2.2.2.jar -jaxb-impl-2.2.3-1.jar jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar -jersey-client-1.9.jar -jersey-core-1.9.jar -jersey-grizzly2-1.9.jar -jersey-guice-1.9.jar -jersey-json-1.9.jar -jersey-server-1.9.jar -jersey-test-framework-core-1.9.jar -jersey-test-framework-grizzly2-1.9.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar jets3t-0.7.1.jar -jettison-1.1.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.11-3.2.10.jar -json4s-core_2.11-3.2.10.jar -json4s-jackson_2.11-3.2.10.jar +json4s-ast_2.11-3.2.11.jar +json4s-core_2.11-3.2.11.jar +json4s-jackson_2.11-3.2.11.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -123,7 +116,6 @@ libfb303-0.9.2.jar libthrift-0.9.2.jar log4j-1.2.17.jar lz4-1.3.0.jar -management-api-3.0.0-b012.jar mesos-0.21.1-shaded-protobuf.jar metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar @@ -135,7 +127,8 @@ netty-all-4.0.29.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar -paranamer-2.6.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.3.jar parquet-column-1.7.0.jar parquet-common-1.7.0.jar parquet-encoding-1.7.0.jar @@ -144,12 +137,11 @@ parquet-generator-1.7.0.jar parquet-hadoop-1.7.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.7.0.jar -pmml-agent-1.2.7.jar -pmml-model-1.2.7.jar -pmml-schema-1.2.7.jar +pmml-model-1.2.15.jar +pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.9.2.jar -pyrolite-4.9.jar +py4j-0.10.3.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar @@ -159,15 +151,15 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar -stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.0.2.jar +univocity-parsers-2.1.1.jar +validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 19c8fad984b30..8ae3c5e60d5e2 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -4,11 +4,11 @@ ST4-4.0.4.jar activation-1.1.1.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.5.2-1.jar +antlr4-runtime-4.5.3.jar aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar -asm-3.1.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -48,7 +48,7 @@ curator-recipes-2.4.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.10.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar guava-14.0.1.jar guice-3.0.jar @@ -68,42 +68,49 @@ hadoop-yarn-client-2.3.0.jar hadoop-yarn-common-2.3.0.jar hadoop-yarn-server-common-2.3.0.jar hadoop-yarn-server-web-proxy-2.3.0.jar -httpclient-4.3.2.jar -httpcore-4.3.2.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +httpclient-4.5.2.jar +httpcore-4.4.4.jar ivy-2.4.0.jar -jackson-annotations-2.5.3.jar -jackson-core-2.5.3.jar +jackson-annotations-2.6.5.jar +jackson-core-2.6.5.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.5.3.jar -jackson-jaxrs-1.9.13.jar +jackson-databind-2.6.5.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.11-2.5.3.jar -jackson-xc-1.9.13.jar +jackson-module-paranamer-2.6.5.jar +jackson-module-scala_2.11-2.6.5.jar janino-2.7.8.jar java-xmlbuilder-1.0.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar javax.inject-1.jar -javax.servlet-3.0.0.v201112011016.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar javolution-5.5.1.jar jaxb-api-2.2.2.jar -jaxb-impl-2.2.3-1.jar jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar -jersey-core-1.9.jar -jersey-guice-1.9.jar -jersey-json-1.9.jar -jersey-server-1.9.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar jets3t-0.9.3.jar -jettison-1.1.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.11-3.2.10.jar -json4s-core_2.11-3.2.10.jar -json4s-jackson_2.11-3.2.10.jar +json4s-ast_2.11-3.2.11.jar +json4s-core_2.11-3.2.11.jar +json4s-jackson_2.11-3.2.11.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -127,7 +134,8 @@ netty-all-4.0.29.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar -paranamer-2.6.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.3.jar parquet-column-1.7.0.jar parquet-common-1.7.0.jar parquet-encoding-1.7.0.jar @@ -136,12 +144,11 @@ parquet-generator-1.7.0.jar parquet-hadoop-1.7.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.7.0.jar -pmml-agent-1.2.7.jar -pmml-model-1.2.7.jar -pmml-schema-1.2.7.jar +pmml-model-1.2.15.jar +pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.9.2.jar -pyrolite-4.9.jar +py4j-0.10.3.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar @@ -151,7 +158,7 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar @@ -159,7 +166,8 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.0.2.jar +univocity-parsers-2.1.1.jar +validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index c2365f9cc5625..7c691029535c6 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -4,11 +4,11 @@ ST4-4.0.4.jar activation-1.1.1.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.5.2-1.jar +antlr4-runtime-4.5.3.jar aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar -asm-3.1.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -48,63 +48,69 @@ curator-recipes-2.4.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.10.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.4.0.jar -hadoop-auth-2.4.0.jar -hadoop-client-2.4.0.jar -hadoop-common-2.4.0.jar -hadoop-hdfs-2.4.0.jar -hadoop-mapreduce-client-app-2.4.0.jar -hadoop-mapreduce-client-common-2.4.0.jar -hadoop-mapreduce-client-core-2.4.0.jar -hadoop-mapreduce-client-jobclient-2.4.0.jar -hadoop-mapreduce-client-shuffle-2.4.0.jar -hadoop-yarn-api-2.4.0.jar -hadoop-yarn-client-2.4.0.jar -hadoop-yarn-common-2.4.0.jar -hadoop-yarn-server-common-2.4.0.jar -hadoop-yarn-server-web-proxy-2.4.0.jar -httpclient-4.3.2.jar -httpcore-4.3.2.jar +hadoop-annotations-2.4.1.jar +hadoop-auth-2.4.1.jar +hadoop-client-2.4.1.jar +hadoop-common-2.4.1.jar +hadoop-hdfs-2.4.1.jar +hadoop-mapreduce-client-app-2.4.1.jar +hadoop-mapreduce-client-common-2.4.1.jar +hadoop-mapreduce-client-core-2.4.1.jar +hadoop-mapreduce-client-jobclient-2.4.1.jar +hadoop-mapreduce-client-shuffle-2.4.1.jar +hadoop-yarn-api-2.4.1.jar +hadoop-yarn-client-2.4.1.jar +hadoop-yarn-common-2.4.1.jar +hadoop-yarn-server-common-2.4.1.jar +hadoop-yarn-server-web-proxy-2.4.1.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar +httpclient-4.5.2.jar +httpcore-4.4.4.jar ivy-2.4.0.jar -jackson-annotations-2.5.3.jar -jackson-core-2.5.3.jar +jackson-annotations-2.6.5.jar +jackson-core-2.6.5.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.5.3.jar -jackson-jaxrs-1.9.13.jar +jackson-databind-2.6.5.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.11-2.5.3.jar -jackson-xc-1.9.13.jar +jackson-module-paranamer-2.6.5.jar +jackson-module-scala_2.11-2.6.5.jar janino-2.7.8.jar java-xmlbuilder-1.0.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar javax.inject-1.jar -javax.servlet-3.0.0.v201112011016.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar javolution-5.5.1.jar jaxb-api-2.2.2.jar -jaxb-impl-2.2.3-1.jar jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar -jersey-client-1.9.jar -jersey-core-1.9.jar -jersey-guice-1.9.jar -jersey-json-1.9.jar -jersey-server-1.9.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar jets3t-0.9.3.jar -jettison-1.1.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.11-3.2.10.jar -json4s-core_2.11-3.2.10.jar -json4s-jackson_2.11-3.2.10.jar +json4s-ast_2.11-3.2.11.jar +json4s-core_2.11-3.2.11.jar +json4s-jackson_2.11-3.2.11.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -128,7 +134,8 @@ netty-all-4.0.29.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar -paranamer-2.6.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.3.jar parquet-column-1.7.0.jar parquet-common-1.7.0.jar parquet-encoding-1.7.0.jar @@ -137,12 +144,11 @@ parquet-generator-1.7.0.jar parquet-hadoop-1.7.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.7.0.jar -pmml-agent-1.2.7.jar -pmml-model-1.2.7.jar -pmml-schema-1.2.7.jar +pmml-model-1.2.15.jar +pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.9.2.jar -pyrolite-4.9.jar +py4j-0.10.3.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar @@ -152,7 +158,7 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar @@ -160,7 +166,8 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.0.2.jar +univocity-parsers-2.1.1.jar +validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 0d8afd19c5616..041e01e9cb427 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -4,15 +4,15 @@ ST4-4.0.4.jar activation-1.1.1.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.5.2-1.jar +antlr4-runtime-4.5.3.jar aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -asm-3.1.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -52,65 +52,73 @@ curator-recipes-2.6.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.10.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.6.0.jar -hadoop-auth-2.6.0.jar -hadoop-client-2.6.0.jar -hadoop-common-2.6.0.jar -hadoop-hdfs-2.6.0.jar -hadoop-mapreduce-client-app-2.6.0.jar -hadoop-mapreduce-client-common-2.6.0.jar -hadoop-mapreduce-client-core-2.6.0.jar -hadoop-mapreduce-client-jobclient-2.6.0.jar -hadoop-mapreduce-client-shuffle-2.6.0.jar -hadoop-yarn-api-2.6.0.jar -hadoop-yarn-client-2.6.0.jar -hadoop-yarn-common-2.6.0.jar -hadoop-yarn-server-common-2.6.0.jar -hadoop-yarn-server-web-proxy-2.6.0.jar +hadoop-annotations-2.6.4.jar +hadoop-auth-2.6.4.jar +hadoop-client-2.6.4.jar +hadoop-common-2.6.4.jar +hadoop-hdfs-2.6.4.jar +hadoop-mapreduce-client-app-2.6.4.jar +hadoop-mapreduce-client-common-2.6.4.jar +hadoop-mapreduce-client-core-2.6.4.jar +hadoop-mapreduce-client-jobclient-2.6.4.jar +hadoop-mapreduce-client-shuffle-2.6.4.jar +hadoop-yarn-api-2.6.4.jar +hadoop-yarn-client-2.6.4.jar +hadoop-yarn-common-2.6.4.jar +hadoop-yarn-server-common-2.6.4.jar +hadoop-yarn-server-web-proxy-2.6.4.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar htrace-core-3.0.4.jar -httpclient-4.3.2.jar -httpcore-4.3.2.jar +httpclient-4.5.2.jar +httpcore-4.4.4.jar ivy-2.4.0.jar -jackson-annotations-2.5.3.jar -jackson-core-2.5.3.jar +jackson-annotations-2.6.5.jar +jackson-core-2.6.5.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.5.3.jar +jackson-databind-2.6.5.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.11-2.5.3.jar +jackson-module-paranamer-2.6.5.jar +jackson-module-scala_2.11-2.6.5.jar jackson-xc-1.9.13.jar janino-2.7.8.jar java-xmlbuilder-1.0.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar javax.inject-1.jar -javax.servlet-3.0.0.v201112011016.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar javolution-5.5.1.jar jaxb-api-2.2.2.jar -jaxb-impl-2.2.3-1.jar jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar -jersey-client-1.9.jar -jersey-core-1.9.jar -jersey-guice-1.9.jar -jersey-json-1.9.jar -jersey-server-1.9.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar jets3t-0.9.3.jar -jettison-1.1.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.11-3.2.10.jar -json4s-core_2.11-3.2.10.jar -json4s-jackson_2.11-3.2.10.jar +json4s-ast_2.11-3.2.11.jar +json4s-core_2.11-3.2.11.jar +json4s-jackson_2.11-3.2.11.jar jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar @@ -134,7 +142,8 @@ netty-all-4.0.29.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar -paranamer-2.6.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.3.jar parquet-column-1.7.0.jar parquet-common-1.7.0.jar parquet-encoding-1.7.0.jar @@ -143,12 +152,11 @@ parquet-generator-1.7.0.jar parquet-hadoop-1.7.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.7.0.jar -pmml-agent-1.2.7.jar -pmml-model-1.2.7.jar -pmml-schema-1.2.7.jar +pmml-model-1.2.15.jar +pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.9.2.jar -pyrolite-4.9.jar +py4j-0.10.3.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar @@ -158,7 +166,7 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar @@ -166,7 +174,8 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.0.2.jar +univocity-parsers-2.1.1.jar +validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index f4274a9441a04..4f70bff405363 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -4,15 +4,15 @@ ST4-4.0.4.jar activation-1.1.1.jar antlr-2.7.7.jar antlr-runtime-3.4.jar -antlr4-runtime-4.5.2-1.jar +antlr4-runtime-4.5.3.jar aopalliance-1.0.jar +aopalliance-repackaged-2.4.0-b34.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar apacheds-kerberos-codec-2.0.0-M15.jar api-asn1-api-1.0.0-M20.jar api-util-1.0.0-M20.jar arpack_combined_all-0.1.jar -asm-3.1.jar avro-1.7.7.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar @@ -52,65 +52,73 @@ curator-recipes-2.6.0.jar datanucleus-api-jdo-3.2.6.jar datanucleus-core-3.2.10.jar datanucleus-rdbms-3.2.9.jar -derby-10.10.1.1.jar +derby-10.12.1.1.jar eigenbase-properties-1.1.5.jar gson-2.2.4.jar guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar -hadoop-annotations-2.7.0.jar -hadoop-auth-2.7.0.jar -hadoop-client-2.7.0.jar -hadoop-common-2.7.0.jar -hadoop-hdfs-2.7.0.jar -hadoop-mapreduce-client-app-2.7.0.jar -hadoop-mapreduce-client-common-2.7.0.jar -hadoop-mapreduce-client-core-2.7.0.jar -hadoop-mapreduce-client-jobclient-2.7.0.jar -hadoop-mapreduce-client-shuffle-2.7.0.jar -hadoop-yarn-api-2.7.0.jar -hadoop-yarn-client-2.7.0.jar -hadoop-yarn-common-2.7.0.jar -hadoop-yarn-server-common-2.7.0.jar -hadoop-yarn-server-web-proxy-2.7.0.jar +hadoop-annotations-2.7.3.jar +hadoop-auth-2.7.3.jar +hadoop-client-2.7.3.jar +hadoop-common-2.7.3.jar +hadoop-hdfs-2.7.3.jar +hadoop-mapreduce-client-app-2.7.3.jar +hadoop-mapreduce-client-common-2.7.3.jar +hadoop-mapreduce-client-core-2.7.3.jar +hadoop-mapreduce-client-jobclient-2.7.3.jar +hadoop-mapreduce-client-shuffle-2.7.3.jar +hadoop-yarn-api-2.7.3.jar +hadoop-yarn-client-2.7.3.jar +hadoop-yarn-common-2.7.3.jar +hadoop-yarn-server-common-2.7.3.jar +hadoop-yarn-server-web-proxy-2.7.3.jar +hk2-api-2.4.0-b34.jar +hk2-locator-2.4.0-b34.jar +hk2-utils-2.4.0-b34.jar htrace-core-3.1.0-incubating.jar -httpclient-4.3.2.jar -httpcore-4.3.2.jar +httpclient-4.5.2.jar +httpcore-4.4.4.jar ivy-2.4.0.jar -jackson-annotations-2.5.3.jar -jackson-core-2.5.3.jar +jackson-annotations-2.6.5.jar +jackson-core-2.6.5.jar jackson-core-asl-1.9.13.jar -jackson-databind-2.5.3.jar +jackson-databind-2.6.5.jar jackson-jaxrs-1.9.13.jar jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.11-2.5.3.jar +jackson-module-paranamer-2.6.5.jar +jackson-module-scala_2.11-2.6.5.jar jackson-xc-1.9.13.jar janino-2.7.8.jar java-xmlbuilder-1.0.jar +javassist-3.18.1-GA.jar +javax.annotation-api-1.2.jar javax.inject-1.jar -javax.servlet-3.0.0.v201112011016.jar +javax.inject-2.4.0-b34.jar +javax.servlet-api-3.1.0.jar +javax.ws.rs-api-2.0.1.jar javolution-5.5.1.jar jaxb-api-2.2.2.jar -jaxb-impl-2.2.3-1.jar jcl-over-slf4j-1.7.16.jar jdo-api-3.0.1.jar -jersey-client-1.9.jar -jersey-core-1.9.jar -jersey-guice-1.9.jar -jersey-json-1.9.jar -jersey-server-1.9.jar +jersey-client-2.22.2.jar +jersey-common-2.22.2.jar +jersey-container-servlet-2.22.2.jar +jersey-container-servlet-core-2.22.2.jar +jersey-guava-2.22.2.jar +jersey-media-jaxb-2.22.2.jar +jersey-server-2.22.2.jar jets3t-0.9.3.jar -jettison-1.1.jar jetty-6.1.26.jar jetty-util-6.1.26.jar -jline-2.12.jar +jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar json-20090211.jar -json4s-ast_2.11-3.2.10.jar -json4s-core_2.11-3.2.10.jar -json4s-jackson_2.11-3.2.10.jar +json4s-ast_2.11-3.2.11.jar +json4s-core_2.11-3.2.11.jar +json4s-jackson_2.11-3.2.11.jar jsp-api-2.1.jar jsr305-1.3.9.jar jta-1.1.jar @@ -135,7 +143,8 @@ netty-all-4.0.29.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar -paranamer-2.6.jar +osgi-resource-locator-1.0.1.jar +paranamer-2.3.jar parquet-column-1.7.0.jar parquet-common-1.7.0.jar parquet-encoding-1.7.0.jar @@ -144,12 +153,11 @@ parquet-generator-1.7.0.jar parquet-hadoop-1.7.0.jar parquet-hadoop-bundle-1.6.0.jar parquet-jackson-1.7.0.jar -pmml-agent-1.2.7.jar -pmml-model-1.2.7.jar -pmml-schema-1.2.7.jar +pmml-model-1.2.15.jar +pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.9.2.jar -pyrolite-4.9.jar +py4j-0.10.3.jar +pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar @@ -159,7 +167,7 @@ scalap-2.11.8.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.4.jar +snappy-java-1.1.2.6.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar @@ -167,7 +175,8 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.0.2.jar +univocity-parsers-2.1.1.jar +validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 4f7544f6ea78b..9be4fdfa51c93 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -53,7 +53,7 @@ while (( "$#" )); do --hadoop) echo "Error: '--hadoop' is no longer supported:" echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." - echo "Error: Related profiles include hadoop-2.2, hadoop-2.3 and hadoop-2.4." + echo "Error: Related profiles include hadoop-2.2, hadoop-2.3, hadoop-2.4, hadoop-2.6 and hadoop-2.7." exit_with_usage ;; --with-yarn) @@ -150,7 +150,7 @@ export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCac # Store the command as an array because $MVN variable might have spaces in it. # Normal quoting tricks don't work. # See: http://mywiki.wooledge.org/BashFAQ/050 -BUILD_COMMAND=("$MVN" clean package -DskipTests $@) +BUILD_COMMAND=("$MVN" -T 1C clean package -DskipTests $@) # Actually build the jar echo -e "\nBuilding with..." diff --git a/dev/run-tests.py b/dev/run-tests.py index 291f821c7f80e..ad9d4ac7fa3c2 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -110,7 +110,7 @@ def determine_modules_to_test(changed_modules): ['graphx', 'examples'] >>> x = [x.name for x in determine_modules_to_test([modules.sql])] >>> x # doctest: +NORMALIZE_WHITESPACE - ['sql', 'hive', 'mllib', 'examples', 'hive-thriftserver', 'hivecontext-compatibility', + ['sql', 'hive', 'mllib', 'sql-kafka-0-10', 'examples', 'hive-thriftserver', 'pyspark-sql', 'sparkr', 'pyspark-mllib', 'pyspark-ml'] """ modules_to_test = set() @@ -294,7 +294,7 @@ def exec_sbt(sbt_args=()): print(line, end='') retcode = sbt_proc.wait() - if retcode > 0: + if retcode != 0: exit_from_command_with_retcode(sbt_cmd, retcode) @@ -335,8 +335,8 @@ def build_spark_maven(hadoop_version): def build_spark_sbt(hadoop_version): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags - sbt_goals = ["package", - "streaming-kafka-assembly/assembly", + sbt_goals = ["test:package", # Build test jars as some tests depend on them + "streaming-kafka-0-8-assembly/assembly", "streaming-flume-assembly/assembly", "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 5640928643952..c5d1a07782221 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -92,10 +92,17 @@ def __ne__(self, other): def __hash__(self): return hash(self.name) +tags = Module( + name="tags", + dependencies=[], + source_file_regexes=[ + "common/tags/", + ] +) catalyst = Module( name="catalyst", - dependencies=[], + dependencies=[tags], source_file_regexes=[ "sql/catalyst/", ], @@ -151,21 +158,21 @@ def __hash__(self): ) -hivecontext_compatibility = Module( - name="hivecontext-compatibility", - dependencies=[hive], +sql_kafka = Module( + name="sql-kafka-0-10", + dependencies=[sql], source_file_regexes=[ - "sql/hivecontext-compatibility/", + "external/kafka-0-10-sql", ], sbt_test_goals=[ - "hivecontext-compatibility/test" + "sql-kafka-0-10/test", ] ) sketch = Module( name="sketch", - dependencies=[], + dependencies=[tags], source_file_regexes=[ "common/sketch/", ], @@ -177,7 +184,7 @@ def __hash__(self): graphx = Module( name="graphx", - dependencies=[], + dependencies=[tags], source_file_regexes=[ "graphx/", ], @@ -189,7 +196,7 @@ def __hash__(self): streaming = Module( name="streaming", - dependencies=[], + dependencies=[tags], source_file_regexes=[ "streaming", ], @@ -205,7 +212,7 @@ def __hash__(self): # fail other PRs. streaming_kinesis_asl = Module( name="streaming-kinesis-asl", - dependencies=[], + dependencies=[tags], source_file_regexes=[ "external/kinesis-asl/", "external/kinesis-asl-assembly/", @@ -223,17 +230,28 @@ def __hash__(self): streaming_kafka = Module( - name="streaming-kafka", + name="streaming-kafka-0-8", dependencies=[streaming], source_file_regexes=[ - "external/kafka", - "external/kafka-assembly", + "external/kafka-0-8", + "external/kafka-0-8-assembly", ], sbt_test_goals=[ - "streaming-kafka/test", + "streaming-kafka-0-8/test", ] ) +streaming_kafka_0_10 = Module( + name="streaming-kafka-0-10", + dependencies=[streaming], + source_file_regexes=[ + "external/kafka-0-10", + "external/kafka-0-10-assembly", + ], + sbt_test_goals=[ + "streaming-kafka-0-10/test", + ] +) streaming_flume_sink = Module( name="streaming-flume-sink", @@ -270,7 +288,7 @@ def __hash__(self): mllib_local = Module( name="mllib-local", - dependencies=[], + dependencies=[tags], source_file_regexes=[ "mllib-local", ], @@ -342,6 +360,7 @@ def __hash__(self): "pyspark.sql.group", "pyspark.sql.functions", "pyspark.sql.readwriter", + "pyspark.sql.streaming", "pyspark.sql.window", "pyspark.sql.tests", ] @@ -406,6 +425,7 @@ def __hash__(self): "pyspark.ml.feature", "pyspark.ml.classification", "pyspark.ml.clustering", + "pyspark.ml.linalg.__init__", "pyspark.ml.recommendation", "pyspark.ml.regression", "pyspark.ml.tuning", diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index d280e797077d1..05af87189b18d 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -53,7 +53,10 @@ def subprocess_check_call(*popenargs, **kwargs): def exit_from_command_with_retcode(cmd, retcode): - print("[error] running", ' '.join(cmd), "; received return code", retcode) + if retcode < 0: + print("[error] running", ' '.join(cmd), "; process was terminated by signal", -retcode) + else: + print("[error] running", ' '.join(cmd), "; received return code", retcode) sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 5ea643ee48d98..28e3d4d8d4f00 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -79,7 +79,7 @@ for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do echo "Generating dependency manifest for $HADOOP_PROFILE" mkdir -p dev/pr-deps $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE dependency:build-classpath -pl assembly \ - | grep "Building Spark Project Assembly" -A 5 \ + | grep "Dependencies classpath:" -A 1 \ | tail -n 1 | tr ":" "\n" | rev | cut -d "/" -f 1 | rev | sort \ | grep -v spark > dev/pr-deps/spark-deps-$HADOOP_PROFILE done diff --git a/docs/README.md b/docs/README.md index bcea93e1f3b6d..ffd3b5712b618 100644 --- a/docs/README.md +++ b/docs/README.md @@ -19,9 +19,11 @@ installed. Also install the following libraries: $ sudo gem install jekyll jekyll-redirect-from pygments.rb $ sudo pip install Pygments # Following is needed only for generating API docs - $ sudo pip install sphinx - $ Rscript -e 'install.packages(c("knitr", "devtools"), repos="http://cran.stat.ucla.edu/")' + $ sudo pip install sphinx pypandoc + $ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")' ``` +(Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0) + ## Generating the Documentation HTML We include the Spark documentation as part of the source (as opposed to using a hosted wiki, such as diff --git a/docs/_config.yml b/docs/_config.yml index 8bdc68aeeac7f..c94dc54c0fd02 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.0.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.0.0 +SPARK_VERSION: 2.0.3-SNAPSHOT +SPARK_VERSION_SHORT: 2.0.3 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.7" MESOS_VERSION: 0.21.0 diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index 3fd3ee2823f75..0c6b9b20a6e4b 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -1,5 +1,5 @@ -- text: "Overview: estimators, transformers and pipelines" - url: ml-guide.html +- text: Pipelines + url: ml-pipeline.html - text: Extracting, transforming and selecting features url: ml-features.html - text: Classification and Regression @@ -8,5 +8,7 @@ url: ml-clustering.html - text: Collaborative filtering url: ml-collaborative-filtering.html +- text: Model selection and tuning + url: ml-tuning.html - text: Advanced topics url: ml-advanced.html diff --git a/docs/_includes/nav-left-wrapper-ml.html b/docs/_includes/nav-left-wrapper-ml.html index e2d7eda027c6e..00ac6cc0dbc7d 100644 --- a/docs/_includes/nav-left-wrapper-ml.html +++ b/docs/_includes/nav-left-wrapper-ml.html @@ -1,8 +1,8 @@
    -

    spark.ml package

    +

    MLlib: Main Guide

    {% include nav-left.html nav=include.nav-ml %} -

    spark.mllib package

    +

    MLlib: RDD-based API Guide

    {% include nav-left.html nav=include.nav-mllib %}
    \ No newline at end of file diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index d493f62f0e578..ad5b5c9adfac8 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -73,7 +73,8 @@
  • Spark Streaming
  • DataFrames, Datasets and SQL
  • -
  • MLlib (Machine Learning)
  • +
  • Structured Streaming
  • +
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • SparkR (R on Spark)
  • @@ -113,7 +114,7 @@
  • Building Spark
  • Contributing to Spark
  • -
  • Supplemental Projects
  • +
  • Third Party Projects
  • diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb index f7485826a762d..6ea1d438f529e 100644 --- a/docs/_plugins/include_example.rb +++ b/docs/_plugins/include_example.rb @@ -32,16 +32,34 @@ def render(context) @code_dir = File.join(site.source, config_dir) clean_markup = @markup.strip - @file = File.join(@code_dir, clean_markup) - @lang = clean_markup.split('.').last - code = File.open(@file).read.encode("UTF-8") + parts = clean_markup.strip.split(' ') + if parts.length > 1 then + @snippet_label = ':' + parts[0] + snippet_file = parts[1] + else + @snippet_label = '' + snippet_file = parts[0] + end + + @file = File.join(@code_dir, snippet_file) + @lang = snippet_file.split('.').last + + begin + code = File.open(@file).read.encode("UTF-8") + rescue => e + # We need to explicitly exit on execptions here because Jekyll will silently swallow + # them, leading to silent build failures (see https://github.com/jekyll/jekyll/issues/5104) + puts(e) + puts(e.backtrace) + exit 1 + end code = select_lines(code) rendered_code = Pygments.highlight(code, :lexer => @lang) hint = "
    Find full example code at " \ - "\"examples/src/main/#{clean_markup}\" in the Spark repo.
    " + "\"examples/src/main/#{snippet_file}\" in the Spark repo." rendered_code + hint end @@ -66,13 +84,13 @@ def select_lines(code) # Select the array of start labels from code. startIndices = lines .each_with_index - .select { |l, i| l.include? "$example on$" } + .select { |l, i| l.include? "$example on#{@snippet_label}$" } .map { |l, i| i } # Select the array of end labels from code. endIndices = lines .each_with_index - .select { |l, i| l.include? "$example off$" } + .select { |l, i| l.include? "$example off#{@snippet_label}$" } .map { |l, i| i } raise "Start indices amount is not equal to end indices amount, see #{@file}." \ @@ -92,7 +110,10 @@ def select_lines(code) if start == endline lastIndex = endline range = Range.new(start + 1, endline - 1) - result += trim_codeblock(lines[range]).join + trimmed = trim_codeblock(lines[range]) + # Filter out possible example tags of overlapped labels. + taggs_filtered = trimmed.select { |l| !l.include? '$example ' } + result += taggs_filtered.join result += "\n" end result diff --git a/docs/building-spark.md b/docs/building-spark.md index fec442af95e1b..857f4d4fcbf9a 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -7,65 +7,62 @@ redirect_from: "building-with-maven.html" * This will become a table of contents (this text will be scraped). {:toc} -Building Spark using Maven requires Maven 3.3.9 or newer and Java 7+. -The Spark build can supply a suitable Maven binary; see below. +# Building Apache Spark -# Building with `build/mvn` +## Apache Maven -Spark now comes packaged with a self-contained Maven installation to ease building and deployment of Spark from source located under the `build/` directory. This script will automatically download and setup all necessary build requirements ([Maven](https://maven.apache.org/), [Scala](http://www.scala-lang.org/), and [Zinc](https://github.com/typesafehub/zinc)) locally within the `build/` directory itself. It honors any `mvn` binary if present already, however, will pull down its own copy of Scala and Zinc regardless to ensure proper version requirements are met. `build/mvn` execution acts as a pass through to the `mvn` call allowing easy transition from previous build methods. As an example, one can build a version of Spark as follows: +The Maven-based build is the build of reference for Apache Spark. +Building Spark using Maven requires Maven 3.3.9 or newer and Java 7+. -{% highlight bash %} -build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package -{% endhighlight %} +### Setting up Maven's Memory Usage -Other build examples can be found below. +You'll need to configure Maven to use more memory than usual by setting `MAVEN_OPTS`: -**Note:** When building on an encrypted filesystem (if your home directory is encrypted, for example), then the Spark build might fail with a "Filename too long" error. As a workaround, add the following in the configuration args of the `scala-maven-plugin` in the project `pom.xml`: + export MAVEN_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=512m" - -Xmax-classfile-name - 128 +When compiling with Java 7, you will need to add the additional option "-XX:MaxPermSize=512M" to MAVEN_OPTS. -and in `project/SparkBuild.scala` add: +If you don't add these parameters to `MAVEN_OPTS`, you may see errors and warnings like the following: - scalacOptions in Compile ++= Seq("-Xmax-classfile-name", "128"), + [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_BINARY_VERSION}}/classes... + [ERROR] PermGen space -> [Help 1] -to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/pull/2883/files) if you are unsure of where to add these lines. + [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_BINARY_VERSION}}/classes... + [ERROR] Java heap space -> [Help 1] -# Building a Runnable Distribution + [INFO] Compiling 233 Scala sources and 41 Java sources to /Users/me/Development/spark/sql/core/target/scala-{site.SCALA_BINARY_VERSION}/classes... + OpenJDK 64-Bit Server VM warning: CodeCache is full. Compiler has been disabled. + OpenJDK 64-Bit Server VM warning: Try increasing the code cache size using -XX:ReservedCodeCacheSize= -To create a Spark distribution like those distributed by the -[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as -to be runnable, use `./dev/make-distribution.sh` in the project root directory. It can be configured -with Maven profile settings and so on like the direct Maven build. Example: - - ./dev/make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn +You can fix these problems by setting the `MAVEN_OPTS` variable as discussed before. -For more information on usage, run `./dev/make-distribution.sh --help` +**Note:** -# Setting up Maven's Memory Usage +* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automatically add the above options to the `MAVEN_OPTS` environment variable. +* The `test` phase of the Spark build will automatically add these options to `MAVEN_OPTS`, even when not using `build/mvn`. +* You may see warnings like "ignoring option MaxPermSize=1g; support was removed in 8.0" when building or running tests with Java 8 and `build/mvn`. These warnings are harmless. + -You'll need to configure Maven to use more memory than usual by setting `MAVEN_OPTS`. We recommend the following settings: +### build/mvn -{% highlight bash %} -export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" -{% endhighlight %} +Spark now comes packaged with a self-contained Maven installation to ease building and deployment of Spark from source located under the `build/` directory. This script will automatically download and setup all necessary build requirements ([Maven](https://maven.apache.org/), [Scala](http://www.scala-lang.org/), and [Zinc](https://github.com/typesafehub/zinc)) locally within the `build/` directory itself. It honors any `mvn` binary if present already, however, will pull down its own copy of Scala and Zinc regardless to ensure proper version requirements are met. `build/mvn` execution acts as a pass through to the `mvn` call allowing easy transition from previous build methods. As an example, one can build a version of Spark as follows: -If you don't run this, you may see errors like the following: + ./build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package - [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_BINARY_VERSION}}/classes... - [ERROR] PermGen space -> [Help 1] +Other build examples can be found below. - [INFO] Compiling 203 Scala sources and 9 Java sources to /Users/me/Development/spark/core/target/scala-{{site.SCALA_BINARY_VERSION}}/classes... - [ERROR] Java heap space -> [Help 1] +## Building a Runnable Distribution -You can fix this by setting the `MAVEN_OPTS` variable as discussed before. +To create a Spark distribution like those distributed by the +[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as +to be runnable, use `./dev/make-distribution.sh` in the project root directory. It can be configured +with Maven profile settings and so on like the direct Maven build. Example: -**Note:** + ./dev/make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn -* For Java 8 and above this step is not required. -* If using `build/mvn` with no `MAVEN_OPTS` set, the script will automate this for you. +For more information on usage, run `./dev/make-distribution.sh --help` -# Specifying the Hadoop Version +## Specifying the Hadoop Version Because HDFS is not protocol-compatible across versions, if you want to read from HDFS, you'll need to build Spark against the specific HDFS version in your environment. You can do this through the `hadoop.version` property. If unset, Spark will build against Hadoop 2.2.0 by default. Note that certain build profiles are required for particular Hadoop versions: @@ -87,72 +84,63 @@ You can enable the `yarn` profile and optionally set the `yarn.version` property Examples: -{% highlight bash %} + # Apache Hadoop 2.2.X + ./build/mvn -Pyarn -Phadoop-2.2 -DskipTests clean package -# Apache Hadoop 2.2.X -mvn -Pyarn -Phadoop-2.2 -DskipTests clean package + # Apache Hadoop 2.3.X + ./build/mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package -# Apache Hadoop 2.3.X -mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package + # Apache Hadoop 2.4.X or 2.5.X + ./build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package -# Apache Hadoop 2.4.X or 2.5.X -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package + # Apache Hadoop 2.6.X + ./build/mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.6.0 -DskipTests clean package -# Apache Hadoop 2.6.X -mvn -Pyarn -Phadoop-2.6 -Dhadoop.version=2.6.0 -DskipTests clean package + # Apache Hadoop 2.7.X and later + ./build/mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=2.7.0 -DskipTests clean package -# Apache Hadoop 2.7.X and later -mvn -Pyarn -Phadoop-2.7 -Dhadoop.version=VERSION -DskipTests clean package + # Different versions of HDFS and YARN. + ./build/mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests clean package -# Different versions of HDFS and YARN. -mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=2.2.0 -DskipTests clean package -{% endhighlight %} +## Building With Hive and JDBC Support -# Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, add the `-Phive` and `Phive-thriftserver` profiles to your existing build options. By default Spark will build with Hive 1.2.1 bindings. -{% highlight bash %} -# Apache Hadoop 2.4.X with Hive 1.2.1 support -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package -{% endhighlight %} - -# Building for Scala 2.10 -To produce a Spark package compiled with Scala 2.10, use the `-Dscala-2.10` property: - - ./dev/change-scala-version.sh 2.10 - mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package - -# Spark Tests in Maven -Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). + # Apache Hadoop 2.4.X with Hive 1.2.1 support + ./build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package -Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: +## Packaging without Hadoop Dependencies for YARN - mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-thriftserver clean package - mvn -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test +The assembly directory produced by `mvn package` will, by default, include all of Spark's +dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this +causes multiple versions of these to appear on executor classpaths: the version packaged in +the Spark assembly and the version on each node, included with `yarn.application.classpath`. +The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, +like ZooKeeper and Hadoop itself. -The ScalaTest plugin also supports running only a specific test suite as follows: +## Building for Scala 2.10 +To produce a Spark package compiled with Scala 2.10, use the `-Dscala-2.10` property: - mvn -Dhadoop.version=... -DwildcardSuites=org.apache.spark.repl.ReplSuite test + ./dev/change-scala-version.sh 2.10 + ./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package -# Building submodules individually +## Building submodules individually It's possible to build Spark sub-modules using the `mvn -pl` option. For instance, you can build the Spark Streaming module using: -{% highlight bash %} -mvn -pl :spark-streaming_2.11 clean install -{% endhighlight %} + ./build/mvn -pl :spark-streaming_2.11 clean install where `spark-streaming_2.11` is the `artifactId` as defined in `streaming/pom.xml` file. -# Continuous Compilation +## Continuous Compilation We use the scala-maven-plugin which supports incremental and continuous compilation. E.g. - mvn scala:cc + ./build/mvn scala:cc should run continuous compilation (i.e. wait for changes). However, this has not been tested extensively. A couple of gotchas to note: @@ -167,34 +155,25 @@ the `spark-parent` module). Thus, the full flow for running continuous-compilation of the `core` submodule may look more like: - $ mvn install + $ ./build/mvn install $ cd core - $ mvn scala:cc - -# Building Spark with IntelliJ IDEA or Eclipse - -For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the -[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IDESetup). + $ ../build/mvn scala:cc -# Running Java 8 Test Suites +## Speeding up Compilation with Zinc -Running only Java 8 tests and nothing else. - - mvn install -DskipTests - mvn -pl :java8-tests_2.11 test - -or - - sbt java8-tests/test - -Java 8 tests are automatically enabled when a Java 8 JDK is detected. -If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. - -# Packaging without Hadoop Dependencies for YARN +[Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental +compiler. When run locally as a background process, it speeds up builds of Scala-based projects +like Spark. Developers who regularly recompile Spark with Maven will be the most interested in +Zinc. The project site gives instructions for building and running `zinc`; OS X users can +install it using `brew install zinc`. -The assembly directory produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +If using the `build/mvn` package `zinc` will automatically be downloaded and leveraged for all +builds. This process will auto-start after the first time `build/mvn` is called and bind to port +3030 unless the `ZINC_PORT` environment variable is set. The `zinc` process can subsequently be +shut down at any time by running `build/zinc-/bin/zinc -shutdown` and will automatically +restart whenever `build/mvn` is called. -# Building with SBT +## Building with SBT Maven is the official build tool recommended for packaging Spark, and is the *build of reference*. But SBT is supported for day-to-day development since it can provide much faster iterative @@ -203,38 +182,111 @@ compilation. More advanced developers may wish to use SBT. The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables can be set to control the SBT build. For example: - build/sbt -Pyarn -Phadoop-2.3 package + ./build/sbt -Pyarn -Phadoop-2.3 package To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt in interactive mode by running `build/sbt`, and then run all build commands at the command prompt. For more recommendations on reducing build time, refer to the [wiki page](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-ReducingBuildTimes). -# Testing with SBT +## Encrypted Filesystems + +When building on an encrypted filesystem (if your home directory is encrypted, for example), then the Spark build might fail with a "Filename too long" error. As a workaround, add the following in the configuration args of the `scala-maven-plugin` in the project `pom.xml`: + + -Xmax-classfile-name + 128 + +and in `project/SparkBuild.scala` add: + + scalacOptions in Compile ++= Seq("-Xmax-classfile-name", "128"), + +to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/pull/2883/files) if you are unsure of where to add these lines. + +## IntelliJ IDEA or Eclipse + +For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troubleshooting, refer to the +[wiki page for IDE setup](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IDESetup). + + +# Running Tests + +Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). + +Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: + + ./build/mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-thriftserver clean package + ./build/mvn -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test + +The ScalaTest plugin also supports running only a specific Scala test suite as follows: + + ./build/mvn -P... -Dtest=none -DwildcardSuites=org.apache.spark.repl.ReplSuite test + ./build/mvn -P... -Dtest=none -DwildcardSuites=org.apache.spark.repl.* test + +or a Java test: + + ./build/mvn test -P... -DwildcardSuites=none -Dtest=org.apache.spark.streaming.JavaAPISuite + +## Testing with SBT Some of the tests require Spark to be packaged first, so always run `build/sbt package` the first time. The following is an example of a correct (build, test) sequence: - build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver package - build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test + ./build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver package + ./build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test To run only a specific test suite as follows: - build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.ReplSuite" + ./build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.ReplSuite" + ./build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.*" To run test suites of a specific sub project as follows: - build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver core/test + ./build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver core/test -# Speeding up Compilation with Zinc +## Running Java 8 Test Suites -[Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental -compiler. When run locally as a background process, it speeds up builds of Scala-based projects -like Spark. Developers who regularly recompile Spark with Maven will be the most interested in -Zinc. The project site gives instructions for building and running `zinc`; OS X users can -install it using `brew install zinc`. +Running only Java 8 tests and nothing else. -If using the `build/mvn` package `zinc` will automatically be downloaded and leveraged for all -builds. This process will auto-start after the first time `build/mvn` is called and bind to port -3030 unless the `ZINC_PORT` environment variable is set. The `zinc` process can subsequently be -shut down at any time by running `build/zinc-/bin/zinc -shutdown` and will automatically -restart whenever `build/mvn` is called. + ./build/mvn install -DskipTests + ./build/mvn -pl :java8-tests_2.11 test + +or + + ./build/sbt java8-tests/test + +Java 8 tests are automatically enabled when a Java 8 JDK is detected. +If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. + +## PySpark Tests with Maven + +If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. + + ./build/mvn -DskipTests clean package -Phive + ./python/run-tests + +The run-tests script also can be limited to a specific Python version or a specific module + + ./python/run-tests --python-executables=python --modules=pyspark-sql + +**Note:** You can also run Python tests with an sbt build, provided you build Spark with Hive support. + +## Running R Tests + +To run the SparkR tests you will need to install the R package `testthat` +(run `install.packages(testthat)` from R shell). You can run just the SparkR tests using +the command: + + ./R/run-tests.sh + +## Running Docker-based Integration Test Suites + +In order to run Docker integration tests, you have to install the `docker` engine on your box. +The instructions for installation can be found at [the Docker site](https://docs.docker.com/engine/installation/). +Once installed, the `docker` service needs to be started, if not already running. +On Linux, this can be done by `sudo service docker start`. + + ./build/mvn install -DskipTests + ./build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 + +or + + ./build/sbt docker-integration-tests/test diff --git a/docs/configuration.md b/docs/configuration.md index 6512e16faf4c1..d9eddf9ba59b3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -123,6 +123,7 @@ of the most common options to set are: Number of cores to use for the driver process, only in cluster mode.
    spark.driver.maxResultSize 1g @@ -217,7 +218,7 @@ Apart from these, the following properties are also available, and may be useful
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-class-path command line option or in - your default properties file.
    spark.executor.logs.rolling.enableCompressionfalse + Enable executor log compression. If it is enabled, the rolled executor logs will be compressed. + Disabled by default. +
    spark.executor.logs.rolling.maxSize (none) - Set the max size of the file by which the executor logs will be rolled over. + Set the max size of the file in bytes by which the executor logs will be rolled over. Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs.
    spark.files + Comma-separated list of files to be placed in the working directory of each executor. +
    spark.submit.pyFiles + Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps. +
    spark.jars + Comma-separated list of local jars to include on the driver and executor classpaths. +
    spark.jars.packages + Comma-separated list of maven coordinates of jars to include on the driver and executor + classpaths. Will search the local maven repo, then maven central and any additional remote + repositories given by spark.jars.ivy. The format for the coordinates should be + groupId:artifactId:version. +
    spark.jars.excludes + Comma-separated list of groupId:artifactId, to exclude while resolving the dependencies + provided in spark.jars.packages to avoid dependency conflicts. +
    spark.jars.ivy + Comma-separated list of additional remote repositories to search for the coordinates given + with spark.jars.packages. +
    #### Shuffle Behavior @@ -550,6 +606,14 @@ Apart from these, the following properties are also available, and may be useful collecting. + + spark.ui.retainedTasks + 100000 + + How many tasks the Spark UI and status APIs remember before garbage + collecting. + + spark.worker.ui.retainedExecutors 1000 @@ -731,14 +795,15 @@ Apart from these, the following properties are also available, and may be useful Property NameDefaultMeaning spark.memory.fraction - 0.75 + 0.6 Fraction of (heap space - 300MB) used for execution and storage. The lower this is, the more frequently spills and cached data eviction occur. The purpose of this config is to set aside memory for internal metadata, user data structures, and imprecise size estimation in the case of sparse, unusually large records. Leaving this at the default value is - recommended. For more detail, see - this description. + recommended. For more detail, including important information about correctly tuning JVM + garbage collection when increasing this value, see + this description. @@ -1126,7 +1191,7 @@ Apart from these, the following properties are also available, and may be useful spark.speculation.quantile 0.75 - Percentage of tasks which must be complete before speculation is enabled for a particular stage. + Fraction of tasks which must be complete before speculation is enabled for a particular stage. @@ -1140,7 +1205,9 @@ Apart from these, the following properties are also available, and may be useful spark.task.maxFailures 4 - Number of individual task failures before giving up on the job. + Number of failures of any particular task before giving up on the job. + The total number of failures spread across different tasks will not cause the job + to fail; a particular task has to fail this number of attempts. Should be greater than or equal to 1. Number of allowed retries = this value - 1. @@ -1154,7 +1221,7 @@ Apart from these, the following properties are also available, and may be useful false Whether to use dynamic resource allocation, which scales the number of executors registered - with this application up and down based on the workload. + with this application up and down based on the workload. For more detail, see the description here.

    @@ -1188,6 +1255,9 @@ Apart from these, the following properties are also available, and may be useful spark.dynamicAllocation.minExecutors Initial number of executors to run if dynamic allocation is enabled. +

    + If `--num-executors` (or `spark.executor.instances`) is set and larger than this value, it will + be used as the initial number of executors. @@ -1231,7 +1301,7 @@ Apart from these, the following properties are also available, and may be useful spark.acls.enable false - Whether Spark acls should are enabled. If enabled, this checks to see if the user has + Whether Spark acls should be enabled. If enabled, this checks to see if the user has access permissions to view or modify the job. Note this requires the user to be known, so if the user comes across as null no checks are done. Filters can be used with the UI to authenticate and set the user. @@ -1243,8 +1313,33 @@ Apart from these, the following properties are also available, and may be useful Comma separated list of users/administrators that have view and modify access to all Spark jobs. This can be used if you run on a shared cluster and have a set of administrators or devs who - help debug when things work. Putting a "*" in the list means any user can have the privilege - of admin. + help debug when things do not work. Putting a "*" in the list means any user can have the + privilege of admin. + + + + spark.admin.acls.groups + Empty + + Comma separated list of groups that have view and modify access to all Spark jobs. + This can be used if you have a set of administrators or developers who help maintain and debug + the underlying infrastructure. Putting a "*" in the list means any user in any group can have + the privilege of admin. The user groups are obtained from the instance of the groups mapping + provider specified by spark.user.groups.mapping. Check the entry + spark.user.groups.mapping for more details. + + + + spark.user.groups.mapping + org.apache.spark.security.ShellBasedGroupsMappingProvider + + The list of groups for a user are determined by a group mapping service defined by the trait + org.apache.spark.security.GroupMappingServiceProvider which can configured by this property. + A default unix shell based implementation is provided org.apache.spark.security.ShellBasedGroupsMappingProvider + which can be specified to resolve a list of groups for a user. + Note: This implementation supports only a Unix/Linux based environment. Windows environment is + currently not supported. However, a new platform/protocol can be supported by implementing + the trait org.apache.spark.security.GroupMappingServiceProvider. @@ -1267,8 +1362,9 @@ Apart from these, the following properties are also available, and may be useful spark.authenticate.enableSaslEncryption false - Enable encrypted communication when authentication is enabled. This option is currently - only supported by the block transfer service. + Enable encrypted communication when authentication is + enabled. This is supported by the block transfer service and the + RPC endpoints. @@ -1305,6 +1401,18 @@ Apart from these, the following properties are also available, and may be useful the list means any user can have access to modify it. + + spark.modify.acls.groups + Empty + + Comma separated list of groups that have modify access to the Spark job. This can be used if you + have a set of administrators or developers from the same team to have access to control the job. + Putting a "*" in the list means any user in any group has the access to modify the Spark job. + The user groups are obtained from the instance of the groups mapping provider specified by + spark.user.groups.mapping. Check the entry spark.user.groups.mapping + for more details. + + spark.ui.filters None @@ -1328,6 +1436,18 @@ Apart from these, the following properties are also available, and may be useful have view access to this Spark job. + + spark.ui.view.acls.groups + Empty + + Comma separated list of groups that have view access to the Spark web ui to view the Spark Job + details. This can be used if you have a set of administrators or developers or users who can + monitor the Spark job submitted. Putting a "*" in the list means any user in any group can view + the Spark job details on the Spark web ui. The user groups are obtained from the instance of the + groups mapping provider specified by spark.user.groups.mapping. Check the entry + spark.user.groups.mapping for more details. + + #### Encryption @@ -1346,8 +1466,10 @@ Apart from these, the following properties are also available, and may be useful the properties must be overwritten in the protocol-specific namespace.

    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for - particular protocol denoted by YYY. Currently YYY can be - only fs for file server.

    + particular protocol denoted by YYY. Example values for YYY + include fs, ui, standalone, and + historyServer. See SSL + Configuration for details on hierarchical SSL configuration for services.

    @@ -1431,6 +1553,48 @@ Apart from these, the following properties are also available, and may be useful +#### Spark SQL +Running the SET -v command will show the entire list of the SQL configuration. + +
    +
    + +{% highlight scala %} +// spark is an existing SparkSession +spark.sql("SET -v").show(numRows = 200, truncate = false) +{% endhighlight %} + +
    + +
    + +{% highlight java %} +// spark is an existing SparkSession +spark.sql("SET -v").show(200, false); +{% endhighlight %} +
    + +
    + +{% highlight python %} +# spark is an existing SparkSession +spark.sql("SET -v").show(n=200, truncate=False) +{% endhighlight %} + +
    + +
    + +{% highlight r %} +sparkR.session() +properties <- sql("SET -v") +showDF(properties, numRows = 200, truncate = FALSE) +{% endhighlight %} + +
    +
    + + #### Spark Streaming diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 9dea9b5904d2d..07b38d9cc9a8f 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -24,7 +24,6 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] [Graph.aggregateMessages]: api/scala/index.html#org.apache.spark.graphx.Graph@aggregateMessages[A]((EdgeContext[VD,ED,A])⇒Unit,(A,A)⇒A,TripletFields)(ClassTag[A]):VertexRDD[A] [EdgeContext]: api/scala/index.html#org.apache.spark.graphx.EdgeContext -[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A] [GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]] [GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]] [RDD Persistence]: programming-guide.html#rdd-persistence @@ -67,23 +66,6 @@ operators (e.g., [subgraph](#structural_operators), [joinVertices](#join_operato [aggregateMessages](#aggregateMessages)) as well as an optimized variant of the [Pregel](#pregel) API. In addition, GraphX includes a growing collection of graph [algorithms](#graph_algorithms) and [builders](#graph_builders) to simplify graph analytics tasks. - -## Migrating from Spark 1.1 - -GraphX in Spark 1.2 contains a few user facing API changes: - -1. To improve performance we have introduced a new version of -[`mapReduceTriplets`][Graph.mapReduceTriplets] called -[`aggregateMessages`][Graph.aggregateMessages] which takes the messages previously returned from -[`mapReduceTriplets`][Graph.mapReduceTriplets] through a callback ([`EdgeContext`][EdgeContext]) -rather than by return value. -We are deprecating [`mapReduceTriplets`][Graph.mapReduceTriplets] and encourage users to consult -the [transition guide](#mrTripletsTransition). - -2. In Spark 1.0 and 1.1, the type signature of [`EdgeRDD`][EdgeRDD] switched from -`EdgeRDD[ED]` to `EdgeRDD[ED, VD]` to enable some caching optimizations. We have since discovered -a more elegant solution and have restored the type signature to the more natural `EdgeRDD[ED]` type. - # Getting Started To get started you first need to import Spark and GraphX into your project, as follows: @@ -132,7 +114,7 @@ var graph: Graph[VertexProperty, String] = null Like RDDs, property graphs are immutable, distributed, and fault-tolerant. Changes to the values or structure of the graph are accomplished by producing a new graph with the desired changes. Note -that substantial parts of the original graph (i.e., unaffected structure, attributes, and indicies) +that substantial parts of the original graph (i.e., unaffected structure, attributes, and indices) are reused in the new graph reducing the cost of this inherently functional data structure. The graph is partitioned across the executors using a range of vertex partitioning heuristics. As with RDDs, each partition of the graph can be recreated on a different machine in the event of a failure. @@ -603,29 +585,7 @@ slightly unreliable and instead opted for more explicit user control. In the following example we use the [`aggregateMessages`][Graph.aggregateMessages] operator to compute the average age of the more senior followers of each user. -{% highlight scala %} -// Import random graph generation library -import org.apache.spark.graphx.util.GraphGenerators -// Create a graph with "age" as the vertex property. Here we use a random graph for simplicity. -val graph: Graph[Double, Int] = - GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble ) -// Compute the number of older followers and their total age -val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)]( - triplet => { // Map Function - if (triplet.srcAttr > triplet.dstAttr) { - // Send message to destination vertex containing counter and age - triplet.sendToDst(1, triplet.srcAttr) - } - }, - // Add counter and age - (a, b) => (a._1 + b._1, a._2 + b._2) // Reduce Function -) -// Divide total age by number of older followers to get average age of older followers -val avgAgeOfOlderFollowers: VertexRDD[Double] = - olderFollowers.mapValues( (id, value) => value match { case (count, totalAge) => totalAge / count } ) -// Display the results -avgAgeOfOlderFollowers.collect.foreach(println(_)) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/graphx/AggregateMessagesExample.scala %} > The `aggregateMessages` operation performs optimally when the messages (and the sums of > messages) are constant sized (e.g., floats and addition instead of lists and concatenation). @@ -635,7 +595,7 @@ avgAgeOfOlderFollowers.collect.foreach(println(_)) ### Map Reduce Triplets Transition Guide (Legacy) In earlier versions of GraphX neighborhood aggregation was accomplished using the -[`mapReduceTriplets`][Graph.mapReduceTriplets] operator: +`mapReduceTriplets` operator: {% highlight scala %} class Graph[VD, ED] { @@ -646,7 +606,7 @@ class Graph[VD, ED] { } {% endhighlight %} -The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which +The `mapReduceTriplets` operator takes a user defined map function which is applied to each triplet and can yield *messages* which are aggregated using the user defined `reduce` function. However, we found the user of the returned iterator to be expensive and it inhibited our ability to @@ -793,29 +753,7 @@ second argument list contains the user defined functions for receiving messages We can use the Pregel operator to express computation such as single source shortest path in the following example. -{% highlight scala %} -import org.apache.spark.graphx._ -// Import random graph generation library -import org.apache.spark.graphx.util.GraphGenerators -// A graph with edge attributes containing distances -val graph: Graph[Long, Double] = - GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble) -val sourceId: VertexId = 42 // The ultimate source -// Initialize the graph such that all vertices except the root have distance infinity. -val initialGraph = graph.mapVertices((id, _) => if (id == sourceId) 0.0 else Double.PositiveInfinity) -val sssp = initialGraph.pregel(Double.PositiveInfinity)( - (id, dist, newDist) => math.min(dist, newDist), // Vertex Program - triplet => { // Send Message - if (triplet.srcAttr + triplet.attr < triplet.dstAttr) { - Iterator((triplet.dstId, triplet.srcAttr + triplet.attr)) - } else { - Iterator.empty - } - }, - (a,b) => math.min(a,b) // Merge Message - ) -println(sssp.vertices.collect.mkString("\n")) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/graphx/SSSPExample.scala %} @@ -1007,66 +945,21 @@ PageRank measures the importance of each vertex in a graph, assuming an edge fro GraphX comes with static and dynamic implementations of PageRank as methods on the [`PageRank` object][PageRank]. Static PageRank runs for a fixed number of iterations, while dynamic PageRank runs until the ranks converge (i.e., stop changing by more than a specified tolerance). [`GraphOps`][GraphOps] allows calling these algorithms directly as methods on `Graph`. -GraphX also includes an example social network dataset that we can run PageRank on. A set of users is given in `graphx/data/users.txt`, and a set of relationships between users is given in `graphx/data/followers.txt`. We compute the PageRank of each user as follows: +GraphX also includes an example social network dataset that we can run PageRank on. A set of users is given in `data/graphx/users.txt`, and a set of relationships between users is given in `data/graphx/followers.txt`. We compute the PageRank of each user as follows: -{% highlight scala %} -// Load the edges as a graph -val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt") -// Run PageRank -val ranks = graph.pageRank(0.0001).vertices -// Join the ranks with the usernames -val users = sc.textFile("graphx/data/users.txt").map { line => - val fields = line.split(",") - (fields(0).toLong, fields(1)) -} -val ranksByUsername = users.join(ranks).map { - case (id, (username, rank)) => (username, rank) -} -// Print the result -println(ranksByUsername.collect().mkString("\n")) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/graphx/PageRankExample.scala %} ## Connected Components The connected components algorithm labels each connected component of the graph with the ID of its lowest-numbered vertex. For example, in a social network, connected components can approximate clusters. GraphX contains an implementation of the algorithm in the [`ConnectedComponents` object][ConnectedComponents], and we compute the connected components of the example social network dataset from the [PageRank section](#pagerank) as follows: -{% highlight scala %} -// Load the graph as in the PageRank example -val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt") -// Find the connected components -val cc = graph.connectedComponents().vertices -// Join the connected components with the usernames -val users = sc.textFile("graphx/data/users.txt").map { line => - val fields = line.split(",") - (fields(0).toLong, fields(1)) -} -val ccByUsername = users.join(cc).map { - case (id, (username, cc)) => (username, cc) -} -// Print the result -println(ccByUsername.collect().mkString("\n")) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/graphx/ConnectedComponentsExample.scala %} ## Triangle Counting A vertex is part of a triangle when it has two adjacent vertices with an edge between them. GraphX implements a triangle counting algorithm in the [`TriangleCount` object][TriangleCount] that determines the number of triangles passing through each vertex, providing a measure of clustering. We compute the triangle count of the social network dataset from the [PageRank section](#pagerank). *Note that `TriangleCount` requires the edges to be in canonical orientation (`srcId < dstId`) and the graph to be partitioned using [`Graph.partitionBy`][Graph.partitionBy].* -{% highlight scala %} -// Load the edges in canonical order and partition the graph for triangle count -val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt", true).partitionBy(PartitionStrategy.RandomVertexCut) -// Find the triangle count for each vertex -val triCounts = graph.triangleCount().vertices -// Join the triangle counts with the usernames -val users = sc.textFile("graphx/data/users.txt").map { line => - val fields = line.split(",") - (fields(0).toLong, fields(1)) -} -val triCountByUsername = users.join(triCounts).map { case (id, (username, tc)) => - (username, tc) -} -// Print the result -println(triCountByUsername.collect().mkString("\n")) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/graphx/TriangleCountingExample.scala %} # Examples @@ -1076,36 +969,4 @@ to important relationships and users, run page-rank on the sub-graph, and then finally return attributes associated with the top users. I can do all of this in just a few lines with GraphX: -{% highlight scala %} -// Connect to the Spark cluster -val sc = new SparkContext("spark://master.amplab.org", "research") - -// Load my user data and parse into tuples of user id and attribute list -val users = (sc.textFile("graphx/data/users.txt") - .map(line => line.split(",")).map( parts => (parts.head.toLong, parts.tail) )) - -// Parse the edge data which is already in userId -> userId format -val followerGraph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt") - -// Attach the user attributes -val graph = followerGraph.outerJoinVertices(users) { - case (uid, deg, Some(attrList)) => attrList - // Some users may not have attributes so we set them as empty - case (uid, deg, None) => Array.empty[String] -} - -// Restrict the graph to users with usernames and names -val subgraph = graph.subgraph(vpred = (vid, attr) => attr.size == 2) - -// Compute the PageRank -val pagerankGraph = subgraph.pageRank(0.001) - -// Get the attributes of the top pagerank users -val userInfoWithPageRank = subgraph.outerJoinVertices(pagerankGraph.vertices) { - case (uid, attrList, Some(pr)) => (pr, attrList.toList) - case (uid, attrList, None) => (0.0, attrList.toList) -} - -println(userInfoWithPageRank.vertices.top(5)(Ordering.by(_._2._1)).mkString("\n")) - -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/graphx/ComprehensiveExample.scala %} diff --git a/docs/hardware-provisioning.md b/docs/hardware-provisioning.md index 60ecb4f483afa..bb6f616b18a24 100644 --- a/docs/hardware-provisioning.md +++ b/docs/hardware-provisioning.md @@ -22,7 +22,7 @@ Hadoop and Spark on a common cluster manager like [Mesos](running-on-mesos.html) * If this is not possible, run Spark on different nodes in the same local-area network as HDFS. -* For low-latency data stores like HBase, it may be preferrable to run computing jobs on different +* For low-latency data stores like HBase, it may be preferable to run computing jobs on different nodes than the storage system to avoid interference. # Local Disks diff --git a/docs/img/cluster-overview.png b/docs/img/cluster-overview.png index 317554c5f2a5b..b1b7c1ab5a5fa 100644 Binary files a/docs/img/cluster-overview.png and b/docs/img/cluster-overview.png differ diff --git a/docs/img/edge-cut.png b/docs/img/edge-cut.png deleted file mode 100644 index 698f4ff181e43..0000000000000 Binary files a/docs/img/edge-cut.png and /dev/null differ diff --git a/docs/img/edge_cut_vs_vertex_cut.png b/docs/img/edge_cut_vs_vertex_cut.png index ae30396d3fe19..5b1ed78a42c06 100644 Binary files a/docs/img/edge_cut_vs_vertex_cut.png and b/docs/img/edge_cut_vs_vertex_cut.png differ diff --git a/docs/img/graph_parallel.png b/docs/img/graph_parallel.png deleted file mode 100644 index 330be5567cf97..0000000000000 Binary files a/docs/img/graph_parallel.png and /dev/null differ diff --git a/docs/img/graphx_logo.png b/docs/img/graphx_logo.png index 9869ac148cad5..04eb4d9217cde 100644 Binary files a/docs/img/graphx_logo.png and b/docs/img/graphx_logo.png differ diff --git a/docs/img/graphx_performance_comparison.png b/docs/img/graphx_performance_comparison.png deleted file mode 100644 index 62dcf098c904f..0000000000000 Binary files a/docs/img/graphx_performance_comparison.png and /dev/null differ diff --git a/docs/img/incubator-logo.png b/docs/img/incubator-logo.png deleted file mode 100644 index 33ca7f6227256..0000000000000 Binary files a/docs/img/incubator-logo.png and /dev/null differ diff --git a/docs/img/ml-Pipeline.png b/docs/img/ml-Pipeline.png index 607928906bedd..33e2150b24151 100644 Binary files a/docs/img/ml-Pipeline.png and b/docs/img/ml-Pipeline.png differ diff --git a/docs/img/ml-PipelineModel.png b/docs/img/ml-PipelineModel.png index 9ebc16719d365..4d6f5c27af1c4 100644 Binary files a/docs/img/ml-PipelineModel.png and b/docs/img/ml-PipelineModel.png differ diff --git a/docs/img/property_graph.png b/docs/img/property_graph.png index 6f3f89a010c5e..72e13b91aaa8c 100644 Binary files a/docs/img/property_graph.png and b/docs/img/property_graph.png differ diff --git a/docs/img/spark-logo-100x40px.png b/docs/img/spark-logo-100x40px.png deleted file mode 100644 index 54c3187bbd752..0000000000000 Binary files a/docs/img/spark-logo-100x40px.png and /dev/null differ diff --git a/docs/img/spark-logo-77x40px-hd.png b/docs/img/spark-logo-77x40px-hd.png deleted file mode 100644 index 270402f100354..0000000000000 Binary files a/docs/img/spark-logo-77x40px-hd.png and /dev/null differ diff --git a/docs/img/spark-logo-77x50px-hd.png b/docs/img/spark-logo-77x50px-hd.png deleted file mode 100644 index 6c5f0993c43f0..0000000000000 Binary files a/docs/img/spark-logo-77x50px-hd.png and /dev/null differ diff --git a/docs/img/spark-logo-hd.png b/docs/img/spark-logo-hd.png index 1381e3004dca9..464eda547ba91 100644 Binary files a/docs/img/spark-logo-hd.png and b/docs/img/spark-logo-hd.png differ diff --git a/docs/img/spark-webui-accumulators.png b/docs/img/spark-webui-accumulators.png index 237052d7b5db6..55c50d608b111 100644 Binary files a/docs/img/spark-webui-accumulators.png and b/docs/img/spark-webui-accumulators.png differ diff --git a/docs/img/streaming-arch.png b/docs/img/streaming-arch.png index ac35f1d34cf3d..f86fb6bcbbdd8 100644 Binary files a/docs/img/streaming-arch.png and b/docs/img/streaming-arch.png differ diff --git a/docs/img/streaming-dstream-ops.png b/docs/img/streaming-dstream-ops.png index a1c5634aa3c3a..73084ff1a1f0c 100644 Binary files a/docs/img/streaming-dstream-ops.png and b/docs/img/streaming-dstream-ops.png differ diff --git a/docs/img/streaming-dstream-window.png b/docs/img/streaming-dstream-window.png index 276d2fee5e30e..4db3c97c291f1 100644 Binary files a/docs/img/streaming-dstream-window.png and b/docs/img/streaming-dstream-window.png differ diff --git a/docs/img/streaming-dstream.png b/docs/img/streaming-dstream.png index 90f43b8c7138c..326a3aef0fa6e 100644 Binary files a/docs/img/streaming-dstream.png and b/docs/img/streaming-dstream.png differ diff --git a/docs/img/streaming-flow.png b/docs/img/streaming-flow.png index a870cb9b1839b..bf5dd40d3559e 100644 Binary files a/docs/img/streaming-flow.png and b/docs/img/streaming-flow.png differ diff --git a/docs/img/streaming-kinesis-arch.png b/docs/img/streaming-kinesis-arch.png index bea5fa88df985..4fb026064dadb 100644 Binary files a/docs/img/streaming-kinesis-arch.png and b/docs/img/streaming-kinesis-arch.png differ diff --git a/docs/img/structured-streaming-example-model.png b/docs/img/structured-streaming-example-model.png new file mode 100644 index 0000000000000..d63498fdec4db Binary files /dev/null and b/docs/img/structured-streaming-example-model.png differ diff --git a/docs/img/structured-streaming-late-data.png b/docs/img/structured-streaming-late-data.png new file mode 100644 index 0000000000000..f9389c60552c0 Binary files /dev/null and b/docs/img/structured-streaming-late-data.png differ diff --git a/docs/img/structured-streaming-model.png b/docs/img/structured-streaming-model.png new file mode 100644 index 0000000000000..02becb91ef6a0 Binary files /dev/null and b/docs/img/structured-streaming-model.png differ diff --git a/docs/img/structured-streaming-stream-as-a-table.png b/docs/img/structured-streaming-stream-as-a-table.png new file mode 100644 index 0000000000000..bc63524464099 Binary files /dev/null and b/docs/img/structured-streaming-stream-as-a-table.png differ diff --git a/docs/img/structured-streaming-window.png b/docs/img/structured-streaming-window.png new file mode 100644 index 0000000000000..875716c708685 Binary files /dev/null and b/docs/img/structured-streaming-window.png differ diff --git a/docs/img/structured-streaming.pptx b/docs/img/structured-streaming.pptx new file mode 100644 index 0000000000000..6aad2ed33e924 Binary files /dev/null and b/docs/img/structured-streaming.pptx differ diff --git a/docs/img/triplet.png b/docs/img/triplet.png index 8b82a09bed29f..5d38ccebd3f21 100644 Binary files a/docs/img/triplet.png and b/docs/img/triplet.png differ diff --git a/docs/img/vertex-cut.png b/docs/img/vertex-cut.png deleted file mode 100644 index 0a508dcee99e5..0000000000000 Binary files a/docs/img/vertex-cut.png and /dev/null differ diff --git a/docs/img/vertex_routing_edge_tables.png b/docs/img/vertex_routing_edge_tables.png index 4379becc87ee4..2d1f3808e7217 100644 Binary files a/docs/img/vertex_routing_edge_tables.png and b/docs/img/vertex_routing_edge_tables.png differ diff --git a/docs/index.md b/docs/index.md index 20eab567a50df..a7a92f6c4f6d7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -8,7 +8,7 @@ description: Apache Spark SPARK_VERSION_SHORT documentation homepage Apache Spark is a fast and general-purpose cluster computing system. It provides high-level APIs in Java, Scala, Python and R, and an optimized engine that supports general execution graphs. -It also supports a rich set of higher-level tools including [Spark SQL](sql-programming-guide.html) for SQL and structured data processing, [MLlib](mllib-guide.html) for machine learning, [GraphX](graphx-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html). +It also supports a rich set of higher-level tools including [Spark SQL](sql-programming-guide.html) for SQL and structured data processing, [MLlib](ml-guide.html) for machine learning, [GraphX](graphx-programming-guide.html) for graph processing, and [Spark Streaming](streaming-programming-guide.html). # Downloading @@ -24,8 +24,8 @@ Spark runs on both Windows and UNIX-like systems (e.g. Linux, Mac OS). It's easy locally on one machine --- all you need is to have `java` installed on your system `PATH`, or the `JAVA_HOME` environment variable pointing to a Java installation. -Spark runs on Java 7+, Python 2.6+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} uses -Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version +Spark runs on Java 7+, Python 2.6+/3.4+ and R 3.1+. For the Scala API, Spark {{site.SPARK_VERSION}} +uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version ({{site.SCALA_BINARY_VERSION}}.x). # Running the Examples and Shell @@ -87,7 +87,7 @@ options for deployment: * Modules built on Spark: * [Spark Streaming](streaming-programming-guide.html): processing real-time data streams * [Spark SQL, Datasets, and DataFrames](sql-programming-guide.html): support for structured data and relational queries - * [MLlib](mllib-guide.html): built-in machine learning library + * [MLlib](ml-guide.html): built-in machine learning library * [GraphX](graphx-programming-guide.html): Spark's new API for graph processing **API Docs:** @@ -120,7 +120,7 @@ options for deployment: * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark](building-spark.html): build Spark using the Maven system * [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -* [Supplemental Projects](https://cwiki.apache.org/confluence/display/SPARK/Supplemental+Spark+Projects): related third party Spark projects +* [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects): related third party Spark projects **External Resources:** diff --git a/docs/ml-advanced.md b/docs/ml-advanced.md index 91731d78a2d43..12a03d3c91984 100644 --- a/docs/ml-advanced.md +++ b/docs/ml-advanced.md @@ -1,13 +1,88 @@ --- layout: global -title: Advanced topics - spark.ml -displayTitle: Advanced topics - spark.ml +title: Advanced topics +displayTitle: Advanced topics --- -# Optimization of linear methods +* Table of contents +{:toc} + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +# Optimization of linear methods (developer) + +## Limited-memory BFGS (L-BFGS) +[L-BFGS](http://en.wikipedia.org/wiki/Limited-memory_BFGS) is an optimization +algorithm in the family of quasi-Newton methods to solve the optimization problems of the form +`$\min_{\wv \in\R^d} \; f(\wv)$`. The L-BFGS method approximates the objective function locally as a +quadratic without evaluating the second partial derivatives of the objective function to construct the +Hessian matrix. The Hessian matrix is approximated by previous gradient evaluations, so there is no +vertical scalability issue (the number of training features) unlike computing the Hessian matrix +explicitly in Newton's method. As a result, L-BFGS often achieves faster convergence compared with +other first-order optimizations. -The optimization algorithm underlying the implementation is called [Orthant-Wise Limited-memory -QuasiNewton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) -(OWL-QN). It is an extension of L-BFGS that can effectively handle L1 -regularization and elastic net. +Quasi-Newton](http://research-srv.microsoft.com/en-us/um/people/jfgao/paper/icml07scalable.pdf) +(OWL-QN) is an extension of L-BFGS that can effectively handle L1 and elastic net regularization. + +L-BFGS is used as a solver for [LinearRegression](api/scala/index.html#org.apache.spark.ml.regression.LinearRegression), +[LogisticRegression](api/scala/index.html#org.apache.spark.ml.classification.LogisticRegression), +[AFTSurvivalRegression](api/scala/index.html#org.apache.spark.ml.regression.AFTSurvivalRegression) +and [MultilayerPerceptronClassifier](api/scala/index.html#org.apache.spark.ml.classification.MultilayerPerceptronClassifier). + +MLlib L-BFGS solver calls the corresponding implementation in [breeze](https://github.com/scalanlp/breeze/blob/master/math/src/main/scala/breeze/optimize/LBFGS.scala). + +## Normal equation solver for weighted least squares + +MLlib implements normal equation solver for [weighted least squares](https://en.wikipedia.org/wiki/Least_squares#Weighted_least_squares) by [WeightedLeastSquares]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala). + +Given $n$ weighted observations $(w_i, a_i, b_i)$: + +* $w_i$ the weight of i-th observation +* $a_i$ the features vector of i-th observation +* $b_i$ the label of i-th observation + +The number of features for each observation is $m$. We use the following weighted least squares formulation: +`\[ +minimize_{x}\frac{1}{2} \sum_{i=1}^n \frac{w_i(a_i^T x -b_i)^2}{\sum_{k=1}^n w_k} + \frac{1}{2}\frac{\lambda}{\delta}\sum_{j=1}^m(\sigma_{j} x_{j})^2 +\]` +where $\lambda$ is the regularization parameter, $\delta$ is the population standard deviation of the label +and $\sigma_j$ is the population standard deviation of the j-th feature column. + +This objective function has an analytic solution and it requires only one pass over the data to collect necessary statistics to solve. +Unlike the original dataset which can only be stored in a distributed system, +these statistics can be loaded into memory on a single machine if the number of features is relatively small, and then we can solve the objective function through Cholesky factorization on the driver. + +WeightedLeastSquares only supports L2 regularization and provides options to enable or disable regularization and standardization. +In order to make the normal equation approach efficient, WeightedLeastSquares requires that the number of features be no more than 4096. For larger problems, use L-BFGS instead. + +## Iteratively reweighted least squares (IRLS) + +MLlib implements [iteratively reweighted least squares (IRLS)](https://en.wikipedia.org/wiki/Iteratively_reweighted_least_squares) by [IterativelyReweightedLeastSquares]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala). +It can be used to find the maximum likelihood estimates of a generalized linear model (GLM), find M-estimator in robust regression and other optimization problems. +Refer to [Iteratively Reweighted Least Squares for Maximum Likelihood Estimation, and some Robust and Resistant Alternatives](http://www.jstor.org/stable/2345503) for more information. + +It solves certain optimization problems iteratively through the following procedure: + +* linearize the objective at current solution and update corresponding weight. +* solve a weighted least squares (WLS) problem by WeightedLeastSquares. +* repeat above steps until convergence. + +Since it involves solving a weighted least squares (WLS) problem by WeightedLeastSquares in each iteration, +it also requires the number of features to be no more than 4096. +Currently IRLS is used as the default solver of [GeneralizedLinearRegression](api/scala/index.html#org.apache.spark.ml.regression.GeneralizedLinearRegression). diff --git a/docs/ml-ann.md b/docs/ml-ann.md index c2d9bd200f62f..7c460c4af6f41 100644 --- a/docs/ml-ann.md +++ b/docs/ml-ann.md @@ -1,7 +1,7 @@ --- layout: global -title: Multilayer perceptron classifier - spark.ml -displayTitle: Multilayer perceptron classifier - spark.ml +title: Multilayer perceptron classifier +displayTitle: Multilayer perceptron classifier --- > This section has been moved into the diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index eaf4f6d843368..7c2437eacde3f 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -1,7 +1,7 @@ --- layout: global -title: Classification and regression - spark.ml -displayTitle: Classification and regression - spark.ml +title: Classification and regression +displayTitle: Classification and regression --- @@ -22,37 +22,14 @@ displayTitle: Classification and regression - spark.ml \newcommand{\zero}{\mathbf{0}} \]` +This page covers algorithms for Classification and Regression. It also includes sections +discussing specific classes of algorithms, such as linear methods, trees, and ensembles. + **Table of Contents** * This will become a table of contents (this text will be scraped). {:toc} -In `spark.ml`, we implement popular linear methods such as logistic -regression and linear least squares with $L_1$ or $L_2$ regularization. -Refer to [the linear methods in mllib](mllib-linear-methods.html) for -details about implementation and tuning. We also include a DataFrame API for [Elastic -net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid -of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization -and variable selection via the elastic -net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). -Mathematically, it is defined as a convex combination of the $L_1$ and -the $L_2$ regularization terms: -`\[ -\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 -\]` -By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ -regularization as special cases. For example, if a [linear -regression](https://en.wikipedia.org/wiki/Linear_regression) model is -trained with the elastic net parameter $\alpha$ set to $1$, it is -equivalent to a -[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. -On the other hand, if $\alpha$ is set to $0$, the trained model reduces -to a [ridge -regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. -We implement Pipelines API for both linear regression and logistic -regression with elastic net regularization. - - # Classification ## Logistic regression @@ -62,6 +39,8 @@ For more background and more details about the implementation, refer to the docu > The current implementation of logistic regression in `spark.ml` only supports binary classes. Support for multiclass regression will be added in the future. + > When fitting LogisticRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM. + **Example** The following example shows how to train a logistic regression model @@ -236,9 +215,9 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). MLPC consists of multiple layers of nodes. -Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs -by performing linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. -It can be written in matrix form for MLPC with `$K+1$` layers as follows: +Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs +by a linear combination of the inputs with the node's weights `$\wv$` and bias `$\bv$` and applying an activation function. +This can be written in matrix form for MLPC with `$K+1$` layers as follows: `\[ \mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K) \]` @@ -252,7 +231,7 @@ Nodes in the output layer use softmax function: \]` The number of nodes `$N$` in the output layer corresponds to the number of classes. -MLPC employs backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine. +MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. **Example** @@ -300,13 +279,20 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRe {% include_example java/org/apache/spark/examples/ml/JavaOneVsRestExample.java %} + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.OneVsRest) for more details. + +{% include_example python/ml/one_vs_rest_example.py %} +
    ## Naive Bayes -[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple +[Naive Bayes classifiers](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence -assumptions between the features. The spark.ml implementation currently supports both [multinomial +assumptions between the features. The `spark.ml` implementation currently supports both [multinomial naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). @@ -344,6 +330,8 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classificat The interface for working with linear regression models and model summaries is similar to the logistic regression case. + > When fitting LinearRegressionModel without intercept on dataset with constant nonzero column by "l-bfgs" solver, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM. + **Example** The following @@ -367,6 +355,138 @@ regression model and extracting model summary statistics. +## Generalized linear regression + +Contrasted with linear regression where the output is assumed to follow a Gaussian +distribution, [generalized linear models](https://en.wikipedia.org/wiki/Generalized_linear_model) (GLMs) are specifications of linear models where the response variable $Y_i$ follows some +distribution from the [exponential family of distributions](https://en.wikipedia.org/wiki/Exponential_family). +Spark's `GeneralizedLinearRegression` interface +allows for flexible specification of GLMs which can be used for various types of +prediction problems including linear regression, Poisson regression, logistic regression, and others. +Currently in `spark.ml`, only a subset of the exponential family distributions are supported and they are listed +[below](#available-families). + +**NOTE**: Spark currently only supports up to 4096 features through its `GeneralizedLinearRegression` +interface, and will throw an exception if this constraint is exceeded. See the [advanced section](ml-advanced) for more details. + Still, for linear and logistic regression, models with an increased number of features can be trained + using the `LinearRegression` and `LogisticRegression` estimators. + +GLMs require exponential family distributions that can be written in their "canonical" or "natural" form, aka +[natural exponential family distributions](https://en.wikipedia.org/wiki/Natural_exponential_family). The form of a natural exponential family distribution is given as: + +$$ +f_Y(y|\theta, \tau) = h(y, \tau)\exp{\left( \frac{\theta \cdot y - A(\theta)}{d(\tau)} \right)} +$$ + +where $\theta$ is the parameter of interest and $\tau$ is a dispersion parameter. In a GLM the response variable $Y_i$ is assumed to be drawn from a natural exponential family distribution: + +$$ +Y_i \sim f\left(\cdot|\theta_i, \tau \right) +$$ + +where the parameter of interest $\theta_i$ is related to the expected value of the response variable $\mu_i$ by + +$$ +\mu_i = A'(\theta_i) +$$ + +Here, $A'(\theta_i)$ is defined by the form of the distribution selected. GLMs also allow specification +of a link function, which defines the relationship between the expected value of the response variable $\mu_i$ +and the so called _linear predictor_ $\eta_i$: + +$$ +g(\mu_i) = \eta_i = \vec{x_i}^T \cdot \vec{\beta} +$$ + +Often, the link function is chosen such that $A' = g^{-1}$, which yields a simplified relationship +between the parameter of interest $\theta$ and the linear predictor $\eta$. In this case, the link +function $g(\mu)$ is said to be the "canonical" link function. + +$$ +\theta_i = A'^{-1}(\mu_i) = g(g^{-1}(\eta_i)) = \eta_i +$$ + +A GLM finds the regression coefficients $\vec{\beta}$ which maximize the likelihood function. + +$$ +\max_{\vec{\beta}} \mathcal{L}(\vec{\theta}|\vec{y},X) = +\prod_{i=1}^{N} h(y_i, \tau) \exp{\left(\frac{y_i\theta_i - A(\theta_i)}{d(\tau)}\right)} +$$ + +where the parameter of interest $\theta_i$ is related to the regression coefficients $\vec{\beta}$ +by + +$$ +\theta_i = A'^{-1}(g^{-1}(\vec{x_i} \cdot \vec{\beta})) +$$ + +Spark's generalized linear regression interface also provides summary statistics for diagnosing the +fit of GLM models, including residuals, p-values, deviances, the Akaike information criterion, and +others. + +[See here](http://data.princeton.edu/wws509/notes/) for a more comprehensive review of GLMs and their applications. + +### Available families + +
    Property NameDefaultMeaning
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    FamilyResponse TypeSupported Links
    GaussianContinuousIdentity*, Log, Inverse
    BinomialBinaryLogit*, Probit, CLogLog
    PoissonCountLog*, Identity, Sqrt
    GammaContinuousInverse*, Idenity, Log
    * Canonical Link
    + +**Example** + +The following example demonstrates training a GLM with a Gaussian response and identity link +function and extracting model summary statistics. + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GeneralizedLinearRegression) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GeneralizedLinearRegression.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java %} +
    + +
    +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GeneralizedLinearRegression) for more details. + +{% include_example python/ml/generalized_linear_regression_example.py %} +
    + +
    + ## Decision tree regression @@ -475,11 +595,11 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression. In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) model which is a parametric survival regression model for censored data. -It describes a model for the log of survival time, so it's often called -log-linear model for survival analysis. Different from +It describes a model for the log of survival time, so it's often called a +log-linear model for survival analysis. Different from a [Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model -designed for the same purpose, the AFT model is more easily to parallelize -because each instance contribute to the objective function independently. +designed for the same purpose, the AFT model is easier to parallelize +because each instance contributes to the objective function independently. Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of subjects i = 1, ..., n, with possible right-censoring, @@ -494,10 +614,10 @@ assumes the form: \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] \]` Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, -and $f_{0}(\epsilon_{i})$ is corresponding density function. +and $f_{0}(\epsilon_{i})$ is the corresponding density function. The most commonly used AFT model is based on the Weibull distribution of the survival time. -The Weibull distribution for lifetime corresponding to extreme value distribution for +The Weibull distribution for lifetime corresponds to the extreme value distribution for the log of the lifetime, and the $S_{0}(\epsilon)$ function is: `\[ S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) @@ -506,7 +626,7 @@ the $f_{0}(\epsilon_{i})$ function is: `\[ f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) \]` -The log-likelihood function for AFT model with Weibull distribution of lifetime is: +The log-likelihood function for AFT model with a Weibull distribution of lifetime is: `\[ \iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] \]` @@ -522,11 +642,13 @@ The gradient functions for $\beta$ and $\log\sigma$ respectively are: The AFT model can be formulated as a convex optimization problem, i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ -that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +that depends on the coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. The optimization algorithm underlying the implementation is L-BFGS. The implementation matches the result from R's survival function [survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) + > When fitting AFTSurvivalRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is different from R survival::survreg. + **Example**
    @@ -546,6 +668,103 @@ The implementation matches the result from R's survival function
    +## Isotonic regression +[Isotonic regression](http://en.wikipedia.org/wiki/Isotonic_regression) +belongs to the family of regression algorithms. Formally isotonic regression is a problem where +given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses +and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted +finding a function that minimises + +`\begin{equation} + f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2 +\end{equation}` + +with respect to complete order subject to +`$x_1\le x_2\le ...\le x_n$` where `$w_i$` are positive weights. +The resulting function is called isotonic regression and it is unique. +It can be viewed as least squares problem under order restriction. +Essentially isotonic regression is a +[monotonic function](http://en.wikipedia.org/wiki/Monotonic_function) +best fitting the original data points. + +We implement a +[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) +which uses an approach to +[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). +The training input is a DataFrame which contains three columns +label, features and weight. Additionally IsotonicRegression algorithm has one +optional parameter called $isotonic$ defaulting to true. +This argument specifies if the isotonic regression is +isotonic (monotonically increasing) or antitonic (monotonically decreasing). + +Training returns an IsotonicRegressionModel that can be used to predict +labels for both known and unknown features. The result of isotonic regression +is treated as piecewise linear function. The rules for prediction therefore are: + +* If the prediction input exactly matches a training feature + then associated prediction is returned. In case there are multiple predictions with the same + feature then one of them is returned. Which one is undefined + (same as java.util.Arrays.binarySearch). +* If the prediction input is lower or higher than all training features + then prediction with lowest or highest feature is returned respectively. + In case there are multiple predictions with the same feature + then the lowest or highest is returned respectively. +* If the prediction input falls between two training features then prediction is treated + as piecewise linear function and interpolated value is calculated from the + predictions of the two closest features. In case there are multiple values + with the same feature then the same rules as in previous point are used. + +### Examples + +
    +
    + +Refer to the [`IsotonicRegression` Scala docs](api/scala/index.html#org.apache.spark.ml.regression.IsotonicRegression) for details on the API. + +{% include_example scala/org/apache/spark/examples/ml/IsotonicRegressionExample.scala %} +
    +
    + +Refer to the [`IsotonicRegression` Java docs](api/java/org/apache/spark/ml/regression/IsotonicRegression.html) for details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaIsotonicRegressionExample.java %} +
    +
    + +Refer to the [`IsotonicRegression` Python docs](api/python/pyspark.ml.html#pyspark.ml.regression.IsotonicRegression) for more details on the API. + +{% include_example python/ml/isotonic_regression_example.py %} +
    +
    + +# Linear methods + +We implement popular linear methods such as logistic +regression and linear least squares with $L_1$ or $L_2$ regularization. +Refer to [the linear methods guide for the RDD-based API](mllib-linear-methods.html) for +details about implementation and tuning; this information is still relevant. + +We also include a DataFrame API for [Elastic +net](http://en.wikipedia.org/wiki/Elastic_net_regularization), a hybrid +of $L_1$ and $L_2$ regularization proposed in [Zou et al, Regularization +and variable selection via the elastic +net](http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf). +Mathematically, it is defined as a convex combination of the $L_1$ and +the $L_2$ regularization terms: +`\[ +\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0 +\]` +By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$ +regularization as special cases. For example, if a [linear +regression](https://en.wikipedia.org/wiki/Linear_regression) model is +trained with the elastic net parameter $\alpha$ set to $1$, it is +equivalent to a +[Lasso](http://en.wikipedia.org/wiki/Least_squares#Lasso_method) model. +On the other hand, if $\alpha$ is set to $0$, the trained model reduces +to a [ridge +regression](http://en.wikipedia.org/wiki/Tikhonov_regularization) model. +We implement Pipelines API for both linear regression and logistic +regression with elastic net regularization. # Decision trees @@ -670,7 +889,7 @@ The main differences between this API and the [original MLlib ensembles API](mll ## Random Forests [Random forests](http://en.wikipedia.org/wiki/Random_forest) -are ensembles of [decision trees](ml-decision-tree.html). +are ensembles of [decision trees](ml-classification-regression.html#decision-trees). Random forests combine many decision trees in order to reduce the risk of overfitting. The `spark.ml` implementation supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. @@ -751,7 +970,7 @@ All output columns are optional; to exclude an output column, set its correspond ## Gradient-Boosted Trees (GBTs) [Gradient-Boosted Trees (GBTs)](http://en.wikipedia.org/wiki/Gradient_boosting) -are ensembles of [decision trees](ml-decision-tree.html). +are ensembles of [decision trees](ml-classification-regression.html#decision-trees). GBTs iteratively train decision trees in order to minimize a loss function. The `spark.ml` implementation supports GBTs for binary classification and for regression, using both continuous and categorical features. diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index 440c455cd077c..8a0a61cb595e7 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -1,10 +1,12 @@ --- layout: global -title: Clustering - spark.ml -displayTitle: Clustering - spark.ml +title: Clustering +displayTitle: Clustering --- -In this section, we introduce the pipeline API for [clustering in mllib](mllib-clustering.html). +This page describes clustering algorithms in MLlib. +The [guide for clustering in the RDD-based API](mllib-clustering.html) also has relevant information +about these algorithms. **Table of Contents** @@ -79,13 +81,17 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html {% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %} - +
    +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.KMeans) for more details. +{% include_example python/ml/kmeans_example.py %} +
    + ## Latent Dirichlet allocation (LDA) `LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, -and generates a `LDAModel` as the base models. Expert users may cast a `LDAModel` generated by +and generates a `LDAModel` as the base model. Expert users may cast a `LDAModel` generated by `EMLDAOptimizer` to a `DistributedLDAModel` if needed.
    @@ -104,4 +110,125 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/LDA.html) f {% include_example java/org/apache/spark/examples/ml/JavaLDAExample.java %}
    - \ No newline at end of file +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.LDA) for more details. + +{% include_example python/ml/lda_example.py %} +
    + + +## Bisecting k-means + +Bisecting k-means is a kind of [hierarchical clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) using a +divisive (or "top-down") approach: all observations start in one cluster, and splits are performed recursively as one +moves down the hierarchy. + +Bisecting K-means can often be much faster than regular K-means, but it will generally produce a different clustering. + +`BisectingKMeans` is implemented as an `Estimator` and generates a `BisectingKMeansModel` as the base model. + +### Example + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.BisectingKMeans) for more details. + +{% include_example scala/org/apache/spark/examples/ml/BisectingKMeansExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/BisectingKMeans.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java %} +
    + +
    +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.BisectingKMeans) for more details. + +{% include_example python/ml/bisecting_k_means_example.py %} +
    +
    + +## Gaussian Mixture Model (GMM) + +A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) +represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions, +each with its own probability. The `spark.ml` implementation uses the +[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) +algorithm to induce the maximum-likelihood model given a set of samples. + +`GaussianMixture` is implemented as an `Estimator` and generates a `GaussianMixtureModel` as the base +model. + +### Input Columns + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    featuresColVector"features"Feature vector
    + +### Output Columns + + + + + + + + + + + + + + + + + + + + + + + + +
    Param nameType(s)DefaultDescription
    predictionColInt"prediction"Predicted cluster center
    probabilityColVector"probability"Probability of each cluster
    + +### Example + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.GaussianMixture) for more details. + +{% include_example scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/GaussianMixture.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java %} +
    + +
    +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.GaussianMixture) for more details. + +{% include_example python/ml/gaussian_mixture_example.py %} +
    +
    diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 4514a358e12f2..1d02d6933cb48 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -1,7 +1,7 @@ --- layout: global -title: Collaborative Filtering - spark.ml -displayTitle: Collaborative Filtering - spark.ml +title: Collaborative Filtering +displayTitle: Collaborative Filtering --- * Table of contents @@ -29,6 +29,10 @@ following parameters: *baseline* confidence in preference observations (defaults to 1.0). * *nonnegative* specifies whether or not to use nonnegative constraints for least squares (defaults to `false`). +**Note:** The DataFrame-based API for ALS currently only supports integers for user and item ids. +Other numeric types are supported for the user and item id columns, +but the ids must be within the integer value range. + ### Explicit vs. implicit feedback The standard approach to matrix factorization based collaborative filtering treats @@ -36,7 +40,7 @@ the entries in the user-item matrix as *explicit* preferences given by the user for example, users giving ratings to movies. It is common in many real-world use cases to only have access to *implicit feedback* (e.g. views, -clicks, purchases, likes, shares etc.). The approach used in `spark.mllib` to deal with such data is taken +clicks, purchases, likes, shares etc.). The approach used in `spark.ml` to deal with such data is taken from [Collaborative Filtering for Implicit Feedback Datasets](http://dx.doi.org/10.1109/ICDM.2008.22). Essentially, instead of trying to model the matrix of ratings directly, this approach treats the data as numbers representing the *strength* in observations of user actions (such as the number of clicks, @@ -60,7 +64,7 @@ best parameter learned from a sampled subset to the full dataset and expect simi
    -In the following example, we load rating data from the +In the following example, we load ratings data from the [MovieLens dataset](http://grouplens.org/datasets/movielens/), each row consisting of a user, a movie, a rating and a timestamp. We then train an ALS model which assumes, by default, that the ratings are @@ -91,7 +95,7 @@ val als = new ALS()
    -In the following example, we load rating data from the +In the following example, we load ratings data from the [MovieLens dataset](http://grouplens.org/datasets/movielens/), each row consisting of a user, a movie, a rating and a timestamp. We then train an ALS model which assumes, by default, that the ratings are @@ -122,7 +126,7 @@ ALS als = new ALS()
    -In the following example, we load rating data from the +In the following example, we load ratings data from the [MovieLens dataset](http://grouplens.org/datasets/movielens/), each row consisting of a user, a movie, a rating and a timestamp. We then train an ALS model which assumes, by default, that the ratings are diff --git a/docs/ml-decision-tree.md b/docs/ml-decision-tree.md index a721d55bc675b..5e1eeb95e4724 100644 --- a/docs/ml-decision-tree.md +++ b/docs/ml-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision trees - spark.ml -displayTitle: Decision trees - spark.ml +title: Decision trees +displayTitle: Decision trees --- > This section has been moved into the diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 303773e8038fc..97f1bdc803d01 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -1,7 +1,7 @@ --- layout: global -title: Tree ensemble methods - spark.ml -displayTitle: Tree ensemble methods - spark.ml +title: Tree ensemble methods +displayTitle: Tree ensemble methods --- > This section has been moved into the diff --git a/docs/ml-features.md b/docs/ml-features.md index 0b8f2d773c2eb..6020114845486 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1,7 +1,7 @@ --- layout: global -title: Extracting, transforming and selecting features - spark.ml -displayTitle: Extracting, transforming and selecting features - spark.ml +title: Extracting, transforming and selecting features +displayTitle: Extracting, transforming and selecting features --- This section covers algorithms for working with features, roughly divided into these groups: @@ -18,27 +18,64 @@ This section covers algorithms for working with features, roughly divided into t # Feature Extractors -## TF-IDF (HashingTF and IDF) - -[Term Frequency-Inverse Document Frequency (TF-IDF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) is a common text pre-processing step. In Spark ML, TF-IDF is separate into two parts: TF (+hashing) and IDF. +## TF-IDF + +[Term frequency-inverse document frequency (TF-IDF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) +is a feature vectorization method widely used in text mining to reflect the importance of a term +to a document in the corpus. Denote a term by `$t$`, a document by `$d$`, and the corpus by `$D$`. +Term frequency `$TF(t, d)$` is the number of times that term `$t$` appears in document `$d$`, while +document frequency `$DF(t, D)$` is the number of documents that contains term `$t$`. If we only use +term frequency to measure the importance, it is very easy to over-emphasize terms that appear very +often but carry little information about the document, e.g. "a", "the", and "of". If a term appears +very often across the corpus, it means it doesn't carry special information about a particular document. +Inverse document frequency is a numerical measure of how much information a term provides: +`\[ +IDF(t, D) = \log \frac{|D| + 1}{DF(t, D) + 1}, +\]` +where `$|D|$` is the total number of documents in the corpus. Since logarithm is used, if a term +appears in all documents, its IDF value becomes 0. Note that a smoothing term is applied to avoid +dividing by zero for terms outside the corpus. The TF-IDF measure is simply the product of TF and IDF: +`\[ +TFIDF(t, d, D) = TF(t, d) \cdot IDF(t, D). +\]` +There are several variants on the definition of term frequency and document frequency. +In MLlib, we separate TF and IDF to make them flexible. **TF**: Both `HashingTF` and `CountVectorizer` can be used to generate the term frequency vectors. `HashingTF` is a `Transformer` which takes sets of terms and converts those sets into fixed-length feature vectors. In text processing, a "set of terms" might be a bag of words. -The algorithm combines Term Frequency (TF) counts with the -[hashing trick](http://en.wikipedia.org/wiki/Feature_hashing) for dimensionality reduction. +`HashingTF` utilizes the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing). +A raw feature is mapped into an index (term) by applying a hash function. The hash function +used here is [MurmurHash 3](https://en.wikipedia.org/wiki/MurmurHash). Then term frequencies +are calculated based on the mapped indices. This approach avoids the need to compute a global +term-to-index map, which can be expensive for a large corpus, but it suffers from potential hash +collisions, where different raw features may become the same term after hashing. To reduce the +chance of collision, we can increase the target feature dimension, i.e. the number of buckets +of the hash table. Since a simple modulo is used to transform the hash function to a column index, +it is advisable to use a power of two as the feature dimension, otherwise the features will +not be mapped evenly to the columns. The default feature dimension is `$2^{18} = 262,144$`. +An optional binary toggle parameter controls term frequency counts. When set to true all nonzero +frequency counts are set to 1. This is especially useful for discrete probabilistic models that +model binary, rather than integer, counts. `CountVectorizer` converts text documents to vectors of term counts. Refer to [CountVectorizer ](ml-features.html#countvectorizer) for more details. **IDF**: `IDF` is an `Estimator` which is fit on a dataset and produces an `IDFModel`. The -`IDFModel` takes feature vectors (generally created from `HashingTF` or `CountVectorizer`) and scales each column. -Intuitively, it down-weights columns which appear frequently in a corpus. +`IDFModel` takes feature vectors (generally created from `HashingTF` or `CountVectorizer`) and +scales each column. Intuitively, it down-weights columns which appear frequently in a corpus. -Please refer to the [MLlib user guide on TF-IDF](mllib-feature-extraction.html#tf-idf) for more details on Term Frequency and Inverse Document Frequency. +**Note:** `spark.ml` doesn't provide tools for text segmentation. +We refer users to the [Stanford NLP Group](http://nlp.stanford.edu/) and +[scalanlp/chalk](https://github.com/scalanlp/chalk). -In the following code segment, we start with a set of sentences. We split each sentence into words using `Tokenizer`. For each sentence (bag of words), we use `HashingTF` to hash the sentence into a feature vector. We use `IDF` to rescale the feature vectors; this generally improves performance when using text as features. Our feature vectors could then be passed to a learning algorithm. +**Examples** + +In the following code segment, we start with a set of sentences. We split each sentence into words +using `Tokenizer`. For each sentence (bag of words), we use `HashingTF` to hash the sentence into +a feature vector. We use `IDF` to rescale the feature vectors; this generally improves performance +when using text as features. Our feature vectors could then be passed to a learning algorithm.
    @@ -71,7 +108,7 @@ the [IDF Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.IDF) for mor `Word2Vec` is an `Estimator` which takes sequences of words representing documents and trains a `Word2VecModel`. The model maps each word to a unique fixed-size vector. The `Word2VecModel` transforms each document into a vector using the average of all words in the document; this vector -can then be used for as features for prediction, document similarity calculations, etc. +can then be used as features for prediction, document similarity calculations, etc. Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2vec) for more details. @@ -107,14 +144,16 @@ for more details on the API. `CountVectorizer` and `CountVectorizerModel` aim to help convert a collection of text documents to vectors of token counts. When an a-priori dictionary is not available, `CountVectorizer` can - be used as an `Estimator` to extract the vocabulary and generates a `CountVectorizerModel`. The + be used as an `Estimator` to extract the vocabulary, and generates a `CountVectorizerModel`. The model produces sparse representations for the documents over the vocabulary, which can then be passed to other algorithms like LDA. During the fitting process, `CountVectorizer` will select the top `vocabSize` words ordered by - term frequency across the corpus. An optional parameter "minDF" also affect the fitting process + term frequency across the corpus. An optional parameter `minDF` also affects the fitting process by specifying the minimum number (or fraction if < 1.0) of documents a term must appear in to be - included in the vocabulary. + included in the vocabulary. Another optional binary toggle parameter controls the output vector. + If set to true all nonzero counts are set to 1. This is especially useful for discrete probabilistic + models that model binary, rather than integer, counts. **Examples** @@ -127,9 +166,9 @@ Assume that we have the following DataFrame with columns `id` and `texts`: 1 | Array("a", "b", "b", "c", "a") ~~~~ -each row in`texts` is a document of type Array[String]. -Invoking fit of `CountVectorizer` produces a `CountVectorizerModel` with vocabulary (a, b, c), -then the output column "vector" after transformation contains: +each row in `texts` is a document of type Array[String]. +Invoking fit of `CountVectorizer` produces a `CountVectorizerModel` with vocabulary (a, b, c). +Then the output column "vector" after transformation contains: ~~~~ id | texts | vector @@ -138,7 +177,7 @@ then the output column "vector" after transformation contains: 1 | Array("a", "b", "b", "c", "a") | (3,[0,1,2],[2.0,2.0,1.0]) ~~~~ -each vector represents the token counts of the document over the vocabulary. +Each vector represents the token counts of the document over the vocabulary.
    @@ -177,7 +216,7 @@ for more details on the API. [RegexTokenizer](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer) allows more advanced tokenization based on regular expression (regex) matching. - By default, the parameter "pattern" (regex, default: \\s+) is used as delimiters to split the input text. + By default, the parameter "pattern" (regex, default: `"\\s+"`) is used as delimiters to split the input text. Alternatively, users can set parameter "gaps" to false indicating the regex "pattern" denotes "tokens" rather than splitting gaps, and find all matching occurrences as the tokenization result. @@ -185,7 +224,7 @@ for more details on the API.
    Refer to the [Tokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) -and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Tokenizer) +and the [RegexTokenizer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.RegexTokenizer) for more details on the API. {% include_example scala/org/apache/spark/examples/ml/TokenizerExample.scala %} @@ -218,11 +257,12 @@ frequently and don't carry as much meaning. `StopWordsRemover` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer)) and drops all the stop words from the input sequences. The list of stopwords is specified by -the `stopWords` parameter. We provide [a list of stop -words](http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words) by -default, accessible by calling `getStopWords` on a newly instantiated -`StopWordsRemover` instance. A boolean parameter `caseSensitive` indicates -if the matches should be case sensitive (false by default). +the `stopWords` parameter. Default stop words for some languages are accessible +by calling `StopWordsRemover.loadDefaultStopWords(language)`, for which available +options are "danish", "dutch", "english", "finnish", "french", "german", "hungarian", +"italian", "norwegian", "portuguese", "russian", "spanish", "swedish" and "turkish". +A boolean parameter `caseSensitive` indicates if the matches should be case sensitive +(false by default). **Examples** @@ -313,7 +353,10 @@ for more details on the API. Binarization is the process of thresholding numerical features to binary (0/1) features. -`Binarizer` takes the common parameters `inputCol` and `outputCol`, as well as the `threshold` for binarization. Feature values greater than the threshold are binarized to 1.0; values equal to or less than the threshold are binarized to 0.0. +`Binarizer` takes the common parameters `inputCol` and `outputCol`, as well as the `threshold` +for binarization. Feature values greater than the threshold are binarized to 1.0; values equal +to or less than the threshold are binarized to 0.0. Both Vector and Double types are supported +for `inputCol`.
    @@ -444,8 +487,7 @@ for more details on the API. ## StringIndexer `StringIndexer` encodes a string column of labels to a column of label indices. -The indices are in `[0, numLabels)`, ordered by label frequencies. -So the most frequent label gets index `0`. +The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or `Transformer` make use of this string-indexed label, you must set the input @@ -552,7 +594,7 @@ for more details on the API. ## IndexToString Symmetrically to `StringIndexer`, `IndexToString` maps a column of label indices -back to a column containing the original labels as strings. The common use case +back to a column containing the original labels as strings. A common use case is to produce indices from labels with `StringIndexer`, train a model with those indices and retrieve the original labels from the column of predicted indices with `IndexToString`. However, you are free to supply your own labels. @@ -619,7 +661,7 @@ for more details on the API. ## OneHotEncoder -[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features +[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features.
    @@ -773,9 +815,9 @@ The rescaled value for a feature E is calculated as, `\begin{equation} Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min \end{equation}` -For the case `E_{max} == E_{min}`, `Rescaled(e_i) = 0.5 * (max + min)` +For the case `$E_{max} == E_{min}$`, `$Rescaled(e_i) = 0.5 * (max + min)$` -Note that since zero values will probably be transformed to non-zero values, output of the transformer will be DenseVector even for sparse input. +Note that since zero values will probably be transformed to non-zero values, output of the transformer will be `DenseVector` even for sparse input. The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [0, 1]. @@ -801,6 +843,7 @@ for more details on the API.
    Refer to the [MinMaxScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MinMaxScaler) +and the [MinMaxScalerModel Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MinMaxScalerModel) for more details on the API. {% include_example python/ml/min_max_scaler_example.py %} @@ -841,6 +884,7 @@ for more details on the API.
    Refer to the [MaxAbsScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MaxAbsScaler) +and the [MaxAbsScalerModel Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MaxAbsScalerModel) for more details on the API. {% include_example python/ml/max_abs_scaler_example.py %} @@ -853,7 +897,7 @@ for more details on the API. * `splits`: Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which also includes y. Splits should be strictly increasing. Values at -inf, inf must be explicitly provided to cover all Double values; Otherwise, values outside the splits specified will be treated as errors. Two examples of `splits` are `Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity)` and `Array(0.0, 1.0, 2.0)`. -Note that if you have no idea of the upper bound and lower bound of the targeted column, you would better add the `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potential out of Bucketizer bounds exception. +Note that if you have no idea of the upper and lower bounds of the targeted column, you should add `Double.NegativeInfinity` and `Double.PositiveInfinity` as the bounds of your splits to prevent a potential out of Bucketizer bounds exception. Note also that the splits that you provided have to be in strictly increasing order, i.e. `s0 < s1 < s2 < ... < sn`. @@ -941,7 +985,7 @@ for more details on the API. Currently we only support SQL syntax like `"SELECT ... FROM __THIS__ ..."` where `"__THIS__"` represents the underlying table of the input dataset. The select clause specifies the fields, constants, and expressions to display in -the output, it can be any select clause that Spark SQL supports. Users can also +the output, and can be any select clause that Spark SQL supports. Users can also use Spark SQL built-in function and UDFs to operate on these selected columns. For example, `SQLTransformer` supports statements like: @@ -1058,14 +1102,13 @@ for more details on the API. ## QuantileDiscretizer `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned -categorical features. -The bin ranges are chosen by taking a sample of the data and dividing it into roughly equal parts. -The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values. -This attempts to find `numBuckets` partitions based on a sample of the given input data, but it may -find fewer depending on the data sample values. - -Note that the result may be different every time you run it, since the sample strategy behind it is -non-deterministic. +categorical features. The number of bins is set by the `numBuckets` parameter. +The bin ranges are chosen using an approximate algorithm (see the documentation for +[approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a +detailed description). The precision of the approximation can be controlled with the +`relativeError` parameter. When set to zero, exact quantiles are calculated +(**Note:** Computing exact quantiles is an expensive operation). The lower and upper bin bounds +will be `-Infinity` and `+Infinity` covering all real values. **Examples** @@ -1086,7 +1129,7 @@ Assume that we have a DataFrame with the columns `id`, `hour`: ~~~ `hour` is a continuous feature with `Double` type. We want to turn the continuous feature into -categorical one. Given `numBuckets = 3`, we should get the following DataFrame: +a categorical one. Given `numBuckets = 3`, we should get the following DataFrame: ~~~ id | hour | result @@ -1118,6 +1161,15 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java %}
    + +
    + +Refer to the [QuantileDiscretizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.QuantileDiscretizer) +for more details on the API. + +{% include_example python/ml/quantile_discretizer_example.py %} +
    +
    # Feature Selectors @@ -1127,19 +1179,19 @@ for more details on the API. `VectorSlicer` is a transformer that takes a feature vector and outputs a new feature vector with a sub-array of the original features. It is useful for extracting features from a vector column. -`VectorSlicer` accepts a vector column with a specified indices, then outputs a new vector column +`VectorSlicer` accepts a vector column with specified indices, then outputs a new vector column whose values are selected via those indices. There are two types of indices, - 1. Integer indices that represents the indices into the vector, `setIndices()`; + 1. Integer indices that represent the indices into the vector, `setIndices()`. - 2. String indices that represents the names of features into the vector, `setNames()`. + 2. String indices that represent the names of features into the vector, `setNames()`. *This requires the vector column to have an `AttributeGroup` since the implementation matches on the name field of an `Attribute`.* Specification by integer and string are both acceptable. Moreover, you can use integer index and string name simultaneously. At least one feature must be selected. Duplicate features are not allowed, so there can be no overlap between selected indices and names. Note that if names of -features are selected, an exception will be threw out when encountering with empty input attributes. +features are selected, an exception will be thrown if empty input attributes are encountered. The output vector will order features with the selected indices first (in the order given), followed by the selected names (in the order given). @@ -1154,8 +1206,8 @@ Suppose that we have a DataFrame with the column `userFeatures`: [0.0, 10.0, 0.5] ~~~ -`userFeatures` is a vector column that contains three user features. Assuming that the first column -of `userFeatures` are all zeros, so we want to remove it and only the last two columns are selected. +`userFeatures` is a vector column that contains three user features. Assume that the first column +of `userFeatures` are all zeros, so we want to remove it and select only the last two columns. The `VectorSlicer` selects the last two elements with `setIndices(1, 2)` then produces a new vector column named `features`: @@ -1165,7 +1217,7 @@ column named `features`: [0.0, 10.0, 0.5] | [10.0, 0.5] ~~~ -Suppose also that we have a potential input attributes for the `userFeatures`, i.e. +Suppose also that we have potential input attributes for the `userFeatures`, i.e. `["f1", "f2", "f3"]`, then we can use `setNames("f2", "f3")` to select them. ~~~ @@ -1293,8 +1345,8 @@ id | features | clicked 9 | [1.0, 0.0, 15.0, 0.1] | 0.0 ~~~ -If we use `ChiSqSelector` with a `numTopFeatures = 1`, then according to our label `clicked` the -last column in our `features` chosen as the most useful feature: +If we use `ChiSqSelector` with `numTopFeatures = 1`, then according to our label `clicked` the +last column in our `features` is chosen as the most useful feature: ~~~ id | features | clicked | selectedFeatures diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 99167873cd02d..4607ad3ba681a 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -1,323 +1,214 @@ --- layout: global -title: "Overview: estimators, transformers and pipelines - spark.ml" -displayTitle: "Overview: estimators, transformers and pipelines - spark.ml" +title: "MLlib: Main Guide" +displayTitle: "Machine Learning Library (MLlib) Guide" --- +MLlib is Spark's machine learning (ML) library. +Its goal is to make practical machine learning scalable and easy. +At a high level, it provides tools such as: -`\[ -\newcommand{\R}{\mathbb{R}} -\newcommand{\E}{\mathbb{E}} -\newcommand{\x}{\mathbf{x}} -\newcommand{\y}{\mathbf{y}} -\newcommand{\wv}{\mathbf{w}} -\newcommand{\av}{\mathbf{\alpha}} -\newcommand{\bv}{\mathbf{b}} -\newcommand{\N}{\mathbb{N}} -\newcommand{\id}{\mathbf{I}} -\newcommand{\ind}{\mathbf{1}} -\newcommand{\0}{\mathbf{0}} -\newcommand{\unit}{\mathbf{e}} -\newcommand{\one}{\mathbf{1}} -\newcommand{\zero}{\mathbf{0}} -\]` +* ML Algorithms: common learning algorithms such as classification, regression, clustering, and collaborative filtering +* Featurization: feature extraction, transformation, dimensionality reduction, and selection +* Pipelines: tools for constructing, evaluating, and tuning ML Pipelines +* Persistence: saving and load algorithms, models, and Pipelines +* Utilities: linear algebra, statistics, data handling, etc. +# Announcement: DataFrame-based API is primary API -The `spark.ml` package aims to provide a uniform set of high-level APIs built on top of -[DataFrames](sql-programming-guide.html#dataframes) that help users create and tune practical -machine learning pipelines. -See the [algorithm guides](#algorithm-guides) section below for guides on sub-packages of -`spark.ml`, including feature transformers unique to the Pipelines API, ensembles, and more. +**The MLlib RDD-based API is now in maintenance mode.** -**Table of contents** +As of Spark 2.0, the [RDD](programming-guide.html#resilient-distributed-datasets-rdds)-based APIs in the `spark.mllib` package have entered maintenance mode. +The primary Machine Learning API for Spark is now the [DataFrame](sql-programming-guide.html)-based API in the `spark.ml` package. -* This will become a table of contents (this text will be scraped). -{:toc} +*What are the implications?* +* MLlib will still support the RDD-based API in `spark.mllib` with bug fixes. +* MLlib will not add new features to the RDD-based API. +* In the Spark 2.x releases, MLlib will add features to the DataFrames-based API to reach feature parity with the RDD-based API. +* After reaching feature parity (roughly estimated for Spark 2.2), the RDD-based API will be deprecated. +* The RDD-based API is expected to be removed in Spark 3.0. -# Main concepts in Pipelines +*Why is MLlib switching to the DataFrame-based API?* -Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple -algorithms into a single pipeline, or workflow. -This section covers the key concepts introduced by the Spark ML API, where the pipeline concept is -mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. +* DataFrames provide a more user-friendly API than RDDs. The many benefits of DataFrames include Spark Datasources, SQL/DataFrame queries, Tungsten and Catalyst optimizations, and uniform APIs across languages. +* The DataFrame-based API for MLlib provides a uniform API across ML algorithms and across multiple languages. +* DataFrames facilitate practical ML Pipelines, particularly feature transformations. See the [Pipelines guide](ml-pipeline.html) for details. -* **[`DataFrame`](ml-guide.html#dataframe)**: Spark ML uses `DataFrame` from Spark SQL as an ML - dataset, which can hold a variety of data types. - E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. +# Dependencies -* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. -E.g., an ML model is a `Transformer` which transforms `DataFrame` with features into a `DataFrame` with predictions. +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on +[netlib-java](https://github.com/fommil/netlib-java) for optimised numerical processing. +If native libraries[^1] are not available at runtime, you will see a warning message and a pure JVM +implementation will be used instead. -* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. -E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. +Due to licensing issues with runtime proprietary binaries, we do not include `netlib-java`'s native +proxies by default. +To configure `netlib-java` / Breeze to use system optimised binaries, include +`com.github.fommil.netlib:all:1.1.2` (or build Spark with `-Pnetlib-lgpl`) as a dependency of your +project and read the [netlib-java](https://github.com/fommil/netlib-java) documentation for your +platform's additional installation instructions. -* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. -* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. +[^1]: To learn more about the benefits and background of system optimised natives, you may wish to + watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). -## DataFrame +# Migration guide -Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. -Spark ML adopts the `DataFrame` from Spark SQL in order to support a variety of data types. +MLlib is under active development. +The APIs marked `Experimental`/`DeveloperApi` may change in future releases, +and the migration guide below will explain all changes between releases. -`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. -In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. +## From 1.6 to 2.0 -A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. +### Breaking changes -Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." +There were several breaking changes in Spark 2.0, which are outlined below. -## Pipeline components +**Linear algebra classes for DataFrame-based APIs** -### Transformers +Spark's linear algebra dependencies were moved to a new project, `mllib-local` +(see [SPARK-13944](https://issues.apache.org/jira/browse/SPARK-13944)). +As part of this change, the linear algebra classes were copied to a new package, `spark.ml.linalg`. +The DataFrame-based APIs in `spark.ml` now depend on the `spark.ml.linalg` classes, +leading to a few breaking changes, predominantly in various model classes +(see [SPARK-14810](https://issues.apache.org/jira/browse/SPARK-14810) for a full list). -A `Transformer` is an abstraction that includes feature transformers and learned models. -Technically, a `Transformer` implements a method `transform()`, which converts one `DataFrame` into -another, generally by appending one or more columns. -For example: +**Note:** the RDD-based APIs in `spark.mllib` continue to depend on the previous package `spark.mllib.linalg`. -* A feature transformer might take a `DataFrame`, read a column (e.g., text), map it into a new - column (e.g., feature vectors), and output a new `DataFrame` with the mapped column appended. -* A learning model might take a `DataFrame`, read the column containing feature vectors, predict the - label for each feature vector, and output a new `DataFrame` with predicted labels appended as a - column. +_Converting vectors and matrices_ -### Estimators +While most pipeline components support backward compatibility for loading, +some existing `DataFrames` and pipelines in Spark versions prior to 2.0, that contain vector or matrix +columns, may need to be migrated to the new `spark.ml` vector and matrix types. +Utilities for converting `DataFrame` columns from `spark.mllib.linalg` to `spark.ml.linalg` types +(and vice versa) can be found in `spark.mllib.util.MLUtils`. -An `Estimator` abstracts the concept of a learning algorithm or any algorithm that fits or trains on -data. -Technically, an `Estimator` implements a method `fit()`, which accepts a `DataFrame` and produces a -`Model`, which is a `Transformer`. -For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling -`fit()` trains a `LogisticRegressionModel`, which is a `Model` and hence a `Transformer`. - -### Properties of pipeline components - -`Transformer.transform()`s and `Estimator.fit()`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. - -Each instance of a `Transformer` or `Estimator` has a unique ID, which is useful in specifying parameters (discussed below). - -## Pipeline - -In machine learning, it is common to run a sequence of algorithms to process and learn from data. -E.g., a simple text document processing workflow might include several stages: - -* Split each document's text into words. -* Convert each document's words into a numerical feature vector. -* Learn a prediction model using the feature vectors and labels. - -Spark ML represents such a workflow as a `Pipeline`, which consists of a sequence of -`PipelineStage`s (`Transformer`s and `Estimator`s) to be run in a specific order. -We will use this simple workflow as a running example in this section. - -### How it works - -A `Pipeline` is specified as a sequence of stages, and each stage is either a `Transformer` or an `Estimator`. -These stages are run in order, and the input `DataFrame` is transformed as it passes through each stage. -For `Transformer` stages, the `transform()` method is called on the `DataFrame`. -For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the `DataFrame`. - -We illustrate this for the simple text document workflow. The figure below is for the *training time* usage of a `Pipeline`. - -

    - Spark ML Pipeline Example -

    - -Above, the top row represents a `Pipeline` with three stages. -The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). -The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. -The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw text documents and labels. -The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. -The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. -Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. -If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` -method on the `DataFrame` before passing the `DataFrame` to the next stage. - -A `Pipeline` is an `Estimator`. -Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel`, which is a -`Transformer`. -This `PipelineModel` is used at *test time*; the figure below illustrates this usage. - -

    - Spark ML PipelineModel Example -

    - -In the figure above, the `PipelineModel` has the same number of stages as the original `Pipeline`, but all `Estimator`s in the original `Pipeline` have become `Transformer`s. -When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed -through the fitted pipeline in order. -Each stage's `transform()` method updates the dataset and passes it to the next stage. - -`Pipeline`s and `PipelineModel`s help to ensure that training and test data go through identical feature processing steps. - -### Details - -*DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. - -*Runtime checking*: Since `Pipeline`s can operate on `DataFrame`s with varied types, they cannot use -compile-time type checking. -`Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. -This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. - -*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance -`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have -unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) -can be put into the same `Pipeline` since different instances will be created with different IDs. - -## Parameters - -Spark ML `Estimator`s and `Transformer`s use a uniform API for specifying parameters. - -A `Param` is a named parameter with self-contained documentation. -A `ParamMap` is a set of (parameter, value) pairs. - -There are two main ways to pass parameters to an algorithm: - -1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could - call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. - This API resembles the API used in `spark.mllib` package. -2. Pass a `ParamMap` to `fit()` or `transform()`. Any parameters in the `ParamMap` will override parameters previously specified via setter methods. - -Parameters belong to specific instances of `Estimator`s and `Transformer`s. -For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. -This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. - -## Saving and Loading Pipelines - -Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. - -# Code examples - -This section gives code examples illustrating the functionality discussed above. -For more info, please refer to the API documentation -([Scala](api/scala/index.html#org.apache.spark.ml.package), -[Java](api/java/org/apache/spark/ml/package-summary.html), -and [Python](api/python/pyspark.ml.html)). -Some Spark ML algorithms are wrappers for `spark.mllib` algorithms, and the -[MLlib programming guide](mllib-guide.html) has details on specific algorithms. - -## Example: Estimator, Transformer, and Param - -This example covers the concepts of `Estimator`, `Transformer`, and `Param`. +There are also utility methods available for converting single instances of +vectors and matrices. Use the `asML` method on a `mllib.linalg.Vector` / `mllib.linalg.Matrix` +for converting to `ml.linalg` types, and +`mllib.linalg.Vectors.fromML` / `mllib.linalg.Matrices.fromML` +for converting to `mllib.linalg` types.
    +
    -
    -{% include_example scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala %} -
    +{% highlight scala %} +import org.apache.spark.mllib.util.MLUtils -
    -{% include_example java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java %} -
    +// convert DataFrame columns +val convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF) +val convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF) +// convert a single vector or matrix +val mlVec: org.apache.spark.ml.linalg.Vector = mllibVec.asML +val mlMat: org.apache.spark.ml.linalg.Matrix = mllibMat.asML +{% endhighlight %} -
    -{% include_example python/ml/estimator_transformer_param_example.py %} -
    - -
    - -## Example: Pipeline - -This example follows the simple text document `Pipeline` illustrated in the figures above. - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/PipelineExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaPipelineExample.java %} -
    - -
    -{% include_example python/ml/pipeline_example.py %} -
    - -
    - -## Example: model selection via cross-validation - -An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. -`Pipeline`s facilitate model selection by making it easy to tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. - -Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) class, which takes an `Estimator`, a set of `ParamMap`s, and an [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator). -`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets; e.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. -`CrossValidator` iterates through the set of `ParamMap`s. For each `ParamMap`, it trains the given `Estimator` and evaluates it using the given `Evaluator`. - -The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) -for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) -for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overridden by the `setMetricName` -method in each of these evaluators. - -The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. -`CrossValidator` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. - -The following example demonstrates using `CrossValidator` to select from a grid of parameters. -To help construct the parameter grid, we use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. - -Note that cross-validation over a grid of parameters is expensive. -E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds. This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained. -In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common). -In other words, using `CrossValidator` can be very expensive. -However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning. - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %} -
    - -
    -{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %} -
    - -
    - -{% include_example python/ml/cross_validator.py %} -
    - -
    - -## Example: model selection via train validation split -In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. -`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in - case of `CrossValidator`. It is therefore less expensive, - but will not produce as reliable results when the training dataset is not sufficiently large. - -`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter, -and an `Evaluator`. -It begins by splitting the dataset into two parts using `trainRatio` parameter -which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default), -`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. -Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s. -For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`. -The `ParamMap` which produces the best evaluation metric is selected as the best option. -`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. - -
    - -
    -{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %} +Refer to the [`MLUtils` Scala docs](api/scala/index.html#org.apache.spark.mllib.util.MLUtils$) for further detail.
    -{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %} -
    -
    -{% include_example python/ml/train_validation_split.py %} -
    +{% highlight java %} +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.sql.Dataset; + +// convert DataFrame columns +Dataset convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF); +Dataset convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF); +// convert a single vector or matrix +org.apache.spark.ml.linalg.Vector mlVec = mllibVec.asML(); +org.apache.spark.ml.linalg.Matrix mlMat = mllibMat.asML(); +{% endhighlight %} + +Refer to the [`MLUtils` Java docs](api/java/org/apache/spark/mllib/util/MLUtils.html) for further detail. +
    + +
    + +{% highlight python %} +from pyspark.mllib.util import MLUtils + +# convert DataFrame columns +convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF) +convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF) +# convert a single vector or matrix +mlVec = mllibVec.asML() +mlMat = mllibMat.asML() +{% endhighlight %} + +Refer to the [`MLUtils` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) for further detail. +
    +
    + +**Deprecated methods removed** + +Several deprecated methods were removed in the `spark.mllib` and `spark.ml` packages: + +* `setScoreCol` in `ml.evaluation.BinaryClassificationEvaluator` +* `weights` in `LinearRegression` and `LogisticRegression` in `spark.ml` +* `setMaxNumIterations` in `mllib.optimization.LBFGS` (marked as `DeveloperApi`) +* `treeReduce` and `treeAggregate` in `mllib.rdd.RDDFunctions` (these functions are available on `RDD`s directly, and were marked as `DeveloperApi`) +* `defaultStategy` in `mllib.tree.configuration.Strategy` +* `build` in `mllib.tree.Node` +* libsvm loaders for multiclass and load/save labeledData methods in `mllib.util.MLUtils` + +A full list of breaking changes can be found at [SPARK-14810](https://issues.apache.org/jira/browse/SPARK-14810). + +### Deprecations and changes of behavior + +**Deprecations** + +Deprecations in the `spark.mllib` and `spark.ml` packages include: + +* [SPARK-14984](https://issues.apache.org/jira/browse/SPARK-14984): + In `spark.ml.regression.LinearRegressionSummary`, the `model` field has been deprecated. +* [SPARK-13784](https://issues.apache.org/jira/browse/SPARK-13784): + In `spark.ml.regression.RandomForestRegressionModel` and `spark.ml.classification.RandomForestClassificationModel`, + the `numTrees` parameter has been deprecated in favor of `getNumTrees` method. +* [SPARK-13761](https://issues.apache.org/jira/browse/SPARK-13761): + In `spark.ml.param.Params`, the `validateParams` method has been deprecated. + We move all functionality in overridden methods to the corresponding `transformSchema`. +* [SPARK-14829](https://issues.apache.org/jira/browse/SPARK-14829): + In `spark.mllib` package, `LinearRegressionWithSGD`, `LassoWithSGD`, `RidgeRegressionWithSGD` and `LogisticRegressionWithSGD` have been deprecated. + We encourage users to use `spark.ml.regression.LinearRegresson` and `spark.ml.classification.LogisticRegresson`. +* [SPARK-14900](https://issues.apache.org/jira/browse/SPARK-14900): + In `spark.mllib.evaluation.MulticlassMetrics`, the parameters `precision`, `recall` and `fMeasure` have been deprecated in favor of `accuracy`. +* [SPARK-15644](https://issues.apache.org/jira/browse/SPARK-15644): + In `spark.ml.util.MLReader` and `spark.ml.util.MLWriter`, the `context` method has been deprecated in favor of `session`. +* In `spark.ml.feature.ChiSqSelectorModel`, the `setLabelCol` method has been deprecated since it was not used by `ChiSqSelectorModel`. + +**Changes of behavior** + +Changes of behavior in the `spark.mllib` and `spark.ml` packages include: + +* [SPARK-7780](https://issues.apache.org/jira/browse/SPARK-7780): + `spark.mllib.classification.LogisticRegressionWithLBFGS` directly calls `spark.ml.classification.LogisticRegresson` for binary classification now. + This will introduce the following behavior changes for `spark.mllib.classification.LogisticRegressionWithLBFGS`: + * The intercept will not be regularized when training binary classification model with L1/L2 Updater. + * If users set without regularization, training with or without feature scaling will return the same solution by the same convergence rate. +* [SPARK-13429](https://issues.apache.org/jira/browse/SPARK-13429): + In order to provide better and consistent result with `spark.ml.classification.LogisticRegresson`, + the default value of `spark.mllib.classification.LogisticRegressionWithLBFGS`: `convergenceTol` has been changed from 1E-4 to 1E-6. +* [SPARK-12363](https://issues.apache.org/jira/browse/SPARK-12363): + Fix a bug of `PowerIterationClustering` which will likely change its result. +* [SPARK-13048](https://issues.apache.org/jira/browse/SPARK-13048): + `LDA` using the `EM` optimizer will keep the last checkpoint by default, if checkpointing is being used. +* [SPARK-12153](https://issues.apache.org/jira/browse/SPARK-12153): + `Word2Vec` now respects sentence boundaries. Previously, it did not handle them correctly. +* [SPARK-10574](https://issues.apache.org/jira/browse/SPARK-10574): + `HashingTF` uses `MurmurHash3` as default hash algorithm in both `spark.ml` and `spark.mllib`. +* [SPARK-14768](https://issues.apache.org/jira/browse/SPARK-14768): + The `expectedType` argument for PySpark `Param` was removed. +* [SPARK-14931](https://issues.apache.org/jira/browse/SPARK-14931): + Some default `Param` values, which were mismatched between pipelines in Scala and Python, have been changed. +* [SPARK-13600](https://issues.apache.org/jira/browse/SPARK-13600): + `QuantileDiscretizer` now uses `spark.sql.DataFrameStatFunctions.approxQuantile` to find splits (previously used custom sampling logic). + The output buckets will differ for same input data and params. + +## Previous Spark versions + +Earlier migration guides are archived [on this page](ml-migration-guides.html). -
    +--- diff --git a/docs/ml-linear-methods.md b/docs/ml-linear-methods.md index a8754835cab95..eb39173505aed 100644 --- a/docs/ml-linear-methods.md +++ b/docs/ml-linear-methods.md @@ -1,7 +1,7 @@ --- layout: global -title: Linear methods - spark.ml -displayTitle: Linear methods - spark.ml +title: Linear methods +displayTitle: Linear methods --- > This section has been moved into the diff --git a/docs/ml-migration-guides.md b/docs/ml-migration-guides.md new file mode 100644 index 0000000000000..82bf9d7760fb4 --- /dev/null +++ b/docs/ml-migration-guides.md @@ -0,0 +1,159 @@ +--- +layout: global +title: Old Migration Guides - MLlib +displayTitle: Old Migration Guides - MLlib +description: MLlib migration guides from before Spark SPARK_VERSION_SHORT +--- + +The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide). + +## From 1.5 to 1.6 + +There are no breaking API changes in the `spark.mllib` or `spark.ml` packages, but there are +deprecations and changes of behavior. + +Deprecations: + +* [SPARK-11358](https://issues.apache.org/jira/browse/SPARK-11358): + In `spark.mllib.clustering.KMeans`, the `runs` parameter has been deprecated. +* [SPARK-10592](https://issues.apache.org/jira/browse/SPARK-10592): + In `spark.ml.classification.LogisticRegressionModel` and + `spark.ml.regression.LinearRegressionModel`, the `weights` field has been deprecated in favor of + the new name `coefficients`. This helps disambiguate from instance (row) "weights" given to + algorithms. + +Changes of behavior: + +* [SPARK-7770](https://issues.apache.org/jira/browse/SPARK-7770): + `spark.mllib.tree.GradientBoostedTrees`: `validationTol` has changed semantics in 1.6. + Previously, it was a threshold for absolute change in error. Now, it resembles the behavior of + `GradientDescent`'s `convergenceTol`: For large errors, it uses relative error (relative to the + previous error); for small errors (`< 0.01`), it uses absolute error. +* [SPARK-11069](https://issues.apache.org/jira/browse/SPARK-11069): + `spark.ml.feature.RegexTokenizer`: Previously, it did not convert strings to lowercase before + tokenizing. Now, it converts to lowercase by default, with an option not to. This matches the + behavior of the simpler `Tokenizer` transformer. + +## From 1.4 to 1.5 + +In the `spark.mllib` package, there are no breaking API changes but several behavior changes: + +* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): + `RegressionMetrics.explainedVariance` returns the average regression sum of squares. +* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become + sorted. +* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default + convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. + +In the `spark.ml` package, there exists one breaking API change and one behavior change: + +* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed + from `Params.setDefault` due to a + [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). +* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is + added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. + +## From 1.3 to 1.4 + +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: + +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. + +In the `spark.ml` package, several major API changes occurred, including: + +* `Param` and other APIs for specifying parameters +* `uid` unique IDs for Pipeline components +* Reorganization of certain classes + +Since the `spark.ml` API was an alpha component in Spark 1.3, we do not list all changes here. +However, since 1.4 `spark.ml` is no longer an alpha component, we will provide details on any API +changes for future releases. + +## From 1.2 to 1.3 + +In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. + +* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. +* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. +* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: + * The constructor taking arguments was removed in favor of a builder pattern using the default constructor plus parameter setter methods. + * Variable `model` is no longer public. +* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: + * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) + * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. +* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. +* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. + So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. + +In the `spark.ml` package, the main API changes are from Spark SQL. We list the most important changes here: + +* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in `spark.ml` which used to use SchemaRDD now use DataFrame. +* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. +* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. + +Other changes were in `LogisticRegression`: + +* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). +* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. + +## From 1.1 to 1.2 + +The only API changes in MLlib v1.2 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.2: + +1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number +of classes. In MLlib v1.1, this argument was called `numClasses` in Python and +`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. +This `numClasses` parameter is specified either via +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Breaking change)* The API for +[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. +This should generally not affect user code, unless the user manually constructs decision trees +(instead of using the `trainClassifier` or `trainRegressor` methods). +The tree `Node` now includes more information, including the probability of the predicted label +(for classification). + +3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. + +Examples in the Spark distribution and examples in the +[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. + +## From 1.0 to 1.1 + +The only API changes in MLlib v1.1 are in +[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +which continues to be an experimental API in MLlib 1.1: + +1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match +the implementations of trees in +[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) +and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). +In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. +In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. +This depth is specified by the `maxDepth` parameter in +[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) +or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) +static `trainClassifier` and `trainRegressor` methods. + +2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` +methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), +rather than using the old parameter class `Strategy`. These new training methods explicitly +separate classification and regression, and they replace specialized parameter types with +simple `String` types. + +Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the +[Decision Trees Guide](mllib-decision-tree.html#examples). + +## From 0.9 to 1.0 + +In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few +breaking changes. If your data is sparse, please store it in a sparse format instead of dense to +take advantage of sparsity in both storage and computation. Details are described below. + diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md new file mode 100644 index 0000000000000..adb057ba7e250 --- /dev/null +++ b/docs/ml-pipeline.md @@ -0,0 +1,245 @@ +--- +layout: global +title: ML Pipelines +displayTitle: ML Pipelines +--- + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +In this section, we introduce the concept of ***ML Pipelines***. +ML Pipelines provide a uniform set of high-level APIs built on top of +[DataFrames](sql-programming-guide.html) that help users create and tune practical +machine learning pipelines. + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Main concepts in Pipelines + +MLlib standardizes APIs for machine learning algorithms to make it easier to combine multiple +algorithms into a single pipeline, or workflow. +This section covers the key concepts introduced by the Pipelines API, where the pipeline concept is +mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. + +* **[`DataFrame`](ml-guide.html#dataframe)**: This ML API uses `DataFrame` from Spark SQL as an ML + dataset, which can hold a variety of data types. + E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. + +* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. +E.g., an ML model is a `Transformer` which transforms a `DataFrame` with features into a `DataFrame` with predictions. + +* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. +E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. + +* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. + +* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. + +## DataFrame + +Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. +This API adopts the `DataFrame` from Spark SQL in order to support a variety of data types. + +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. + +A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. + +Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label." + +## Pipeline components + +### Transformers + +A `Transformer` is an abstraction that includes feature transformers and learned models. +Technically, a `Transformer` implements a method `transform()`, which converts one `DataFrame` into +another, generally by appending one or more columns. +For example: + +* A feature transformer might take a `DataFrame`, read a column (e.g., text), map it into a new + column (e.g., feature vectors), and output a new `DataFrame` with the mapped column appended. +* A learning model might take a `DataFrame`, read the column containing feature vectors, predict the + label for each feature vector, and output a new `DataFrame` with predicted labels appended as a + column. + +### Estimators + +An `Estimator` abstracts the concept of a learning algorithm or any algorithm that fits or trains on +data. +Technically, an `Estimator` implements a method `fit()`, which accepts a `DataFrame` and produces a +`Model`, which is a `Transformer`. +For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling +`fit()` trains a `LogisticRegressionModel`, which is a `Model` and hence a `Transformer`. + +### Properties of pipeline components + +`Transformer.transform()`s and `Estimator.fit()`s are both stateless. In the future, stateful algorithms may be supported via alternative concepts. + +Each instance of a `Transformer` or `Estimator` has a unique ID, which is useful in specifying parameters (discussed below). + +## Pipeline + +In machine learning, it is common to run a sequence of algorithms to process and learn from data. +E.g., a simple text document processing workflow might include several stages: + +* Split each document's text into words. +* Convert each document's words into a numerical feature vector. +* Learn a prediction model using the feature vectors and labels. + +MLlib represents such a workflow as a `Pipeline`, which consists of a sequence of +`PipelineStage`s (`Transformer`s and `Estimator`s) to be run in a specific order. +We will use this simple workflow as a running example in this section. + +### How it works + +A `Pipeline` is specified as a sequence of stages, and each stage is either a `Transformer` or an `Estimator`. +These stages are run in order, and the input `DataFrame` is transformed as it passes through each stage. +For `Transformer` stages, the `transform()` method is called on the `DataFrame`. +For `Estimator` stages, the `fit()` method is called to produce a `Transformer` (which becomes part of the `PipelineModel`, or fitted `Pipeline`), and that `Transformer`'s `transform()` method is called on the `DataFrame`. + +We illustrate this for the simple text document workflow. The figure below is for the *training time* usage of a `Pipeline`. + +

    + ML Pipeline Example +

    + +Above, the top row represents a `Pipeline` with three stages. +The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red). +The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s. +The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw text documents and labels. +The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. +The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. +Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. +If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +method on the `DataFrame` before passing the `DataFrame` to the next stage. + +A `Pipeline` is an `Estimator`. +Thus, after a `Pipeline`'s `fit()` method runs, it produces a `PipelineModel`, which is a +`Transformer`. +This `PipelineModel` is used at *test time*; the figure below illustrates this usage. + +

    + ML PipelineModel Example +

    + +In the figure above, the `PipelineModel` has the same number of stages as the original `Pipeline`, but all `Estimator`s in the original `Pipeline` have become `Transformer`s. +When the `PipelineModel`'s `transform()` method is called on a test dataset, the data are passed +through the fitted pipeline in order. +Each stage's `transform()` method updates the dataset and passes it to the next stage. + +`Pipeline`s and `PipelineModel`s help to ensure that training and test data go through identical feature processing steps. + +### Details + +*DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order. + +*Runtime checking*: Since `Pipeline`s can operate on `DataFrame`s with varied types, they cannot use +compile-time type checking. +`Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. +This type checking is done using the `DataFrame` *schema*, a description of the data types of columns in the `DataFrame`. + +*Unique Pipeline stages*: A `Pipeline`'s stages should be unique instances. E.g., the same instance +`myHashingTF` should not be inserted into the `Pipeline` twice since `Pipeline` stages must have +unique IDs. However, different instances `myHashingTF1` and `myHashingTF2` (both of type `HashingTF`) +can be put into the same `Pipeline` since different instances will be created with different IDs. + +## Parameters + +MLlib `Estimator`s and `Transformer`s use a uniform API for specifying parameters. + +A `Param` is a named parameter with self-contained documentation. +A `ParamMap` is a set of (parameter, value) pairs. + +There are two main ways to pass parameters to an algorithm: + +1. Set parameters for an instance. E.g., if `lr` is an instance of `LogisticRegression`, one could + call `lr.setMaxIter(10)` to make `lr.fit()` use at most 10 iterations. + This API resembles the API used in `spark.mllib` package. +2. Pass a `ParamMap` to `fit()` or `transform()`. Any parameters in the `ParamMap` will override parameters previously specified via setter methods. + +Parameters belong to specific instances of `Estimator`s and `Transformer`s. +For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. +This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. + +## Saving and Loading Pipelines + +Often times it is worth it to save a model or a pipeline to disk for later use. In Spark 1.6, a model import/export functionality was added to the Pipeline API. Most basic transformers are supported as well as some of the more basic ML models. Please refer to the algorithm's API documentation to see if saving and loading is supported. + +# Code examples + +This section gives code examples illustrating the functionality discussed above. +For more info, please refer to the API documentation +([Scala](api/scala/index.html#org.apache.spark.ml.package), +[Java](api/java/org/apache/spark/ml/package-summary.html), +and [Python](api/python/pyspark.ml.html)). + +## Example: Estimator, Transformer, and Param + +This example covers the concepts of `Estimator`, `Transformer`, and `Param`. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java %} +
    + +
    +{% include_example python/ml/estimator_transformer_param_example.py %} +
    + +
    + +## Example: Pipeline + +This example follows the simple text document `Pipeline` illustrated in the figures above. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/PipelineExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaPipelineExample.java %} +
    + +
    +{% include_example python/ml/pipeline_example.py %} +
    + +
    + +## Model selection (hyperparameter tuning) + +A big benefit of using ML Pipelines is hyperparameter optimization. See the [ML Tuning Guide](ml-tuning.html) for more information on automatic model selection. diff --git a/docs/ml-survival-regression.md b/docs/ml-survival-regression.md index 856ceb2f4e7f6..efa3c21c7ca1b 100644 --- a/docs/ml-survival-regression.md +++ b/docs/ml-survival-regression.md @@ -1,7 +1,7 @@ --- layout: global -title: Survival Regression - spark.ml -displayTitle: Survival Regression - spark.ml +title: Survival Regression +displayTitle: Survival Regression --- > This section has been moved into the diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md new file mode 100644 index 0000000000000..2ca90c7092fd3 --- /dev/null +++ b/docs/ml-tuning.md @@ -0,0 +1,121 @@ +--- +layout: global +title: "ML Tuning" +displayTitle: "ML Tuning: model selection and hyperparameter tuning" +--- + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +This section describes how to use MLlib's tooling for tuning ML algorithms and Pipelines. +Built-in Cross-Validation and other tooling allow users to optimize hyperparameters in algorithms and Pipelines. + +**Table of contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Model selection (a.k.a. hyperparameter tuning) + +An important task in ML is *model selection*, or using data to find the best model or parameters for a given task. This is also called *tuning*. +Tuning may be done for individual `Estimator`s such as `LogisticRegression`, or for entire `Pipeline`s which include multiple algorithms, featurization, and other steps. Users can tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately. + +MLlib supports model selection using tools such as [`CrossValidator`](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) and [`TrainValidationSplit`](api/scala/index.html#org.apache.spark.ml.tuning.TrainValidationSplit). +These tools require the following items: + +* [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator): algorithm or `Pipeline` to tune +* Set of `ParamMap`s: parameters to choose from, sometimes called a "parameter grid" to search over +* [`Evaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.Evaluator): metric to measure how well a fitted `Model` does on held-out test data + +At a high level, these model selection tools work as follows: + +* They split the input data into separate training and test datasets. +* For each (training, test) pair, they iterate through the set of `ParamMap`s: + * For each `ParamMap`, they fit the `Estimator` using those parameters, get the fitted `Model`, and evaluate the `Model`'s performance using the `Evaluator`. +* They select the `Model` produced by the best-performing set of parameters. + +The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) +for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) +for binary data, or a [`MulticlassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator) +for multiclass problems. The default metric used to choose the best `ParamMap` can be overridden by the `setMetricName` +method in each of these evaluators. + +To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility. + +# Cross-Validation + +`CrossValidator` begins by splitting the dataset into a set of *folds* which are used as separate training and test datasets. E.g., with `$k=3$` folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. To evaluate a particular `ParamMap`, `CrossValidator` computes the average evaluation metric for the 3 `Model`s produced by fitting the `Estimator` on the 3 different (training, test) dataset pairs. + +After identifying the best `ParamMap`, `CrossValidator` finally re-fits the `Estimator` using the best `ParamMap` and the entire dataset. + +## Example: model selection via cross-validation + +The following example demonstrates using `CrossValidator` to select from a grid of parameters. + +Note that cross-validation over a grid of parameters is expensive. +E.g., in the example below, the parameter grid has 3 values for `hashingTF.numFeatures` and 2 values for `lr.regParam`, and `CrossValidator` uses 2 folds. This multiplies out to `$(3 \times 2) \times 2 = 12$` different models being trained. +In realistic settings, it can be common to try many more parameters and use more folds (`$k=3$` and `$k=10$` are common). +In other words, using `CrossValidator` can be very expensive. +However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning. + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %} +
    + +
    + +{% include_example python/ml/cross_validator.py %} +
    + +
    + +# Train-Validation Split + +In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning. +`TrainValidationSplit` only evaluates each combination of parameters once, as opposed to k times in + the case of `CrossValidator`. It is therefore less expensive, + but will not produce as reliable results when the training dataset is not sufficiently large. + +Unlike `CrossValidator`, `TrainValidationSplit` creates a single (training, test) dataset pair. +It splits the dataset into these two parts using the `trainRatio` parameter. For example with `$trainRatio=0.75$`, +`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation. + +Like `CrossValidator`, `TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. + +## Example: model selection via train validation split + +
    + +
    +{% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %} +
    + +
    +{% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %} +
    + +
    +{% include_example python/ml/train_validation_split.py %} +
    + +
    diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index aaf8bd465c9ab..a7b90de09369c 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -1,7 +1,7 @@ --- layout: global -title: Classification and Regression - spark.mllib -displayTitle: Classification and Regression - spark.mllib +title: Classification and Regression - RDD-based API +displayTitle: Classification and Regression - RDD-based API --- The `spark.mllib` package supports various methods for diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 6897ba4a5d57d..d5f6ae379a85e 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -1,7 +1,7 @@ --- layout: global -title: Clustering - spark.mllib -displayTitle: Clustering - spark.mllib +title: Clustering - RDD-based API +displayTitle: Clustering - RDD-based API --- [Clustering](https://en.wikipedia.org/wiki/Cluster_analysis) is an unsupervised learning problem whereby we aim to group subsets @@ -170,10 +170,6 @@ which contains the computed clustering assignments. Refer to the [`PowerIterationClustering` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClustering) and [`PowerIterationClusteringModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClusteringModel) for details on the API. {% include_example scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala %} - -A full example that produces the experiment described in the PIC paper can be found under -[`examples/`](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala). -
    diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 5c33292aaf086..0f891a09a6e61 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -1,7 +1,7 @@ --- layout: global -title: Collaborative Filtering - spark.mllib -displayTitle: Collaborative Filtering - spark.mllib +title: Collaborative Filtering - RDD-based API +displayTitle: Collaborative Filtering - RDD-based API --- * Table of contents diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 5e3ee472a72c3..7dd3c97a83e4d 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -1,7 +1,7 @@ --- layout: global -title: Data Types - MLlib -displayTitle: Data Types - MLlib +title: Data Types - RDD-based API +displayTitle: Data Types - RDD-based API --- * Table of contents @@ -33,7 +33,7 @@ implementations: [`DenseVector`](api/scala/index.html#org.apache.spark.mllib.lin using the factory methods implemented in [`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) to create local vectors. -Refer to the [`Vector` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) and [`Vectors` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors) for details on the API. +Refer to the [`Vector` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) and [`Vectors` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) for details on the API. {% highlight scala %} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -199,7 +199,7 @@ After loading, the feature indices are converted to zero-based. [`MLUtils.loadLibSVMFile`](api/scala/index.html#org.apache.spark.mllib.util.MLUtils$) reads training examples stored in LIBSVM format. -Refer to the [`MLUtils` Scala docs](api/scala/index.html#org.apache.spark.mllib.util.MLUtils) for details on the API. +Refer to the [`MLUtils` Scala docs](api/scala/index.html#org.apache.spark.mllib.util.MLUtils$) for details on the API. {% highlight scala %} import org.apache.spark.mllib.regression.LabeledPoint @@ -264,7 +264,7 @@ We recommend using the factory methods implemented in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local matrices. Remember, local matrices in MLlib are stored in column-major order. -Refer to the [`Matrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix) and [`Matrices` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) for details on the API. +Refer to the [`Matrix` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix) and [`Matrices` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) for details on the API. {% highlight scala %} import org.apache.spark.mllib.linalg.{Matrix, Matrices} @@ -314,12 +314,12 @@ matrices. Remember, local matrices in MLlib are stored in column-major order. Refer to the [`Matrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrix) and [`Matrices` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Matrices) for more details on the API. {% highlight python %} -import org.apache.spark.mllib.linalg.{Matrix, Matrices} +from pyspark.mllib.linalg import Matrix, Matrices -// Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) +# Create a dense matrix ((1.0, 2.0), (3.0, 4.0), (5.0, 6.0)) dm2 = Matrices.dense(3, 2, [1, 2, 3, 4, 5, 6]) -// Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) +# Create a sparse matrix ((9.0, 0.0), (0.0, 8.0), (0.0, 6.0)) sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) {% endhighlight %}
    @@ -331,7 +331,7 @@ sm = Matrices.sparse(3, 2, [0, 1, 3], [0, 2, 1], [9, 6, 8]) A distributed matrix has long-typed row and column indices and double-typed values, stored distributively in one or more RDDs. It is very important to choose the right format to store large and distributed matrices. Converting a distributed matrix to a different format may require a -global shuffle, which is quite expensive. Three types of distributed matrices have been implemented +global shuffle, which is quite expensive. Four types of distributed matrices have been implemented so far. The basic type is called `RowMatrix`. A `RowMatrix` is a row-oriented distributed @@ -344,6 +344,8 @@ An `IndexedRowMatrix` is similar to a `RowMatrix` but with row indices, which can be used for identifying rows and executing joins. A `CoordinateMatrix` is a distributed matrix stored in [coordinate list (COO)](https://en.wikipedia.org/wiki/Sparse_matrix#Coordinate_list_.28COO.29) format, backed by an RDD of its entries. +A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock` +which is a tuple of `(Int, Int, Matrix)`. ***Note*** @@ -535,12 +537,6 @@ rowsRDD = mat.rows # Convert to a RowMatrix by dropping the row indices. rowMat = mat.toRowMatrix() - -# Convert to a CoordinateMatrix. -coordinateMat = mat.toCoordinateMatrix() - -# Convert to a BlockMatrix. -blockMat = mat.toBlockMatrix() {% endhighlight %}
    diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 9af48357b3dfc..0e753b8dd04a2 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -1,7 +1,7 @@ --- layout: global -title: Decision Trees - spark.mllib -displayTitle: Decision Trees - spark.mllib +title: Decision Trees - RDD-based API +displayTitle: Decision Trees - RDD-based API --- * Table of contents @@ -136,7 +136,7 @@ When tuning these parameters, be careful to validate on held-out test data to av * **`maxDepth`**: Maximum depth of a tree. Deeper trees are more expressive (potentially allowing higher accuracy), but they are also more costly to train and are more likely to overfit. -* **`minInstancesPerNode`**: For a node to be split further, each of its children must receive at least this number of training instances. This is commonly used with [RandomForest](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) since those are often trained deeper than individual trees. +* **`minInstancesPerNode`**: For a node to be split further, each of its children must receive at least this number of training instances. This is commonly used with [RandomForest](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest$) since those are often trained deeper than individual trees. * **`minInfoGain`**: For a node to be split further, the split must improve at least this much (in terms of information gain). @@ -152,13 +152,13 @@ These parameters may be tuned. Be careful to validate on held-out test data whe * The default value is conservatively chosen to be 256 MB to allow the decision algorithm to work in most scenarios. Increasing `maxMemoryInMB` can lead to faster training (if the memory is available) by allowing fewer passes over the data. However, there may be decreasing returns as `maxMemoryInMB` grows since the amount of communication on each iteration can be proportional to `maxMemoryInMB`. * *Implementation details*: For faster processing, the decision tree algorithm collects statistics about groups of nodes to split (rather than 1 node at a time). The number of nodes which can be handled in one group is determined by the memory requirements (which vary per features). The `maxMemoryInMB` parameter specifies the memory limit in terms of megabytes which each worker can use for these statistics. -* **`subsamplingRate`**: Fraction of the training data used for learning the decision tree. This parameter is most relevant for training ensembles of trees (using [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) and [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees)), where it can be useful to subsample the original data. For training a single decision tree, this parameter is less useful since the number of training instances is generally not the main constraint. +* **`subsamplingRate`**: Fraction of the training data used for learning the decision tree. This parameter is most relevant for training ensembles of trees (using [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest$) and [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees)), where it can be useful to subsample the original data. For training a single decision tree, this parameter is less useful since the number of training instances is generally not the main constraint. * **`impurity`**: Impurity measure (discussed above) used to choose between candidate splits. This measure must match the `algo` parameter. ### Caching and checkpointing -MLlib 1.2 adds several features for scaling up to larger (deeper) trees and tree ensembles. When `maxDepth` is set to be large, it can be useful to turn on node ID caching and checkpointing. These parameters are also useful for [RandomForest](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) when `numTrees` is set to be large. +MLlib 1.2 adds several features for scaling up to larger (deeper) trees and tree ensembles. When `maxDepth` is set to be large, it can be useful to turn on node ID caching and checkpointing. These parameters are also useful for [RandomForest](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest$) when `numTrees` is set to be large. * **`useNodeIdCache`**: If this is set to true, the algorithm will avoid passing the current model (tree or trees) to executors on each iteration. * This can be useful with deep trees (speeding up computation on workers) and for large Random Forests (reducing communication on each iteration). diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index cceddce9f79a6..539cbc1b3163a 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -1,7 +1,7 @@ --- layout: global -title: Dimensionality Reduction - spark.mllib -displayTitle: Dimensionality Reduction - spark.mllib +title: Dimensionality Reduction - RDD-based API +displayTitle: Dimensionality Reduction - RDD-based API --- * Table of contents diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index 2416b6fa0aeb3..e1984b6c8d5a5 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -1,7 +1,7 @@ --- layout: global -title: Ensembles - spark.mllib -displayTitle: Ensembles - spark.mllib +title: Ensembles - RDD-based API +displayTitle: Ensembles - RDD-based API --- * Table of contents @@ -9,7 +9,7 @@ displayTitle: Ensembles - spark.mllib An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -`spark.mllib` supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest). +`spark.mllib` supports two major ensemble algorithms: [`GradientBoostedTrees`](api/scala/index.html#org.apache.spark.mllib.tree.GradientBoostedTrees) and [`RandomForest`](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest$). Both use [decision trees](mllib-decision-tree.html) as their base models. ## Gradient-Boosted Trees vs. Random Forests @@ -96,7 +96,7 @@ The test error is calculated to measure the algorithm accuracy.
    -Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. +Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest$) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. {% include_example scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala %}
    @@ -127,7 +127,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
    -Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. +Refer to the [`RandomForest` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.RandomForest$) and [`RandomForestModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.RandomForestModel) for details on the API. {% include_example scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala %}
    diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index a269dbf030e7c..ac82f43cfb79d 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -1,7 +1,7 @@ --- layout: global -title: Evaluation Metrics - spark.mllib -displayTitle: Evaluation Metrics - spark.mllib +title: Evaluation Metrics - RDD-based API +displayTitle: Evaluation Metrics - RDD-based API --- * Table of contents @@ -140,7 +140,7 @@ definitions of positive and negative labels is straightforward. #### Label based metrics Opposed to binary classification where there are only two possible labels, multiclass classification problems have many -possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all +possible labels and so the concept of label-based metrics is introduced. Accuracy measures precision across all labels - the number of times any class was predicted correctly (true positives) normalized by the number of data points. Precision by label considers only one class, and measures the number of time a specific label was predicted correctly normalized by the number of times that label appears in the output. @@ -182,20 +182,10 @@ $$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}. - Overall Precision - $PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - - \mathbf{y}_i\right)$ - - - Overall Recall - $TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + Accuracy + $ACC = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - \mathbf{y}_i\right)$ - - Overall F1-measure - $F1 = 2 \cdot \left(\frac{PPV \cdot TPR} - {PPV + TPR}\right)$ - Precision by label $PPV(\ell) = \frac{TP}{TP + FP} = diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 7a97285032655..867be7f2932ed 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -1,7 +1,7 @@ --- layout: global -title: Feature Extraction and Transformation - spark.mllib -displayTitle: Feature Extraction and Transformation - spark.mllib +title: Feature Extraction and Transformation - RDD-based API +displayTitle: Feature Extraction and Transformation - RDD-based API --- * Table of contents @@ -10,6 +10,9 @@ displayTitle: Feature Extraction and Transformation - spark.mllib ## TF-IDF +**Note** We recommend using the DataFrame-based API, which is detailed in the [ML user guide on +TF-IDF](ml-features.html#tf-idf). + [Term frequency-inverse document frequency (TF-IDF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) is a feature vectorization method widely used in text mining to reflect the importance of a term to a document in the corpus. Denote a term by `$t$`, a document by `$d$`, and the corpus by `$D$`. @@ -330,7 +333,7 @@ Details you can read at [dimensionality reduction](mllib-dimensionality-reductio The following code demonstrates how to compute principal components on a `Vector` and use them to project the vectors into a low-dimensional space while keeping associated labels -for calculation a [Linear Regression]((mllib-linear-methods.html)) +for calculation a [Linear Regression](mllib-linear-methods.html)
    diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index a7b55dc5e5668..93e3f0b2d2267 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -1,7 +1,7 @@ --- layout: global -title: Frequent Pattern Mining - spark.mllib -displayTitle: Frequent Pattern Mining - spark.mllib +title: Frequent Pattern Mining - RDD-based API +displayTitle: Frequent Pattern Mining - RDD-based API --- Mining frequent items, itemsets, subsequences, or other substructures is usually among the diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index fa5e90603505d..30112c72c9c31 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -1,32 +1,12 @@ --- layout: global -title: MLlib -displayTitle: Machine Learning Library (MLlib) Guide -description: MLlib machine learning library overview for Spark SPARK_VERSION_SHORT +title: "MLlib: RDD-based API" +displayTitle: "MLlib: RDD-based API" --- -MLlib is Spark's machine learning (ML) library. -Its goal is to make practical machine learning scalable and easy. -It consists of common learning algorithms and utilities, including classification, regression, -clustering, collaborative filtering, dimensionality reduction, as well as lower-level optimization -primitives and higher-level pipeline APIs. - -It divides into two packages: - -* [`spark.mllib`](mllib-guide.html#data-types-algorithms-and-utilities) contains the original API - built on top of [RDDs](programming-guide.html#resilient-distributed-datasets-rdds). -* [`spark.ml`](ml-guide.html) provides higher-level API - built on top of [DataFrames](sql-programming-guide.html#dataframes) for constructing ML pipelines. - -Using `spark.ml` is recommended because with DataFrames the API is more versatile and flexible. -But we will keep supporting `spark.mllib` along with the development of `spark.ml`. -Users should be comfortable using `spark.mllib` features and expect more features coming. -Developers should contribute new algorithms to `spark.ml` if they fit the ML pipeline concept well, -e.g., feature extractors and transformers. - -We list major functionality from both below, with links to detailed guides. - -# spark.mllib: data types, algorithms, and utilities +This page documents sections of the MLlib guide for the RDD-based API (the `spark.mllib` package). +Please see the [MLlib Main Guide](ml-guide.html) for the DataFrame-based API (the `spark.ml` package), +which is now the primary API for MLlib. * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) @@ -65,72 +45,3 @@ We list major functionality from both below, with links to detailed guides. * [stochastic gradient descent](mllib-optimization.html#stochastic-gradient-descent-sgd) * [limited-memory BFGS (L-BFGS)](mllib-optimization.html#limited-memory-bfgs-l-bfgs) -# spark.ml: high-level APIs for ML pipelines - -* [Overview: estimators, transformers and pipelines](ml-guide.html) -* [Extracting, transforming and selecting features](ml-features.html) -* [Classification and regression](ml-classification-regression.html) -* [Clustering](ml-clustering.html) -* [Collaborative filtering](ml-collaborative-filtering.html) -* [Advanced topics](ml-advanced.html) - -Some techniques are not available yet in spark.ml, most notably dimensionality reduction -Users can seamlessly combine the implementation of these techniques found in `spark.mllib` with the rest of the algorithms found in `spark.ml`. - -# Dependencies - -MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on -[netlib-java](https://github.com/fommil/netlib-java) for optimised numerical processing. -If natives libraries[^1] are not available at runtime, you will see a warning message and a pure JVM -implementation will be used instead. - -Due to licensing issues with runtime proprietary binaries, we do not include `netlib-java`'s native -proxies by default. -To configure `netlib-java` / Breeze to use system optimised binaries, include -`com.github.fommil.netlib:all:1.1.2` (or build Spark with `-Pnetlib-lgpl`) as a dependency of your -project and read the [netlib-java](https://github.com/fommil/netlib-java) documentation for your -platform's additional installation instructions. - -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 or newer. - -[^1]: To learn more about the benefits and background of system optimised natives, you may wish to - watch Sam Halliday's ScalaX talk on [High Performance Linear Algebra in Scala](http://fommil.github.io/scalax14/#/). - -# Migration guide - -MLlib is under active development. -The APIs marked `Experimental`/`DeveloperApi` may change in future releases, -and the migration guide below will explain all changes between releases. - -## From 1.5 to 1.6 - -There are no breaking API changes in the `spark.mllib` or `spark.ml` packages, but there are -deprecations and changes of behavior. - -Deprecations: - -* [SPARK-11358](https://issues.apache.org/jira/browse/SPARK-11358): - In `spark.mllib.clustering.KMeans`, the `runs` parameter has been deprecated. -* [SPARK-10592](https://issues.apache.org/jira/browse/SPARK-10592): - In `spark.ml.classification.LogisticRegressionModel` and - `spark.ml.regression.LinearRegressionModel`, the `weights` field has been deprecated in favor of - the new name `coefficients`. This helps disambiguate from instance (row) "weights" given to - algorithms. - -Changes of behavior: - -* [SPARK-7770](https://issues.apache.org/jira/browse/SPARK-7770): - `spark.mllib.tree.GradientBoostedTrees`: `validationTol` has changed semantics in 1.6. - Previously, it was a threshold for absolute change in error. Now, it resembles the behavior of - `GradientDescent`'s `convergenceTol`: For large errors, it uses relative error (relative to the - previous error); for small errors (`< 0.01`), it uses absolute error. -* [SPARK-11069](https://issues.apache.org/jira/browse/SPARK-11069): - `spark.ml.feature.RegexTokenizer`: Previously, it did not convert strings to lowercase before - tokenizing. Now, it converts to lowercase by default, with an option not to. This matches the - behavior of the simpler `Tokenizer` transformer. - -## Previous Spark versions - -Earlier migration guides are archived [on this page](mllib-migration-guides.html). - ---- diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 8ede4407d5843..d90905a86ade9 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -1,7 +1,7 @@ --- layout: global -title: Isotonic regression - spark.mllib -displayTitle: Regression - spark.mllib +title: Isotonic regression - RDD-based API +displayTitle: Regression - RDD-based API --- ## Isotonic regression diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 63665c49bc972..6fcd3ae85700c 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -1,7 +1,7 @@ --- layout: global -title: Linear Methods - spark.mllib -displayTitle: Linear Methods - spark.mllib +title: Linear Methods - RDD-based API +displayTitle: Linear Methods - RDD-based API --- * Table of contents @@ -185,10 +185,10 @@ algorithm for 200 iterations. import org.apache.spark.mllib.optimization.L1Updater val svmAlg = new SVMWithSGD() -svmAlg.optimizer. - setNumIterations(200). - setRegParam(0.1). - setUpdater(new L1Updater) +svmAlg.optimizer + .setNumIterations(200) + .setRegParam(0.1) + .setUpdater(new L1Updater) val modelL1 = svmAlg.run(training) {% endhighlight %} @@ -395,7 +395,7 @@ section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. -###Streaming linear regression +### Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, updating the parameters of the model as new data arrives. `spark.mllib` currently supports diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index f3daef2dbadbe..ea6f93fcf67f3 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -1,132 +1,9 @@ --- layout: global -title: Old Migration Guides - spark.mllib -displayTitle: Old Migration Guides - spark.mllib -description: MLlib migration guides from before Spark SPARK_VERSION_SHORT +title: Old Migration Guides - MLlib +displayTitle: Old Migration Guides - MLlib --- -The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). - -## From 1.4 to 1.5 - -In the `spark.mllib` package, there are no breaking API changes but several behavior changes: - -* [SPARK-9005](https://issues.apache.org/jira/browse/SPARK-9005): - `RegressionMetrics.explainedVariance` returns the average regression sum of squares. -* [SPARK-8600](https://issues.apache.org/jira/browse/SPARK-8600): `NaiveBayesModel.labels` become - sorted. -* [SPARK-3382](https://issues.apache.org/jira/browse/SPARK-3382): `GradientDescent` has a default - convergence tolerance `1e-3`, and hence iterations might end earlier than 1.4. - -In the `spark.ml` package, there exists one breaking API change and one behavior change: - -* [SPARK-9268](https://issues.apache.org/jira/browse/SPARK-9268): Java's varargs support is removed - from `Params.setDefault` due to a - [Scala compiler bug](https://issues.scala-lang.org/browse/SI-9013). -* [SPARK-10097](https://issues.apache.org/jira/browse/SPARK-10097): `Evaluator.isLargerBetter` is - added to indicate metric ordering. Metrics like RMSE no longer flip signs as in 1.4. - -## From 1.3 to 1.4 - -In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: - -* Gradient-Boosted Trees - * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. - * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. -* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. - -In the `spark.ml` package, several major API changes occurred, including: - -* `Param` and other APIs for specifying parameters -* `uid` unique IDs for Pipeline components -* Reorganization of certain classes - -Since the `spark.ml` API was an alpha component in Spark 1.3, we do not list all changes here. -However, since 1.4 `spark.ml` is no longer an alpha component, we will provide details on any API -changes for future releases. - -## From 1.2 to 1.3 - -In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. - -* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. -* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. -* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: - * The constructor taking arguments was removed in favor of a builder pattern using the default constructor plus parameter setter methods. - * Variable `model` is no longer public. -* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: - * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) - * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. -* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. -* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. - So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. - -In the `spark.ml` package, the main API changes are from Spark SQL. We list the most important changes here: - -* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame. -* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`. -* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details. - -Other changes were in `LogisticRegression`: - -* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future). -* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future. - -## From 1.1 to 1.2 - -The only API changes in MLlib v1.2 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.2: - -1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number -of classes. In MLlib v1.1, this argument was called `numClasses` in Python and -`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`. -This `numClasses` parameter is specified either via -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. - -2. *(Breaking change)* The API for -[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed. -This should generally not affect user code, unless the user manually constructs decision trees -(instead of using the `trainClassifier` or `trainRegressor` methods). -The tree `Node` now includes more information, including the probability of the predicted label -(for classification). - -3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`. - -Examples in the Spark distribution and examples in the -[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly. - -## From 1.0 to 1.1 - -The only API changes in MLlib v1.1 are in -[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -which continues to be an experimental API in MLlib 1.1: - -1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match -the implementations of trees in -[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree) -and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html). -In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes. -In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes. -This depth is specified by the `maxDepth` parameter in -[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy) -or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) -static `trainClassifier` and `trainRegressor` methods. - -2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor` -methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree), -rather than using the old parameter class `Strategy`. These new training methods explicitly -separate classification and regression, and they replace specialized parameter types with -simple `String` types. - -Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the -[Decision Trees Guide](mllib-decision-tree.html#examples). - -## From 0.9 to 1.0 - -In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few -breaking changes. If your data is sparse, please store it in a sparse format instead of dense to -take advantage of sparsity in both storage and computation. Details are described below. +The migration guide for the current Spark version is kept on the [MLlib Guide main page](ml-guide.html#migration-guide). +Past migration guides are now stored at [ml-migration-guides.html](ml-migration-guides.html). diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index d0d594af6a4ad..7471d18a0dddc 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -1,7 +1,7 @@ --- layout: global -title: Naive Bayes - spark.mllib -displayTitle: Naive Bayes - spark.mllib +title: Naive Bayes - RDD-based API +displayTitle: Naive Bayes - RDD-based API --- [Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index f90b66f8e2c44..eefd7dcf1108b 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -1,7 +1,7 @@ --- layout: global -title: Optimization - spark.mllib -displayTitle: Optimization - spark.mllib +title: Optimization - RDD-based API +displayTitle: Optimization - RDD-based API --- * Table of contents diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md index 58ed5a0e9d702..d3530908706d0 100644 --- a/docs/mllib-pmml-model-export.md +++ b/docs/mllib-pmml-model-export.md @@ -1,7 +1,7 @@ --- layout: global -title: PMML model export - spark.mllib -displayTitle: PMML model export - spark.mllib +title: PMML model export - RDD-based API +displayTitle: PMML model export - RDD-based API --- * Table of contents @@ -47,7 +47,7 @@ To export a supported `model` (see table above) to PMML, simply call `model.toPM As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats. -Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`Vectors` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors) for details on the API. +Refer to the [`KMeans` Scala docs](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) and [`Vectors` Scala docs](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) for details on the API. Here a complete example of building a KMeansModel and print it out in PMML format: {% include_example scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 02b81f153bf7f..12797bd8688e1 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -1,7 +1,7 @@ --- layout: global -title: Basic Statistics - spark.mllib -displayTitle: Basic Statistics - spark.mllib +title: Basic Statistics - RDD-based API +displayTitle: Basic Statistics - RDD-based API --- * Table of contents @@ -80,7 +80,7 @@ correlation methods are currently Pearson's and Spearman's correlation. calculate correlations between series. Depending on the type of input, two `RDD[Double]`s or an `RDD[Vector]`, the output will be a `Double` or the correlation `Matrix` respectively. -Refer to the [`Statistics` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.Statistics) for details on the API. +Refer to the [`Statistics` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) for details on the API. {% include_example scala/org/apache/spark/examples/mllib/CorrelationsExample.scala %}
    @@ -210,7 +210,7 @@ message. run a 1-sample, 2-sided Kolmogorov-Smirnov test. The following example demonstrates how to run and interpret the hypothesis tests. -Refer to the [`Statistics` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.Statistics) for details on the API. +Refer to the [`Statistics` Scala docs](api/scala/index.html#org.apache.spark.mllib.stat.Statistics$) for details on the API. {% include_example scala/org/apache/spark/examples/mllib/HypothesisTestingKolmogorovSmirnovTestExample.scala %}
    @@ -277,12 +277,12 @@ uniform, standard normal, or Poisson.
    -[`RandomRDDs`](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory +[`RandomRDDs`](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs$) provides factory methods to generate random double RDDs or vector RDDs. The following example generates a random double RDD, whose values follows the standard normal distribution `N(0, 1)`, and then map it to `N(1, 4)`. -Refer to the [`RandomRDDs` Scala docs](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs) for details on the API. +Refer to the [`RandomRDDs` Scala docs](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs$) for details on the API. {% highlight scala %} import org.apache.spark.SparkContext diff --git a/docs/monitoring.md b/docs/monitoring.md index 9912cde743a45..1bc3d266b66b4 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -27,11 +27,6 @@ in the UI to persisted storage. ## Viewing After the Fact -Spark's Standalone Mode cluster manager also has its own -[web UI](spark-standalone.html#monitoring-and-logging). If an application has logged events over -the course of its lifetime, then the Standalone master's web UI will automatically re-render the -application's UI after the application has finished. - If Spark is run on Mesos or YARN, it is still possible to construct the UI of an application through Spark's history server, provided that the application's event logs exist. You can start the history server by executing: @@ -119,8 +114,17 @@ The history server can be configured as follows: spark.history.retainedApplications 50 - The number of application UIs to retain. If this cap is exceeded, then the oldest - applications will be removed. + The number of applications to retain UI data for in the cache. If this cap is exceeded, then + the oldest applications will be removed from the cache. If an application is not in the cache, + it will have to be loaded from disk if its accessed from the UI. + + + + spark.history.ui.maxApplications + Int.MaxValue + + The number of applications to display on the history summary page. Application UIs are still + available by accessing their URLs directly even if they are not displayed on the history summary page. @@ -162,8 +166,8 @@ The history server can be configured as follows: If enabled, access control checks are made regardless of what the individual application had set for spark.ui.acls.enable when the application was run. The application owner will always have authorization to view their own application and any users specified via - spark.ui.view.acls when the application was run will also have authorization - to view that application. + spark.ui.view.acls and groups specified via spark.ui.view.acls.groups + when the application was run will also have authorization to view that application. If disabled, no access control checks are made. @@ -230,9 +234,10 @@ for the history server, they would typically be accessible at `http:// EndpointMeaning @@ -246,7 +251,8 @@ where `[base-app-id]` is the YARN application ID.
    Examples:
    ?minDate=2015-02-10
    ?minDate=2015-02-03T16:42:40.000GMT -
    ?maxDate=[date] latest date/time to list; uses same format as minDate. +
    ?maxDate=[date] latest date/time to list; uses same format as minDate. +
    ?limit=[limit] limits the number of applications listed. /applications/[app-id]/jobs @@ -341,7 +347,7 @@ keep the paths consistent in both modes. # Metrics Spark has a configurable metrics system based on the -[Coda Hale Metrics Library](http://metrics.codahale.com/). +[Dropwizard Metrics Library](http://metrics.dropwizard.io/). This allows users to report Spark metrics to a variety of sinks including HTTP, JMX, and CSV files. The metrics system is configured via a configuration file that Spark expects to be present at `$SPARK_HOME/conf/metrics.properties`. A custom file location can be specified via the diff --git a/docs/programming-guide.md b/docs/programming-guide.md index cf6f1d89147f0..204ad5e4ada22 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -328,7 +328,7 @@ Text file RDDs can be created using `SparkContext`'s `textFile` method. This met {% highlight scala %} scala> val distFile = sc.textFile("data.txt") -distFile: RDD[String] = MappedRDD@1d4cee08 +distFile: org.apache.spark.rdd.RDD[String] = data.txt MapPartitionsRDD[10] at textFile at :26 {% endhighlight %} Once created, `distFile` can be acted on by dataset operations. For example, we can add up the sizes of all the lines using the `map` and `reduce` operations as follows: `distFile.map(s => s.length).reduce((a, b) => a + b)`. @@ -491,7 +491,7 @@ for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` RDDs support two types of operations: *transformations*, which create a new dataset from an existing one, and *actions*, which return a value to the driver program after running a computation on the dataset. For example, `map` is a transformation that passes each dataset element through a function and returns a new RDD representing the results. On the other hand, `reduce` is an action that aggregates all the elements of the RDD using some function and returns the final result to the driver program (although there is also a parallel `reduceByKey` that returns a distributed dataset). -All transformations in Spark are lazy, in that they do not compute their results right away. Instead, they just remember the transformations applied to some base dataset (e.g. a file). The transformations are only computed when an action requires a result to be returned to the driver program. This design enables Spark to run more efficiently -- for example, we can realize that a dataset created through `map` will be used in a `reduce` and return only the result of the `reduce` to the driver, rather than the larger mapped dataset. +All transformations in Spark are lazy, in that they do not compute their results right away. Instead, they just remember the transformations applied to some base dataset (e.g. a file). The transformations are only computed when an action requires a result to be returned to the driver program. This design enables Spark to run more efficiently. For example, we can realize that a dataset created through `map` will be used in a `reduce` and return only the result of the `reduce` to the driver, rather than the larger mapped dataset. By default, each transformed RDD may be recomputed each time you run an action on it. However, you may also *persist* an RDD in memory using the `persist` (or `cache`) method, in which case Spark will keep the elements around on the cluster for much faster access the next time you query it. There is also support for persisting RDDs on disk, or replicated across multiple nodes. @@ -618,7 +618,7 @@ class MyClass { } {% endhighlight %} -Here, if we create a `new MyClass` and call `doStuff` on it, the `map` inside there references the +Here, if we create a new `MyClass` instance and call `doStuff` on it, the `map` inside there references the `func1` method *of that `MyClass` instance*, so the whole object needs to be sent to the cluster. It is similar to writing `rdd.map(x => this.func1(x))`. @@ -1097,10 +1097,13 @@ for details. foreach(func) Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. -
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. +
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. +The Spark RDD API also exposes asynchronous versions of some actions, like `foreachAsync` for `foreach`, which immediately return a `FutureAction` to the caller instead of blocking on completion of the action. This can be used to manage or wait for the asynchronous execution of the action. + + ### Shuffle operations Certain operations within Spark trigger an event known as the shuffle. The shuffle is Spark's @@ -1156,7 +1159,7 @@ to disk, incurring the additional overhead of disk I/O and increased garbage col Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files are preserved until the corresponding RDDs are no longer used and are garbage collected. This is done so the shuffle files don't need to be re-created if the lineage is re-computed. -Garbage collection may happen only after a long period time, if the application retains references +Garbage collection may happen only after a long period of time, if the application retains references to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. @@ -1220,6 +1223,11 @@ storage levels is: MEMORY_ONLY_2, MEMORY_AND_DISK_2, etc. Same as the levels above, but replicate each partition on two cluster nodes. + + OFF_HEAP (experimental) + Similar to MEMORY_ONLY_SER, but store the data in + off-heap memory. This requires off-heap memory to be enabled. + **Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, @@ -1340,60 +1348,68 @@ running stages (NOTE: this is not yet supported in Python). Accumulators in the Spark UI

    -An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks -running on a cluster can then add to it using the `add` method or the `+=` operator (in Scala and Python). -However, they cannot read its value. -Only the driver program can read the accumulator's value, using its `value` method. - -The code below shows an accumulator being used to add up the elements of an array: -
    +A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` +to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +using its `value` method. + +The code below shows an accumulator being used to add up the elements of an array: + {% highlight scala %} -scala> val accum = sc.accumulator(0, "My Accumulator") -accum: org.apache.spark.Accumulator[Int] = 0 +scala> val accum = sc.longAccumulator("My Accumulator") +accum: org.apache.spark.util.LongAccumulator = LongAccumulator(id: 0, name: Some(My Accumulator), value: 0) -scala> sc.parallelize(Array(1, 2, 3, 4)).foreach(x => accum += x) +scala> sc.parallelize(Array(1, 2, 3, 4)).foreach(x => accum.add(x)) ... 10/09/29 18:41:08 INFO SparkContext: Tasks finished in 0.317106 s scala> accum.value -res2: Int = 10 +res2: Long = 10 {% endhighlight %} -While this code used the built-in support for accumulators of type Int, programmers can also -create their own types by subclassing [AccumulatorParam](api/scala/index.html#org.apache.spark.AccumulatorParam). -The AccumulatorParam interface has two methods: `zero` for providing a "zero value" for your data -type, and `addInPlace` for adding two values together. For example, supposing we had a `Vector` class +While this code used the built-in support for accumulators of type Long, programmers can also +create their own types by subclassing [AccumulatorV2](api/scala/index.html#org.apache.spark.util.AccumulatorV2). +The AccumulatorV2 abstract class has several methods which need to override: +`reset` for resetting the accumulator to zero, and `add` for add anothor value into the accumulator, `merge` for merging another same-type accumulator into this one. Other methods need to override can refer to scala API document. For example, supposing we had a `MyVector` class representing mathematical vectors, we could write: {% highlight scala %} -object VectorAccumulatorParam extends AccumulatorParam[Vector] { - def zero(initialValue: Vector): Vector = { - Vector.zeros(initialValue.size) +object VectorAccumulatorV2 extends AccumulatorV2[MyVector, MyVector] { + val vec_ : MyVector = MyVector.createZeroVector + def reset(): MyVector = { + vec_.reset() } - def addInPlace(v1: Vector, v2: Vector): Vector = { - v1 += v2 + def add(v1: MyVector, v2: MyVector): MyVector = { + vec_.add(v2) } + ... } // Then, create an Accumulator of this type: -val vecAccum = sc.accumulator(new Vector(...))(VectorAccumulatorParam) +val myVectorAcc = new VectorAccumulatorV2 +// Then, register it into spark context: +sc.register(myVectorAcc, "MyVectorAcc1") {% endhighlight %} -In Scala, Spark also supports the more general [Accumulable](api/scala/index.html#org.apache.spark.Accumulable) -interface to accumulate data where the resulting type is not the same as the elements added (e.g. build -a list by collecting together elements), and the `SparkContext.accumulableCollection` method for accumulating -common Scala collection types. +Note that, when programmers define their own type of AccumulatorV2, the resulting type can be different than that of the elements added.
    +A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` +to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +using its `value` method. + +The code below shows an accumulator being used to add up the elements of an array: + {% highlight java %} -Accumulator accum = sc.accumulator(0); +LongAccumulator accum = jsc.sc().longAccumulator(); sc.parallelize(Arrays.asList(1, 2, 3, 4)).foreach(x -> accum.add(x)); // ... @@ -1403,8 +1419,8 @@ accum.value(); // returns 10 {% endhighlight %} -While this code used the built-in support for accumulators of type Integer, programmers can also -create their own types by subclassing [AccumulatorParam](api/java/index.html?org/apache/spark/AccumulatorParam.html). +Programmers can also create their own types by subclassing +[AccumulatorParam](api/java/index.html?org/apache/spark/AccumulatorParam.html). The AccumulatorParam interface has two methods: `zero` for providing a "zero value" for your data type, and `addInPlace` for adding two values together. For example, supposing we had a `Vector` class representing mathematical vectors, we could write: @@ -1431,6 +1447,12 @@ a list by collecting together elements).
    +An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks +running on a cluster can then add to it using the `add` method or the `+=` operator. However, they cannot read its value. +Only the driver program can read the accumulator's value, using its `value` method. + +The code below shows an accumulator being used to add up the elements of an array: + {% highlight python %} >>> accum = sc.accumulator(0) Accumulator @@ -1476,15 +1498,15 @@ Accumulators do not change the lazy evaluation model of Spark. If they are being
    {% highlight scala %} -val accum = sc.accumulator(0) -data.map { x => accum += x; x } +val accum = sc.longAccumulator +data.map { x => accum.add(x); x } // Here, accum is still 0 because no actions have caused the map operation to be computed. {% endhighlight %}
    {% highlight java %} -Accumulator accum = sc.accumulator(0); +LongAccumulator accum = jsc.sc().longAccumulator(); data.map(x -> { accum.add(x); return f(x); }); // Here, accum is still 0 because no actions have caused the `map` to be computed. {% endhighlight %} @@ -1522,49 +1544,6 @@ and then call `SparkContext.stop()` to tear it down. Make sure you stop the context within a `finally` block or the test framework's `tearDown` method, as Spark does not support two contexts running concurrently in the same program. -# Migrating from pre-1.0 Versions of Spark - -
    - -
    - -Spark 1.0 freezes the API of Spark Core for the 1.X series, in that any API available today that is -not marked "experimental" or "developer API" will be supported in future versions. -The only change for Scala users is that the grouping operations, e.g. `groupByKey`, `cogroup` and `join`, -have changed from returning `(Key, Seq[Value])` pairs to `(Key, Iterable[Value])`. - -
    - -
    - -Spark 1.0 freezes the API of Spark Core for the 1.X series, in that any API available today that is -not marked "experimental" or "developer API" will be supported in future versions. -Several changes were made to the Java API: - -* The Function classes in `org.apache.spark.api.java.function` became interfaces in 1.0, meaning that old - code that `extends Function` should `implement Function` instead. -* New variants of the `map` transformations, like `mapToPair` and `mapToDouble`, were added to create RDDs - of special data types. -* Grouping operations like `groupByKey`, `cogroup` and `join` have changed from returning - `(Key, List)` pairs to `(Key, Iterable)`. - -
    - -
    - -Spark 1.0 freezes the API of Spark Core for the 1.X series, in that any API available today that is -not marked "experimental" or "developer API" will be supported in future versions. -The only change for Python users is that the grouping operations, e.g. `groupByKey`, `cogroup` and `join`, -have changed from returning (key, list of values) pairs to (key, iterable of values). - -
    - -
    - -Migration guides are also available for [Spark Streaming](streaming-programming-guide.html#migration-guide-from-091-or-below-to-1x), -[MLlib](mllib-guide.html#migration-guide) and [GraphX](graphx-programming-guide.html#migrating-from-spark-091). - - # Where to Go from Here You can see some [example Spark programs](http://spark.apache.org/examples.html) on the Spark website. diff --git a/docs/quick-start.md b/docs/quick-start.md index d481fe0ea6d70..04b0f0a7096d3 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -33,7 +33,7 @@ Spark's primary abstraction is a distributed collection of items called a Resili {% highlight scala %} scala> val textFile = sc.textFile("README.md") -textFile: spark.RDD[String] = spark.MappedRDD@2ee9b6e3 +textFile: org.apache.spark.rdd.RDD[String] = README.md MapPartitionsRDD[1] at textFile at :25 {% endhighlight %} RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions: @@ -50,7 +50,7 @@ Now let's use a transformation. We will use the [`filter`](programming-guide.htm {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) -linesWithSpark: spark.RDD[String] = spark.FilteredRDD@7dd4af09 +linesWithSpark: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[2] at filter at :27 {% endhighlight %} We can chain together transformations and actions: @@ -123,7 +123,7 @@ One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can i {% highlight scala %} scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (word, 1)).reduceByKey((a, b) => a + b) -wordCounts: spark.RDD[(String, Int)] = spark.ShuffledAggregatedRDD@71f027b8 +wordCounts: org.apache.spark.rdd.RDD[(String, Int)] = ShuffledRDD[8] at reduceByKey at :28 {% endhighlight %} Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: @@ -181,7 +181,7 @@ Spark also supports pulling data sets into a cluster-wide in-memory cache. This {% highlight scala %} scala> linesWithSpark.cache() -res7: spark.RDD[String] = spark.FilteredRDD@17e51082 +res7: linesWithSpark.type = MapPartitionsRDD[2] at filter at :27 scala> linesWithSpark.count() res8: Long = 19 @@ -240,7 +240,8 @@ object SimpleApp { val logData = sc.textFile(logFile, 2).cache() val numAs = logData.filter(line => line.contains("a")).count() val numBs = logData.filter(line => line.contains("b")).count() - println("Lines with a: %s, Lines with b: %s".format(numAs, numBs)) + println(s"Lines with a: $numAs, Lines with b: $numBs") + sc.stop() } } {% endhighlight %} @@ -289,13 +290,13 @@ $ find . # Package a jar containing your application $ sbt package ... -[info] Packaging {..}/{..}/target/scala-2.10/simple-project_2.10-1.0.jar +[info] Packaging {..}/{..}/target/scala-{{site.SCALA_BINARY_VERSION}}/simple-project_{{site.SCALA_BINARY_VERSION}}-1.0.jar # Use spark-submit to run your application $ YOUR_SPARK_HOME/bin/spark-submit \ --class "SimpleApp" \ --master local[4] \ - target/scala-2.10/simple-project_2.10-1.0.jar + target/scala-{{site.SCALA_BINARY_VERSION}}/simple-project_{{site.SCALA_BINARY_VERSION}}-1.0.jar ... Lines with a: 46, Lines with b: 23 {% endhighlight %} @@ -328,6 +329,8 @@ public class SimpleApp { }).count(); System.out.println("Lines with a: " + numAs + ", lines with b: " + numBs); + + sc.stop() } } {% endhighlight %} @@ -407,6 +410,8 @@ numAs = logData.filter(lambda s: 'a' in s).count() numBs = logData.filter(lambda s: 'b' in s).count() print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) + +sc.stop() {% endhighlight %} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 4a0ab623c1082..5219e99fee73e 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -180,30 +180,58 @@ Note that jars or python files that are passed to spark-submit should be URIs re # Mesos Run Modes -Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained". - -The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos -machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup -overhead, but at the cost of reserving the Mesos resources for the complete duration of the -application. - -Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true -to turn it on explicitly in [SparkConf](configuration.html#spark-properties): - -{% highlight scala %} -conf.set("spark.mesos.coarse", "true") -{% endhighlight %} - -In addition, for coarse-grained mode, you can control the maximum number of resources Spark will -acquire. By default, it will acquire *all* cores in the cluster (that get offered by Mesos), which -only makes sense if you run just one application at a time. You can cap the maximum number of cores -using `conf.set("spark.cores.max", "10")` (for example). - -In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows -multiple instances of Spark (and other frameworks) to share machines at a very fine granularity, -where each application gets more or fewer machines as it ramps up and down, but it comes with an -additional overhead in launching each task. This mode may be inappropriate for low-latency -requirements like interactive queries or serving web requests. +Spark can run over Mesos in two modes: "coarse-grained" (default) and +"fine-grained" (deprecated). + +## Coarse-Grained + +In "coarse-grained" mode, each Spark executor runs as a single Mesos +task. Spark executors are sized according to the following +configuration variables: + +* Executor memory: `spark.executor.memory` +* Executor cores: `spark.executor.cores` +* Number of executors: `spark.cores.max`/`spark.executor.cores` + +Please see the [Spark Configuration](configuration.html) page for +details and default values. + +Executors are brought up eagerly when the application starts, until +`spark.cores.max` is reached. If you don't set `spark.cores.max`, the +Spark application will reserve all resources offered to it by Mesos, +so we of course urge you to set this variable in any sort of +multi-tenant cluster, including one which runs multiple concurrent +Spark applications. + +The scheduler will start executors round-robin on the offers Mesos +gives it, but there are no spread guarantees, as Mesos does not +provide such guarantees on the offer stream. + +The benefit of coarse-grained mode is much lower startup overhead, but +at the cost of reserving Mesos resources for the complete duration of +the application. To configure your job to dynamically adjust to its +resource requirements, look into +[Dynamic Allocation](#dynamic-resource-allocation-with-mesos). + +## Fine-Grained (deprecated) + +**NOTE:** Fine-grained mode is deprecated as of Spark 2.0.0. Consider + using [Dynamic Allocation](#dynamic-resource-allocation-with-mesos) + for some of the benefits. For a full explanation see + [SPARK-11857](https://issues.apache.org/jira/browse/SPARK-11857) + +In "fine-grained" mode, each Spark task inside the Spark executor runs +as a separate Mesos task. This allows multiple instances of Spark (and +other frameworks) to share cores at a very fine granularity, where +each application gets more or fewer cores as it ramps up and down, but +it comes with an additional overhead in launching each task. This mode +may be inappropriate for low-latency requirements like interactive +queries or serving web requests. + +Note that while Spark tasks in fine-grained will relinquish cores as +they terminate, they will not relinquish memory, as the JVM does not +give memory back to the Operating System. Neither will executors +terminate when they're idle. To run in fine-grained mode, set the `spark.mesos.coarse` property to false in your [SparkConf](configuration.html#spark-properties): @@ -212,7 +240,9 @@ To run in fine-grained mode, set the `spark.mesos.coarse` property to false in y conf.set("spark.mesos.coarse", "false") {% endhighlight %} -You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. +You may also make use of `spark.mesos.constraints` to set +attribute-based constraints on Mesos resource offers. By default, all +resource offers will be accepted. {% highlight scala %} conf.set("spark.mesos.constraints", "os:centos7;us-east-1:false") @@ -246,7 +276,7 @@ In either case, HDFS runs separately from Hadoop MapReduce, without being schedu # Dynamic Resource Allocation with Mesos -Mesos supports dynamic allocation only with coarse-grain mode, which can resize the number of +Mesos supports dynamic allocation only with coarse-grained mode, which can resize the number of executors based on statistics of the application. For general information, see [Dynamic Resource Allocation](job-scheduling.html#dynamic-resource-allocation). diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3bd16bf60c818..dbd46cc48c14b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -60,6 +60,8 @@ Running Spark on YARN requires a binary distribution of Spark which is built wit Binary distributions can be downloaded from the [downloads page](http://spark.apache.org/downloads.html) of the project website. To build Spark yourself, refer to [Building Spark](building-spark.html). +To make Spark runtime jars accessible from YARN side, you can specify `spark.yarn.archive` or `spark.yarn.jars`. For details please refer to [Spark Properties](running-on-yarn.html#spark-properties). If neither `spark.yarn.archive` nor `spark.yarn.jars` is specified, Spark will create a zip file with all jars under `$SPARK_HOME/jars` and upload it to the distributed cache. + # Configuration Most of the configs are the same for Spark on YARN as for other deployment modes. See the [configuration page](configuration.html) for more information on those. These are configs that are specific to Spark on YARN. @@ -99,6 +101,8 @@ to the same log file). If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your `log4j.properties`. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming applications, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log files, and logs can be accessed using YARN's log utility. +To use a custom metrics.properties for the application master and executors, update the `$SPARK_CONF_DIR/metrics.properties` file. It will automatically be uploaded with other configurations, so you don't need to specify it manually with `--files`. + #### Spark Properties @@ -240,7 +244,7 @@ If you need a reference to the proper location to put log files in the YARN so t @@ -476,3 +480,99 @@ If you need a reference to the proper location to put log files in the YARN so t - In `cluster` mode, the local directories used by the Spark executors and the Spark driver will be the local directories configured for YARN (Hadoop YARN config `yarn.nodemanager.local-dirs`). If the user specifies `spark.local.dir`, it will be ignored. In `client` mode, the Spark executors will use the local directories configured for YARN while the Spark driver will use those defined in `spark.local.dir`. This is because the Spark driver does not run on the YARN cluster in `client` mode, only the Spark executors do. - The `--files` and `--archives` options support specifying file names with the # similar to Hadoop. For example you can specify: `--files localtest.txt#appSees.txt` and this will upload the file you have locally named `localtest.txt` into HDFS but this will be linked to by the name `appSees.txt`, and your application should use the name as `appSees.txt` to reference it when running on YARN. - The `--jars` option allows the `SparkContext.addJar` function to work if you are using it with local files and running in `cluster` mode. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files. + +# Running in a Secure Cluster + +As covered in [security](security.html), Kerberos is used in a secure Hadoop cluster to +authenticate principals associated with services and clients. This allows clients to +make requests of these authenticated services; the services to grant rights +to the authenticated principals. + +Hadoop services issue *hadoop tokens* to grant access to the services and data. +Clients must first acquire tokens for the services they will access and pass them along with their +application as it is launched in the YARN cluster. + +For a Spark application to interact with HDFS, HBase and Hive, it must acquire the relevant tokens +using the Kerberos credentials of the user launching the application +—that is, the principal whose identity will become that of the launched Spark application. + +This is normally done at launch time: in a secure cluster Spark will automatically obtain a +token for the cluster's HDFS filesystem, and potentially for HBase and Hive. + +An HBase token will be obtained if HBase is in on classpath, the HBase configuration declares +the application is secure (i.e. `hbase-site.xml` sets `hbase.security.authentication` to `kerberos`), +and `spark.yarn.security.tokens.hbase.enabled` is not set to `false`. + +Similarly, a Hive token will be obtained if Hive is on the classpath, its configuration +includes a URI of the metadata store in `"hive.metastore.uris`, and +`spark.yarn.security.tokens.hive.enabled` is not set to `false`. + +If an application needs to interact with other secure HDFS clusters, then +the tokens needed to access these clusters must be explicitly requested at +launch time. This is done by listing them in the `spark.yarn.access.namenodes` property. + +``` +spark.yarn.access.namenodes hdfs://ireland.example.org:8020/,hdfs://frankfurt.example.org:8020/ +``` + +## Launching your application with Apache Oozie + +Apache Oozie can launch Spark applications as part of a workflow. +In a secure cluster, the launched application will need the relevant tokens to access the cluster's +services. If Spark is launched with a keytab, this is automatic. +However, if Spark is to be launched without a keytab, the responsibility for setting up security +must be handed over to Oozie. + +The details of configuring Oozie for secure clusters and obtaining +credentials for a job can be found on the [Oozie web site](http://oozie.apache.org/) +in the "Authentication" section of the specific release's documentation. + +For Spark applications, the Oozie workflow must be set up for Oozie to request all tokens which +the application needs, including: + +- The YARN resource manager. +- The local HDFS filesystem. +- Any remote HDFS filesystems used as a source or destination of I/O. +- Hive —if used. +- HBase —if used. +- The YARN timeline server, if the application interacts with this. + +To avoid Spark attempting —and then failing— to obtain Hive, HBase and remote HDFS tokens, +the Spark configuration must be set to disable token collection for the services. + +The Spark configuration must include the lines: + +``` +spark.yarn.security.tokens.hive.enabled false +spark.yarn.security.tokens.hbase.enabled false +``` + +The configuration option `spark.yarn.access.namenodes` must be unset. + +## Troubleshooting Kerberos + +Debugging Hadoop/Kerberos problems can be "difficult". One useful technique is to +enable extra logging of Kerberos operations in Hadoop by setting the `HADOOP_JAAS_DEBUG` +environment variable. + +```bash +export HADOOP_JAAS_DEBUG=true +``` + +The JDK classes can be configured to enable extra logging of their Kerberos and +SPNEGO/REST authentication via the system properties `sun.security.krb5.debug` +and `sun.security.spnego.debug=true` + +``` +-Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true +``` + +All these options can be enabled in the Application Master: + +``` +spark.yarn.appMasterEnv.HADOOP_JAAS_DEBUG true +spark.yarn.am.extraJavaOptions -Dsun.security.krb5.debug=true -Dsun.security.spnego.debug=true +``` + +Finally, if the log level for `org.apache.spark.deploy.yarn.Client` is set to `DEBUG`, the log +will include a list of all tokens obtained, and their expiry details diff --git a/docs/security.md b/docs/security.md index 32c33d285747a..baadfefbec826 100644 --- a/docs/security.md +++ b/docs/security.md @@ -16,10 +16,10 @@ and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via the `spark.ui.h ### Authentication -A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable` and `spark.ui.view.acls` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. +A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view ACLs to make sure they are authorized to view the UI. The configs `spark.acls.enable`, `spark.ui.view.acls` and `spark.ui.view.acls.groups` control the behavior of the ACLs. Note that the user who started the application always has view access to the UI. On YARN, the Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. -Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable` and `spark.modify.acls`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. -Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the config `spark.admin.acls`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. +Spark also supports modify ACLs to control who has access to modify a running Spark application. This includes things like killing the application or a task. This is controlled by the configs `spark.acls.enable`, `spark.modify.acls` and `spark.modify.acls.groups`. Note that if you are authenticating the web UI, in order to use the kill button on the web UI it might be necessary to add the users in the modify acls to the view acls also. On YARN, the modify acls are passed in and control who has modify access via YARN interfaces. +Spark allows for a set of administrators to be specified in the acls who always have view and modify permissions to all the applications. is controlled by the configs `spark.admin.acls` and `spark.admin.acls.groups`. This is useful on a shared cluster where you might have administrators or support staff who help users debug applications. ## Event Logging @@ -27,7 +27,8 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service. +Spark supports SSL for HTTP protocols. SASL encryption is supported for the block transfer service +and the RPC endpoints. Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle files, cached data, and other application files. If encrypting this data is desired, a workaround is diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index fd94c34d1638d..9915487fb9d97 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -94,8 +94,8 @@ You can optionally configure the cluster further by setting environment variable
    spark.executor.instances 2 - The number of executors. Note that this property is incompatible with spark.dynamicAllocation.enabled. If both spark.dynamicAllocation.enabled and spark.executor.instances are specified, dynamic allocation is turned off and the specified number of spark.executor.instances is used. + The number of executors for static allocation. With spark.dynamicAllocation.enabled, the initial set of executors will be at least this large.
    - - + + @@ -133,15 +133,6 @@ You can optionally configure the cluster further by setting environment variable - - - - @@ -204,6 +195,21 @@ SPARK_MASTER_OPTS supports the following system properties: the whole cluster by default.
    + + + + + @@ -244,6 +250,15 @@ SPARK_WORKER_OPTS supports the following system properties: especially if you run jobs very frequently. + + + + +
    Environment VariableMeaning
    SPARK_MASTER_IPBind the master to a specific IP address, for example a public one.SPARK_MASTER_HOSTBind the master to a specific hostname or IP address, for example a public one.
    SPARK_MASTER_PORTSPARK_WORKER_WEBUI_PORT Port for the worker web UI (default: 8081).
    SPARK_WORKER_INSTANCES - Number of worker instances to run on each machine (default: 1). You can make this more than 1 if - you have have very large machines and would like multiple Spark worker processes. If you do set - this, make sure to also set SPARK_WORKER_CORES explicitly to limit the cores per worker, - or else each worker will try to use all the cores. -
    SPARK_WORKER_DIR Directory to run applications in, which will include both logs and scratch space (default: SPARK_HOME/work).
    spark.deploy.maxExecutorRetries10 + Limit on the maximum number of back-to-back executor failures that can occur before the + standalone cluster manager removes a faulty application. An application will never be removed + if it has any running executors. If an application experiences more than + spark.deploy.maxExecutorRetries failures in a row, no executors + successfully start running in between those failures, and the application has no running + executors then the standalone cluster manager will remove the application and mark it as failed. + To disable this automatic removal, set spark.deploy.maxExecutorRetries to + -1. +
    +
    spark.worker.timeout 60
    spark.worker.ui.compressedLogFileLengthCacheSize100 + For compressed log files, the uncompressed file can only be computed by uncompressing the files. + Spark caches the uncompressed file size of compressed log files. This property controls the cache + size. +
    # Connecting an Application to the Cluster @@ -342,7 +357,7 @@ Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.o **Configuration** In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations. -For more information about these configurations please refer to the configurations (doc)[configurations.html#deploy] +For more information about these configurations please refer to the [configuration doc](configuration.html#deploy) Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently). diff --git a/docs/sparkr.md b/docs/sparkr.md index 760534ae145fa..340e7f7cb1a0b 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -14,29 +14,24 @@ supports operations like selection, filtering, aggregation etc. (similar to R da [dplyr](https://github.com/hadley/dplyr)) but on large datasets. SparkR also supports distributed machine learning using MLlib. -# SparkR DataFrames +# SparkDataFrame -A DataFrame is a distributed collection of data organized into named columns. It is conceptually +A SparkDataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R, but with richer -optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: +optimizations under the hood. SparkDataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing local R data frames. All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell. -## Starting Up: SparkContext, SQLContext +## Starting Up: SparkSession
    -The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. -You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name -, any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`, -which can be created from the SparkContext. If you are working from the `sparkR` shell, the -`SQLContext` and `SparkContext` should already be created for you, and you would not need to call -`sparkR.init`. +The entry point into SparkR is the `SparkSession` which connects your R program to a Spark cluster. +You can create a `SparkSession` using `sparkR.session` and pass in options such as the application name, any spark packages depended on, etc. Further, you can also work with SparkDataFrames via `SparkSession`. If you are working from the `sparkR` shell, the `SparkSession` should already be created for you, and you would not need to call `sparkR.session`.
    {% highlight r %} -sc <- sparkR.init() -sqlContext <- sparkRSQL.init(sc) +sparkR.session() {% endhighlight %}
    @@ -45,13 +40,13 @@ sqlContext <- sparkRSQL.init(sc) You can also start SparkR from RStudio. You can connect your R program to a Spark cluster from RStudio, R shell, Rscript or other R IDEs. To start, make sure SPARK_HOME is set in environment (you can check [Sys.getenv](https://stat.ethz.ch/R-manual/R-devel/library/base/html/Sys.getenv.html)), -load the SparkR package, and call `sparkR.init` as below. In addition to calling `sparkR.init`, you -could also specify certain Spark driver properties. Normally these +load the SparkR package, and call `sparkR.session` as below. In addition to calling `sparkR.session`, + you could also specify certain Spark driver properties. Normally these [Application properties](configuration.html#application-properties) and [Runtime Environment](configuration.html#runtime-environment) cannot be set programmatically, as the driver JVM process would have been started, in this case SparkR takes care of this for you. To set -them, pass them as you would other configuration properties in the `sparkEnvir` argument to -`sparkR.init()`. +them, pass them as you would other configuration properties in the `sparkConfig` argument to +`sparkR.session()`.
    {% highlight r %} @@ -59,14 +54,29 @@ if (nchar(Sys.getenv("SPARK_HOME")) < 1) { Sys.setenv(SPARK_HOME = "/home/spark") } library(SparkR, lib.loc = c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"))) -sc <- sparkR.init(master = "local[*]", sparkEnvir = list(spark.driver.memory="2g")) +sparkR.session(master = "local[*]", sparkConfig = list(spark.driver.memory = "2g")) {% endhighlight %}
    -The following options can be set in `sparkEnvir` with `sparkR.init` from RStudio: +The following Spark driver properties can be set in `sparkConfig` with `sparkR.session` from RStudio: + + + + + + + + + + + + + + + @@ -91,17 +101,17 @@ The following options can be set in `sparkEnvir` with `sparkR.init` from RStudio -## Creating DataFrames -With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). +## Creating SparkDataFrames +With a `SparkSession`, applications can create `SparkDataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). ### From local data frames -The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. +The simplest way to create a data frame is to convert a local R data frame into a SparkDataFrame. Specifically we can use `as.DataFrame` or `createDataFrame` and pass in the local R data frame to create a SparkDataFrame. As an example, the following creates a `SparkDataFrame` based using the `faithful` dataset from R.
    {% highlight r %} -df <- createDataFrame(sqlContext, faithful) +df <- as.DataFrame(faithful) -# Displays the content of the DataFrame to stdout +# Displays the first part of the SparkDataFrame head(df) ## eruptions waiting ##1 3.600 79 @@ -113,25 +123,23 @@ head(df) ### From Data Sources -SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. -The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by -specifying `--packages` with `spark-submit` or `sparkR` commands, or if creating context through `init` -you can specify the packages with the `packages` argument. +The general method for creating SparkDataFrames from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active SparkSession will be used automatically. +SparkR supports reading JSON, CSV and Parquet files natively, and through packages available from sources like [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects), you can find data source connectors for popular file formats like Avro. These packages can either be added by +specifying `--packages` with `spark-submit` or `sparkR` commands, or if initializing SparkSession with `sparkPackages` parameter when in an interactive R shell or from RStudio.
    {% highlight r %} -sc <- sparkR.init(sparkPackages="com.databricks:spark-csv_2.11:1.0.3") -sqlContext <- sparkRSQL.init(sc) +sparkR.session(sparkPackages = "com.databricks:spark-avro_2.11:3.0.0") {% endhighlight %}
    We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail.
    - {% highlight r %} -people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") +people <- read.df("./examples/src/main/resources/people.json", "json") head(people) ## age name ##1 NA Michael @@ -144,34 +152,45 @@ printSchema(people) # |-- age: long (nullable = true) # |-- name: string (nullable = true) +# Similarly, multiple files can be read with read.json +people <- read.json(c("./examples/src/main/resources/people.json", "./examples/src/main/resources/people2.json")) + {% endhighlight %}
    -The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example -to a Parquet file using `write.df` (Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API) +The data sources API natively supports CSV formatted input files. For more information please refer to SparkR [read.df](api/R/read.df.html) API documentation.
    {% highlight r %} -write.df(people, path="people.parquet", source="parquet", mode="overwrite") +df <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "NA") + +{% endhighlight %} +
    + +The data sources API can also be used to save out SparkDataFrames into multiple file formats. For example we can save the SparkDataFrame from the previous example +to a Parquet file using `write.df`. + +
    +{% highlight r %} +write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite") {% endhighlight %}
    ### From Hive tables -You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext). +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sparksession). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`).
    {% highlight r %} -# sc is an existing SparkContext. -hiveContext <- sparkRHive.init(sc) +sparkR.session() -sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results <- sql(hiveContext, "FROM src SELECT key, value") +results <- sql("FROM src SELECT key, value") -# results is now a DataFrame +# results is now a SparkDataFrame head(results) ## key value ## 1 238 val_238 @@ -181,19 +200,19 @@ head(results) {% endhighlight %}
    -## DataFrame Operations +## SparkDataFrame Operations -SparkR DataFrames support a number of functions to do structured data processing. +SparkDataFrames support a number of functions to do structured data processing. Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs: ### Selecting rows, columns
    {% highlight r %} -# Create the DataFrame -df <- createDataFrame(sqlContext, faithful) +# Create the SparkDataFrame +df <- as.DataFrame(faithful) -# Get basic information about the DataFrame +# Get basic information about the SparkDataFrame df ## SparkDataFrame[eruptions:double, waiting:double] @@ -207,7 +226,7 @@ head(select(df, df$eruptions)) # You can also pass in column name as strings head(select(df, "eruptions")) -# Filter the DataFrame to only retain rows with wait times shorter than 50 mins +# Filter the SparkDataFrame to only retain rows with wait times shorter than 50 mins head(filter(df, df$waiting < 50)) ## eruptions waiting ##1 1.750 47 @@ -251,7 +270,7 @@ SparkR also provides a number of functions that can directly applied to columns {% highlight r %} # Convert waiting time from hours to seconds. -# Note that we can assign this to a new column in the same DataFrame +# Note that we can assign this to a new column in the same SparkDataFrame df$waiting_secs <- df$waiting * 60 head(df) ## eruptions waiting waiting_secs @@ -262,95 +281,269 @@ head(df) {% endhighlight %}
    -## Running SQL Queries from SparkR -A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data. -The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. +### Applying User-Defined Function +In SparkR, we support several kinds of User-Defined Functions: + +#### Run a given function on a large dataset using `dapply` or `dapplyCollect` + +##### dapply +Apply a function to each partition of a `SparkDataFrame`. The function to be applied to each partition of the `SparkDataFrame` +and should have only one parameter, to which a `data.frame` corresponds to each partition will be passed. The output of function should be a `data.frame`. Schema specifies the row format of the resulting a `SparkDataFrame`. It must match to [data types](#data-type-mapping-between-r-and-spark) of returned value.
    {% highlight r %} -# Load a JSON file -people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") -# Register this DataFrame as a table. -registerTempTable(people, "people") +# Convert waiting time from hours to seconds. +# Note that we can apply UDF to DataFrame. +schema <- structType(structField("eruptions", "double"), structField("waiting", "double"), + structField("waiting_secs", "double")) +df1 <- dapply(df, function(x) { x <- cbind(x, x$waiting * 60) }, schema) +head(collect(df1)) +## eruptions waiting waiting_secs +##1 3.600 79 4740 +##2 1.800 54 3240 +##3 3.333 74 4440 +##4 2.283 62 3720 +##5 4.533 85 5100 +##6 2.883 55 3300 +{% endhighlight %} +
    -# SQL statements can be run by using the sql method -teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") -head(teenagers) -## name -##1 Justin +##### dapplyCollect +Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function +should be a `data.frame`. But, Schema is not required to be passed. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. + +
    +{% highlight r %} + +# Convert waiting time from hours to seconds. +# Note that we can apply UDF to DataFrame and return a R's data.frame +ldf <- dapplyCollect( + df, + function(x) { + x <- cbind(x, "waiting_secs" = x$waiting * 60) + }) +head(ldf, 3) +## eruptions waiting waiting_secs +##1 3.600 79 4740 +##2 1.800 54 3240 +##3 3.333 74 4440 {% endhighlight %}
    -# Machine Learning +#### Run a given function on a large dataset grouping by input column(s) and using `gapply` or `gapplyCollect` -SparkR allows the fitting of generalized linear models over DataFrames using the [glm()](api/R/glm.html) function. Under the hood, SparkR uses MLlib to train a model of the specified family. Currently the gaussian and binomial families are supported. We support a subset of the available R formula operators for model fitting, including '~', '.', ':', '+', and '-'. +##### gapply +Apply a function to each group of a `SparkDataFrame`. The function is to be applied to each group of the `SparkDataFrame` and should have only two parameters: grouping key and R `data.frame` corresponding to +that key. The groups are chosen from `SparkDataFrame`s column(s). +The output of function should be a `data.frame`. Schema specifies the row format of the resulting +`SparkDataFrame`. It must represent R function's output schema on the basis of Spark [data types](#data-type-mapping-between-r-and-spark). The column names of the returned `data.frame` are set by user. -The [summary()](api/R/summary.html) function gives the summary of a model produced by [glm()](api/R/glm.html). +
    +{% highlight r %} -* For gaussian GLM model, it returns a list with 'devianceResiduals' and 'coefficients' components. The 'devianceResiduals' gives the min/max deviance residuals of the estimation; the 'coefficients' gives the estimated coefficients and their estimated standard errors, t values and p-values. (It only available when model fitted by normal solver.) -* For binomial GLM model, it returns a list with 'coefficients' component which gives the estimated coefficients. +# Determine six waiting times with the largest eruption time in minutes. +schema <- structType(structField("waiting", "double"), structField("max_eruption", "double")) +result <- gapply( + df, + "waiting", + function(key, x) { + y <- data.frame(key, max(x$eruptions)) + }, + schema) +head(collect(arrange(result, "max_eruption", decreasing = TRUE))) + +## waiting max_eruption +##1 64 5.100 +##2 69 5.067 +##3 71 5.033 +##4 87 5.000 +##5 63 4.933 +##6 89 4.900 +{% endhighlight %} +
    -The examples below show the use of building gaussian GLM model and binomial GLM model using SparkR. +##### gapplyCollect +Like `gapply`, applies a function to each partition of a `SparkDataFrame` and collect the result back to R data.frame. The output of the function should be a `data.frame`. But, the schema is not required to be passed. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. -## Gaussian GLM model +
    +{% highlight r %} + +# Determine six waiting times with the largest eruption time in minutes. +result <- gapplyCollect( + df, + "waiting", + function(key, x) { + y <- data.frame(key, max(x$eruptions)) + colnames(y) <- c("waiting", "max_eruption") + y + }) +head(result[order(result$max_eruption, decreasing = TRUE), ]) + +## waiting max_eruption +##1 64 5.100 +##2 69 5.067 +##3 71 5.033 +##4 87 5.000 +##5 63 4.933 +##6 89 4.900 + +{% endhighlight %} +
    + +#### Data type mapping between R and Spark +
    Property NameProperty groupspark-submit equivalent
    spark.masterApplication Properties--master
    spark.yarn.keytabApplication Properties--keytab
    spark.yarn.principalApplication Properties--principal
    spark.driver.memory Application Properties
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    RSpark
    bytebyte
    integerinteger
    floatfloat
    doubledouble
    numericdouble
    characterstring
    stringstring
    binarybinary
    rawbinary
    logicalboolean
    POSIXcttimestamp
    POSIXlttimestamp
    Datedate
    arrayarray
    listarray
    envmap
    + +#### Run local R functions distributed using `spark.lapply` + +##### spark.lapply +Similar to `lapply` in native R, `spark.lapply` runs a function over a list of elements and distributes the computations with Spark. +Applies a function in a manner that is similar to `doParallel` or `lapply` to elements of a list. The results of all the computations +should fit in a single machine. If that is not the case they can do something like `df <- createDataFrame(list)` and then use +`dapply`
    {% highlight r %} -# Create the DataFrame -df <- createDataFrame(sqlContext, iris) - -# Fit a gaussian GLM model over the dataset. -model <- glm(Sepal_Length ~ Sepal_Width + Species, data = df, family = "gaussian") - -# Model summary are returned in a similar format to R's native glm(). -summary(model) -##$devianceResiduals -## Min Max -## -1.307112 1.412532 -## -##$coefficients -## Estimate Std. Error t value Pr(>|t|) -##(Intercept) 2.251393 0.3697543 6.08889 9.568102e-09 -##Sepal_Width 0.8035609 0.106339 7.556598 4.187317e-12 -##Species_versicolor 1.458743 0.1121079 13.01195 0 -##Species_virginica 1.946817 0.100015 19.46525 0 - -# Make predictions based on the model. -predictions <- predict(model, newData = df) -head(select(predictions, "Sepal_Length", "prediction")) -## Sepal_Length prediction -##1 5.1 5.063856 -##2 4.9 4.662076 -##3 4.7 4.822788 -##4 4.6 4.742432 -##5 5.0 5.144212 -##6 5.4 5.385281 +# Perform distributed training of multiple models with spark.lapply. Here, we pass +# a read-only list of arguments which specifies family the generalized linear model should be. +families <- c("gaussian", "poisson") +train <- function(family) { + model <- glm(Sepal.Length ~ Sepal.Width + Species, iris, family = family) + summary(model) +} +# Return a list of model's summaries +model.summaries <- spark.lapply(families, train) + +# Print the summary of each model +print(model.summaries) + {% endhighlight %}
    -## Binomial GLM model +## Running SQL Queries from SparkR +A SparkDataFrame can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. +The `sql` function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`.
    {% highlight r %} -# Create the DataFrame -df <- createDataFrame(sqlContext, iris) -training <- filter(df, df$Species != "setosa") - -# Fit a binomial GLM model over the dataset. -model <- glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial") - -# Model coefficients are returned in a similar format to R's native glm(). -summary(model) -##$coefficients -## Estimate -##(Intercept) -13.046005 -##Sepal_Length 1.902373 -##Sepal_Width 0.404655 +# Load a JSON file +people <- read.df("./examples/src/main/resources/people.json", "json") + +# Register this SparkDataFrame as a temporary view. +createOrReplaceTempView(people, "people") + +# SQL statements can be run by using the sql method +teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +##1 Justin + {% endhighlight %}
    +# Machine Learning + +SparkR supports the following machine learning algorithms currently: `Generalized Linear Model`, `Accelerated Failure Time (AFT) Survival Regression Model`, `Naive Bayes Model` and `KMeans Model`. +Under the hood, SparkR uses MLlib to train the model. +Users can call `summary` to print a summary of the fitted model, [predict](api/R/predict.html) to make predictions on new data, and [write.ml](api/R/write.ml.html)/[read.ml](api/R/read.ml.html) to save/load fitted models. +SparkR supports a subset of the available R formula operators for model fitting, including ‘~’, ‘.’, ‘:’, ‘+’, and ‘-‘. + +## Algorithms + +### Generalized Linear Model + +[spark.glm()](api/R/spark.glm.html) or [glm()](api/R/glm.html) fits generalized linear model against a Spark DataFrame. +Currently "gaussian", "binomial", "poisson" and "gamma" families are supported. +{% include_example glm r/ml.R %} + +### Accelerated Failure Time (AFT) Survival Regression Model + +[spark.survreg()](api/R/spark.survreg.html) fits an accelerated failure time (AFT) survival regression model on a SparkDataFrame. +Note that the formula of [spark.survreg()](api/R/spark.survreg.html) does not support operator '.' currently. +{% include_example survreg r/ml.R %} + +### Naive Bayes Model + +[spark.naiveBayes()](api/R/spark.naiveBayes.html) fits a Bernoulli naive Bayes model against a SparkDataFrame. Only categorical data is supported. +{% include_example naiveBayes r/ml.R %} + +### KMeans Model + +[spark.kmeans()](api/R/spark.kmeans.html) fits a k-means clustering model against a Spark DataFrame, similarly to R's kmeans(). +{% include_example kmeans r/ml.R %} + +## Model persistence + +The following example shows how to save/load a MLlib model by SparkR. +{% include_example read_write r/ml.R %} + # R Function Name Conflicts When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a @@ -386,8 +579,15 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading From SparkR 1.5.x to 1.6.x - Before Spark 1.6.0, the default mode for writes was `append`. It was changed in Spark 1.6.0 to `error` to match the Scala API. + - SparkSQL converts `NA` in R to `null` and vice-versa. ## Upgrading From SparkR 1.6.x to 2.0 - The method `table` has been removed and replaced by `tableToDF`. - The class `DataFrame` has been renamed to `SparkDataFrame` to avoid name conflicts. + - Spark's `SQLContext` and `HiveContext` have been deprecated to be replaced by `SparkSession`. Instead of `sparkR.init()`, call `sparkR.session()` in its place to instantiate the SparkSession. Once that is done, that currently active SparkSession will be used for SparkDataFrame operations. + - The parameter `sparkExecutorEnv` is not supported by `sparkR.session`. To set environment for the executors, set Spark config properties with the prefix "spark.executorEnv.VAR_NAME", for example, "spark.executorEnv.PATH" + - The `sqlContext` parameter is no longer required for these functions: `createDataFrame`, `as.DataFrame`, `read.json`, `jsonFile`, `read.parquet`, `parquetFile`, `read.text`, `sql`, `tables`, `tableNames`, `cacheTable`, `uncacheTable`, `clearCache`, `dropTempTable`, `read.df`, `loadDF`, `createExternalTable`. + - The method `registerTempTable` has been deprecated to be replaced by `createOrReplaceTempView`. + - The method `dropTempTable` has been deprecated to be replaced by `dropTempView`. + - The `sc` SparkContext parameter is no longer required for these functions: `setJobGroup`, `clearJobGroup`, `cancelJobGroup` diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a16a6bb1d93ef..bac0e8156913b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -12,292 +12,156 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. Unlike the basic Spark RDD API, the interfaces provided by Spark SQL provide Spark with more information about the structure of both the data and the computation being performed. Internally, Spark SQL uses this extra information to perform extra optimizations. There are several ways to -interact with Spark SQL including SQL, the DataFrames API and the Datasets API. When computing a result +interact with Spark SQL including SQL and the Dataset API. When computing a result the same execution engine is used, independent of which API/language you are using to express the -computation. This unification means that developers can easily switch back and forth between the -various APIs based on which provides the most natural way to express a given transformation. +computation. This unification means that developers can easily switch back and forth between +different APIs based on which provides the most natural way to express a given transformation. All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. ## SQL -One use of Spark SQL is to execute SQL queries written using either a basic SQL syntax or HiveQL. +One use of Spark SQL is to execute SQL queries. Spark SQL can also be used to read data from an existing Hive installation. For more on how to configure this feature, please refer to the [Hive Tables](#hive-tables) section. When running -SQL from within another programming language the results will be returned as a [DataFrame](#DataFrames). +SQL from within another programming language the results will be returned as a [Dataset/DataFrame](#datasets-and-dataframes). You can also interact with the SQL interface using the [command-line](#running-the-spark-sql-cli) or over [JDBC/ODBC](#running-the-thrift-jdbcodbc-server). -## DataFrames +## Datasets and DataFrames -A DataFrame is a distributed collection of data organized into named columns. It is conceptually +A Dataset is a distributed collection of data. +Dataset is a new interface added in Spark 1.6 that provides the benefits of RDDs (strong +typing, ability to use powerful lambda functions) with the benefits of Spark SQL's optimized +execution engine. A Dataset can be [constructed](#creating-datasets) from JVM objects and then +manipulated using functional transformations (`map`, `flatMap`, `filter`, etc.). +The Dataset API is available in [Scala][scala-datasets] and +[Java][java-datasets]. Python does not have the support for the Dataset API. But due to Python's dynamic nature, +many of the benefits of the Dataset API are already available (i.e. you can access the field of a row by name naturally +`row.columnName`). The case for R is similar. + +A DataFrame is a *Dataset* organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of [sources](#data-sources) such as: structured data files, tables in Hive, external databases, or existing RDDs. +The DataFrame API is available in Scala, +Java, [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). +In Scala and Java, a DataFrame is represented by a Dataset of `Row`s. +In [the Scala API][scala-datasets], `DataFrame` is simply a type alias of `Dataset[Row]`. +While, in [Java API][java-datasets], users need to use `Dataset` to represent a `DataFrame`. -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), -[Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), -[Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). - -## Datasets - -A Dataset is a new experimental interface added in Spark 1.6 that tries to provide the benefits of -RDDs (strong typing, ability to use powerful lambda functions) with the benefits of Spark SQL's -optimized execution engine. A Dataset can be [constructed](#creating-datasets) from JVM objects and then manipulated -using functional transformations (map, flatMap, filter, etc.). +[scala-datasets]: api/scala/index.html#org.apache.spark.sql.Dataset +[java-datasets]: api/java/index.html?org/apache/spark/sql/Dataset.html -The unified Dataset API can be used both in [Scala](api/scala/index.html#org.apache.spark.sql.Dataset) and -[Java](api/java/index.html?org/apache/spark/sql/Dataset.html). Python does not yet have support for -the Dataset API, but due to its dynamic nature many of the benefits are already available (i.e. you can -access the field of a row by name naturally `row.columnName`). Full python support will be added -in a future release. +Throughout this document, we will often refer to Scala/Java Datasets of `Row`s as DataFrames. # Getting Started -## Starting Point: SQLContext +## Starting Point: SparkSession
    -The entry point into all functionality in Spark SQL is the -[`SQLContext`](api/scala/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. - -{% highlight scala %} -val sc: SparkContext // An existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) - -// this is used to implicitly convert an RDD to a DataFrame. -import sqlContext.implicits._ -{% endhighlight %} +The entry point into all functionality in Spark is the [`SparkSession`](api/scala/index.html#org.apache.spark.sql.SparkSession) class. To create a basic `SparkSession`, just use `SparkSession.builder()`: +{% include_example init_session scala/org/apache/spark/examples/sql/SparkSQLExample.scala %}
    -The entry point into all functionality in Spark SQL is the -[`SQLContext`](api/java/index.html#org.apache.spark.sql.SQLContext) class, or one of its -descendants. To create a basic `SQLContext`, all you need is a SparkContext. - -{% highlight java %} -JavaSparkContext sc = ...; // An existing JavaSparkContext. -SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -{% endhighlight %} +The entry point into all functionality in Spark is the [`SparkSession`](api/java/index.html#org.apache.spark.sql.SparkSession) class. To create a basic `SparkSession`, just use `SparkSession.builder()`: +{% include_example init_session java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %}
    -The entry point into all relational functionality in Spark is the -[`SQLContext`](api/python/pyspark.sql.html#pyspark.sql.SQLContext) class, or one -of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. - -{% highlight python %} -from pyspark.sql import SQLContext -sqlContext = SQLContext(sc) -{% endhighlight %} +The entry point into all functionality in Spark is the [`SparkSession`](api/python/pyspark.sql.html#pyspark.sql.SparkSession) class. To create a basic `SparkSession`, just use `SparkSession.builder`: +{% include_example init_session python/sql/basic.py %}
    -The entry point into all relational functionality in Spark is the -`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. +The entry point into all functionality in Spark is the [`SparkSession`](api/R/sparkR.session.html) class. To initialize a basic `SparkSession`, just call `sparkR.session()`: -{% highlight r %} -sqlContext <- sparkRSQL.init(sc) -{% endhighlight %} +{% include_example init_session r/RSparkSQLExample.R %} +Note that when invoked for the first time, `sparkR.session()` initializes a global `SparkSession` singleton instance, and always returns a reference to this instance for successive invocations. In this way, users only need to initialize the `SparkSession` once, then SparkR functions like `read.df` will be able to access this global instance implicitly, and users don't need to pass the `SparkSession` instance around.
    -In addition to the basic `SQLContext`, you can also create a `HiveContext`, which provides a -superset of the functionality provided by the basic `SQLContext`. Additional features include -the ability to write queries using the more complete HiveQL parser, access to Hive UDFs, and the -ability to read data from Hive tables. To use a `HiveContext`, you do not need to have an -existing Hive setup, and all of the data sources available to a `SQLContext` are still available. -`HiveContext` is only packaged separately to avoid including all of Hive's dependencies in the default -Spark build. If these dependencies are not a problem for your application then using `HiveContext` -is recommended for the 1.3 release of Spark. Future releases will focus on bringing `SQLContext` up -to feature parity with a `HiveContext`. - +`SparkSession` in Spark 2.0 provides builtin support for Hive features including the ability to +write queries using HiveQL, access to Hive UDFs, and the ability to read data from Hive tables. +To use these features, you do not need to have an existing Hive setup. ## Creating DataFrames -With a `SQLContext`, applications can create `DataFrame`s from an existing `RDD`, from a Hive table, or from data sources. - -As an example, the following creates a `DataFrame` based on the content of a JSON file: -
    -{% highlight scala %} -val sc: SparkContext // An existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) - -val df = sqlContext.read.json("examples/src/main/resources/people.json") +With a `SparkSession`, applications can create DataFrames from an [existing `RDD`](#interoperating-with-rdds), +from a Hive table, or from [Spark data sources](#data-sources). -// Displays the content of the DataFrame to stdout -df.show() -{% endhighlight %} +As an example, the following creates a DataFrame based on the content of a JSON file: +{% include_example create_df scala/org/apache/spark/examples/sql/SparkSQLExample.scala %}
    -{% highlight java %} -JavaSparkContext sc = ...; // An existing JavaSparkContext. -SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); +With a `SparkSession`, applications can create DataFrames from an [existing `RDD`](#interoperating-with-rdds), +from a Hive table, or from [Spark data sources](#data-sources). -DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); - -// Displays the content of the DataFrame to stdout -df.show(); -{% endhighlight %} +As an example, the following creates a DataFrame based on the content of a JSON file: +{% include_example create_df java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %}
    -{% highlight python %} -from pyspark.sql import SQLContext -sqlContext = SQLContext(sc) +With a `SparkSession`, applications can create DataFrames from an [existing `RDD`](#interoperating-with-rdds), +from a Hive table, or from [Spark data sources](#data-sources). -df = sqlContext.read.json("examples/src/main/resources/people.json") - -# Displays the content of the DataFrame to stdout -df.show() -{% endhighlight %} +As an example, the following creates a DataFrame based on the content of a JSON file: +{% include_example create_df python/sql/basic.py %}
    -{% highlight r %} -sqlContext <- SQLContext(sc) +With a `SparkSession`, applications can create DataFrames from a local R data.frame, +from a Hive table, or from [Spark data sources](#data-sources). -df <- read.json(sqlContext, "examples/src/main/resources/people.json") +As an example, the following creates a DataFrame based on the content of a JSON file: -# Displays the content of the DataFrame to stdout -showDF(df) -{% endhighlight %} +{% include_example create_df r/RSparkSQLExample.R %}
    -
    -## DataFrame Operations +## Untyped Dataset Operations (aka DataFrame Operations) + +DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.Dataset), [Java](api/java/index.html?org/apache/spark/sql/Dataset.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame) and [R](api/R/SparkDataFrame.html). -DataFrames provide a domain-specific language for structured data manipulation in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame) and [R](api/R/DataFrame.html). +As mentioned above, in Spark 2.0, DataFrames are just Dataset of `Row`s in Scala and Java API. These operations are also referred as "untyped transformations" in contrast to "typed transformations" come with strongly typed Scala/Java Datasets. -Here we include some basic examples of structured data processing using DataFrames: +Here we include some basic examples of structured data processing using Datasets:
    -{% highlight scala %} -val sc: SparkContext // An existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) - -// Create the DataFrame -val df = sqlContext.read.json("examples/src/main/resources/people.json") - -// Show the content of the DataFrame -df.show() -// age name -// null Michael -// 30 Andy -// 19 Justin - -// Print the schema in a tree format -df.printSchema() -// root -// |-- age: long (nullable = true) -// |-- name: string (nullable = true) - -// Select only the "name" column -df.select("name").show() -// name -// Michael -// Andy -// Justin - -// Select everybody, but increment the age by 1 -df.select(df("name"), df("age") + 1).show() -// name (age + 1) -// Michael null -// Andy 31 -// Justin 20 - -// Select people older than 21 -df.filter(df("age") > 21).show() -// age name -// 30 Andy - -// Count people by age -df.groupBy("age").count().show() -// age count -// null 1 -// 19 1 -// 30 1 -{% endhighlight %} - -For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.DataFrame). - -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$). +{% include_example untyped_ops scala/org/apache/spark/examples/sql/SparkSQLExample.scala %} +For a complete list of the types of operations that can be performed on a Dataset refer to the [API Documentation](api/scala/index.html#org.apache.spark.sql.Dataset). +In addition to simple column references and expressions, Datasets also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/scala/index.html#org.apache.spark.sql.functions$).
    -{% highlight java %} -JavaSparkContext sc // An existing SparkContext. -SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) - -// Create the DataFrame -DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); - -// Show the content of the DataFrame -df.show(); -// age name -// null Michael -// 30 Andy -// 19 Justin - -// Print the schema in a tree format -df.printSchema(); -// root -// |-- age: long (nullable = true) -// |-- name: string (nullable = true) - -// Select only the "name" column -df.select("name").show(); -// name -// Michael -// Andy -// Justin - -// Select everybody, but increment the age by 1 -df.select(df.col("name"), df.col("age").plus(1)).show(); -// name (age + 1) -// Michael null -// Andy 31 -// Justin 20 - -// Select people older than 21 -df.filter(df.col("age").gt(21)).show(); -// age name -// 30 Andy - -// Count people by age -df.groupBy("age").count().show(); -// age count -// null 1 -// 19 1 -// 30 1 -{% endhighlight %} -For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/java/org/apache/spark/sql/DataFrame.html). +{% include_example untyped_ops java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %} -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html). +For a complete list of the types of operations that can be performed on a Dataset refer to the [API Documentation](api/java/org/apache/spark/sql/Dataset.html). +In addition to simple column references and expressions, Datasets also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/java/org/apache/spark/sql/functions.html).
    @@ -307,54 +171,7 @@ interactive data exploration, users are highly encouraged to use the latter form, which is future proof and won't break with column names that are also attributes on the DataFrame class. -{% highlight python %} -from pyspark.sql import SQLContext -sqlContext = SQLContext(sc) - -# Create the DataFrame -df = sqlContext.read.json("examples/src/main/resources/people.json") - -# Show the content of the DataFrame -df.show() -## age name -## null Michael -## 30 Andy -## 19 Justin - -# Print the schema in a tree format -df.printSchema() -## root -## |-- age: long (nullable = true) -## |-- name: string (nullable = true) - -# Select only the "name" column -df.select("name").show() -## name -## Michael -## Andy -## Justin - -# Select everybody, but increment the age by 1 -df.select(df['name'], df['age'] + 1).show() -## name (age + 1) -## Michael null -## Andy 31 -## Justin 20 - -# Select people older than 21 -df.filter(df['age'] > 21).show() -## age name -## 30 Andy - -# Count people by age -df.groupBy("age").count().show() -## age count -## null 1 -## 19 1 -## 30 1 - -{% endhighlight %} - +{% include_example untyped_ops python/sql/basic.py %} For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/python/pyspark.sql.html#pyspark.sql.DataFrame). In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/python/pyspark.sql.html#module-pyspark.sql.functions). @@ -362,56 +179,12 @@ In addition to simple column references and expressions, DataFrames also have a
    -{% highlight r %} -sqlContext <- sparkRSQL.init(sc) - -# Create the DataFrame -df <- read.json(sqlContext, "examples/src/main/resources/people.json") - -# Show the content of the DataFrame -showDF(df) -## age name -## null Michael -## 30 Andy -## 19 Justin - -# Print the schema in a tree format -printSchema(df) -## root -## |-- age: long (nullable = true) -## |-- name: string (nullable = true) - -# Select only the "name" column -showDF(select(df, "name")) -## name -## Michael -## Andy -## Justin - -# Select everybody, but increment the age by 1 -showDF(select(df, df$name, df$age + 1)) -## name (age + 1) -## Michael null -## Andy 31 -## Justin 20 - -# Select people older than 21 -showDF(where(df, df$age > 21)) -## age name -## 30 Andy - -# Count people by age -showDF(count(groupBy(df, "age"))) -## age count -## null 1 -## 19 1 -## 30 1 -{% endhighlight %} +{% include_example untyped_ops r/RSparkSQLExample.R %} For a complete list of the types of operations that can be performed on a DataFrame refer to the [API Documentation](api/R/index.html). -In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/index.html). +In addition to simple column references and expressions, DataFrames also have a rich library of functions including string manipulation, date arithmetic, common math operations and more. The complete list is available in the [DataFrame Function Reference](api/R/SparkDataFrame.html).
    @@ -419,44 +192,37 @@ In addition to simple column references and expressions, DataFrames also have a ## Running SQL Queries Programmatically -The `sql` function on a `SQLContext` enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. -
    -{% highlight scala %} -val sqlContext = ... // An existing SQLContext -val df = sqlContext.sql("SELECT * FROM table") -{% endhighlight %} +The `sql` function on a `SparkSession` enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +{% include_example run_sql scala/org/apache/spark/examples/sql/SparkSQLExample.scala %}
    -{% highlight java %} -SQLContext sqlContext = ... // An existing SQLContext -DataFrame df = sqlContext.sql("SELECT * FROM table") -{% endhighlight %} +The `sql` function on a `SparkSession` enables applications to run SQL queries programmatically and returns the result as a `Dataset`. + +{% include_example run_sql java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %}
    -{% highlight python %} -from pyspark.sql import SQLContext -sqlContext = SQLContext(sc) -df = sqlContext.sql("SELECT * FROM table") -{% endhighlight %} +The `sql` function on a `SparkSession` enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +{% include_example run_sql python/sql/basic.py %}
    -{% highlight r %} -sqlContext <- sparkRSQL.init(sc) -df <- sql(sqlContext, "SELECT * FROM table") -{% endhighlight %} -
    +The `sql` function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. + +{% include_example run_sql r/RSparkSQLExample.R %}
    +
    ## Creating Datasets -Datasets are similar to RDDs, however, instead of using Java Serialization or Kryo they use +Datasets are similar to RDDs, however, instead of using Java serialization or Kryo they use a specialized [Encoder](api/scala/index.html#org.apache.spark.sql.Encoder) to serialize the objects for processing or transmitting over the network. While both encoders and standard serialization are responsible for turning an object into bytes, encoders are code generated dynamically and use a format @@ -465,44 +231,24 @@ the bytes back into an object.
    - -{% highlight scala %} -// Encoders for most common types are automatically provided by importing sqlContext.implicits._ -val ds = Seq(1, 2, 3).toDS() -ds.map(_ + 1).collect() // Returns: Array(2, 3, 4) - -// Encoders are also created for case classes. -case class Person(name: String, age: Long) -val ds = Seq(Person("Andy", 32)).toDS() - -// DataFrames can be converted to a Dataset by providing a class. Mapping will be done by name. -val path = "examples/src/main/resources/people.json" -val people = sqlContext.read.json(path).as[Person] - -{% endhighlight %} - +{% include_example create_ds scala/org/apache/spark/examples/sql/SparkSQLExample.scala %}
    - -{% highlight java %} -JavaSparkContext sc = ...; // An existing JavaSparkContext. -SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -{% endhighlight %} - +{% include_example create_ds java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %}
    ## Interoperating with RDDs -Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first +Spark SQL supports two different methods for converting existing RDDs into Datasets. The first method uses reflection to infer the schema of an RDD that contains specific types of objects. This reflection based approach leads to more concise code and works well when you already know the schema while writing your Spark application. -The second method for creating DataFrames is through a programmatic interface that allows you to +The second method for creating Datasets is through a programmatic interface that allows you to construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows -you to construct DataFrames when the columns and their types are not known until runtime. +you to construct Datasets when the columns and their types are not known until runtime. ### Inferring the Schema Using Reflection
    @@ -513,147 +259,31 @@ The Scala interface for Spark SQL supports automatically converting an RDD conta to a DataFrame. The case class defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex -types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be +types such as `Seq`s or `Array`s. This RDD can be implicitly converted to a DataFrame and then be registered as a table. Tables can be used in subsequent SQL statements. -{% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// this is used to implicitly convert an RDD to a DataFrame. -import sqlContext.implicits._ - -// Define the schema using a case class. -// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, -// you can use custom classes that implement the Product interface. -case class Person(name: String, age: Int) - -// Create an RDD of Person objects and register it as a table. -val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)).toDF() -people.registerTempTable("people") - -// SQL statements can be run by using the sql methods provided by sqlContext. -val teenagers = sqlContext.sql("SELECT name, age FROM people WHERE age >= 13 AND age <= 19") - -// The results of SQL queries are DataFrames and support all the normal RDD operations. -// The columns of a row in the result can be accessed by field index: -teenagers.map(t => "Name: " + t(0)).collect().foreach(println) - -// or by field name: -teenagers.map(t => "Name: " + t.getAs[String]("name")).collect().foreach(println) - -// row.getValuesMap[T] retrieves multiple columns at once into a Map[String, T] -teenagers.map(_.getValuesMap[Any](List("name", "age"))).collect().foreach(println) -// Map("name" -> "Justin", "age" -> 19) -{% endhighlight %} - +{% include_example schema_inferring scala/org/apache/spark/examples/sql/SparkSQLExample.scala %}
    -Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) -into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table. -Currently, Spark SQL does not support JavaBeans that contain -nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a -class that implements Serializable and has getters and setters for all of its fields. - -{% highlight java %} - -public static class Person implements Serializable { - private String name; - private int age; - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public int getAge() { - return age; - } - - public void setAge(int age) { - this.age = age; - } -} - -{% endhighlight %} - - -A schema can be applied to an existing RDD by calling `createDataFrame` and providing the Class object -for the JavaBean. - -{% highlight java %} -// sc is an existing JavaSparkContext. -SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); - -// Load a text file and convert each line to a JavaBean. -JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( - new Function() { - public Person call(String line) throws Exception { - String[] parts = line.split(","); - - Person person = new Person(); - person.setName(parts[0]); - person.setAge(Integer.parseInt(parts[1].trim())); - - return person; - } - }); - -// Apply a schema to an RDD of JavaBeans and register it as a table. -DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class); -schemaPeople.registerTempTable("people"); - -// SQL can be run over RDDs that have been registered as tables. -DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") - -// The results of SQL queries are DataFrames and support all the normal RDD operations. -// The columns of a row in the result can be accessed by ordinal. -List teenagerNames = teenagers.javaRDD().map(new Function() { - public String call(Row row) { - return "Name: " + row.getString(0); - } -}).collect(); - -{% endhighlight %} +Spark SQL supports automatically converting an RDD of +[JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) into a DataFrame. +The `BeanInfo`, obtained using reflection, defines the schema of the table. Currently, Spark SQL +does not support JavaBeans that contain `Map` field(s). Nested JavaBeans and `List` or `Array` +fields are supported though. You can create a JavaBean by creating a class that implements +Serializable and has getters and setters for all of its fields. +{% include_example schema_inferring java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %}
    Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table, -and the types are inferred by looking at the first row. Since we currently only look at the first -row, it is important that there is no missing data in the first row of the RDD. In future versions we -plan to more completely infer the schema by looking at more data, similar to the inference that is -performed on JSON files. - -{% highlight python %} -# sc is an existing SparkContext. -from pyspark.sql import SQLContext, Row -sqlContext = SQLContext(sc) - -# Load a text file and convert each line to a Row. -lines = sc.textFile("examples/src/main/resources/people.txt") -parts = lines.map(lambda l: l.split(",")) -people = parts.map(lambda p: Row(name=p[0], age=int(p[1]))) - -# Infer the schema, and register the DataFrame as a table. -schemaPeople = sqlContext.createDataFrame(people) -schemaPeople.registerTempTable("people") - -# SQL can be run over DataFrames that have been registered as a table. -teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") - -# The results of SQL queries are RDDs and support all the normal RDD operations. -teenNames = teenagers.map(lambda p: "Name: " + p.name) -for teenName in teenNames.collect(): - print(teenName) -{% endhighlight %} +and the types are inferred by sampling the whole datase, similar to the inference that is performed on JSON files. +{% include_example schema_inferring python/sql/basic.py %}
    @@ -673,48 +303,11 @@ a `DataFrame` can be created programmatically with three steps. 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. 3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided -by `SQLContext`. +by `SparkSession`. For example: -{% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) - -// Create an RDD -val people = sc.textFile("examples/src/main/resources/people.txt") - -// The schema is encoded in a string -val schemaString = "name age" - -// Import Row. -import org.apache.spark.sql.Row; - -// Import Spark SQL data types -import org.apache.spark.sql.types.{StructType,StructField,StringType}; - -// Generate the schema based on the string of schema -val schema = - StructType( - schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true))) - -// Convert records of the RDD (people) to Rows. -val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim)) - -// Apply the schema to the RDD. -val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema) - -// Register the DataFrames as a table. -peopleDataFrame.registerTempTable("people") - -// SQL statements can be run by using the sql methods provided by sqlContext. -val results = sqlContext.sql("SELECT name FROM people") - -// The results of SQL queries are DataFrames and support all the normal RDD operations. -// The columns of a row in the result can be accessed by field index or by field name. -results.map(t => "Name: " + t(0)).collect().foreach(println) -{% endhighlight %} - +{% include_example programmatic_schema scala/org/apache/spark/examples/sql/SparkSQLExample.scala %}
    @@ -722,71 +315,17 @@ results.map(t => "Name: " + t(0)).collect().foreach(println) When JavaBean classes cannot be defined ahead of time (for example, the structure of records is encoded in a string, or a text dataset will be parsed and fields will be projected differently for different users), -a `DataFrame` can be created programmatically with three steps. +a `Dataset` can be created programmatically with three steps. 1. Create an RDD of `Row`s from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of `Row`s in the RDD created in Step 1. 3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided -by `SQLContext`. +by `SparkSession`. For example: -{% highlight java %} -import org.apache.spark.api.java.function.Function; -// Import factory methods provided by DataTypes. -import org.apache.spark.sql.types.DataTypes; -// Import StructType and StructField -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.StructField; -// Import Row. -import org.apache.spark.sql.Row; -// Import RowFactory. -import org.apache.spark.sql.RowFactory; - -// sc is an existing JavaSparkContext. -SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); - -// Load a text file and convert each line to a JavaBean. -JavaRDD people = sc.textFile("examples/src/main/resources/people.txt"); - -// The schema is encoded in a string -String schemaString = "name age"; - -// Generate the schema based on the string of schema -List fields = new ArrayList<>(); -for (String fieldName: schemaString.split(" ")) { - fields.add(DataTypes.createStructField(fieldName, DataTypes.StringType, true)); -} -StructType schema = DataTypes.createStructType(fields); - -// Convert records of the RDD (people) to Rows. -JavaRDD rowRDD = people.map( - new Function() { - public Row call(String record) throws Exception { - String[] fields = record.split(","); - return RowFactory.create(fields[0], fields[1].trim()); - } - }); - -// Apply the schema to the RDD. -DataFrame peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema); - -// Register the DataFrame as a table. -peopleDataFrame.registerTempTable("people"); - -// SQL can be run over RDDs that have been registered as tables. -DataFrame results = sqlContext.sql("SELECT name FROM people"); - -// The results of SQL queries are DataFrames and support all the normal RDD operations. -// The columns of a row in the result can be accessed by ordinal. -List names = results.javaRDD().map(new Function() { - public String call(Row row) { - return "Name: " + row.getString(0); - } -}).collect(); - -{% endhighlight %} +{% include_example programmatic_schema java/org/apache/spark/examples/sql/JavaSparkSQLExample.java %}
    @@ -799,43 +338,11 @@ a `DataFrame` can be created programmatically with three steps. 1. Create an RDD of tuples or lists from the original RDD; 2. Create the schema represented by a `StructType` matching the structure of tuples or lists in the RDD created in the step 1. -3. Apply the schema to the RDD via `createDataFrame` method provided by `SQLContext`. +3. Apply the schema to the RDD via `createDataFrame` method provided by `SparkSession`. For example: -{% highlight python %} -# Import SQLContext and data types -from pyspark.sql import SQLContext -from pyspark.sql.types import * - -# sc is an existing SparkContext. -sqlContext = SQLContext(sc) - -# Load a text file and convert each line to a tuple. -lines = sc.textFile("examples/src/main/resources/people.txt") -parts = lines.map(lambda l: l.split(",")) -people = parts.map(lambda p: (p[0], p[1].strip())) - -# The schema is encoded in a string. -schemaString = "name age" - -fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split()] -schema = StructType(fields) - -# Apply the schema to the RDD. -schemaPeople = sqlContext.createDataFrame(people, schema) - -# Register the DataFrame as a table. -schemaPeople.registerTempTable("people") - -# SQL can be run over DataFrames that have been registered as a table. -results = sqlContext.sql("SELECT name FROM people") - -# The results of SQL queries are RDDs and support all the normal RDD operations. -names = results.map(lambda p: "Name: " + p.name) -for name in names.collect(): - print(name) -{% endhighlight %} +{% include_example programmatic_schema python/sql/basic.py %}
    @@ -843,9 +350,9 @@ for name in names.collect(): # Data Sources -Spark SQL supports operating on a variety of data sources through the `DataFrame` interface. -A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table. -Registering a DataFrame as a table allows you to run SQL queries over its data. This section +Spark SQL supports operating on a variety of data sources through the DataFrame interface. +A DataFrame can be operated on using relational transformations and can also be used to create a temporary view. +Registering a DataFrame as a temporary view allows you to run SQL queries over its data. This section describes the general methods for loading and saving data using the Spark Data Sources and then goes into specific options that are available for the built-in data sources. @@ -856,42 +363,21 @@ In the simplest form, the default data source (`parquet` unless otherwise config
    - -{% highlight scala %} -val df = sqlContext.read.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") -{% endhighlight %} - +{% include_example generic_load_save_functions scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
    - -{% highlight java %} - -DataFrame df = sqlContext.read().load("examples/src/main/resources/users.parquet"); -df.select("name", "favorite_color").write().save("namesAndFavColors.parquet"); - -{% endhighlight %} - +{% include_example generic_load_save_functions java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
    -{% highlight python %} - -df = sqlContext.read.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") - -{% endhighlight %} - +{% include_example generic_load_save_functions python/sql/datasource.py %}
    -{% highlight r %} -df <- read.df(sqlContext, "examples/src/main/resources/users.parquet") -write.df(select(df, "name", "favorite_color"), "namesAndFavColors.parquet") -{% endhighlight %} +{% include_example generic_load_save_functions r/RSparkSQLExample.R %}
    @@ -901,49 +387,24 @@ write.df(select(df, "name", "favorite_color"), "namesAndFavColors.parquet") You can also manually specify the data source that will be used along with any extra options that you would like to pass to the data source. Data sources are specified by their fully qualified name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short -names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types -using this syntax. +names (`json`, `parquet`, `jdbc`, `orc`, `libsvm`, `csv`, `text`). DataFrames loaded from any data +source type can be converted into other types using this syntax.
    - -{% highlight scala %} -val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") -df.select("name", "age").write.format("parquet").save("namesAndAges.parquet") -{% endhighlight %} - +{% include_example manual_load_options scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
    - -{% highlight java %} - -DataFrame df = sqlContext.read().format("json").load("examples/src/main/resources/people.json"); -df.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); - -{% endhighlight %} - +{% include_example manual_load_options java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
    - -{% highlight python %} - -df = sqlContext.read.load("examples/src/main/resources/people.json", format="json") -df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") - -{% endhighlight %} - +{% include_example manual_load_options python/sql/datasource.py %}
    -
    - -{% highlight r %} - -df <- read.df(sqlContext, "examples/src/main/resources/people.json", "json") -write.df(select(df, "name", "age"), "namesAndAges.parquet", "parquet") - -{% endhighlight %} +
    +{% include_example manual_load_options r/RSparkSQLExample.R %}
    @@ -954,33 +415,19 @@ file directly with SQL.
    - -{% highlight scala %} -val df = sqlContext.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") -{% endhighlight %} - +{% include_example direct_sql scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
    - -{% highlight java %} -DataFrame df = sqlContext.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); -{% endhighlight %} +{% include_example direct_sql java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
    - -{% highlight python %} -df = sqlContext.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") -{% endhighlight %} - +{% include_example direct_sql python/sql/datasource.py %}
    - -{% highlight r %} -df <- sql(sqlContext, "SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") -{% endhighlight %} +{% include_example direct_sql r/RSparkSQLExample.R %}
    @@ -989,7 +436,7 @@ df <- sql(sqlContext, "SELECT * FROM parquet.`examples/src/main/resources/users. Save operations can optionally take a `SaveMode`, that specifies how to handle existing data if present. It is important to realize that these save modes do not utilize any locking and are not -atomic. Additionally, when performing a `Overwrite`, the data will be deleted before writing out the +atomic. Additionally, when performing an `Overwrite`, the data will be deleted before writing out the new data. @@ -1032,12 +479,13 @@ new data. ### Saving to Persistent Tables -When working with a `HiveContext`, `DataFrames` can also be saved as persistent tables using the -`saveAsTable` command. Unlike the `registerTempTable` command, `saveAsTable` will materialize the -contents of the dataframe and create a pointer to the data in the HiveMetastore. Persistent tables -will still exist even after your Spark program has restarted, as long as you maintain your connection -to the same metastore. A DataFrame for a persistent table can be created by calling the `table` -method on a `SQLContext` with the name of the table. +`DataFrames` can also be saved as persistent tables into Hive metastore using the `saveAsTable` +command. Notice existing Hive deployment is not necessary to use this feature. Spark will create a +default local Hive metastore (using Derby) for you. Unlike the `createOrReplaceTempView` command, +`saveAsTable` will materialize the contents of the DataFrame and create a pointer to the data in the +Hive metastore. Persistent tables will still exist even after your Spark program has restarted, as +long as you maintain your connection to the same metastore. A DataFrame for a persistent table can +be created by calling the `table` method on a `SparkSession` with the name of the table. By default `saveAsTable` will create a "managed table", meaning that the location of the data will be controlled by the metastore. Managed tables will also have their data deleted automatically @@ -1047,7 +495,7 @@ when a table is dropped. [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. When writing Parquet files, all columns are automatically converted to be nullable for +of the original data. When writing Parquet files, all columns are automatically converted to be nullable for compatibility reasons. ### Loading Data Programmatically @@ -1057,102 +505,21 @@ Using the data from the above example:
    - -{% highlight scala %} -// sqlContext from the previous example is used in this example. -// This is used to implicitly convert an RDD to a DataFrame. -import sqlContext.implicits._ - -val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. - -// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. -people.write.parquet("people.parquet") - -// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a Parquet file is also a DataFrame. -val parquetFile = sqlContext.read.parquet("people.parquet") - -//Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerTempTable("parquetFile") -val teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") -teenagers.map(t => "Name: " + t(0)).collect().foreach(println) -{% endhighlight %} - +{% include_example basic_parquet_example scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
    - -{% highlight java %} -// sqlContext from the previous example is used in this example. - -DataFrame schemaPeople = ... // The DataFrame from the previous example. - -// DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.write().parquet("people.parquet"); - -// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a parquet file is also a DataFrame. -DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); - -// Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerTempTable("parquetFile"); -DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); -List teenagerNames = teenagers.javaRDD().map(new Function() { - public String call(Row row) { - return "Name: " + row.getString(0); - } -}).collect(); -{% endhighlight %} - +{% include_example basic_parquet_example java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
    -{% highlight python %} -# sqlContext from the previous example is used in this example. - -schemaPeople # The DataFrame from the previous example. - -# DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.write.parquet("people.parquet") - -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -# The result of loading a parquet file is also a DataFrame. -parquetFile = sqlContext.read.parquet("people.parquet") - -# Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerTempTable("parquetFile"); -teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") -teenNames = teenagers.map(lambda p: "Name: " + p.name) -for teenName in teenNames.collect(): - print(teenName) -{% endhighlight %} - +{% include_example basic_parquet_example python/sql/datasource.py %}
    -{% highlight r %} -# sqlContext from the previous example is used in this example. - -schemaPeople # The DataFrame from the previous example. - -# DataFrames can be saved as Parquet files, maintaining the schema information. -write.parquet(schemaPeople, "people.parquet") - -# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. -# The result of loading a parquet file is also a DataFrame. -parquetFile <- read.parquet(sqlContext, "people.parquet") - -# Parquet files can also be registered as tables and then used in SQL statements. -registerTempTable(parquetFile, "parquetFile") -teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") -schema <- structType(structField("name", "string")) -teenNames <- dapply(df, function(p) { cbind(paste("Name:", p$name)) }, schema) -for (teenName in collect(teenNames)$name) { - cat(teenName, "\n") -} -{% endhighlight %} +{% include_example basic_parquet_example r/RSparkSQLExample.R %}
    @@ -1160,7 +527,7 @@ for (teenName in collect(teenNames)$name) { {% highlight sql %} -CREATE TEMPORARY TABLE parquetTable +CREATE TEMPORARY VIEW parquetTable USING org.apache.spark.sql.parquet OPTIONS ( path "examples/src/main/resources/people.parquet" @@ -1207,7 +574,7 @@ path {% endhighlight %} -By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL +By passing `path/to/table` to either `SparkSession.read.parquet` or `SparkSession.read.load`, Spark SQL will automatically extract the partitioning information from the paths. Now the schema of the returned DataFrame becomes: @@ -1228,8 +595,8 @@ can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, w `true`. When type inference is disabled, string type will be used for the partitioning columns. Starting from Spark 1.6.0, partition discovery only finds partitions under the given paths -by default. For the above example, if users pass `path/to/table/gender=male` to either -`SQLContext.read.parquet` or `SQLContext.read.load`, `gender` will not be considered as a +by default. For the above example, if users pass `path/to/table/gender=male` to either +`SparkSession.read.parquet` or `SparkSession.read.load`, `gender` will not be considered as a partitioning column. If users need to specify the base path that partition discovery should start with, they can set `basePath` in the data source options. For example, when `path/to/table/gender=male` is the path of the data and @@ -1252,91 +619,21 @@ turned it off by default starting from 1.5.0. You may enable it by
    +{% include_example schema_merging scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
    -{% highlight scala %} -// sqlContext from the previous example is used in this example. -// This is used to implicitly convert an RDD to a DataFrame. -import sqlContext.implicits._ - -// Create a simple DataFrame, stored into a partition directory -val df1 = sc.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") -df1.write.parquet("data/test_table/key=1") - -// Create another DataFrame in a new partition directory, -// adding a new column and dropping an existing column -val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") -df2.write.parquet("data/test_table/key=2") - -// Read the partitioned table -val df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") -df3.printSchema() - -// The final schema consists of all 3 columns in the Parquet files together -// with the partitioning column appeared in the partition directory paths. -// root -// |-- single: int (nullable = true) -// |-- double: int (nullable = true) -// |-- triple: int (nullable = true) -// |-- key : int (nullable = true) -{% endhighlight %} - +
    +{% include_example schema_merging java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
    -{% highlight python %} -# sqlContext from the previous example is used in this example. - -# Create a simple DataFrame, stored into a partition directory -df1 = sqlContext.createDataFrame(sc.parallelize(range(1, 6))\ - .map(lambda i: Row(single=i, double=i * 2))) -df1.write.parquet("data/test_table/key=1") - -# Create another DataFrame in a new partition directory, -# adding a new column and dropping an existing column -df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) - .map(lambda i: Row(single=i, triple=i * 3))) -df2.write.parquet("data/test_table/key=2") - -# Read the partitioned table -df3 = sqlContext.read.option("mergeSchema", "true").parquet("data/test_table") -df3.printSchema() - -# The final schema consists of all 3 columns in the Parquet files together -# with the partitioning column appeared in the partition directory paths. -# root -# |-- single: int (nullable = true) -# |-- double: int (nullable = true) -# |-- triple: int (nullable = true) -# |-- key : int (nullable = true) -{% endhighlight %} - +{% include_example schema_merging python/sql/datasource.py %}
    -{% highlight r %} -# sqlContext from the previous example is used in this example. - -# Create a simple DataFrame, stored into a partition directory -write.df(df1, "data/test_table/key=1", "parquet", "overwrite") - -# Create another DataFrame in a new partition directory, -# adding a new column and dropping an existing column -write.df(df2, "data/test_table/key=2", "parquet", "overwrite") - -# Read the partitioned table -df3 <- read.df(sqlContext, "data/test_table", "parquet", mergeSchema="true") -printSchema(df3) - -# The final schema consists of all 3 columns in the Parquet files together -# with the partitioning column appeared in the partition directory paths. -# root -# |-- single: int (nullable = true) -# |-- double: int (nullable = true) -# |-- triple: int (nullable = true) -# |-- key : int (nullable = true) -{% endhighlight %} +{% include_example schema_merging r/RSparkSQLExample.R %}
    @@ -1381,8 +678,8 @@ metadata.
    {% highlight scala %} -// sqlContext is an existing HiveContext -sqlContext.refreshTable("my_table") +// spark is an existing SparkSession +spark.catalog.refreshTable("my_table") {% endhighlight %}
    @@ -1390,8 +687,8 @@ sqlContext.refreshTable("my_table")
    {% highlight java %} -// sqlContext is an existing HiveContext -sqlContext.refreshTable("my_table") +// spark is an existing SparkSession +spark.catalog().refreshTable("my_table"); {% endhighlight %}
    @@ -1399,8 +696,8 @@ sqlContext.refreshTable("my_table")
    {% highlight python %} -# sqlContext is an existing HiveContext -sqlContext.refreshTable("my_table") +# spark is an existing SparkSession +spark.catalog.refreshTable("my_table") {% endhighlight %}
    @@ -1417,7 +714,7 @@ REFRESH TABLE my_table; ### Configuration -Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running +Configuration of Parquet can be done using the `setConf` method on `SparkSession` or by running `SET key=value` commands using SQL.
    @@ -1448,7 +745,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` - + - + - + - + - + @@ -1880,54 +1028,26 @@ the Data Sources API. The following options are supported:
    - -{% highlight scala %} -val jdbcDF = sqlContext.read.format("jdbc").options( - Map("url" -> "jdbc:postgresql:dbserver", - "dbtable" -> "schema.tablename")).load() -{% endhighlight %} - +{% include_example jdbc_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
    - -{% highlight java %} - -Map options = new HashMap<>(); -options.put("url", "jdbc:postgresql:dbserver"); -options.put("dbtable", "schema.tablename"); - -DataFrame jdbcDF = sqlContext.read().format("jdbc"). options(options).load(); -{% endhighlight %} - - +{% include_example jdbc_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
    - -{% highlight python %} - -df = sqlContext.read.format('jdbc').options(url='jdbc:postgresql:dbserver', dbtable='schema.tablename').load() - -{% endhighlight %} - +{% include_example jdbc_dataset python/sql/datasource.py %}
    - -{% highlight r %} - -df <- loadDF(sqlContext, source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") - -{% endhighlight %} - +{% include_example jdbc_dataset r/RSparkSQLExample.R %}
    {% highlight sql %} -CREATE TEMPORARY TABLE jdbcTable +CREATE TEMPORARY VIEW jdbcTable USING org.apache.spark.sql.jdbc OPTIONS ( url "jdbc:postgresql:dbserver", @@ -1952,11 +1072,11 @@ turning on some experimental options. ## Caching Data In Memory -Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `dataFrame.cache()`. +Spark SQL can cache tables using an in-memory columnar format by calling `spark.cacheTable("tableName")` or `dataFrame.cache()`. Then Spark SQL will scan only required columns and will automatically tune compression to minimize -memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory. +memory usage and GC pressure. You can call `spark.uncacheTable("tableName")` to remove the table from memory. -Configuration of in-memory caching can be done using the `setConf` method on `SQLContext` or by running +Configuration of in-memory caching can be done using the `setConf` method on `SparkSession` or by running `SET key=value` commands using SQL.
    spark.sql.parquet.compression.codecgzipsnappy Sets the compression codec use when writing Parquet files. Acceptable values include: uncompressed, snappy, gzip, lzo. @@ -1469,7 +766,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
    spark.sql.parquet.mergeSchemafalsefalse

    When true, the Parquet data source merges schemas collected from all data files, otherwise the @@ -1483,157 +780,58 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`

    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using `SQLContext.read.json()` on either an RDD of String, +Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset[Row]`. +This conversion can be done using `SparkSession.read.json()` on either an RDD of String, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. -{% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) - -// A JSON dataset is pointed to by path. -// The path can be either a single text file or a directory storing text files. -val path = "examples/src/main/resources/people.json" -val people = sqlContext.read.json(path) - -// The inferred schema can be visualized using the printSchema() method. -people.printSchema() -// root -// |-- age: long (nullable = true) -// |-- name: string (nullable = true) - -// Register this DataFrame as a table. -people.registerTempTable("people") - -// SQL statements can be run by using the sql methods provided by sqlContext. -val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") - -// Alternatively, a DataFrame can be created for a JSON dataset represented by -// an RDD[String] storing one JSON object per string. -val anotherPeopleRDD = sc.parallelize( - """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) -val anotherPeople = sqlContext.read.json(anotherPeopleRDD) -{% endhighlight %} - +{% include_example json_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
    -Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using `SQLContext.read().json()` on either an RDD of String, +Spark SQL can automatically infer the schema of a JSON dataset and load it as a `Dataset`. +This conversion can be done using `SparkSession.read().json()` on either an RDD of String, or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. -{% highlight java %} -// sc is an existing JavaSparkContext. -SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); - -// A JSON dataset is pointed to by path. -// The path can be either a single text file or a directory storing text files. -DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json"); - -// The inferred schema can be visualized using the printSchema() method. -people.printSchema(); -// root -// |-- age: long (nullable = true) -// |-- name: string (nullable = true) - -// Register this DataFrame as a table. -people.registerTempTable("people"); - -// SQL statements can be run by using the sql methods provided by sqlContext. -DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - -// Alternatively, a DataFrame can be created for a JSON dataset represented by -// an RDD[String] storing one JSON object per string. -List jsonData = Arrays.asList( - "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); -JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD); -{% endhighlight %} +{% include_example json_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using `SQLContext.read.json` on a JSON file. +This conversion can be done using `SparkSession.read.json` on a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. -{% highlight python %} -# sc is an existing SparkContext. -from pyspark.sql import SQLContext -sqlContext = SQLContext(sc) - -# A JSON dataset is pointed to by path. -# The path can be either a single text file or a directory storing text files. -people = sqlContext.read.json("examples/src/main/resources/people.json") - -# The inferred schema can be visualized using the printSchema() method. -people.printSchema() -# root -# |-- age: long (nullable = true) -# |-- name: string (nullable = true) - -# Register this DataFrame as a table. -people.registerTempTable("people") - -# SQL statements can be run by using the sql methods provided by `sqlContext`. -teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") - -# Alternatively, a DataFrame can be created for a JSON dataset represented by -# an RDD[String] storing one JSON object per string. -anotherPeopleRDD = sc.parallelize([ - '{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}']) -anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) -{% endhighlight %} +{% include_example json_dataset python/sql/datasource.py %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using -the `jsonFile` function, which loads data from a directory of JSON files where each line of the +the `read.json()` function, which loads data from a directory of JSON files where each line of the files is a JSON object. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. -{% highlight r %} -# sc is an existing SparkContext. -sqlContext <- sparkRSQL.init(sc) +{% include_example json_dataset r/RSparkSQLExample.R %} -# A JSON dataset is pointed to by path. -# The path can be either a single text file or a directory storing text files. -path <- "examples/src/main/resources/people.json" -# Create a DataFrame from the file(s) pointed to by path -people <- read.json(sqlContext, path) - -# The inferred schema can be visualized using the printSchema() method. -printSchema(people) -# root -# |-- age: long (nullable = true) -# |-- name: string (nullable = true) - -# Register this DataFrame as a table. -registerTempTable(people, "people") - -# SQL statements can be run by using the sql methods provided by `sqlContext`. -teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") -{% endhighlight %}
    {% highlight sql %} -CREATE TEMPORARY TABLE jsonTable +CREATE TEMPORARY VIEW jsonTable USING org.apache.spark.sql.json OPTIONS ( path "examples/src/main/resources/people.json" @@ -1650,95 +848,45 @@ SELECT * FROM jsonTable ## Hive Tables Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). -However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. -Hive support is enabled by adding the `-Phive` and `-Phive-thriftserver` flags to Spark's build. -This command builds a new assembly directory that includes Hive. Note that this Hive assembly directory must also be present -on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries -(SerDes) in order to access data stored in Hive. +However, since Hive has a large number of dependencies, these dependencies are not included in the +default Spark distribution. If Hive dependencies can be found on the classpath, Spark will load them +automatically. Note that these Hive dependencies must also be present on all of the worker nodes, as +they will need access to the Hive serialization and deserialization libraries (SerDes) in order to +access data stored in Hive. Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), -`hdfs-site.xml` (for HDFS configuration) file in `conf/`. +and `hdfs-site.xml` (for HDFS configuration) file in `conf/`. + +When working with Hive, one must instantiate `SparkSession` with Hive support, including +connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined functions. +Users who do not have an existing Hive deployment can still enable Hive support. When not configured +by the `hive-site.xml`, the context automatically creates `metastore_db` in the current directory and +creates a directory configured by `spark.sql.warehouse.dir`, which defaults to the directory +`spark-warehouse` in the current directory that the Spark application is started. Note that +the `hive.metastore.warehouse.dir` property in `hive-site.xml` is deprecated since Spark 2.0.0. +Instead, use `spark.sql.warehouse.dir` to specify the default location of database in warehouse. +You may need to grant write privilege to the user who starts the Spark application.
    - -When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do -not have an existing Hive deployment can still create a `HiveContext`. When not configured by the -hive-site.xml, the context automatically creates `metastore_db` in the current directory and -creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`. -Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts -the spark application. - -{% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.hive.HiveContext(sc) - -sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") - -// Queries are expressed in HiveQL -sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) -{% endhighlight %} - +{% include_example spark_hive scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala %}
    - -When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` method, which allows queries to be -expressed in HiveQL. - -{% highlight java %} -// sc is an existing JavaSparkContext. -HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc.sc); - -sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); -sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); - -// Queries are expressed in HiveQL. -Row[] results = sqlContext.sql("FROM src SELECT key, value").collect(); - -{% endhighlight %} - +{% include_example spark_hive java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java %}
    - -When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. -{% highlight python %} -# sc is an existing SparkContext. -from pyspark.sql import HiveContext -sqlContext = HiveContext(sc) - -sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") - -# Queries can be expressed in HiveQL. -results = sqlContext.sql("FROM src SELECT key, value").collect() - -{% endhighlight %} - +{% include_example spark_hive python/sql/hive.py %}
    -When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +When working with Hive one must instantiate `SparkSession` with Hive support. This adds support for finding tables in the MetaStore and writing queries using HiveQL. -{% highlight r %} -# sc is an existing SparkContext. -sqlContext <- sparkRHive.init(sc) - -sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") -# Queries can be expressed in HiveQL. -results <- collect(sql(sqlContext, "FROM src SELECT key, value")) - -{% endhighlight %} +{% include_example spark_hive r/RSparkSQLExample.R %}
    @@ -1856,7 +1004,7 @@ the Data Sources API. The following options are supported: The class name of the JDBC driver to use to connect to this URL.
    partitionColumn, lowerBound, upperBound, numPartitions @@ -1868,9 +1016,9 @@ the Data Sources API. The following options are supported: partitioned and returned.
    fetchSizefetchsize The JDBC fetch size, which determines how many rows to fetch per round trip. This can help performance on JDBC drivers which default to low fetch size (eg. Oracle with 10 rows).
    @@ -1987,6 +1107,32 @@ that these options will be deprecated in future release as more optimizations ar
    + + + + + + + + + + + + + + + @@ -1997,14 +1143,6 @@ that these options will be deprecated in future release as more optimizations ar ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run. - - - - - @@ -2095,6 +1233,23 @@ options. # Migration Guide +## Upgrading From Spark SQL 1.6 to 2.0 + + - `SparkSession` is now the new entry point of Spark that replaces the old `SQLContext` and + `HiveContext`. Note that the old SQLContext and HiveContext are kept for backward compatibility. A new `catalog` interface is accessible from `SparkSession` - existing API on databases and tables access such as `listTables`, `createExternalTable`, `dropTempView`, `cacheTable` are moved here. + + - Dataset API and DataFrame API are unified. In Scala, `DataFrame` becomes a type alias for + `Dataset[Row]`, while Java API users must replace `DataFrame` with `Dataset`. Both the typed + transformations (e.g. `map`, `filter`, and `groupByKey`) and untyped transformations (e.g. + `select` and `groupBy`) are available on the Dataset class. Since compile-time type-safety in + Python and R is not a language feature, the concept of Dataset does not apply to these languages’ + APIs. Instead, `DataFrame` remains the primary programing abstraction, which is analogous to the + single-node data frame notion in these languages. + + - Dataset and DataFrame API `unionAll` has been deprecated and replaced by `union` + - Dataset and DataFrame API `explode` has been deprecated, alternatively, use `functions.explode()` with `select` or `flatMap` + - Dataset and DataFrame API `registerTempTable` has been deprecated and replaced by `createOrReplaceTempView` + ## Upgrading From Spark SQL 1.5 to 1.6 - From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC @@ -2125,7 +1280,7 @@ options. `spark.sql.parquet.mergeSchema` to `true`. - Resolution of strings to columns in python now supports using dots (`.`) to qualify the column or access nested values. For example `df['table.column.nestedField']`. However, this means that if - your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). + your column name contains any dots you must now escape them using backticks (e.g., ``table.`column.with.dots`.nested``). - In-memory columnar storage partition pruning is on by default. It can be disabled by setting `spark.sql.inMemoryColumnarStorage.partitionPruning` to `false`. - Unlimited precision decimal columns are no longer supported, instead Spark SQL enforces a maximum @@ -2207,7 +1362,7 @@ import pyspark.sql.functions as func # In 1.3.x, in order for the grouping column "department" to show up, # it must be included explicitly as part of the agg function call. -df.groupBy("department").agg("department"), func.max("age"), func.sum("expense")) +df.groupBy("department").agg(df["department"], func.max("age"), func.sum("expense")) # In 1.4+, grouping column "department" is included automatically. df.groupBy("department").agg(func.max("age"), func.sum("expense")) @@ -2467,9 +1622,8 @@ Spark SQL and DataFrames support the following data types: All data types of Spark SQL are located in the package `org.apache.spark.sql.types`. You can access them by doing -{% highlight scala %} -import org.apache.spark.sql.types._ -{% endhighlight %} + +{% include_example data_types scala/org/apache/spark/examples/sql/SparkSQLExample.scala %}
    Property NameDefaultMeaning
    spark.sql.files.maxPartitionBytes134217728 (128 MB) + The maximum number of bytes to pack into a single partition when reading files. +
    spark.sql.files.openCostInBytes4194304 (4 MB) + The estimated cost to open a file, measured by the number of bytes could be scanned in the same + time. This is used when putting multiple files into a partition. It is better to over estimated, + then the partitions with small files will be faster than partitions with bigger files (which is + scheduled first). +
    spark.sql.broadcastTimeout300 +

    + Timeout in seconds for the broadcast wait time in broadcast joins +

    +
    spark.sql.autoBroadcastJoinThreshold 10485760 (10 MB)
    spark.sql.tungsten.enabledtrue - When true, use the optimized Tungsten physical execution backend which explicitly manages memory - and dynamically generates bytecode for expression evaluation. -
    spark.sql.shuffle.partitions 200
    diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index a4e17fd24eac2..f52bf348fcc99 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -36,7 +36,7 @@ Any exception in the receiving threads should be caught and handled properly to failures of the receiver. `restart()` will restart the receiver by asynchronously calling `onStop()` and then calling `onStart()` after a delay. `stop()` will call `onStop()` and terminate the receiver. Also, `reportError()` -reports a error message to the driver (visible in the logs and UI) without stopping / restarting +reports an error message to the driver (visible in the logs and UI) without stopping / restarting the receiver. The following is a custom receiver that receives a stream of text over a socket. It treats @@ -181,7 +181,7 @@ val words = lines.flatMap(_.split(" ")) ... {% endhighlight %} -The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala). +The full source code is in the example [CustomReceiver.scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala).
    @@ -193,7 +193,7 @@ JavaDStream words = lines.flatMap(new FlatMapFunction() ... {% endhighlight %} -The full source code is in the example [JavaCustomReceiver.java](https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java). +The full source code is in the example [JavaCustomReceiver.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java).
    diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 8eeeee75dbf40..767e1f9402e01 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -63,7 +63,7 @@ configuring Flume agents. By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/flume_wordcount.py). diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md new file mode 100644 index 0000000000000..c1ef396907db7 --- /dev/null +++ b/docs/streaming-kafka-0-10-integration.md @@ -0,0 +1,307 @@ +--- +layout: global +title: Spark Streaming + Kafka Integration Guide (Kafka broker version 0.10.0 or higher) +--- + +The Spark Streaming integration for Kafka 0.10 is similar in design to the 0.8 [Direct Stream approach](streaming-kafka-0-8-integration.html#approach-2-direct-approach-no-receivers). It provides simple parallelism, 1:1 correspondence between Kafka partitions and Spark partitions, and access to offsets and metadata. However, because the newer integration uses the [new Kafka consumer API](http://kafka.apache.org/documentation.html#newconsumerapi) instead of the simple API, there are notable differences in usage. This version of the integration is marked as experimental, so the API is potentially subject to change. + +### Linking +For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +### Creating a Direct Stream + Note that the namespace for the import includes the version, org.apache.spark.streaming.kafka010 + +
    +
    + import org.apache.kafka.clients.consumer.ConsumerRecord + import org.apache.kafka.common.serialization.StringDeserializer + import org.apache.spark.streaming.kafka010._ + import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent + import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe + + val kafkaParams = Map[String, Object]( + "bootstrap.servers" -> "localhost:9092,anotherhost:9092", + "key.deserializer" -> classOf[StringDeserializer], + "value.deserializer" -> classOf[StringDeserializer], + "group.id" -> "use_a_separate_group_id_for_each_stream", + "auto.offset.reset" -> "latest", + "enable.auto.commit" -> (false: java.lang.Boolean) + ) + + val topics = Array("topicA", "topicB") + val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Subscribe[String, String](topics, kafkaParams) + ) + + stream.map(record => (record.key, record.value)) + +Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javadoc/org/apache/kafka/clients/consumer/ConsumerRecord.html) +
    +
    + import java.util.*; + import org.apache.spark.SparkConf; + import org.apache.spark.TaskContext; + import org.apache.spark.api.java.*; + import org.apache.spark.api.java.function.*; + import org.apache.spark.streaming.api.java.*; + import org.apache.spark.streaming.kafka010.*; + import org.apache.kafka.clients.consumer.ConsumerRecord; + import org.apache.kafka.common.TopicPartition; + import org.apache.kafka.common.serialization.StringDeserializer; + import scala.Tuple2; + + Map kafkaParams = new HashMap<>(); + kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); + kafkaParams.put("key.deserializer", StringDeserializer.class); + kafkaParams.put("value.deserializer", StringDeserializer.class); + kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); + kafkaParams.put("auto.offset.reset", "latest"); + kafkaParams.put("enable.auto.commit", false); + + Collection topics = Arrays.asList("topicA", "topicB"); + + final JavaInputDStream> stream = + KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(topics, kafkaParams) + ); + + stream.mapToPair( + new PairFunction, String, String>() { + @Override + public Tuple2 call(ConsumerRecord record) { + return new Tuple2<>(record.key(), record.value()); + } + }) +
    +
    + +For possible kafkaParams, see [Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). +If your Spark batch duration is larger than the default Kafka heartbeat session timeout (30 seconds), increase heartbeat.interval.ms and session.timeout.ms appropriately. For batches larger than 5 minutes, this will require changing group.max.session.timeout.ms on the broker. +Note that the example sets enable.auto.commit to false, for discussion see [Storing Offsets](streaming-kafka-0-10-integration.html#storing-offsets) below. + +### LocationStrategies +The new Kafka consumer API will pre-fetch messages into buffers. Therefore it is important for performance reasons that the Spark integration keep cached consumers on executors (rather than recreating them for each batch), and prefer to schedule partitions on the host locations that have the appropriate consumers. + +In most cases, you should use `LocationStrategies.PreferConsistent` as shown above. This will distribute partitions evenly across available executors. If your executors are on the same hosts as your Kafka brokers, use `PreferBrokers`, which will prefer to schedule partitions on the Kafka leader for that partition. Finally, if you have a significant skew in load among partitions, use `PreferFixed`. This allows you to specify an explicit mapping of partitions to hosts (any unspecified partitions will use a consistent location). + +The cache for consumers has a default maximum size of 64. If you expect to be handling more than (64 * number of executors) Kafka partitions, you can change this setting via `spark.streaming.kafka.consumer.cache.maxCapacity` + +The cache is keyed by topicpartition and group.id, so use a **separate** `group.id` for each call to `createDirectStream`. + + +### ConsumerStrategies +The new Kafka consumer API has a number of different ways to specify topics, some of which require considerable post-object-instantiation setup. `ConsumerStrategies` provides an abstraction that allows Spark to obtain properly configured consumers even after restart from checkpoint. + +`ConsumerStrategies.Subscribe`, as shown above, allows you to subscribe to a fixed collection of topics. `SubscribePattern` allows you to use a regex to specify topics of interest. Note that unlike the 0.8 integration, using `Subscribe` or `SubscribePattern` should respond to adding partitions during a running stream. Finally, `Assign` allows you to specify a fixed collection of partitions. All three strategies have overloaded constructors that allow you to specify the starting offset for a particular partition. + +If you have specific consumer setup needs that are not met by the options above, `ConsumerStrategy` is a public class that you can extend. + +### Creating an RDD +If you have a use case that is better suited to batch processing, you can create an RDD for a defined range of offsets. + +
    +
    + // Import dependencies and create kafka params as in Create Direct Stream above + + val offsetRanges = Array( + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange("test", 0, 0, 100), + OffsetRange("test", 1, 0, 100) + ) + + val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) + +
    +
    + // Import dependencies and create kafka params as in Create Direct Stream above + + OffsetRange[] offsetRanges = { + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange.create("test", 0, 0, 100), + OffsetRange.create("test", 1, 0, 100) + }; + + JavaRDD> rdd = KafkaUtils.createRDD( + sparkContext, + kafkaParams, + offsetRanges, + LocationStrategies.PreferConsistent() + ); +
    +
    + +Note that you cannot use `PreferBrokers`, because without the stream there is not a driver-side consumer to automatically look up broker metadata for you. Use `PreferFixed` with your own metadata lookups if necessary. + +### Obtaining Offsets + +
    +
    + stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd.foreachPartition { iter => + val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + } +
    +
    + stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + rdd.foreachPartition(new VoidFunction>>() { + @Override + public void call(Iterator> consumerRecords) { + OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); + } + }); + } + }); +
    +
    + +Note that the typecast to `HasOffsetRanges` will only succeed if it is done in the first method called on the result of `createDirectStream`, not later down a chain of methods. Be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). + +### Storing Offsets +Kafka delivery semantics in the case of failure depend on how and when offsets are stored. Spark output operations are [at-least-once](streaming-programming-guide.html#semantics-of-output-operations). So if you want the equivalent of exactly-once semantics, you must either store offsets after an idempotent output, or store offsets in an atomic transaction alongside output. With this integration, you have 3 options, in order of increasing reliablity (and code complexity), for how to store offsets. + +#### Checkpoints +If you enable Spark [checkpointing](streaming-programming-guide.html#checkpointing), offsets will be stored in the checkpoint. This is easy to enable, but there are drawbacks. Your output operation must be idempotent, since you will get repeated outputs; transactions are not an option. Furthermore, you cannot recover from a checkpoint if your application code has changed. For planned upgrades, you can mitigate this by running the new code at the same time as the old code (since outputs need to be idempotent anyway, they should not clash). But for unplanned failures that require code changes, you will lose data unless you have another way to identify known good starting offsets. + +#### Kafka itself +Kafka has an offset commit API that stores offsets in a special Kafka topic. By default, the new consumer will periodically auto-commit offsets. This is almost certainly not what you want, because messages successfully polled by the consumer may not yet have resulted in a Spark output operation, resulting in undefined semantics. This is why the stream example above sets "enable.auto.commit" to false. However, you can commit offsets to Kafka after you know your output has been stored, using the `commitAsync` API. The benefit as compared to checkpoints is that Kafka is a durable store regardless of changes to your application code. However, Kafka is not transactional, so your outputs must still be idempotent. + +
    +
    + stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + + // some time later, after outputs have completed + stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) + } + +As with HasOffsetRanges, the cast to CanCommitOffsets will only succeed if called on the result of createDirectStream, not after transformations. The commitAsync call is threadsafe, but must occur after outputs if you want meaningful semantics. +
    +
    + stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + // some time later, after outputs have completed + ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); + } + }); +
    +
    + +#### Your own data store +For data stores that support transactions, saving offsets in the same transaction as the results can keep the two in sync, even in failure situations. If you're careful about detecting repeated or skipped offset ranges, rolling back the transaction prevents duplicated or lost messages from affecting results. This gives the equivalent of exactly-once semantics. It is also possible to use this tactic even for outputs that result from aggregations, which are typically hard to make idempotent. + +
    +
    + // The details depend on your data store, but the general idea looks like this + + // begin from the the offsets committed to the database + val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => + new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") + }.toMap + + val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) + ) + + stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + + val results = yourCalculation(rdd) + + // begin your transaction + + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly + + // end your transaction + } +
    +
    + // The details depend on your data store, but the general idea looks like this + + // begin from the the offsets committed to the database + Map fromOffsets = new HashMap<>(); + for (resultSet : selectOffsetsFromYourDatabase) + fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); + } + + JavaInputDStream> stream = KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) + ); + + stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + Object results = yourCalculation(rdd); + + // begin your transaction + + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly + + // end your transaction + } + }); +
    +
    + +### SSL / TLS +The new Kafka consumer [supports SSL](http://kafka.apache.org/documentation.html#security_ssl). To enable it, set kafkaParams appropriately before passing to `createDirectStream` / `createRDD`. Note that this only applies to communication between Spark and Kafka brokers; you are still responsible for separately [securing](security.html) Spark inter-node communication. + + +
    +
    + val kafkaParams = Map[String, Object]( + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + "security.protocol" -> "SSL", + "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", + "ssl.truststore.password" -> "test1234", + "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", + "ssl.keystore.password" -> "test1234", + "ssl.key.password" -> "test1234" + ) +
    +
    + Map kafkaParams = new HashMap(); + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + kafkaParams.put("security.protocol", "SSL"); + kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); + kafkaParams.put("ssl.truststore.password", "test1234"); + kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); + kafkaParams.put("ssl.keystore.password", "test1234"); + kafkaParams.put("ssl.key.password", "test1234"); +
    +
    + +### Deploying + +As with any Spark applications, `spark-submit` is used to launch your application. + +For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md new file mode 100644 index 0000000000000..f8f7b95cf7458 --- /dev/null +++ b/docs/streaming-kafka-0-8-integration.md @@ -0,0 +1,210 @@ +--- +layout: global +title: Spark Streaming + Kafka Integration Guide (Kafka broker version 0.8.2.1 or higher) +--- +Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. Both approaches are considered stable APIs as of the current version of Spark. + +## Approach 1: Receiver-based Approach +This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. + +However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + + For Python applications, you will have to add this above library and its dependencies when deploying your application. See the *Deploying* subsection below. + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. + +
    +
    + import org.apache.spark.streaming.kafka._ + + val kafkaStream = KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) + + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). +
    +
    + import org.apache.spark.streaming.kafka.*; + + JavaPairReceiverInputDStream kafkaStream = + KafkaUtils.createStream(streamingContext, + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); + + You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). + +
    +
    + from pyspark.streaming.kafka import KafkaUtils + + kafkaStream = KafkaUtils.createStream(streamingContext, \ + [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) + + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/kafka_wordcount.py). +
    +
    + + **Points to remember:** + + - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. + + - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. + + - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use +`KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). + +3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. + + For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + For Python applications which lack SBT/Maven project management, `spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, + + ./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + + Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-0-8-assembly` from the + [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-0-8-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. + +## Approach 2: Direct Approach (No Receivers) +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this feature was introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. + +This approach has the following advantages over the receiver-based approach (i.e. Approach 1). + +- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. + +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. + +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). + +Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). + +Next, we discuss how to use this approach in your streaming application. + +1. **Linking:** This approach is supported only in Scala/Java application. Link your SBT/Maven project with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. + +
    +
    + import org.apache.spark.streaming.kafka._ + + val directKafkaStream = KafkaUtils.createDirectStream[ + [key class], [value class], [key decoder class], [value decoder class] ]( + streamingContext, [map of Kafka parameters], [set of topics to consume]) + + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. + See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). +
    +
    + import org.apache.spark.streaming.kafka.*; + + JavaPairInputDStream directKafkaStream = + KafkaUtils.createDirectStream(streamingContext, + [key class], [value class], [key decoder class], [value decoder class], + [map of Kafka parameters], [set of topics to consume]); + + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. + See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). + +
    +
    + from pyspark.streaming.kafka import KafkaUtils + directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + + You can also pass a `messageHandler` to `createDirectStream` to access `KafkaMessageAndMetadata` that contains metadata about the current message and transform it to any desired type. + By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/direct_kafka_wordcount.py). +
    +
    + + In the Kafka parameters, you must specify either `metadata.broker.list` or `bootstrap.servers`. + By default, it will start consuming from the latest offset of each Kafka partition. If you set configuration `auto.offset.reset` in Kafka parameters to `smallest`, then it will start consuming from the smallest offset. + + You can also start consuming from any arbitrary offset using other variations of `KafkaUtils.createDirectStream`. Furthermore, if you want to access the Kafka offsets consumed in each batch, you can do the following. + +
    +
    + // Hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + directKafkaStream.transform { rdd => + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.map { + ... + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + ... + } +
    +
    + // Hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference<>(); + + directKafkaStream.transformToPair( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + return rdd; + } + } + ).map( + ... + ).foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws IOException { + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } + ... + return null; + } + } + ); +
    +
    + offsetRanges = [] + + def storeOffsetRanges(rdd): + global offsetRanges + offsetRanges = rdd.offsetRanges() + return rdd + + def printOffsetRanges(rdd): + for o in offsetRanges: + print "%s %s %s %s" % (o.topic, o.partition, o.fromOffset, o.untilOffset) + + directKafkaStream\ + .transform(storeOffsetRanges)\ + .foreachRDD(printOffsetRanges) +
    +
    + + You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. + + Note that the typecast to HasOffsetRanges will only succeed if it is done in the first method called on the directKafkaStream, not later down a chain of methods. You can use transform() instead of foreachRDD() as your first method call in order to access offsets, then call further Spark methods. However, be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). + + Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. + +3. **Deploying:** This is same as the first approach. diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 015a2f1fa0bdc..a8f3667a49850 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -2,209 +2,52 @@ layout: global title: Spark Streaming + Kafka Integration Guide --- -[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. -## Approach 1: Receiver-based Approach -This approach uses a Receiver to receive the data. The Receiver is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. - -However, under default configuration, this approach can lose data under failures (see [receiver reliability](streaming-programming-guide.html#receiver-reliability). To ensure zero-data loss, you have to additionally enable Write Ahead Logs in Spark Streaming (introduced in Spark 1.2). This synchronously saves all the received Kafka data into write ahead logs on a distributed file system (e.g HDFS), so that all the data can be recovered on failure. See [Deploying section](streaming-programming-guide.html#deploying-applications) in the streaming programming guide for more details on Write Ahead Logs. - -Next, we discuss how to use this approach in your streaming application. - -1. **Linking:** For Scala/Java applications using SBT/Maven project definitions, link your streaming application with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). - - groupId = org.apache.spark - artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION_SHORT}} - - For Python applications, you will have to add this above library and its dependencies when deploying your application. See the *Deploying* subsection below. - -2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. - -
    -
    - import org.apache.spark.streaming.kafka._ - - val kafkaStream = KafkaUtils.createStream(streamingContext, - [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) - - You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). -
    -
    - import org.apache.spark.streaming.kafka.*; - - JavaPairReceiverInputDStream kafkaStream = - KafkaUtils.createStream(streamingContext, - [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]); - - You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). - -
    -
    - from pyspark.streaming.kafka import KafkaUtils - - kafkaStream = KafkaUtils.createStream(streamingContext, \ - [ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]) - - By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/kafka_wordcount.py). -
    -
    - - **Points to remember:** - - - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. - - - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. - - - If you have enabled Write Ahead Logs with a replicated file system like HDFS, the received data is already being replicated in the log. Hence, the storage level in storage level for the input stream to `StorageLevel.MEMORY_AND_DISK_SER` (that is, use -`KafkaUtils.createStream(..., StorageLevel.MEMORY_AND_DISK_SER)`). - -3. **Deploying:** As with any Spark applications, `spark-submit` is used to launch your application. However, the details are slightly different for Scala/Java applications and Python applications. - - For Scala and Java applications, if you are using SBT or Maven for project management, then package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies into the application JAR. Make sure `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` are marked as `provided` dependencies as those are already present in a Spark installation. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). - - For Python applications which lack SBT/Maven project management, `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies can be directly added to `spark-submit` using `--packages` (see [Application Submission Guide](submitting-applications.html)). That is, - - ./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... - - Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-assembly` from the - [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. - -## Approach 2: Direct Approach (No Receivers) -This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API. - -This approach has the following advantages over the receiver-based approach (i.e. Approach 1). - -- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. - -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. - -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semantics of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). - -Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). - -Next, we discuss how to use this approach in your streaming application. - -1. **Linking:** This approach is supported only in Scala/Java application. Link your SBT/Maven project with the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). - - groupId = org.apache.spark - artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION_SHORT}} - -2. **Programming:** In the streaming application code, import `KafkaUtils` and create an input DStream as follows. - -
    -
    - import org.apache.spark.streaming.kafka._ - - val directKafkaStream = KafkaUtils.createDirectStream[ - [key class], [value class], [key decoder class], [value decoder class] ]( - streamingContext, [map of Kafka parameters], [set of topics to consume]) - - You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. - See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). -
    -
    - import org.apache.spark.streaming.kafka.*; - - JavaPairReceiverInputDStream directKafkaStream = - KafkaUtils.createDirectStream(streamingContext, - [key class], [value class], [key decoder class], [value decoder class], - [map of Kafka parameters], [set of topics to consume]); - - You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. - See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). - -
    -
    - from pyspark.streaming.kafka import KafkaUtils - directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) - - You can also pass a `messageHandler` to `createDirectStream` to access `KafkaMessageAndMetadata` that contains metadata about the current message and transform it to any desired type. - By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py). -
    -
    - - In the Kafka parameters, you must specify either `metadata.broker.list` or `bootstrap.servers`. - By default, it will start consuming from the latest offset of each Kafka partition. If you set configuration `auto.offset.reset` in Kafka parameters to `smallest`, then it will start consuming from the smallest offset. - - You can also start consuming from any arbitrary offset using other variations of `KafkaUtils.createDirectStream`. Furthermore, if you want to access the Kafka offsets consumed in each batch, you can do the following. - -
    -
    - // Hold a reference to the current offset ranges, so it can be used downstream - var offsetRanges = Array[OffsetRange]() - - directKafkaStream.transform { rdd => - offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - rdd - }.map { - ... - }.foreachRDD { rdd => - for (o <- offsetRanges) { - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") - } - ... - } -
    -
    - // Hold a reference to the current offset ranges, so it can be used downstream - final AtomicReference offsetRanges = new AtomicReference<>(); - - directKafkaStream.transformToPair( - new Function, JavaPairRDD>() { - @Override - public JavaPairRDD call(JavaPairRDD rdd) throws Exception { - OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - offsetRanges.set(offsets); - return rdd; - } - } - ).map( - ... - ).foreachRDD( - new Function, Void>() { - @Override - public Void call(JavaPairRDD rdd) throws IOException { - for (OffsetRange o : offsetRanges.get()) { - System.out.println( - o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() - ); - } - ... - return null; - } - } - ); -
    -
    - offsetRanges = [] - - def storeOffsetRanges(rdd): - global offsetRanges - offsetRanges = rdd.offsetRanges() - return rdd - - def printOffsetRanges(rdd): - for o in offsetRanges: - print "%s %s %s %s" % (o.topic, o.partition, o.fromOffset, o.untilOffset) - - directKafkaStream\ - .transform(storeOffsetRanges)\ - .foreachRDD(printOffsetRanges) -
    -
    - - You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. - - Note that the typecast to HasOffsetRanges will only succeed if it is done in the first method called on the directKafkaStream, not later down a chain of methods. You can use transform() instead of foreachRDD() as your first method call in order to access offsets, then call further Spark methods. However, be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). - - Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. - -3. **Deploying:** This is same as the first approach. +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Please read the [Kafka documentation](http://kafka.apache.org/documentation.html) thoroughly before starting an integration using Spark. + +The Kafka project introduced a new consumer api between versions 0.8 and 0.10, so there are 2 separate corresponding Spark Streaming packages available. Please choose the correct package for your brokers and desired features; note that the 0.8 integration is compatible with later 0.9 and 0.10 brokers, but the 0.10 integration is not compatible with earlier brokers. + + +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    spark-streaming-kafka-0-8spark-streaming-kafka-0-10
    Broker Version0.8.2.1 or higher0.10.0 or higher
    Api StabilityStableExperimental
    Language SupportScala, Java, PythonScala, Java
    Receiver DStreamYesNo
    Direct DStreamYesYes
    SSL / TLS SupportNoYes
    Offset Commit ApiNoYes
    Dynamic Topic SubscriptionNoYes
    diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 5b9a7554d2e64..6be0b548bc62b 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -111,7 +111,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - `[checkpoint interval]`: The interval (e.g., Duration(2000) = 2 seconds) at which the Kinesis Client Library saves its position in the stream. For starters, set it to the same as the batch interval of the streaming application. - - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see [`Kinesis Checkpointing`](#kinesis-checkpointing) section and [`Amazon Kinesis API documentation`](http://docs.aws.amazon.com/streams/latest/dev/developing-consumers-with-sdk.html) for more details). - `[message handler]`: A function that takes a Kinesis `Record` and outputs generic `T`. @@ -128,14 +128,6 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kinesis-asl-assembly` from the [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kinesis-asl-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. - *Points to remember at runtime:* - - - Kinesis data processing is ordered per partition and occurs at-least once per message. - - - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamoDB. - - - A single Kinesis stream shard is processed by one input DStream at a time. -

    + *Points to remember at runtime:* + + - Kinesis data processing is ordered per partition and occurs at-least once per message. + + - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamoDB. + + - A single Kinesis stream shard is processed by one input DStream at a time. + - A single Kinesis input DStream can read from multiple shards of a Kinesis stream by creating multiple KinesisRecordProcessor threads. - Multiple input DStreams running in separate processes/instances can read from a Kinesis stream. @@ -166,26 +166,23 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m #### Running the Example To run the example, -- Download Spark source and follow the [instructions](building-spark.html) to build Spark with profile *-Pkinesis-asl*. - - mvn -Pkinesis-asl -DskipTests clean package - +- Download a Spark binary from the [download site](http://spark.apache.org/downloads.html). - Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created. -- Set up the environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_KEY with your AWS credentials. +- Set up the environment variables `AWS_ACCESS_KEY_ID` and `AWS_SECRET_KEY` with your AWS credentials. - In the Spark root directory, run the example as
    - bin/run-example streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.KinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
    - bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + bin/run-example --packages org.apache.spark:spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL]
    @@ -216,6 +213,6 @@ de-aggregate records during consumption. - Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random-backoff-retry strategy. -- If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPositionInStream.LATEST). This is configurable. -- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). -- InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency. +- If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (`InitialPositionInStream.TRIM_HORIZON`) or from the latest tip (`InitialPositionInStream.LATEST`). This is configurable. + - `InitialPositionInStream.LATEST` could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). + - `InitialPositionInStream.TRIM_HORIZON` may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 7f6c0ed6994ba..236ae5d649c4a 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -15,7 +15,7 @@ like Kafka, Flume, Kinesis, or TCP sockets, and can be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's -[machine learning](mllib-guide.html) and +[machine learning](ml-guide.html) and [graph processing](graphx-programming-guide.html) algorithms on data streams.

    @@ -126,7 +126,7 @@ ssc.awaitTermination() // Wait for the computation to terminate {% endhighlight %} The complete code can be found in the Spark Streaming example -[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala). +[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala).

    @@ -145,8 +145,8 @@ import org.apache.spark.streaming.api.java.*; import scala.Tuple2; // Create a local StreamingContext with two working thread and batch interval of 1 second -SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") -JavaStreamingContext jssc = new JavaStreamingContext(conf, Durations.seconds(1)) +SparkConf conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount"); +JavaStreamingContext jssc = new JavaStreamingContext(conf, Durations.seconds(1)); {% endhighlight %} Using this context, we can create a DStream that represents streaming data from a TCP @@ -216,7 +216,7 @@ jssc.awaitTermination(); // Wait for the computation to terminate {% endhighlight %} The complete code can be found in the Spark Streaming example -[JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java). +[JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
    @@ -277,7 +277,7 @@ ssc.awaitTermination() # Wait for the computation to terminate {% endhighlight %} The complete code can be found in the Spark Streaming example -[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/network_wordcount.py). +[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/network_wordcount.py).
    @@ -416,7 +416,7 @@ some of the common ones are as follows. - + @@ -477,7 +477,7 @@ import org.apache.spark.*; import org.apache.spark.streaming.api.java.*; SparkConf conf = new SparkConf().setAppName(appName).setMaster(master); -JavaStreamingContext ssc = new JavaStreamingContext(conf, Duration(1000)); +JavaStreamingContext ssc = new JavaStreamingContext(conf, new Duration(1000)); {% endhighlight %} The `appName` parameter is a name for your application to show on the cluster UI. @@ -612,7 +612,7 @@ as well as to run the receiver(s). - When running a Spark Streaming program locally, do not use "local" or "local[1]" as the master URL. Either of these means that only one thread will be used for running tasks locally. If you are using - a input DStream based on a receiver (e.g. sockets, Kafka, Flume, etc.), then the single thread will + an input DStream based on a receiver (e.g. sockets, Kafka, Flume, etc.), then the single thread will be used to run the receiver, leaving no thread for processing the received data. Hence, when running locally, always use "local[*n*]" as the master URL, where *n* > number of receivers to run (see [Spark Properties](configuration.html#spark-properties) for information on how to set @@ -656,7 +656,7 @@ methods for creating DStreams from files as input sources. Python API `fileStream` is not available in the Python API, only `textFileStream` is available. - **Streams based on Custom Receivers:** DStreams can be created with data streams received through custom receivers. See the [Custom Receiver - Guide](streaming-custom-receivers.html) and [DStream Akka](https://github.com/spark-packages/dstream-akka) for more details. + Guide](streaming-custom-receivers.html) for more details. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. @@ -683,7 +683,7 @@ and add it to the classpath. Some of these advanced sources are as follows. -- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka 0.8.2.1. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Kafka broker versions 0.8.2.1 or higher. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. - **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} is compatible with Flume 1.6.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. @@ -854,7 +854,7 @@ JavaPairDStream runningCounts = pairs.updateStateByKey(updateFu The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Java code, take a look at the example -[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming +[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming /JavaStatefulNetworkWordCount.java). @@ -877,7 +877,7 @@ runningCounts = pairs.updateStateByKey(updateFunction) The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Python code, take a look at the example -[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). +[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/stateful_network_wordcount.py). @@ -1259,7 +1259,7 @@ dstream.foreachRDD(sendRecord) This is incorrect as this requires the connection object to be serialized and sent from the -driver to the worker. Such connection objects are rarely transferrable across machines. This +driver to the worker. Such connection objects are rarely transferable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. @@ -1395,13 +1395,13 @@ object WordBlacklist { object DroppedWordsCounter { - @volatile private var instance: Accumulator[Long] = null + @volatile private var instance: LongAccumulator = null - def getInstance(sc: SparkContext): Accumulator[Long] = { + def getInstance(sc: SparkContext): LongAccumulator = { if (instance == null) { synchronized { if (instance == null) { - instance = sc.accumulator(0L, "WordsInBlacklistCounter") + instance = sc.longAccumulator("WordsInBlacklistCounter") } } } @@ -1409,7 +1409,7 @@ object DroppedWordsCounter { } } -wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { +wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => // Get or register the blacklist Broadcast val blacklist = WordBlacklist.getInstance(rdd.sparkContext) // Get or register the droppedWordsCounter Accumulator @@ -1417,18 +1417,18 @@ wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { // Use blacklist to drop words and use droppedWordsCounter to count them val counts = rdd.filter { case (word, count) => if (blacklist.value.contains(word)) { - droppedWordsCounter += count + droppedWordsCounter.add(count) false } else { true } - }.collect() + }.collect().mkString("[", ", ", "]") val output = "Counts at time " + time + " " + counts }) {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala).
    {% highlight java %} @@ -1452,13 +1452,13 @@ class JavaWordBlacklist { class JavaDroppedWordsCounter { - private static volatile Accumulator instance = null; + private static volatile LongAccumulator instance = null; - public static Accumulator getInstance(JavaSparkContext jsc) { + public static LongAccumulator getInstance(JavaSparkContext jsc) { if (instance == null) { synchronized (JavaDroppedWordsCounter.class) { if (instance == null) { - instance = jsc.accumulator(0, "WordsInBlacklistCounter"); + instance = jsc.sc().longAccumulator("WordsInBlacklistCounter"); } } } @@ -1472,7 +1472,7 @@ wordCounts.foreachRDD(new Function2, Time, Void>() // Get or register the blacklist Broadcast final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); // Get or register the droppedWordsCounter Accumulator - final Accumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + final LongAccumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); // Use blacklist to drop words and use droppedWordsCounter to count them String counts = rdd.filter(new Function, Boolean>() { @Override @@ -1491,7 +1491,7 @@ wordCounts.foreachRDD(new Function2, Time, Void>() {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java).
    {% highlight python %} @@ -1526,7 +1526,7 @@ wordCounts.foreachRDD(echo) {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/recoverable_network_wordcount.py). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/recoverable_network_wordcount.py).
    @@ -1534,7 +1534,7 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma *** ## DataFrame and SQL Operations -You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SQLContext using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SQLContext. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. +You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.
    @@ -1546,25 +1546,25 @@ val words: DStream[String] = ... words.foreachRDD { rdd => - // Get the singleton instance of SQLContext - val sqlContext = SQLContext.getOrCreate(rdd.sparkContext) - import sqlContext.implicits._ + // Get the singleton instance of SparkSession + val spark = SparkSession.builder.config(rdd.sparkContext.getConf).getOrCreate() + import spark.implicits._ // Convert RDD[String] to DataFrame val wordsDataFrame = rdd.toDF("word") - // Register as table - wordsDataFrame.registerTempTable("words") + // Create a temporary view + wordsDataFrame.createOrReplaceTempView("words") // Do word count on DataFrame using SQL and print it val wordCountsDataFrame = - sqlContext.sql("select word, count(*) as total from words group by word") + spark.sql("select word, count(*) as total from words group by word") wordCountsDataFrame.show() } {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala).
    {% highlight java %} @@ -1593,8 +1593,8 @@ words.foreachRDD( @Override public Void call(JavaRDD rdd, Time time) { - // Get the singleton instance of SQLContext - SQLContext sqlContext = SQLContext.getOrCreate(rdd.context()); + // Get the singleton instance of SparkSession + SparkSession spark = SparkSession.builder().config(rdd.sparkContext().getConf()).getOrCreate(); // Convert RDD[String] to RDD[case class] to DataFrame JavaRDD rowRDD = rdd.map(new Function() { @@ -1604,14 +1604,14 @@ words.foreachRDD( return record; } }); - DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRow.class); + DataFrame wordsDataFrame = spark.createDataFrame(rowRDD, JavaRow.class); - // Register as table - wordsDataFrame.registerTempTable("words"); + // Creates a temporary view using the DataFrame + wordsDataFrame.createOrReplaceTempView("words"); // Do word count on table using SQL and print it DataFrame wordCountsDataFrame = - sqlContext.sql("select word, count(*) as total from words group by word"); + spark.sql("select word, count(*) as total from words group by word"); wordCountsDataFrame.show(); return null; } @@ -1619,16 +1619,19 @@ words.foreachRDD( ); {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java).
    {% highlight python %} -# Lazily instantiated global instance of SQLContext -def getSqlContextInstance(sparkContext): - if ('sqlContextSingletonInstance' not in globals()): - globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext) - return globals()['sqlContextSingletonInstance'] +# Lazily instantiated global instance of SparkSession +def getSparkSessionInstance(sparkConf): + if ('sparkSessionSingletonInstance' not in globals()): + globals()['sparkSessionSingletonInstance'] = SparkSession\ + .builder\ + .config(conf=sparkConf)\ + .getOrCreate() + return globals()['sparkSessionSingletonInstance'] ... @@ -1639,18 +1642,18 @@ words = ... # DStream of strings def process(time, rdd): print("========= %s =========" % str(time)) try: - # Get the singleton instance of SQLContext - sqlContext = getSqlContextInstance(rdd.context) + # Get the singleton instance of SparkSession + spark = getSparkSessionInstance(rdd.context.getConf()) # Convert RDD[String] to RDD[Row] to DataFrame rowRdd = rdd.map(lambda w: Row(word=w)) - wordsDataFrame = sqlContext.createDataFrame(rowRdd) + wordsDataFrame = spark.createDataFrame(rowRdd) - # Register as table - wordsDataFrame.registerTempTable("words") + # Creates a temporary view using the DataFrame + wordsDataFrame.createOrReplaceTempView("words") # Do word count on table using SQL and print it - wordCountsDataFrame = sqlContext.sql("select word, count(*) as total from words group by word") + wordCountsDataFrame = spark.sql("select word, count(*) as total from words group by word") wordCountsDataFrame.show() except: pass @@ -1658,7 +1661,7 @@ def process(time, rdd): words.foreachRDD(process) {% endhighlight %} -See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/sql_network_wordcount.py). +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/sql_network_wordcount.py).
    @@ -1670,7 +1673,7 @@ See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more abo *** ## MLlib Operations -You can also easily use machine learning algorithms provided by [MLlib](mllib-guide.html). First of all, there are streaming machine learning algorithms (e.g. [Streaming Linear Regression](mllib-linear-methods.html#streaming-linear-regression), [Streaming KMeans](mllib-clustering.html#streaming-k-means), etc.) which can simultaneously learn from the streaming data as well as apply the model on the streaming data. Beyond these, for a much larger class of machine learning algorithms, you can learn a learning model offline (i.e. using historical data) and then apply the model online on streaming data. See the [MLlib](mllib-guide.html) guide for more details. +You can also easily use machine learning algorithms provided by [MLlib](ml-guide.html). First of all, there are streaming machine learning algorithms (e.g. [Streaming Linear Regression](mllib-linear-methods.html#streaming-linear-regression), [Streaming KMeans](mllib-clustering.html#streaming-k-means), etc.) which can simultaneously learn from the streaming data as well as apply the model on the streaming data. Beyond these, for a much larger class of machine learning algorithms, you can learn a learning model offline (i.e. using historical data) and then apply the model online on streaming data. See the [MLlib](ml-guide.html) guide for more details. *** @@ -1788,7 +1791,7 @@ This example appends the word counts of network data into a file. This behavior is made simple by using `JavaStreamingContext.getOrCreate`. This is used as follows. {% highlight java %} -// Create a factory object that can create a and setup a new JavaStreamingContext +// Create a factory object that can create and setup a new JavaStreamingContext JavaStreamingContextFactory contextFactory = new JavaStreamingContextFactory() { @Override public JavaStreamingContext create() { JavaStreamingContext jssc = new JavaStreamingContext(...); // new context @@ -1892,7 +1895,7 @@ To run a Spark Streaming applications, you need to have the following. if your application uses [advanced sources](#advanced-sources) (e.g. Kafka, Flume), then you will have to package the extra artifact they link to, along with their dependencies, in the JAR that is used to deploy the application. For example, an application using `KafkaUtils` - will have to include `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and all its + will have to include `spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}` and all its transitive dependencies in the application JAR. - *Configuring sufficient memory for the executors* - Since the received data must be stored in @@ -2181,6 +2184,25 @@ consistent batch processing times. Make sure you set the CMS GC on both the driv - Persist RDDs using the `OFF_HEAP` storage level. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence). - Use more executors with smaller heap sizes. This will reduce the GC pressure within each JVM heap. +*** + +##### Important points to remember: +{:.no_toc} +- A DStream is associated with a single receiver. For attaining read parallelism multiple receivers i.e. multiple DStreams need to be created. A receiver is run within an executor. It occupies one core. Ensure that there are enough cores for processing after receiver slots are booked i.e. `spark.cores.max` should take the receiver slots into account. The receivers are allocated to executors in a round robin fashion. + +- When data is received from a stream source, receiver creates blocks of data. A new block of data is generated every blockInterval milliseconds. N blocks of data are created during the batchInterval where N = batchInterval/blockInterval. These blocks are distributed by the BlockManager of the current executor to the block managers of other executors. After that, the Network Input Tracker running on the driver is informed about the block locations for further processing. + +- A RDD is created on the driver for the blocks created during the batchInterval. The blocks generated during the batchInterval are partitions of the RDD. Each partition is a task in spark. blockInterval== batchinterval would mean that a single partition is created and probably it is processed locally. + +- The map tasks on the blocks are processed in the executors (one that received the block, and another where the block was replicated) that has the blocks irrespective of block interval, unless non-local scheduling kicks in. +Having bigger blockinterval means bigger blocks. A high value of `spark.locality.wait` increases the chance of processing a block on the local node. A balance needs to be found out between these two parameters to ensure that the bigger blocks are processed locally. + +- Instead of relying on batchInterval and blockInterval, you can define the number of partitions by calling `inputDstream.repartition(n)`. This reshuffles the data in RDD randomly to create n number of partitions. Yes, for greater parallelism. Though comes at the cost of a shuffle. An RDD's processing is scheduled by driver's jobscheduler as a job. At a given point of time only one job is active. So, if one job is executing the other jobs are queued. + +- If you have two dstreams there will be two RDDs formed and there will be two jobs created which will be scheduled one after the another. To avoid this, you can union two dstreams. This will ensure that a single unionRDD is formed for the two RDDs of the dstreams. This unionRDD is then considered as a single job. However the partitioning of the RDDs is not impacted. + +- If the batch processing time is more than batchinterval then obviously the receiver's memory will start filling up and will end up in throwing exceptions (most probably BlockNotFoundException). Currently there is no way to pause the receiver. Using SparkConf configuration `spark.streaming.receiver.maxRate`, rate of receiver can be limited. + *************************************************************************************************** *************************************************************************************************** @@ -2328,7 +2350,7 @@ The following table summarizes the semantics under failures: ### With Kafka Direct API {:.no_toc} -In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach (experimental as of Spark {{site.SPARK_VERSION_SHORT}}) is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). +In Spark 1.3, we have introduced a new Kafka Direct API, which can ensure that all the Kafka data is received by Spark Streaming exactly once. Along with this, if you implement exactly-once output operation, you can achieve end-to-end exactly-once guarantees. This approach is further discussed in the [Kafka Integration Guide](streaming-kafka-integration.html). ## Semantics of output operations {:.no_toc} @@ -2356,61 +2378,12 @@ additional effort may be necessary to achieve exactly-once semantics. There are *************************************************************************************************** *************************************************************************************************** -# Migration Guide from 0.9.1 or below to 1.x -Between Spark 0.9.1 and Spark 1.0, there were a few API changes made to ensure future API stability. -This section elaborates the steps required to migrate your existing code to 1.0. - -**Input DStreams**: All operations that create an input stream (e.g., `StreamingContext.socketStream`, `FlumeUtils.createStream`, etc.) now returns -[InputDStream](api/scala/index.html#org.apache.spark.streaming.dstream.InputDStream) / -[ReceiverInputDStream](api/scala/index.html#org.apache.spark.streaming.dstream.ReceiverInputDStream) -(instead of DStream) for Scala, and [JavaInputDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaInputDStream.html) / -[JavaPairInputDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairInputDStream.html) / -[JavaReceiverInputDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaReceiverInputDStream.html) / -[JavaPairReceiverInputDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairReceiverInputDStream.html) -(instead of JavaDStream) for Java. This ensures that functionality specific to input streams can -be added to these classes in the future without breaking binary compatibility. -Note that your existing Spark Streaming applications should not require any change -(as these new classes are subclasses of DStream/JavaDStream) but may require recompilation with Spark 1.0. - -**Custom Network Receivers**: Since the release to Spark Streaming, custom network receivers could be defined -in Scala using the class NetworkReceiver. However, the API was limited in terms of error handling -and reporting, and could not be used from Java. Starting Spark 1.0, this class has been -replaced by [Receiver](api/scala/index.html#org.apache.spark.streaming.receiver.Receiver) which has -the following advantages. - -* Methods like `stop` and `restart` have been added to for better control of the lifecycle of a receiver. See -the [custom receiver guide](streaming-custom-receivers.html) for more details. -* Custom receivers can be implemented using both Scala and Java. - -To migrate your existing custom receivers from the earlier NetworkReceiver to the new Receiver, you have -to do the following. - -* Make your custom receiver class extend -[`org.apache.spark.streaming.receiver.Receiver`](api/scala/index.html#org.apache.spark.streaming.receiver.Receiver) -instead of `org.apache.spark.streaming.dstream.NetworkReceiver`. -* Earlier, a BlockGenerator object had to be created by the custom receiver, to which received data was -added for being stored in Spark. It had to be explicitly started and stopped from `onStart()` and `onStop()` -methods. The new Receiver class makes this unnecessary as it adds a set of methods named `store()` -that can be called to store the data in Spark. So, to migrate your custom network receiver, remove any -BlockGenerator object (does not exist any more in Spark 1.0 anyway), and use `store(...)` methods on -received data. - -**Actor-based Receivers**: The Actor-based Receiver APIs have been moved to [DStream Akka](https://github.com/spark-packages/dstream-akka). -Please refer to the project for more details. - -*************************************************************************************************** -*************************************************************************************************** - # Where to Go from Here * Additional guides - [Kafka Integration Guide](streaming-kafka-integration.html) - [Kinesis Integration Guide](streaming-kinesis-integration.html) - [Custom Receiver Guide](streaming-custom-receivers.html) -* External DStream data sources: - - [DStream MQTT](https://github.com/spark-packages/dstream-mqtt) - - [DStream Twitter](https://github.com/spark-packages/dstream-twitter) - - [DStream Akka](https://github.com/spark-packages/dstream-akka) - - [DStream ZeroMQ](https://github.com/spark-packages/dstream-zeromq) +* Third-party DStream data sources can be found in [Third Party Projects](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) * API documentation - Scala docs * [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) and diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md new file mode 100644 index 0000000000000..a6c3b3a9024d8 --- /dev/null +++ b/docs/structured-streaming-kafka-integration.md @@ -0,0 +1,259 @@ +--- +layout: global +title: Structured Streaming + Kafka Integration Guide (Kafka broker version 0.10.0 or higher) +--- + +Structured Streaming integration for Kafka 0.10 to poll data from Kafka. + +### Linking +For Scala/Java applications using SBT/Maven project definitions, link your application with the following artifact: + + groupId = org.apache.spark + artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +For Python applications, you need to add this above library and its dependencies when deploying your +application. See the [Deploying](#deploying) subsection below. + +### Creating a Kafka Source Stream + +
    +
    + + // Subscribe to 1 topic + val ds1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + // Subscribe to multiple topics + val ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() + ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + // Subscribe to a pattern + val ds3 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() + ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + +
    +
    + + // Subscribe to 1 topic + Dataset ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + // Subscribe to multiple topics + Dataset ds2 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() + ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + // Subscribe to a pattern + Dataset ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() + ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + +
    +
    + + # Subscribe to 1 topic + ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() + ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + # Subscribe to multiple topics + ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() + ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + + # Subscribe to a pattern + ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() + ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + +
    +
    + +Each row in the source has the following schema: +
    SourceArtifact
    Kafka spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}
    Kafka spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}
    Flume spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}
    Kinesis
    spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} [Amazon Software License]
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    ColumnType
    keybinary
    valuebinary
    topicstring
    partitionint
    offsetlong
    timestamplong
    timestampTypeint
    + +The following options must be set for the Kafka source. + + + + + + + + + + + + + + + + + + + + + + + +
    Optionvaluemeaning
    assignjson string {"topicA":[0,1],"topicB":[2,4]}Specific TopicPartitions to consume. + Only one of "assign", "subscribe" or "subscribePattern" + options can be specified for Kafka source.
    subscribeA comma-separated list of topicsThe topic list to subscribe. + Only one of "assign", "subscribe" or "subscribePattern" + options can be specified for Kafka source.
    subscribePatternJava regex stringThe pattern used to subscribe to topic(s). + Only one of "assign, "subscribe" or "subscribePattern" + options can be specified for Kafka source.
    kafka.bootstrap.serversA comma-separated list of host:portThe Kafka "bootstrap.servers" configuration.
    + +The following configurations are optional: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Optionvaluedefaultmeaning
    startingOffsetsearliest, latest, or json string + {"topicA":{"0":23,"1":-1},"topicB":{"0":-2}} + latestThe start point when a query is started, either "earliest" which is from the earliest offsets, + "latest" which is just from the latest offsets, or a json string specifying a starting offset for + each TopicPartition. In the json, -2 as an offset can be used to refer to earliest, -1 to latest. + Note: This only applies when a new Streaming query is started, and that resuming will always pick + up from where the query left off. Newly discovered partitions during a query will start at + earliest.
    failOnDataLosstrue or falsetrueWhether to fail the query when it's possible that data is lost (e.g., topics are deleted, or + offsets are out of range). This may be a false alarm. You can disable it when it doesn't work + as you expected.
    kafkaConsumer.pollTimeoutMslong512The timeout in milliseconds to poll data from Kafka in executors.
    fetchOffset.numRetriesint3Number of times to retry before giving up fatch Kafka latest offsets.
    fetchOffset.retryIntervalMslong10milliseconds to wait before retrying to fetch Kafka offsets
    maxOffsetsPerTriggerlongnoneRate limit on maximum number of offsets processed per trigger interval. The specified total number of offsets will be proportionally split across topicPartitions of different volume.
    + +Kafka's own configurations can be set via `DataStreamReader.option` with `kafka.` prefix, e.g, +`stream.option("kafka.bootstrap.servers", "host:port")`. For possible kafkaParams, see +[Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). + +Note that the following Kafka params cannot be set and the Kafka source will throw an exception: +- **group.id**: Kafka source will create a unique group id for each query automatically. +- **auto.offset.reset**: Set the source option `startingOffsets` to specify + where to start instead. Structured Streaming manages which offsets are consumed internally, rather + than rely on the kafka Consumer to do it. This will ensure that no data is missed when when new + topics/partitions are dynamically subscribed. Note that `startingOffsets` only applies when a new + Streaming query is started, and that resuming will always pick up from where the query left off. +- **key.deserializer**: Keys are always deserialized as byte arrays with ByteArrayDeserializer. Use + DataFrame operations to explicitly deserialize the keys. +- **value.deserializer**: Values are always deserialized as byte arrays with ByteArrayDeserializer. + Use DataFrame operations to explicitly deserialize the values. +- **enable.auto.commit**: Kafka source doesn't commit any offset. +- **interceptor.classes**: Kafka source always read keys and values as byte arrays. It's not safe to + use ConsumerInterceptor as it may break the query. + +### Deploying + +As with any Spark applications, `spark-submit` is used to launch your application. `spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` +and its dependencies can be directly added to `spark-submit` using `--packages`, such as, + + ./bin/spark-submit --packages org.apache.spark:spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ... + +See [Application Submission Guide](submitting-applications.html) for more details about submitting +applications with external dependencies. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md new file mode 100644 index 0000000000000..be730b8c1c2ee --- /dev/null +++ b/docs/structured-streaming-programming-guide.md @@ -0,0 +1,1154 @@ +--- +layout: global +displayTitle: Structured Streaming Programming Guide [Alpha] +title: Structured Streaming Programming Guide +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data.The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* + +**Spark 2.0 is the ALPHA RELEASE of Structured Streaming** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. + +# Quick Example +Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/ +[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/ +[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). And if you +[download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark. + +
    +
    + +{% highlight scala %} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.SparkSession + +val spark = SparkSession + .builder + .appName("StructuredNetworkWordCount") + .getOrCreate() + +import spark.implicits._ +{% endhighlight %} + +
    +
    + +{% highlight java %} +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.sql.*; +import org.apache.spark.sql.streaming.StreamingQuery; + +import java.util.Arrays; +import java.util.Iterator; + +SparkSession spark = SparkSession + .builder() + .appName("JavaStructuredNetworkWordCount") + .getOrCreate(); +{% endhighlight %} + +
    +
    + +{% highlight python %} +from pyspark.sql import SparkSession +from pyspark.sql.functions import explode +from pyspark.sql.functions import split + +spark = SparkSession\ + .builder()\ + .appName("StructuredNetworkWordCount")\ + .getOrCreate() +{% endhighlight %} + +
    +
    + +Next, let’s create a streaming DataFrame that represents text data received from a server listening on localhost:9999, and transform the DataFrame to calculate word counts. + +
    +
    + +{% highlight scala %} +// Create DataFrame representing the stream of input lines from connection to localhost:9999 +val lines = spark.readStream + .format("socket") + .option("host", "localhost") + .option("port", 9999) + .load() + +// Split the lines into words +val words = lines.as[String].flatMap(_.split(" ")) + +// Generate running word count +val wordCounts = words.groupBy("value").count() +{% endhighlight %} + +This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as[String]`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. + +
    +
    + +{% highlight java %} +// Create DataFrame representing the stream of input lines from connection to localhost:9999 +Dataset lines = spark + .readStream() + .format("socket") + .option("host", "localhost") + .option("port", 9999) + .load(); + +// Split the lines into words +Dataset words = lines + .as(Encoders.STRING()) + .flatMap( + new FlatMapFunction() { + @Override + public Iterator call(String x) { + return Arrays.asList(x.split(" ")).iterator(); + } + }, Encoders.STRING()); + +// Generate running word count +Dataset wordCounts = words.groupBy("value").count(); +{% endhighlight %} + +This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as(Encoders.STRING())`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. + +
    +
    + +{% highlight python %} +# Create DataFrame representing the stream of input lines from connection to localhost:9999 +lines = spark\ + .readStream\ + .format('socket')\ + .option('host', 'localhost')\ + .option('port', 9999)\ + .load() + +# Split the lines into words +words = lines.select( + explode( + split(lines.value, ' ') + ).alias('word') +) + +# Generate running word count +wordCounts = words.groupBy('word').count() +{% endhighlight %} + +This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as "word". Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. + +
    +
    + +We have now set up the query on the streaming data. All that is left is to actually start receiving data and computing the counts. To do this, we set it up to print the complete set of counts (specified by `outputMode("complete")`) to the console every time they are updated. And then start the streaming computation using `start()`. + +
    +
    + +{% highlight scala %} +// Start running the query that prints the running counts to the console +val query = wordCounts.writeStream + .outputMode("complete") + .format("console") + .start() + +query.awaitTermination() +{% endhighlight %} + +
    +
    + +{% highlight java %} +// Start running the query that prints the running counts to the console +StreamingQuery query = wordCounts.writeStream() + .outputMode("complete") + .format("console") + .start(); + +query.awaitTermination(); +{% endhighlight %} + +
    +
    + +{% highlight python %} + # Start running the query that prints the running counts to the console +query = wordCounts\ + .writeStream\ + .outputMode('complete')\ + .format('console')\ + .start() + +query.awaitTermination() +{% endhighlight %} + +
    +
    + +After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `query.awaitTermination()` to prevent the process from exiting while the query is active. + +To actually execute this example code, you can either compile the code in your own +[Spark application](quick-start.html#self-contained-applications), or simply +[run the example](index.html#running-the-examples-and-shell) once you have downloaded Spark. We are showing the latter. You will first need to run Netcat (a small utility found in most Unix-like systems) as a data server by using + + + $ nc -lk 9999 + +Then, in a different terminal, you can start the example by using + +
    +
    +{% highlight bash %} +$ ./bin/run-example org.apache.spark.examples.sql.streaming.StructuredNetworkWordCount localhost 9999 +{% endhighlight %} +
    +
    +{% highlight bash %} +$ ./bin/run-example org.apache.spark.examples.sql.streaming.JavaStructuredNetworkWordCount localhost 9999 +{% endhighlight %} +
    +
    +{% highlight bash %} +$ ./bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py localhost 9999 +{% endhighlight %} +
    +
    + +Then, any lines typed in the terminal running the netcat server will be counted and printed on screen every second. It will look something like the following. + + + + + +
    +{% highlight bash %} +# TERMINAL 1: +# Running Netcat + +$ nc -lk 9999 +apache spark +apache hadoop + + + + + + + + + + + + + + + + + + + +... +{% endhighlight %} + +
    + +
    +{% highlight bash %} +# TERMINAL 2: RUNNING StructuredNetworkWordCount + +$ ./bin/run-example org.apache.spark.examples.sql.streaming.StructuredNetworkWordCount localhost 9999 + +------------------------------------------- +Batch: 0 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 1| +| spark| 1| ++------+-----+ + +------------------------------------------- +Batch: 1 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 2| +| spark| 1| +|hadoop| 1| ++------+-----+ +... +{% endhighlight %} +
    + +
    +{% highlight bash %} +# TERMINAL 2: RUNNING JavaStructuredNetworkWordCount + +$ ./bin/run-example org.apache.spark.examples.sql.streaming.JavaStructuredNetworkWordCount localhost 9999 + +------------------------------------------- +Batch: 0 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 1| +| spark| 1| ++------+-----+ + +------------------------------------------- +Batch: 1 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 2| +| spark| 1| +|hadoop| 1| ++------+-----+ +... +{% endhighlight %} +
    +
    +{% highlight bash %} +# TERMINAL 2: RUNNING structured_network_wordcount.py + +$ ./bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py localhost 9999 + +------------------------------------------- +Batch: 0 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 1| +| spark| 1| ++------+-----+ + +------------------------------------------- +Batch: 1 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 2| +| spark| 1| +|hadoop| 1| ++------+-----+ +... +{% endhighlight %} +
    +
    +
    + + +# Programming Model + +The key idea in Structured Streaming is to treat a live data stream as a +table that is being continuously appended. This leads to a new stream +processing model that is very similar to a batch processing model. You will +express your streaming computation as standard batch-like query as on a static +table, and Spark runs it as an *incremental* query on the *unbounded* input +table. Let’s understand this model in more detail. + +## Basic Concepts +Consider the input data stream as the "Input Table". Every data item that is +arriving on the stream is like a new row being appended to the Input Table. + +![Stream as a Table](img/structured-streaming-stream-as-a-table.png "Stream as a Table") + +A query on the input will generate the "Result Table". Every trigger interval (say, every 1 second), new rows get appended to the Input Table, which eventually updates the Result Table. Whenever the result table gets updated, we would want to write the changed result rows to an external sink. + +![Model](img/structured-streaming-model.png) + +The "Output" is defined as what gets written out to the external storage. The output can be defined in different modes + + - *Complete Mode* - The entire updated Result Table will be written to the external storage. It is up to the storage connector to decide how to handle writing of the entire table. + + - *Append Mode* - Only the new rows appended in the Result Table since the last trigger will be written to the external storage. This is applicable only on the queries where existing rows in the Result Table are not expected to change. + + - *Update Mode* - Only the rows that were updated in the Result Table since the last trigger will be written to the external storage (not available yet in Spark 2.0). Note that this is different from the Complete Mode in that this mode does not output the rows that are not changed. + +Note that each mode is applicable on certain types of queries. This is discussed in detail [later](#output-modes). + +To illustrate the use of this model, let’s understand the model in context of +the [Quick Example](#quick-example) above. The first `lines` DataFrame is the input table, and +the final `wordCounts` DataFrame is the result table. Note that the query on +streaming `lines` DataFrame to generate `wordCounts` is *exactly the same* as +it would be a static DataFrame. However, when this query is started, Spark +will continuously check for new data from the socket connection. If there is +new data, Spark will run an "incremental" query that combines the previous +running counts with the new data to compute updated counts, as shown below. + +![Model](img/structured-streaming-example-model.png) + +This model is significantly different from many other stream processing +engines. Many streaming systems require the user to maintain running +aggregations themselves, thus having to reason about fault-tolerance, and +data consistency (at-least-once, or at-most-once, or exactly-once). In this +model, Spark is responsible for updating the Result Table when there is new +data, thus relieving the users from reasoning about it. As an example, let’s +see how this model handles event-time based processing and late arriving data. + +## Handling Event-time and Late Data +Event-time is the time embedded in the data itself. For many applications, you may want to operate on this event-time. For example, if you want to get the number of events generated by IoT devices every minute, then you probably want to use the time when the data was generated (that is, event-time in the data), rather than the time Spark receives them. This event-time is very naturally expressed in this model -- each event from the devices is a row in the table, and event-time is a column value in the row. This allows window-based aggregations (e.g. number of event every minute) to be just a special type of grouping and aggregation on the even-time column -- each time window is a group and each row can belong to multiple windows/groups. Therefore, such event-time-window-based aggregation queries can be defined consistently on both a static dataset (e.g. from collected device events logs) as well as on a data stream, making the life of the user much easier. + +Furthermore, this model naturally handles data that has arrived later than expected based on its event-time. Since Spark is updating the Result Table, it has full control over updating/cleaning up the aggregates when there is late data. While not yet implemented in Spark 2.0, event-time watermarking will be used to manage this data. These are explained later in more details in the [Window Operations](#window-operations-on-event-time) section. + +## Fault Tolerance Semantics +Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers) +to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotant sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure. + +# API using Datasets and DataFrames +Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` ([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/ +[Java](api/java/org/apache/spark/sql/SparkSession.html)/ +[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the +[DataFrame/Dataset Programming Guide](sql-programming-guide.html). + +## Creating streaming DataFrames and streaming Datasets +Streaming DataFrames can be created through the `DataStreamReader` interface +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/ +[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/ +[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. + +#### Data Sources +In Spark 2.0, there are a few built-in sources. + + - **File source** - Reads files written in a directory as a stream of data. Supported file formats are text, csv, json, parquet. See the docs of the DataStreamReader interface for a more up-to-date list, and supported options for each file format. Note that the files must be atomically placed in the given directory, which in most file systems, can be achieved by file move operations. + + - **Kafka source** - Poll data from Kafka. It's compatible with Kafka broker versions 0.10.0 or higher. See the [Kafka Integration Guide](structured-streaming-kafka-integration.html) for more details. + + - **Socket source (for testing)** - Reads UTF8 text data from a socket connection. The listening server socket is at the driver. Note that this should be used only for testing as this does not provide end-to-end fault-tolerance guarantees. + +Here are some examples. + +
    +
    + +{% highlight scala %} +val spark: SparkSession = ... + +// Read text from socket +val socketDF = spark + .readStream + .format("socket") + .option("host", "localhost") + .option("port", 9999) + .load() + +socketDF.isStreaming // Returns True for DataFrames that have streaming sources + +socketDF.printSchema + +// Read all the csv files written atomically in a directory +val userSchema = new StructType().add("name", "string").add("age", "integer") +val csvDF = spark + .readStream + .option("sep", ";") + .schema(userSchema) // Specify schema of the csv files + .csv("/path/to/directory") // Equivalent to format("csv").load("/path/to/directory") +{% endhighlight %} + +
    +
    + +{% highlight java %} +SparkSession spark = ... + +// Read text from socket +Dataset[Row] socketDF = spark + .readStream() + .format("socket") + .option("host", "localhost") + .option("port", 9999) + .load(); + +socketDF.isStreaming(); // Returns True for DataFrames that have streaming sources + +socketDF.printSchema(); + +// Read all the csv files written atomically in a directory +StructType userSchema = new StructType().add("name", "string").add("age", "integer"); +Dataset[Row] csvDF = spark + .readStream() + .option("sep", ";") + .schema(userSchema) // Specify schema of the csv files + .csv("/path/to/directory"); // Equivalent to format("csv").load("/path/to/directory") +{% endhighlight %} + +
    +
    + +{% highlight python %} +spark = SparkSession. ... + +# Read text from socket +socketDF = spark \ + .readStream() \ + .format("socket") \ + .option("host", "localhost") \ + .option("port", 9999) \ + .load() + +socketDF.isStreaming() # Returns True for DataFrames that have streaming sources + +socketDF.printSchema() + +# Read all the csv files written atomically in a directory +userSchema = StructType().add("name", "string").add("age", "integer") +csvDF = spark \ + .readStream() \ + .option("sep", ";") \ + .schema(userSchema) \ + .csv("/path/to/directory") # Equivalent to format("csv").load("/path/to/directory") +{% endhighlight %} + +
    +
    + +These examples generate streaming DataFrames that are untyped, meaning that the schema of the DataFrame is not checked at compile time, only checked at runtime when the query is submitted. Some operations like `map`, `flatMap`, etc. need the type to be known at compile time. To do those, you can convert these untyped streaming DataFrames to typed streaming Datasets using the same methods as static DataFrame. See the [SQL Programming Guide](sql-programming-guide.html) for more details. Additionally, more details on the supported streaming sources are discussed later in the document. + +### Schema inference and partition of streaming DataFrames/Datasets + +By default, Structured Streaming from file based sources requires you to specify the schema, rather than rely on Spark to infer it automatically. This restriction ensures a consistent schema will be used for the streaming query, even in the case of failures. For ad-hoc use cases, you can reenable schema inference by setting `spark.sql.streaming.schemaInference` to `true`. + +Partition discovery does occur when subdirectories that are named `/key=value/` are present and listing will automatically recurse into these directories. If these columns appear in the user provided schema, they will be filled in by Spark based on the path of the file being read. The directories that make up the partitioning scheme must be present when the query starts and must remain static. For example, it is okay to add `/data/year=2016/` when `/data/year=2015/` was present, but it is invalid to change the partitioning column (i.e. by creating the directory `/data/date=2016-04-17/`). + +## Operations on streaming DataFrames/Datasets +You can apply all kinds of operations on streaming DataFrames/Datasets – ranging from untyped, SQL-like operations (e.g. `select`, `where`, `groupBy`), to typed RDD-like operations (e.g. `map`, `filter`, `flatMap`). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use. + +### Basic Operations - Selection, Projection, Aggregation +Most of the common operations on DataFrame/Dataset are supported for streaming. The few operations that are not supported are [discussed later](#unsupported-operations) in this section. + +
    +
    + +{% highlight scala %} +case class DeviceData(device: String, type: String, signal: Double, time: DateTime) + +val df: DataFrame = ... // streaming DataFrame with IOT device data with schema { device: string, type: string, signal: double, time: string } +val ds: Dataset[DeviceData] = df.as[DeviceData] // streaming Dataset with IOT device data + +// Select the devices which have signal more than 10 +df.select("device").where("signal > 10") // using untyped APIs +ds.filter(_.signal > 10).map(_.device) // using typed APIs + +// Running count of the number of updates for each device type +df.groupBy("type").count() // using untyped API + +// Running average signal for each device type +Import org.apache.spark.sql.expressions.scalalang.typed._ +ds.groupByKey(_.type).agg(typed.avg(_.signal)) // using typed API +{% endhighlight %} + +
    +
    + +{% highlight java %} +import org.apache.spark.api.java.function.*; +import org.apache.spark.sql.*; +import org.apache.spark.sql.expressions.javalang.typed; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; + +public class DeviceData { + private String device; + private String type; + private Double signal; + private java.sql.Date time; + ... + // Getter and setter methods for each field +} + +Dataset df = ...; // streaming DataFrame with IOT device data with schema { device: string, type: string, signal: double, time: DateType } +Dataset ds = df.as(ExpressionEncoder.javaBean(DeviceData.class)); // streaming Dataset with IOT device data + +// Select the devices which have signal more than 10 +df.select("device").where("signal > 10"); // using untyped APIs +ds.filter(new FilterFunction() { // using typed APIs + @Override + public boolean call(DeviceData value) throws Exception { + return value.getSignal() > 10; + } +}).map(new MapFunction() { + @Override + public String call(DeviceData value) throws Exception { + return value.getDevice(); + } +}, Encoders.STRING()); + +// Running count of the number of updates for each device type +df.groupBy("type").count(); // using untyped API + +// Running average signal for each device type +ds.groupByKey(new MapFunction() { // using typed API + @Override + public String call(DeviceData value) throws Exception { + return value.getType(); + } +}, Encoders.STRING()).agg(typed.avg(new MapFunction() { + @Override + public Double call(DeviceData value) throws Exception { + return value.getSignal(); + } +})); +{% endhighlight %} + + +
    +
    + +{% highlight python %} + +df = ... # streaming DataFrame with IOT device data with schema { device: string, type: string, signal: double, time: DateType } + +# Select the devices which have signal more than 10 +df.select("device").where("signal > 10") + +# Running count of the number of updates for each device type +df.groupBy("type").count() +{% endhighlight %} +
    +
    + +### Window Operations on Event Time +Aggregations over a sliding event-time window are straightforward with Structured Streaming. The key idea to understand about window-based aggregations are very similar to grouped aggregations. In a grouped aggregation, aggregate values (e.g. counts) are maintained for each unique value in the user-specified grouping column. In case of window-based aggregations, aggregate values are maintained for each window the event-time of a row falls into. Let's understand this with an illustration. + +Imagine our [quick example](#quick-example) is modified and the stream now contains lines along with the time when the line was generated. Instead of running word counts, we want to count words within 10 minute windows, updating every 5 minutes. That is, word counts in words received between 10 minute windows 12:00 - 12:10, 12:05 - 12:15, 12:10 - 12:20, etc. Note that 12:00 - 12:10 means data that arrived after 12:00 but before 12:10. Now, consider a word that was received at 12:07. This word should increment the counts corresponding to two windows 12:00 - 12:10 and 12:05 - 12:15. So the counts will be indexed by both, the grouping key (i.e. the word) and the window (can be calculated from the event-time). + +The result tables would look something like the following. + +![Window Operations](img/structured-streaming-window.png) + +Since this windowing is similar to grouping, in code, you can use `groupBy()` and `window()` operations to express windowed aggregations. You can see the full code for the below examples in +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala)/ +[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java)/ +[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py). + +
    +
    + +{% highlight scala %} +import spark.implicits._ + +val words = ... // streaming DataFrame of schema { timestamp: Timestamp, word: String } + +// Group the data by window and word and compute the count of each group +val windowedCounts = words.groupBy( + window($"timestamp", "10 minutes", "5 minutes"), + $"word" +).count() +{% endhighlight %} + +
    +
    + +{% highlight java %} +Dataset words = ... // streaming DataFrame of schema { timestamp: Timestamp, word: String } + +// Group the data by window and word and compute the count of each group +Dataset windowedCounts = words.groupBy( + functions.window(words.col("timestamp"), "10 minutes", "5 minutes"), + words.col("word") +).count(); +{% endhighlight %} + +
    +
    +{% highlight python %} +words = ... # streaming DataFrame of schema { timestamp: Timestamp, word: String } + +# Group the data by window and word and compute the count of each group +windowedCounts = words.groupBy( + window(words.timestamp, '10 minutes', '5 minutes'), + words.word +).count() +{% endhighlight %} + +
    +
    + + +Now consider what happens if one of the events arrives late to the application. +For example, a word that was generated at 12:04 but it was received at 12:11. +Since this windowing is based on the time in the data, the time 12:04 should be considered for windowing. This occurs naturally in our window-based grouping – the late data is automatically placed in the proper windows and the correct aggregates are updated as illustrated below. + +![Handling Late Data](img/structured-streaming-late-data.png) + +### Join Operations +Streaming DataFrames can be joined with static DataFrames to create new streaming DataFrames. Here are a few examples. + +
    +
    + +{% highlight scala %} +val staticDf = spark.read. ... +val streamingDf = spark.readStream. ... + +streamingDf.join(staticDf, "type") // inner equi-join with a static DF +streamingDf.join(staticDf, "type", "right_join") // right outer join with a static DF + +{% endhighlight %} + +
    +
    + +{% highlight java %} +Dataset staticDf = spark.read. ...; +Dataset streamingDf = spark.readStream. ...; +streamingDf.join(staticDf, "type"); // inner equi-join with a static DF +streamingDf.join(staticDf, "type", "right_join"); // right outer join with a static DF +{% endhighlight %} + + +
    +
    + +{% highlight python %} +staticDf = spark.read. ... +streamingDf = spark.readStream. ... +streamingDf.join(staticDf, "type") # inner equi-join with a static DF +streamingDf.join(staticDf, "type", "right_join") # right outer join with a static DF +{% endhighlight %} + +
    +
    + +### Unsupported Operations +However, note that all of the operations applicable on static DataFrames/Datasets are not supported in streaming DataFrames/Datasets yet. While some of these unsupported operations will be supported in future releases of Spark, there are others which are fundamentally hard to implement on streaming data efficiently. For example, sorting is not supported on the input streaming Dataset, as it requires keeping track of all the data received in the stream. This is therefore fundamentally hard to execute efficiently. As of Spark 2.0, some of the unsupported operations are as follows + +- Multiple streaming aggregations (i.e. a chain of aggregations on a streaming DF) are not yet supported on streaming Datasets. + +- Limit and take first N rows are not supported on streaming Datasets. + +- Distinct operations on streaming Datasets are not supported. + +- Sorting operations are supported on streaming Datasets only after an aggregation and in Complete Output Mode. + +- Outer joins between a streaming and a static Datasets are conditionally supported. + + + Full outer join with a streaming Dataset is not supported + + + Left outer join with a streaming Dataset on the right is not supported + + + Right outer join with a streaming Dataset on the left is not supported + +- Any kind of joins between two streaming Datasets are not yet supported. + +In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not make sense on a streaming Dataset. Rather, those functionalities can be done by explicitly starting a streaming query (see the next section regarding that). + +- `count()` - Cannot return a single count from a streaming Dataset. Instead, use `ds.groupBy.count()` which returns a streaming Dataset containing a running count. + +- `foreach()` - Instead use `ds.writeStream.foreach(...)` (see next section). + +- `show()` - Instead use the console sink (see next section). + +If you try any of these operations, you will see an AnalysisException like "operation XYZ is not supported with streaming DataFrames/Datasets". + +## Starting Streaming Queries +Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the +`DataStreamWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/ +[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/ +[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. + +- *Details of the output sink:* Data format, location, etc. + +- *Output mode:* Specify what gets written to the output sink. + +- *Query name:* Optionally, specify a unique name of the query for identification. + +- *Trigger interval:* Optionally, specify the trigger interval. If it is not specified, the system will check for availability of new data as soon as the previous processing has completed. If a trigger time is missed because the previous processing has not completed, then the system will attempt to trigger at the next trigger point, not immediately after the processing has completed. + +- *Checkpoint location:* For some output sinks where the end-to-end fault-tolerance can be guaranteed, specify the location where the system will write all the checkpoint information. This should be a directory in an HDFS-compatible fault-tolerant file system. The semantics of checkpointing is discussed in more detail in the next section. + +#### Output Modes +There are two types of output mode currently implemented. + +- **Append mode (default)** - This is the default mode, where only the new rows added to the result table since the last trigger will be outputted to the sink. This is only applicable to queries that *do not have any aggregations* (e.g. queries with only `select`, `where`, `map`, `flatMap`, `filter`, `join`, etc.). + +- **Complete mode** - The whole result table will be outputted to the sink.This is only applicable to queries that *have aggregations*. + +#### Output Sinks +There are a few types of built-in output sinks. + +- **File sink** - Stores the output to a directory. As of Spark 2.0, this only supports Parquet file format, and Append output mode. + +- **Foreach sink** - Runs arbitrary computation on the records in the output. See later in the section for more details. + +- **Console sink (for debugging)** - Prints the output to the console/stdout every time there is a trigger. Both, Append and Complete output modes, are supported. This should be used for debugging purposes on low data volumes as the entire output is collected and stored in the driver's memory after every trigger. + +- **Memory sink (for debugging)** - The output is stored in memory as an in-memory table. Both, Append and Complete output modes, are supported. This should be used for debugging purposes on low data volumes as the entire output is collected and stored in the driver's memory after every trigger. + +Here is a table of all the sinks, and the corresponding settings. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    SinkSupported Output ModesUsageFault-tolerantNotes
    File Sink
    (only parquet in Spark 2.0)
    Append
    writeStream
    .format("parquet")
    .start()
    YesSupports writes to partitioned tables. Partitioning by time may be useful.
    Foreach SinkAll modes
    writeStream
    .foreach(...)
    .start()
    Depends on ForeachWriter implementationMore details in the next section
    Console SinkAppend, Complete
    writeStream
    .format("console")
    .start()
    No
    Memory SinkAppend, Complete
    writeStream
    .format("memory")
    .queryName("table")
    .start()
    NoSaves the output data as a table, for interactive querying. Table name is the query name.
    + +Finally, you have to call `start()` to actually start the execution of the query. This returns a StreamingQuery object which is a handle to the continuously running execution. You can use this object to manage the query, which we will discuss in the next subsection. For now, let’s understand all this with a few examples. + + +
    +
    + +{% highlight scala %} +// ========== DF with no aggregations ========== +val noAggDF = deviceDataDf.select("device").where("signal > 10") + +// Print new data to console +noAggDF + .writeStream + .format("console") + .start() + +// Write new data to Parquet files +noAggDF + .writeStream + .parquet("path/to/destination/directory") + .start() + +// ========== DF with aggregation ========== +val aggDF = df.groupBy("device").count() + +// Print updated aggregations to console +aggDF + .writeStream + .outputMode("complete") + .format("console") + .start() + +// Have all the aggregates in an in-memory table +aggDF + .writeStream + .queryName("aggregates") // this query name will be the table name + .outputMode("complete") + .format("memory") + .start() + +spark.sql("select * from aggregates").show() // interactively query in-memory table +{% endhighlight %} + +
    +
    + +{% highlight java %} +// ========== DF with no aggregations ========== +Dataset noAggDF = deviceDataDf.select("device").where("signal > 10"); + +// Print new data to console +noAggDF + .writeStream() + .format("console") + .start(); + +// Write new data to Parquet files +noAggDF + .writeStream() + .parquet("path/to/destination/directory") + .start(); + +// ========== DF with aggregation ========== +Dataset aggDF = df.groupBy("device").count(); + +// Print updated aggregations to console +aggDF + .writeStream() + .outputMode("complete") + .format("console") + .start(); + +// Have all the aggregates in an in-memory table +aggDF + .writeStream() + .queryName("aggregates") // this query name will be the table name + .outputMode("complete") + .format("memory") + .start(); + +spark.sql("select * from aggregates").show(); // interactively query in-memory table +{% endhighlight %} + +
    +
    + +{% highlight python %} +# ========== DF with no aggregations ========== +noAggDF = deviceDataDf.select("device").where("signal > 10") + +# Print new data to console +noAggDF\ + .writeStream()\ + .format("console")\ + .start() + +# Write new data to Parquet files +noAggDF\ + .writeStream()\ + .parquet("path/to/destination/directory")\ + .start() + +# ========== DF with aggregation ========== +aggDF = df.groupBy("device").count() + +# Print updated aggregations to console +aggDF\ + .writeStream()\ + .outputMode("complete")\ + .format("console")\ + .start() + +# Have all the aggregates in an in memory table. The query name will be the table name +aggDF\ + .writeStream()\ + .queryName("aggregates")\ + .outputMode("complete")\ + .format("memory")\ + .start() + +spark.sql("select * from aggregates").show() # interactively query in-memory table +{% endhighlight %} + +
    +
    + +#### Using Foreach +The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.0, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/ +[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), which has methods that get called whenever there is a sequence of rows generated as output after a trigger. Note the following important points. + +- The writer must be serializable, as it will be serialized and sent to the executors for execution. + +- All the three methods, `open`, `process` and `close` will be called on the executors. + +- The writer must do all the initialization (e.g. opening connections, starting a transaction, etc.) only when the `open` method is called. Be aware that, if there is any initialization in the class as soon as the object is created, then that initialization will happen in the driver (because that is where the instance is being created), which may not be what you intend. + +- `version` and `partition` are two parameters in `open` that uniquely represent a set of rows that needs to be pushed out. `version` is a monotonically increasing id that increases with every trigger. `partition` is an id that represents a partition of the output, since the output is distributed and will be processed on multiple executors. + +- `open` can use the `version` and `partition` to choose whether it needs to write the sequence of rows. Accordingly, it can return `true` (proceed with writing), or `false` (no need to write). If `false` is returned, then `process` will not be called on any row. For example, after a partial failure, some of the output partitions of the failed trigger may have already been committed to a database. Based on metadata stored in the database, the writer can identify partitions that have already been committed and accordingly return false to skip committing them again. + +- Whenever `open` is called, `close` will also be called (unless the JVM exits due to some error). This is true even if `open` returns false. If there is any error in processing and writing the data, `close` will be called with the error. It is your responsibility to clean up state (e.g. connections, transactions, etc.) that have been created in `open` such that there are no resource leaks. + +## Managing Streaming Queries +The `StreamingQuery` object created when a query is started can be used to monitor and manage the query. + +
    +
    + +{% highlight scala %} +val query = df.writeStream.format("console").start() // get the query object + +query.id // get the unique identifier of the running query + +query.name // get the name of the auto-generated or user-specified name + +query.explain() // print detailed explanations of the query + +query.stop() // stop the query + +query.awaitTermination() // block until query is terminated, with stop() or with error + +query.exception() // the exception if the query has been terminated with error + +query.sourceStatus() // progress information about data has been read from the input sources + +query.sinkStatus() // progress information about data written to the output sink +{% endhighlight %} + + +
    +
    + +{% highlight java %} +StreamingQuery query = df.writeStream().format("console").start(); // get the query object + +query.id(); // get the unique identifier of the running query + +query.name(); // get the name of the auto-generated or user-specified name + +query.explain(); // print detailed explanations of the query + +query.stop(); // stop the query + +query.awaitTermination(); // block until query is terminated, with stop() or with error + +query.exception(); // the exception if the query has been terminated with error + +query.sourceStatus(); // progress information about data has been read from the input sources + +query.sinkStatus(); // progress information about data written to the output sink + +{% endhighlight %} + +
    +
    + +{% highlight python %} +query = df.writeStream().format("console").start() # get the query object + +query.id() # get the unique identifier of the running query + +query.name() # get the name of the auto-generated or user-specified name + +query.explain() # print detailed explanations of the query + +query.stop() # stop the query + +query.awaitTermination() # block until query is terminated, with stop() or with error + +query.exception() # the exception if the query has been terminated with error + +query.sourceStatus() # progress information about data has been read from the input sources + +query.sinkStatus() # progress information about data written to the output sink + +{% endhighlight %} + +
    +
    + +You can start any number of queries in a single SparkSession. They will all be running concurrently sharing the cluster resources. You can use `sparkSession.streams()` to get the `StreamingQueryManager` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryManager)/ +[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryManager.html)/ +[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryManager) docs) that can be used to manage the currently active queries. + +
    +
    + +{% highlight scala %} +val spark: SparkSession = ... + +spark.streams.active // get the list of currently active streaming queries + +spark.streams.get(id) // get a query object by its unique id + +spark.streams.awaitAnyTermination() // block until any one of them terminates +{% endhighlight %} + +
    +
    + +{% highlight java %} +SparkSession spark = ... + +spark.streams().active(); // get the list of currently active streaming queries + +spark.streams().get(id); // get a query object by its unique id + +spark.streams().awaitAnyTermination(); // block until any one of them terminates +{% endhighlight %} + +
    +
    + +{% highlight python %} +spark = ... # spark session + +spark.streams().active # get the list of currently active streaming queries + +spark.streams().get(id) # get a query object by its unique id + +spark.streams().awaitAnyTermination() # block until any one of them terminates +{% endhighlight %} + +
    +
    + +Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/ +[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), which will give you regular callback-based updates when queries are started and terminated. + +## Recovering from Failures with Checkpointing +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. As of Spark 2.0, this checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). + +
    +
    + +{% highlight scala %} +aggDF + .writeStream + .outputMode("complete") + .option("checkpointLocation", "path/to/HDFS/dir") + .format("memory") + .start() +{% endhighlight %} + +
    +
    + +{% highlight java %} +aggDF + .writeStream() + .outputMode("complete") + .option("checkpointLocation", "path/to/HDFS/dir") + .format("memory") + .start(); +{% endhighlight %} + +
    +
    + +{% highlight python %} +aggDF\ + .writeStream()\ + .outputMode("complete")\ + .option("checkpointLocation", "path/to/HDFS/dir")\ + .format("memory")\ + .start() +{% endhighlight %} + +
    +
    + +# Where to go from here +- Examples: See and run the +[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming) +examples. +- Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) + + + + + + + + + diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 100ff0b147efd..6fe3049995876 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -58,7 +58,8 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Currently only YARN supports cluster mode for Python applications. +the drivers and the executors. Currently, standalone mode does not support cluster mode for Python +applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. diff --git a/docs/tuning.md b/docs/tuning.md index e73ed69ffbbf8..9c43b315bbb9e 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -45,6 +45,7 @@ and calling `conf.set("spark.serializer", "org.apache.spark.serializer.KryoSeria This setting configures the serializer used for not only shuffling data between worker nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom registration requirement, but we recommend trying it in any network-intensive application. +Since Spark 2.0.0, we internally use Kryo serializer when shuffling RDDs with simple types, arrays of simple types, or string type. Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library. @@ -115,12 +116,15 @@ Although there are two relevant configurations, the typical user should not need as the default values are applicable to most workloads: * `spark.memory.fraction` expresses the size of `M` as a fraction of the (JVM heap space - 300MB) -(default 0.75). The rest of the space (25%) is reserved for user data structures, internal +(default 0.6). The rest of the space (40%) is reserved for user data structures, internal metadata in Spark, and safeguarding against OOM errors in the case of sparse and unusually large records. * `spark.memory.storageFraction` expresses the size of `R` as a fraction of `M` (default 0.5). `R` is the storage space within `M` where cached blocks immune to being evicted by execution. +The value of `spark.memory.fraction` should be set in order to fit this amount of heap space +comfortably within the JVM's old or "tenured" generation. See the discussion of advanced GC +tuning below for details. ## Determining Memory Consumption @@ -201,14 +205,22 @@ temporary objects created during task execution. Some steps which may be useful * Check if there are too many garbage collections by collecting GC stats. If a full GC is invoked multiple times for before a task completes, it means that there isn't enough memory available for executing tasks. -* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of - memory used for caching by lowering `spark.memory.storageFraction`; it is better to cache fewer - objects than to slow down task execution! - * If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You can set the size of the Eden to be an over-estimate of how much memory each task will need. If the size of Eden is determined to be `E`, then you can set the size of the Young generation using the option `-Xmn=4/3*E`. (The scaling up by 4/3 is to account for space used by survivor regions as well.) + +* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of + memory used for caching by lowering `spark.memory.fraction`; it is better to cache fewer + objects than to slow down task execution. Alternatively, consider decreasing the size of + the Young generation. This means lowering `-Xmn` if you've set it as above. If not, try changing the + value of the JVM's `NewRatio` parameter. Many JVMs default this to 2, meaning that the Old generation + occupies 2/3 of the heap. It should be large enough such that this fraction exceeds `spark.memory.fraction`. + +* Try the G1GC garbage collector with `-XX:+UseG1GC`. It can improve performance in some situations where + garbage collection is a bottleneck. Note that with large executor heap sizes, it may be important to + increase the [G1 region size](https://blogs.oracle.com/g1gc/entry/g1_gc_tuning_a_case) + with `-XX:G1HeapRegionSize` * As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the @@ -221,6 +233,9 @@ Our experience suggests that the effect of GC tuning depends on your application There are [many more tuning options](http://www.oracle.com/technetwork/java/javase/gc-tuning-6-140523.html) described online, but at a high level, managing how frequently full GC takes place can help in reducing the overhead. +GC tuning flags for executors can be specified by setting `spark.executor.extraJavaOptions` in +a job's configuration. + # Other Considerations ## Level of Parallelism diff --git a/examples/pom.xml b/examples/pom.xml index 4423d0fbe1df3..b4ba272884c36 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml @@ -87,7 +87,7 @@ org.apache.spark - spark-streaming-kafka_${scala.binary.version} + spark-streaming-kafka-0-8_${scala.binary.version} ${project.version} provided diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index 31a79ddd3fff1..362bd4435ecb3 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -17,11 +17,10 @@ package org.apache.spark.examples; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; +import org.apache.spark.sql.SparkSession; import java.io.Serializable; import java.util.Arrays; @@ -32,8 +31,7 @@ * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. + * please refer to org.apache.spark.ml.classification.LogisticRegression. */ public final class JavaHdfsLR { @@ -43,8 +41,7 @@ public final class JavaHdfsLR { static void showWarning() { String warning = "WARN: This is a naive implementation of Logistic Regression " + "and is given as an example!\n" + - "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " + - "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " + + "Please use org.apache.spark.ml.classification.LogisticRegression " + "for more conventional use."; System.err.println(warning); } @@ -124,9 +121,12 @@ public static void main(String[] args) { showWarning(); - SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - JavaRDD lines = sc.textFile(args[0]); + SparkSession spark = SparkSession + .builder() + .appName("JavaHdfsLR") + .getOrCreate(); + + JavaRDD lines = spark.read().textFile(args[0]).javaRDD(); JavaRDD points = lines.map(new ParsePoint()).cache(); int ITERATIONS = Integer.parseInt(args[1]); @@ -154,6 +154,6 @@ public static void main(String[] args) { System.out.print("Final w: "); printWeights(w); - sc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index ebb0687b14ae0..7775443861661 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -20,12 +20,13 @@ import com.google.common.collect.Lists; import scala.Tuple2; import scala.Tuple3; -import org.apache.spark.SparkConf; + import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.sql.SparkSession; import java.io.Serializable; import java.util.List; @@ -99,9 +100,12 @@ public static Stats extractStats(String line) { } public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaLogQuery") + .getOrCreate(); - SparkConf sparkConf = new SparkConf().setAppName("JavaLogQuery"); - JavaSparkContext jsc = new JavaSparkContext(sparkConf); + JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); JavaRDD dataSet = (args.length == 1) ? jsc.textFile(args[0]) : jsc.parallelize(exampleApacheLogs); @@ -123,6 +127,6 @@ public Stats call(Stats stats, Stats stats2) { for (Tuple2 t : output) { System.out.println(t._1() + "\t" + t._2()); } - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index 229d1234414e5..ed0bb876579ad 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -26,14 +26,13 @@ import com.google.common.collect.Iterables; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.sql.SparkSession; /** * Computes the PageRank of URLs from an input file. Input file should @@ -73,15 +72,17 @@ public static void main(String[] args) throws Exception { showWarning(); - SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank"); - JavaSparkContext ctx = new JavaSparkContext(sparkConf); + SparkSession spark = SparkSession + .builder() + .appName("JavaPageRank") + .getOrCreate(); // Loads in input file. It should be in format of: // URL neighbor URL // URL neighbor URL // URL neighbor URL // ... - JavaRDD lines = ctx.textFile(args[0], 1); + JavaRDD lines = spark.read().textFile(args[0]).javaRDD(); // Loads all URLs from input file and initialize their neighbors. JavaPairRDD> links = lines.mapToPair( @@ -132,6 +133,6 @@ public Double call(Double sum) { System.out.println(tuple._1() + " has rank: " + tuple._2() + "."); } - ctx.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 04a57a6bfb58b..7df145e3117b8 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -17,11 +17,11 @@ package org.apache.spark.examples; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; +import org.apache.spark.sql.SparkSession; import java.util.ArrayList; import java.util.List; @@ -33,8 +33,12 @@ public final class JavaSparkPi { public static void main(String[] args) throws Exception { - SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi"); - JavaSparkContext jsc = new JavaSparkContext(sparkConf); + SparkSession spark = SparkSession + .builder() + .appName("JavaSparkPi") + .getOrCreate(); + + JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); int slices = (args.length == 1) ? Integer.parseInt(args[0]) : 2; int n = 100000 * slices; @@ -61,6 +65,6 @@ public Integer call(Integer integer, Integer integer2) { System.out.println("Pi is roughly " + 4.0 * count / n); - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java index e68ec74c3ed54..6f899c772eb98 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java @@ -17,13 +17,14 @@ package org.apache.spark.examples; -import org.apache.spark.SparkConf; import org.apache.spark.SparkJobInfo; import org.apache.spark.SparkStageInfo; import org.apache.spark.api.java.JavaFutureAction; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SparkSession; + import java.util.Arrays; import java.util.List; @@ -44,11 +45,15 @@ public T call(T x) throws Exception { } public static void main(String[] args) throws Exception { - SparkConf sparkConf = new SparkConf().setAppName(APP_NAME); - final JavaSparkContext sc = new JavaSparkContext(sparkConf); + SparkSession spark = SparkSession + .builder() + .appName(APP_NAME) + .getOrCreate(); + + final JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); // Example of implementing a progress reporter for a simple job. - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map( + JavaRDD rdd = jsc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map( new IdentityWithDelay()); JavaFutureAction> jobFuture = rdd.collectAsync(); while (!jobFuture.isDone()) { @@ -58,13 +63,13 @@ public static void main(String[] args) throws Exception { continue; } int currentJobId = jobIds.get(jobIds.size() - 1); - SparkJobInfo jobInfo = sc.statusTracker().getJobInfo(currentJobId); - SparkStageInfo stageInfo = sc.statusTracker().getStageInfo(jobInfo.stageIds()[0]); + SparkJobInfo jobInfo = jsc.statusTracker().getJobInfo(currentJobId); + SparkStageInfo stageInfo = jsc.statusTracker().getStageInfo(jobInfo.stageIds()[0]); System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() + " active, " + stageInfo.numCompletedTasks() + " complete"); } System.out.println("Job results are: " + jobFuture.get()); - sc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index ca10384212da2..f12ca77ed1eb0 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -25,10 +25,10 @@ import scala.Tuple2; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.sql.SparkSession; /** * Transitive closure on a graph, implemented in Java. @@ -64,10 +64,15 @@ public Tuple2 call(Tuple2> t } public static void main(String[] args) { - SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); + SparkSession spark = SparkSession + .builder() + .appName("JavaTC") + .getOrCreate(); + + JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); + Integer slices = (args.length > 0) ? Integer.parseInt(args[0]): 2; - JavaPairRDD tc = sc.parallelizePairs(generateGraph(), slices).cache(); + JavaPairRDD tc = jsc.parallelizePairs(generateGraph(), slices).cache(); // Linear transitive closure: each round grows paths by one edge, // by joining the graph's edges with the already-discovered paths. @@ -94,6 +99,6 @@ public Tuple2 call(Tuple2 e) { } while (nextCount != oldCount); System.out.println("TC has " + tc.count() + " edges."); - sc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 3ff5412b934f0..8f18604c0750c 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -18,13 +18,13 @@ package org.apache.spark.examples; import scala.Tuple2; -import org.apache.spark.SparkConf; + import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.sql.SparkSession; import java.util.Arrays; import java.util.Iterator; @@ -41,9 +41,12 @@ public static void main(String[] args) throws Exception { System.exit(1); } - SparkConf sparkConf = new SparkConf().setAppName("JavaWordCount"); - JavaSparkContext ctx = new JavaSparkContext(sparkConf); - JavaRDD lines = ctx.textFile(args[0], 1); + SparkSession spark = SparkSession + .builder() + .appName("JavaWordCount") + .getOrCreate(); + + JavaRDD lines = spark.read().textFile(args[0]).javaRDD(); JavaRDD words = lines.flatMap(new FlatMapFunction() { @Override @@ -72,6 +75,6 @@ public Integer call(Integer i1, Integer i2) { for (Tuple2 tuple : output) { System.out.println(tuple._1() + ": " + tuple._2()); } - ctx.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java index 22b93a3a85c52..3f034588c9527 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java @@ -21,23 +21,33 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.regression.AFTSurvivalRegression; import org.apache.spark.ml.regression.AFTSurvivalRegressionModel; -import org.apache.spark.mllib.linalg.*; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; // $example off$ +/** + * An example demonstrating AFTSurvivalRegression. + * Run with + *
    + * bin/run-example ml.JavaAFTSurvivalRegressionExample
    + * 
    + */ public class JavaAFTSurvivalRegressionExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaAFTSurvivalRegressionExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaAFTSurvivalRegressionExample") + .getOrCreate(); // $example on$ List data = Arrays.asList( @@ -52,7 +62,7 @@ public static void main(String[] args) { new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - Dataset training = jsql.createDataFrame(data, schema); + Dataset training = spark.createDataFrame(data, schema); double[] quantileProbabilities = new double[]{0.3, 0.6}; AFTSurvivalRegression aft = new AFTSurvivalRegression() .setQuantileProbabilities(quantileProbabilities) @@ -66,6 +76,6 @@ public static void main(String[] args) { model.transform(training).show(false); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 088037d427f5b..739558e81ffb0 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -17,11 +17,9 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.io.Serializable; @@ -31,7 +29,6 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.recommendation.ALS; import org.apache.spark.ml.recommendation.ALSModel; -import org.apache.spark.sql.types.DataTypes; // $example off$ public class JavaALSExample { @@ -83,18 +80,20 @@ public static Rating parseRating(String str) { // $example off$ public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaALSExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaALSExample") + .getOrCreate(); // $example on$ - JavaRDD ratingsRDD = jsc.textFile("data/mllib/als/sample_movielens_ratings.txt") + JavaRDD ratingsRDD = spark + .read().textFile("data/mllib/als/sample_movielens_ratings.txt").javaRDD() .map(new Function() { public Rating call(String str) { return Rating.parseRating(str); } }); - Dataset ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class); + Dataset ratings = spark.createDataFrame(ratingsRDD, Rating.class); Dataset[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); Dataset training = splits[0]; Dataset test = splits[1]; @@ -109,10 +108,7 @@ public Rating call(String str) { ALSModel model = als.fit(training); // Evaluate the model by computing the RMSE on the test data - Dataset rawPredictions = model.transform(test); - Dataset predictions = rawPredictions - .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType)) - .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType)); + Dataset predictions = model.transform(test); RegressionEvaluator evaluator = new RegressionEvaluator() .setMetricName("rmse") @@ -121,6 +117,6 @@ public Rating call(String str) { Double rmse = evaluator.evaluate(predictions); System.out.println("Root-mean-square error = " + rmse); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java index 0a6e9c2a1f93c..a954dbd20c12f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBinarizerExample.java @@ -17,15 +17,13 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.Binarizer; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -37,21 +35,22 @@ public class JavaBinarizerExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaBinarizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaBinarizerExample") + .getOrCreate(); // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(0, 0.1), RowFactory.create(1, 0.8), RowFactory.create(2, 0.2) - )); + ); StructType schema = new StructType(new StructField[]{ - new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - Dataset continuousDataFrame = jsql.createDataFrame(jrdd, schema); + Dataset continuousDataFrame = spark.createDataFrame(data, schema); Binarizer binarizer = new Binarizer() .setInputCol("feature") .setOutputCol("binarized_feature") @@ -63,6 +62,6 @@ public static void main(String[] args) { System.out.println(binarized_value); } // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java index 1d1a518bbca12..8c82aaaacca38 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java @@ -17,65 +17,51 @@ package org.apache.spark.examples.ml; -import java.util.Arrays; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; // $example on$ import org.apache.spark.ml.clustering.BisectingKMeans; import org.apache.spark.ml.clustering.BisectingKMeansModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; // $example off$ +import org.apache.spark.sql.SparkSession; /** - * An example demonstrating a bisecting k-means clustering. + * An example demonstrating bisecting k-means clustering. + * Run with + *
    + * bin/run-example ml.JavaBisectingKMeansExample
    + * 
    */ public class JavaBisectingKMeansExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaBisectingKMeansExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaBisectingKMeansExample") + .getOrCreate(); // $example on$ - JavaRDD data = jsc.parallelize(Arrays.asList( - RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)), - RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)), - RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)), - RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)), - RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)), - RowFactory.create(Vectors.dense(18.9, 20.0, 19.7)) - )); - - StructType schema = new StructType(new StructField[]{ - new StructField("features", new VectorUDT(), false, Metadata.empty()), - }); - - Dataset dataset = jsql.createDataFrame(data, schema); + // Loads data. + Dataset dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt"); - BisectingKMeans bkm = new BisectingKMeans().setK(2); + // Trains a bisecting k-means model. + BisectingKMeans bkm = new BisectingKMeans().setK(2).setSeed(1); BisectingKMeansModel model = bkm.fit(dataset); - System.out.println("Compute Cost: " + model.computeCost(dataset)); + // Evaluate clustering. + double cost = model.computeCost(dataset); + System.out.println("Within Set Sum of Squared Errors = " + cost); - Vector[] clusterCenters = model.clusterCenters(); - for (int i = 0; i < clusterCenters.length; i++) { - Vector clusterCenter = clusterCenters[i]; - System.out.println("Cluster Center " + i + ": " + clusterCenter); + // Shows the result. + System.out.println("Cluster Centers: "); + Vector[] centers = model.clusterCenters(); + for (Vector center : centers) { + System.out.println(center); } // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java index 68ffa702ea5e2..691df3887a9bb 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -17,14 +17,12 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.Bucketizer; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -37,23 +35,24 @@ public class JavaBucketizerExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaBucketizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaBucketizerExample") + .getOrCreate(); // $example on$ double[] splits = {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}; - JavaRDD data = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), RowFactory.create(0.0), RowFactory.create(0.2) - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("features", DataTypes.DoubleType, false, Metadata.empty()) }); - Dataset dataFrame = jsql.createDataFrame(data, schema); + Dataset dataFrame = spark.createDataFrame(data, schema); Bucketizer bucketizer = new Bucketizer() .setInputCol("features") @@ -64,7 +63,7 @@ public static void main(String[] args) { Dataset bucketedData = bucketizer.transform(dataFrame); bucketedData.show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java index b1bf1cfeb2153..fcf90d8d18748 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java @@ -17,18 +17,16 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; import org.apache.spark.ml.feature.ChiSqSelector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; @@ -39,23 +37,24 @@ public class JavaChiSqSelectorExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaChiSqSelectorExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaChiSqSelectorExample") + .getOrCreate(); // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(7, Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), RowFactory.create(8, Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), RowFactory.create(9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("clicked", DataTypes.DoubleType, false, Metadata.empty()) }); - Dataset df = sqlContext.createDataFrame(jrdd, schema); + Dataset df = spark.createDataFrame(data, schema); ChiSqSelector selector = new ChiSqSelector() .setNumTopFeatures(1) @@ -66,6 +65,6 @@ public static void main(String[] args) { Dataset result = selector.fit(df).transform(df); result.show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java index ec3ac202bea4e..0a6b13601425b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java @@ -19,36 +19,34 @@ // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.CountVectorizer; import org.apache.spark.ml.feature.CountVectorizerModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.*; // $example off$ public class JavaCountVectorizerExample { public static void main(String[] args) { - - SparkConf conf = new SparkConf().setAppName("JavaCountVectorizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaCountVectorizerExample") + .getOrCreate(); // $example on$ // Input data: Each row is a bag of words from a sentence or document. - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(Arrays.asList("a", "b", "c")), RowFactory.create(Arrays.asList("a", "b", "b", "c", "a")) - )); + ); StructType schema = new StructType(new StructField [] { new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - Dataset df = sqlContext.createDataFrame(jrdd, schema); + Dataset df = spark.createDataFrame(data, schema); // fit a CountVectorizerModel from the corpus CountVectorizerModel cvModel = new CountVectorizer() @@ -66,6 +64,6 @@ public static void main(String[] args) { cvModel.transform(df).show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java index 4b15fde9c35fa..66ce23b49d361 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDCTExample.java @@ -17,18 +17,16 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.DCT; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.Metadata; @@ -38,20 +36,21 @@ public class JavaDCTExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaDCTExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaDCTExample") + .getOrCreate(); // $example on$ - JavaRDD data = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(Vectors.dense(0.0, 1.0, -2.0, 3.0)), RowFactory.create(Vectors.dense(-1.0, 2.0, 4.0, -7.0)), RowFactory.create(Vectors.dense(14.0, -2.0, -5.0, 1.0)) - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("features", new VectorUDT(), false, Metadata.empty()), }); - Dataset df = jsql.createDataFrame(data, schema); + Dataset df = spark.createDataFrame(data, schema); DCT dct = new DCT() .setInputCol("features") .setOutputCol("featuresDCT") @@ -59,7 +58,7 @@ public static void main(String[] args) { Dataset dctDf = dct.transform(df); dctDf.select("featuresDCT").show(3); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java index 8214952f80695..a9c6e7f0bf6c5 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java @@ -17,8 +17,6 @@ // scalastyle:off println package org.apache.spark.examples.ml; // $example on$ -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; @@ -28,18 +26,19 @@ import org.apache.spark.ml.feature.*; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example off$ public class JavaDecisionTreeClassificationExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaDecisionTreeClassificationExample") + .getOrCreate(); // $example on$ // Load the data stored in LIBSVM format as a DataFrame. - Dataset data = sqlContext + Dataset data = spark .read() .format("libsvm") .load("data/mllib/sample_libsvm_data.txt"); @@ -55,10 +54,10 @@ public static void main(String[] args) { VectorIndexerModel featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous. .fit(data); - // Split the data into training and test sets (30% held out for testing) + // Split the data into training and test sets (30% held out for testing). Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3}); Dataset trainingData = splits[0]; Dataset testData = splits[1]; @@ -74,11 +73,11 @@ public static void main(String[] args) { .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels()); - // Chain indexers and tree in a Pipeline + // Chain indexers and tree in a Pipeline. Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); - // Train model. This also runs the indexers. + // Train model. This also runs the indexers. PipelineModel model = pipeline.fit(trainingData); // Make predictions. @@ -87,11 +86,11 @@ public static void main(String[] args) { // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(5); - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") - .setMetricName("precision"); + .setMetricName("accuracy"); double accuracy = evaluator.evaluate(predictions); System.out.println("Test Error = " + (1.0 - accuracy)); @@ -100,6 +99,6 @@ public static void main(String[] args) { System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java index a4f3e97bf318a..cffb7139edccd 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java @@ -17,8 +17,6 @@ // scalastyle:off println package org.apache.spark.examples.ml; // $example on$ -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; @@ -29,17 +27,18 @@ import org.apache.spark.ml.regression.DecisionTreeRegressor; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example off$ public class JavaDecisionTreeRegressionExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaDecisionTreeRegressionExample") + .getOrCreate(); // $example on$ // Load the data stored in LIBSVM format as a DataFrame. - Dataset data = sqlContext.read().format("libsvm") + Dataset data = spark.read().format("libsvm") .load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. @@ -50,7 +49,7 @@ public static void main(String[] args) { .setMaxCategories(4) .fit(data); - // Split the data into training and test sets (30% held out for testing) + // Split the data into training and test sets (30% held out for testing). Dataset[] splits = data.randomSplit(new double[]{0.7, 0.3}); Dataset trainingData = splits[0]; Dataset testData = splits[1]; @@ -59,11 +58,11 @@ public static void main(String[] args) { DecisionTreeRegressor dt = new DecisionTreeRegressor() .setFeaturesCol("indexedFeatures"); - // Chain indexer and tree in a Pipeline + // Chain indexer and tree in a Pipeline. Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[]{featureIndexer, dt}); - // Train model. This also runs the indexer. + // Train model. This also runs the indexer. PipelineModel model = pipeline.fit(trainingData); // Make predictions. @@ -72,7 +71,7 @@ public static void main(String[] args) { // Select example rows to display. predictions.select("label", "features").show(5); - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. RegressionEvaluator evaluator = new RegressionEvaluator() .setLabelCol("label") .setPredictionCol("prediction") @@ -85,6 +84,6 @@ public static void main(String[] args) { System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java deleted file mode 100644 index 0ba94786d4e5f..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ /dev/null @@ -1,242 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import java.util.List; - -import com.google.common.collect.Lists; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.Classifier; -import org.apache.spark.ml.classification.ClassificationModel; -import org.apache.spark.ml.param.IntParam; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.util.Identifiable$; -import org.apache.spark.mllib.linalg.BLAS; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - - -/** - * A simple example demonstrating how to write your own learning algorithm using Estimator, - * Transformer, and other abstractions. - * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}. - * - * Run with - *
    - * bin/run-example ml.JavaDeveloperApiExample
    - * 
    - */ -public class JavaDeveloperApiExample { - - public static void main(String[] args) throws Exception { - SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // Prepare training data. - List localTraining = Lists.newArrayList( - new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - Dataset training = jsql.createDataFrame( - jsc.parallelize(localTraining), LabeledPoint.class); - - // Create a LogisticRegression instance. This instance is an Estimator. - MyJavaLogisticRegression lr = new MyJavaLogisticRegression(); - // Print out the parameters, documentation, and any default values. - System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n"); - - // We may set parameters using setter methods. - lr.setMaxIter(10); - - // Learn a LogisticRegression model. This uses the parameters stored in lr. - MyJavaLogisticRegressionModel model = lr.fit(training); - - // Prepare test data. - List localTest = Lists.newArrayList( - new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - Dataset test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); - - // Make predictions on test documents. cvModel uses the best model found (lrModel). - Dataset results = model.transform(test); - double sumPredictions = 0; - for (Row r : results.select("features", "label", "prediction").collectAsList()) { - sumPredictions += r.getDouble(2); - } - if (sumPredictions != 0.0) { - throw new Exception("MyJavaLogisticRegression predicted something other than 0," + - " even though all coefficients are 0!"); - } - - jsc.stop(); - } -} - -/** - * Example of defining a type of {@link Classifier}. - * - * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to - * {@link org.apache.spark.ml.param.Params#set} using incompatible return types. - * However, this should still compile and run successfully. - */ -class MyJavaLogisticRegression - extends Classifier { - - MyJavaLogisticRegression() { - init(); - } - - MyJavaLogisticRegression(String uid) { - this.uid_ = uid; - init(); - } - - private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); - - @Override - public String uid() { - return uid_; - } - - /** - * Param for max number of iterations - *

    - * NOTE: The usual way to add a parameter to a model or algorithm is to include: - * - val myParamName: ParamType - * - def getMyParamName - * - def setMyParamName - */ - IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); - - int getMaxIter() { return (Integer) getOrDefault(maxIter); } - - private void init() { - setMaxIter(100); - } - - // The parameter setter is in this class since it should return type MyJavaLogisticRegression. - MyJavaLogisticRegression setMaxIter(int value) { - return (MyJavaLogisticRegression) set(maxIter, value); - } - - // This method is used by fit(). - // In Java, we have to make it public since Java does not understand Scala's protected modifier. - public MyJavaLogisticRegressionModel train(Dataset dataset) { - // Extract columns from data using helper method. - JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); - - // Do learning to estimate the coefficients vector. - int numFeatures = oldDataset.take(1).get(0).features().size(); - Vector coefficients = Vectors.zeros(numFeatures); // Learning would happen here. - - // Create a model, and return it. - return new MyJavaLogisticRegressionModel(uid(), coefficients).setParent(this); - } - - @Override - public MyJavaLogisticRegression copy(ParamMap extra) { - return defaultCopy(extra); - } -} - -/** - * Example of defining a type of {@link ClassificationModel}. - * - * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to - * {@link org.apache.spark.ml.param.Params#set} using incompatible return types. - * However, this should still compile and run successfully. - */ -class MyJavaLogisticRegressionModel - extends ClassificationModel { - - private Vector coefficients_; - public Vector coefficients() { return coefficients_; } - - MyJavaLogisticRegressionModel(String uid, Vector coefficients) { - this.uid_ = uid; - this.coefficients_ = coefficients; - } - - private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); - - @Override - public String uid() { - return uid_; - } - - // This uses the default implementation of transform(), which reads column "features" and outputs - // columns "prediction" and "rawPrediction." - - // This uses the default implementation of predict(), which chooses the label corresponding to - // the maximum value returned by [[predictRaw()]]. - - /** - * Raw prediction for each possible label. - * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives - * a measure of confidence in each possible label (where larger = more confident). - * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. - * - * @return vector where element i is the raw prediction for label i. - * This raw prediction may be any real number, where a larger value indicates greater - * confidence for that label. - * - * In Java, we have to make this method public since Java does not understand Scala's protected - * modifier. - */ - public Vector predictRaw(Vector features) { - double margin = BLAS.dot(features, coefficients_); - // There are 2 classes (binary classification), so we return a length-2 vector, - // where index i corresponds to class i (i = 0, 1). - return Vectors.dense(-margin, margin); - } - - /** - * Number of classes the label can take. 2 indicates binary classification. - */ - public int numClasses() { return 2; } - - /** - * Number of features the model was trained on. - */ - public int numFeatures() { return coefficients_.size(); } - - /** - * Create a copy of the model. - * The copy is shallow, except for the embedded paramMap, which gets a deep copy. - *

    - * This is used for the default implementation of [[transform()]]. - * - * In Java, we have to make this method public since Java does not understand Scala's protected - * modifier. - */ - @Override - public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), coefficients_), extra) - .setParent(parent()); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java index 37de9cf3596a9..d2e70c23babc7 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaElementwiseProductExample.java @@ -17,21 +17,18 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.ElementwiseProduct; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; @@ -41,16 +38,17 @@ public class JavaElementwiseProductExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaElementwiseProductExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaElementwiseProductExample") + .getOrCreate(); // $example on$ // Create some vector data; also works for sparse vectors - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create("a", Vectors.dense(1.0, 2.0, 3.0)), RowFactory.create("b", Vectors.dense(4.0, 5.0, 6.0)) - )); + ); List fields = new ArrayList<>(2); fields.add(DataTypes.createStructField("id", DataTypes.StringType, false)); @@ -58,7 +56,7 @@ public static void main(String[] args) { StructType schema = DataTypes.createStructType(fields); - Dataset dataFrame = sqlContext.createDataFrame(jrdd, schema); + Dataset dataFrame = spark.createDataFrame(data, schema); Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); @@ -70,6 +68,6 @@ public static void main(String[] args) { // Batch transform the vectors to create new column: transformer.transform(dataFrame).show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java index 604b193dd489b..9e07a0c2f899a 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java @@ -19,42 +19,46 @@ // $example on$ import java.util.Arrays; -// $example off$ +import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; -// $example on$ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; // $example off$ -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; /** * Java example for Estimator, Transformer, and Param. */ public class JavaEstimatorTransformerParamExample { public static void main(String[] args) { - SparkConf conf = new SparkConf() - .setAppName("JavaEstimatorTransformerParamExample"); - SparkContext sc = new SparkContext(conf); - SQLContext sqlContext = new SQLContext(sc); + SparkSession spark = SparkSession + .builder() + .appName("JavaEstimatorTransformerParamExample") + .getOrCreate(); // $example on$ // Prepare training data. - // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans into - // DataFrames, where it uses the bean metadata to infer the schema. - Dataset training = sqlContext.createDataFrame( - Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), - new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)) - ), LabeledPoint.class); + List dataTraining = Arrays.asList( + RowFactory.create(1.0, Vectors.dense(0.0, 1.1, 0.1)), + RowFactory.create(0.0, Vectors.dense(2.0, 1.0, -1.0)), + RowFactory.create(0.0, Vectors.dense(2.0, 1.3, 1.0)), + RowFactory.create(1.0, Vectors.dense(0.0, 1.2, -0.5)) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + Dataset training = spark.createDataFrame(dataTraining, schema); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -89,11 +93,12 @@ public static void main(String[] args) { System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap()); // Prepare test documents. - Dataset test = sqlContext.createDataFrame(Arrays.asList( - new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), - new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), - new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)) - ), LabeledPoint.class); + List dataTest = Arrays.asList( + RowFactory.create(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + RowFactory.create(0.0, Vectors.dense(3.0, 2.0, -0.1)), + RowFactory.create(1.0, Vectors.dense(0.0, 2.2, -1.5)) + ); + Dataset test = spark.createDataFrame(dataTest, schema); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. @@ -107,6 +112,6 @@ public static void main(String[] args) { } // $example off$ - sc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java new file mode 100644 index 0000000000000..526bed93fbd24 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import org.apache.spark.ml.clustering.GaussianMixture; +import org.apache.spark.ml.clustering.GaussianMixtureModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SparkSession; + + +/** + * An example demonstrating Gaussian Mixture Model. + * Run with + *

    + * bin/run-example ml.JavaGaussianMixtureExample
    + * 
    + */ +public class JavaGaussianMixtureExample { + + public static void main(String[] args) { + + // Creates a SparkSession + SparkSession spark = SparkSession + .builder() + .appName("JavaGaussianMixtureExample") + .getOrCreate(); + + // $example on$ + // Loads data + Dataset dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt"); + + // Trains a GaussianMixture model + GaussianMixture gmm = new GaussianMixture() + .setK(2); + GaussianMixtureModel model = gmm.fit(dataset); + + // Output the parameters of the mixture model + for (int i = 0; i < model.getK(); i++) { + System.out.printf("weight=%f\nmu=%s\nsigma=\n%s\n", + model.weights()[i], model.gaussians()[i].mean(), model.gaussians()[i].cov()); + } + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java new file mode 100644 index 0000000000000..3f072d1e50eb5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; + +import org.apache.spark.ml.regression.GeneralizedLinearRegression; +import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel; +import org.apache.spark.ml.regression.GeneralizedLinearRegressionTrainingSummary; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SparkSession; + +/** + * An example demonstrating generalized linear regression. + * Run with + *
    + * bin/run-example ml.JavaGeneralizedLinearRegressionExample
    + * 
    + */ + +public class JavaGeneralizedLinearRegressionExample { + + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaGeneralizedLinearRegressionExample") + .getOrCreate(); + + // $example on$ + // Load training data + Dataset dataset = spark.read().format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt"); + + GeneralizedLinearRegression glr = new GeneralizedLinearRegression() + .setFamily("gaussian") + .setLink("identity") + .setMaxIter(10) + .setRegParam(0.3); + + // Fit the model + GeneralizedLinearRegressionModel model = glr.fit(dataset); + + // Print the coefficients and intercept for generalized linear regression model + System.out.println("Coefficients: " + model.coefficients()); + System.out.println("Intercept: " + model.intercept()); + + // Summarize the model over the training set and print out some metrics + GeneralizedLinearRegressionTrainingSummary summary = model.summary(); + System.out.println("Coefficient Standard Errors: " + + Arrays.toString(summary.coefficientStandardErrors())); + System.out.println("T Values: " + Arrays.toString(summary.tValues())); + System.out.println("P Values: " + Arrays.toString(summary.pValues())); + System.out.println("Dispersion: " + summary.dispersion()); + System.out.println("Null Deviance: " + summary.nullDeviance()); + System.out.println("Residual Degree Of Freedom Null: " + summary.residualDegreeOfFreedomNull()); + System.out.println("Deviance: " + summary.deviance()); + System.out.println("Residual Degree Of Freedom: " + summary.residualDegreeOfFreedom()); + System.out.println("AIC: " + summary.aic()); + System.out.println("Deviance Residuals: "); + summary.residuals().show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java index 553070dace882..3e9eb998c8e1f 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java @@ -17,8 +17,6 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; // $example on$ import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; @@ -29,18 +27,21 @@ import org.apache.spark.ml.feature.*; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example off$ public class JavaGradientBoostedTreeClassifierExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeClassifierExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaGradientBoostedTreeClassifierExample") + .getOrCreate(); // $example on$ // Load and parse the data file, converting it to a DataFrame. - Dataset data = sqlContext.read().format("libsvm") + Dataset data = spark + .read() + .format("libsvm") .load("data/mllib/sample_libsvm_data.txt"); // Index labels, adding metadata to the label column. @@ -74,11 +75,11 @@ public static void main(String[] args) { .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels()); - // Chain indexers and GBT in a Pipeline + // Chain indexers and GBT in a Pipeline. Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter}); - // Train model. This also runs the indexers. + // Train model. This also runs the indexers. PipelineModel model = pipeline.fit(trainingData); // Make predictions. @@ -87,11 +88,11 @@ public static void main(String[] args) { // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(5); - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") - .setMetricName("precision"); + .setMetricName("accuracy"); double accuracy = evaluator.evaluate(predictions); System.out.println("Test Error = " + (1.0 - accuracy)); @@ -99,6 +100,6 @@ public static void main(String[] args) { System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString()); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java index 83fd89e3bd59b..769b5c3e85258 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java @@ -17,8 +17,6 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; // $example on$ import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; @@ -30,19 +28,19 @@ import org.apache.spark.ml.regression.GBTRegressor; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example off$ public class JavaGradientBoostedTreeRegressorExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaGradientBoostedTreeRegressorExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaGradientBoostedTreeRegressorExample") + .getOrCreate(); // $example on$ // Load and parse the data file, converting it to a DataFrame. - Dataset data = - sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + Dataset data = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -52,7 +50,7 @@ public static void main(String[] args) { .setMaxCategories(4) .fit(data); - // Split the data into training and test sets (30% held out for testing) + // Split the data into training and test sets (30% held out for testing). Dataset[] splits = data.randomSplit(new double[] {0.7, 0.3}); Dataset trainingData = splits[0]; Dataset testData = splits[1]; @@ -63,10 +61,10 @@ public static void main(String[] args) { .setFeaturesCol("indexedFeatures") .setMaxIter(10); - // Chain indexer and GBT in a Pipeline + // Chain indexer and GBT in a Pipeline. Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {featureIndexer, gbt}); - // Train model. This also runs the indexer. + // Train model. This also runs the indexer. PipelineModel model = pipeline.fit(trainingData); // Make predictions. @@ -75,7 +73,7 @@ public static void main(String[] args) { // Select example rows to display. predictions.select("prediction", "label", "features").show(5); - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. RegressionEvaluator evaluator = new RegressionEvaluator() .setLabelCol("label") .setPredictionCol("prediction") @@ -87,6 +85,6 @@ public static void main(String[] args) { System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString()); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java index 9b8c22f3bdfde..0064beb8c8f33 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIndexToStringExample.java @@ -17,14 +17,12 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; import org.apache.spark.ml.feature.IndexToString; import org.apache.spark.ml.feature.StringIndexer; @@ -39,24 +37,25 @@ public class JavaIndexToStringExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaIndexToStringExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaIndexToStringExample") + .getOrCreate(); // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(0, "a"), RowFactory.create(1, "b"), RowFactory.create(2, "c"), RowFactory.create(3, "a"), RowFactory.create(4, "a"), RowFactory.create(5, "c") - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("category", DataTypes.StringType, false, Metadata.empty()) }); - Dataset df = sqlContext.createDataFrame(jrdd, schema); + Dataset df = spark.createDataFrame(data, schema); StringIndexerModel indexer = new StringIndexer() .setInputCol("category") @@ -70,6 +69,6 @@ public static void main(String[] args) { Dataset converted = converter.transform(indexed); converted.select("id", "originalCategory").show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaIsotonicRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaIsotonicRegressionExample.java new file mode 100644 index 0000000000000..0ec17b0471553 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaIsotonicRegressionExample.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.ml; + +// $example on$ + +import org.apache.spark.ml.regression.IsotonicRegression; +import org.apache.spark.ml.regression.IsotonicRegressionModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off$ +import org.apache.spark.sql.SparkSession; + +/** + * An example demonstrating IsotonicRegression. + * Run with + *
    + * bin/run-example ml.JavaIsotonicRegressionExample
    + * 
    + */ +public class JavaIsotonicRegressionExample { + + public static void main(String[] args) { + // Create a SparkSession. + SparkSession spark = SparkSession + .builder() + .appName("JavaIsotonicRegressionExample") + .getOrCreate(); + + // $example on$ + // Loads data. + Dataset dataset = spark.read().format("libsvm") + .load("data/mllib/sample_isotonic_regression_libsvm_data.txt"); + + // Trains an isotonic regression model. + IsotonicRegression ir = new IsotonicRegression(); + IsotonicRegressionModel model = ir.fit(dataset); + + System.out.println("Boundaries in increasing order: " + model.boundaries()); + System.out.println("Predictions associated with the boundaries: " + model.predictions()); + + // Makes predictions. + model.transform(dataset).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index c5022f4c0b8fe..d8f948ae38cb3 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -17,78 +17,45 @@ package org.apache.spark.examples.ml; -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; // $example on$ import org.apache.spark.ml.clustering.KMeansModel; import org.apache.spark.ml.clustering.KMeans; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; // $example off$ +import org.apache.spark.sql.SparkSession; /** - * An example demonstrating a k-means clustering. + * An example demonstrating k-means clustering. * Run with *
    - * bin/run-example ml.JavaKMeansExample  
    + * bin/run-example ml.JavaKMeansExample
      * 
    */ public class JavaKMeansExample { - private static class ParsePoint implements Function { - private static final Pattern separator = Pattern.compile(" "); - - @Override - public Row call(String line) { - String[] tok = separator.split(line); - double[] point = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - point[i] = Double.parseDouble(tok[i]); - } - Vector[] points = {Vectors.dense(point)}; - return new GenericRow(points); - } - } - public static void main(String[] args) { - if (args.length != 2) { - System.err.println("Usage: ml.JavaKMeansExample "); - System.exit(1); - } - String inputFile = args[0]; - int k = Integer.parseInt(args[1]); - - // Parses the arguments - SparkConf conf = new SparkConf().setAppName("JavaKMeansExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + // Create a SparkSession. + SparkSession spark = SparkSession + .builder() + .appName("JavaKMeansExample") + .getOrCreate(); // $example on$ - // Loads data - JavaRDD points = jsc.textFile(inputFile).map(new ParsePoint()); - StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; - StructType schema = new StructType(fields); - Dataset dataset = sqlContext.createDataFrame(points, schema); + // Loads data. + Dataset dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt"); - // Trains a k-means model - KMeans kmeans = new KMeans() - .setK(k); + // Trains a k-means model. + KMeans kmeans = new KMeans().setK(2).setSeed(1L); KMeansModel model = kmeans.fit(dataset); - // Shows the result + // Evaluate clustering by computing Within Set Sum of Squared Errors. + double WSSSE = model.computeCost(dataset); + System.out.println("Within Set Sum of Squared Errors = " + WSSSE); + + // Shows the result. Vector[] centers = model.clusterCenters(); System.out.println("Cluster Centers: "); for (Vector center: centers) { @@ -96,6 +63,6 @@ public static void main(String[] args) { } // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java index 351bc401180cc..9041244279079 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLDAExample.java @@ -17,28 +17,15 @@ package org.apache.spark.examples.ml; // $example on$ -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; import org.apache.spark.ml.clustering.LDA; import org.apache.spark.ml.clustering.LDAModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.SparkSession; // $example off$ /** - * An example demonstrating LDA + * An example demonstrating LDA. * Run with *
      * bin/run-example ml.JavaLDAExample
    @@ -46,52 +33,37 @@
      */
     public class JavaLDAExample {
     
    -  // $example on$
    -  private static class ParseVector implements Function {
    -    private static final Pattern separator = Pattern.compile(" ");
    -
    -    @Override
    -    public Row call(String line) {
    -      String[] tok = separator.split(line);
    -      double[] point = new double[tok.length];
    -      for (int i = 0; i < tok.length; ++i) {
    -        point[i] = Double.parseDouble(tok[i]);
    -      }
    -      Vector[] points = {Vectors.dense(point)};
    -      return new GenericRow(points);
    -    }
    -  }
    -
       public static void main(String[] args) {
    +    // Creates a SparkSession
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaLDAExample")
    +      .getOrCreate();
     
    -    String inputFile = "data/mllib/sample_lda_data.txt";
    +    // $example on$
    +    // Loads data.
    +    Dataset dataset = spark.read().format("libsvm")
    +      .load("data/mllib/sample_lda_libsvm_data.txt");
     
    -    // Parses the arguments
    -    SparkConf conf = new SparkConf().setAppName("JavaLDAExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext sqlContext = new SQLContext(jsc);
    -
    -    // Loads data
    -    JavaRDD points = jsc.textFile(inputFile).map(new ParseVector());
    -    StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
    -    StructType schema = new StructType(fields);
    -    Dataset dataset = sqlContext.createDataFrame(points, schema);
    -
    -    // Trains a LDA model
    -    LDA lda = new LDA()
    -      .setK(10)
    -      .setMaxIter(10);
    +    // Trains a LDA model.
    +    LDA lda = new LDA().setK(10).setMaxIter(10);
         LDAModel model = lda.fit(dataset);
     
    -    System.out.println(model.logLikelihood(dataset));
    -    System.out.println(model.logPerplexity(dataset));
    +    double ll = model.logLikelihood(dataset);
    +    double lp = model.logPerplexity(dataset);
    +    System.out.println("The lower bound on the log likelihood of the entire corpus: " + ll);
    +    System.out.println("The upper bound bound on perplexity: " + lp);
     
    -    // Shows the result
    +    // Describe topics.
         Dataset topics = model.describeTopics(3);
    +    System.out.println("The topics described by their top-weighted terms:");
         topics.show(false);
    -    model.transform(dataset).show(false);
     
    -    jsc.stop();
    +    // Shows the result.
    +    Dataset transformed = model.transform(dataset);
    +    transformed.show(false);
    +    // $example off$
    +
    +    spark.stop();
       }
    -  // $example off$
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
    index 08fce89359fc5..a561b6d39ba83 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java
    @@ -17,27 +17,26 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
     // $example on$
     import org.apache.spark.ml.regression.LinearRegression;
     import org.apache.spark.ml.regression.LinearRegressionModel;
     import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
    -import org.apache.spark.mllib.linalg.Vectors;
    +import org.apache.spark.ml.linalg.Vectors;
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     // $example off$
     
     public class JavaLinearRegressionWithElasticNetExample {
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithElasticNetExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext sqlContext = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaLinearRegressionWithElasticNetExample")
    +      .getOrCreate();
     
         // $example on$
    -    // Load training data
    -    Dataset training = sqlContext.read().format("libsvm")
    +    // Load training data.
    +    Dataset training = spark.read().format("libsvm")
           .load("data/mllib/sample_linear_regression_data.txt");
     
         LinearRegression lr = new LinearRegression()
    @@ -45,14 +44,14 @@ public static void main(String[] args) {
           .setRegParam(0.3)
           .setElasticNetParam(0.8);
     
    -    // Fit the model
    +    // Fit the model.
         LinearRegressionModel lrModel = lr.fit(training);
     
    -    // Print the coefficients and intercept for linear regression
    +    // Print the coefficients and intercept for linear regression.
         System.out.println("Coefficients: "
           + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
     
    -    // Summarize the model over the training set and print out some metrics
    +    // Summarize the model over the training set and print out some metrics.
         LinearRegressionTrainingSummary trainingSummary = lrModel.summary();
         System.out.println("numIterations: " + trainingSummary.totalIterations());
         System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory()));
    @@ -61,6 +60,6 @@ public static void main(String[] args) {
         System.out.println("r2: " + trainingSummary.r2());
         // $example off$
     
    -    jsc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
    index 73b028fb44409..dee56799d8aee 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java
    @@ -17,8 +17,6 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
     // $example on$
     import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary;
     import org.apache.spark.ml.classification.LogisticRegression;
    @@ -26,18 +24,19 @@
     import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     import org.apache.spark.sql.functions;
     // $example off$
     
     public class JavaLogisticRegressionSummaryExample {
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionSummaryExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext sqlContext = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaLogisticRegressionSummaryExample")
    +      .getOrCreate();
     
         // Load training data
    -    Dataset training = sqlContext.read().format("libsvm")
    +    Dataset training = spark.read().format("libsvm")
           .load("data/mllib/sample_libsvm_data.txt");
     
         LogisticRegression lr = new LogisticRegression()
    @@ -80,6 +79,6 @@ public static void main(String[] args) {
         lrModel.setThreshold(bestThreshold);
         // $example off$
     
    -    jsc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java
    index 691166852206c..6101c79fb0c98 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java
    @@ -17,25 +17,24 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
     // $example on$
     import org.apache.spark.ml.classification.LogisticRegression;
     import org.apache.spark.ml.classification.LogisticRegressionModel;
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     // $example off$
     
     public class JavaLogisticRegressionWithElasticNetExample {
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithElasticNetExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext sqlContext = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaLogisticRegressionWithElasticNetExample")
    +      .getOrCreate();
     
         // $example on$
         // Load training data
    -    Dataset training = sqlContext.read().format("libsvm")
    +    Dataset training = spark.read().format("libsvm")
           .load("data/mllib/sample_libsvm_data.txt");
     
         LogisticRegression lr = new LogisticRegression()
    @@ -51,6 +50,6 @@ public static void main(String[] args) {
           + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
         // $example off$
     
    -    jsc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java
    index a2a072b253f39..9a27b0e9e23b7 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java
    @@ -17,28 +17,30 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
     // $example on$
     import org.apache.spark.ml.feature.MaxAbsScaler;
     import org.apache.spark.ml.feature.MaxAbsScalerModel;
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
     // $example off$
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     
     public class JavaMaxAbsScalerExample {
     
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JavaMaxAbsScalerExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext jsql = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaMaxAbsScalerExample")
    +      .getOrCreate();
     
         // $example on$
    -    Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
    +    Dataset dataFrame = spark
    +      .read()
    +      .format("libsvm")
    +      .load("data/mllib/sample_libsvm_data.txt");
         MaxAbsScaler scaler = new MaxAbsScaler()
    -        .setInputCol("features")
    -        .setOutputCol("scaledFeatures");
    +      .setInputCol("features")
    +      .setOutputCol("scaledFeatures");
     
         // Compute summary statistics and generate MaxAbsScalerModel
         MaxAbsScalerModel scalerModel = scaler.fit(dataFrame);
    @@ -47,7 +49,7 @@ public static void main(String[] args) {
         Dataset scaledData = scalerModel.transform(dataFrame);
         scaledData.show();
         // $example off$
    -    jsc.stop();
    +    spark.stop();
       }
     
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java
    index 4aee18eeabfcf..37fa1c5434ea6 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java
    @@ -17,9 +17,7 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     
     // $example on$
     import org.apache.spark.ml.feature.MinMaxScaler;
    @@ -30,12 +28,16 @@
     
     public class JavaMinMaxScalerExample {
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JaveMinMaxScalerExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext jsql = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaMinMaxScalerExample")
    +      .getOrCreate();
     
         // $example on$
    -    Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
    +    Dataset dataFrame = spark
    +      .read()
    +      .format("libsvm")
    +      .load("data/mllib/sample_libsvm_data.txt");
         MinMaxScaler scaler = new MinMaxScaler()
           .setInputCol("features")
           .setOutputCol("scaledFeatures");
    @@ -47,6 +49,6 @@ public static void main(String[] args) {
         Dataset scaledData = scalerModel.transform(dataFrame);
         scaledData.show();
         // $example off$
    -    jsc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
    index c4122d1247a94..975c65edc0ca6 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
    @@ -21,8 +21,6 @@
     import java.util.Arrays;
     // $example off$
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.SparkContext;
     // $example on$
     import org.apache.spark.ml.Pipeline;
     import org.apache.spark.ml.PipelineStage;
    @@ -37,21 +35,21 @@
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
     // $example off$
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     
     /**
      * Java example for Model Selection via Cross Validation.
      */
     public class JavaModelSelectionViaCrossValidationExample {
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf()
    -      .setAppName("JavaModelSelectionViaCrossValidationExample");
    -    SparkContext sc = new SparkContext(conf);
    -    SQLContext sqlContext = new SQLContext(sc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaModelSelectionViaCrossValidationExample")
    +      .getOrCreate();
     
         // $example on$
         // Prepare training documents, which are labeled.
    -    Dataset training = sqlContext.createDataFrame(Arrays.asList(
    +    Dataset training = spark.createDataFrame(Arrays.asList(
           new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
           new JavaLabeledDocument(1L, "b d", 0.0),
           new JavaLabeledDocument(2L,"spark f g h", 1.0),
    @@ -102,7 +100,7 @@ public static void main(String[] args) {
         CrossValidatorModel cvModel = cv.fit(training);
     
         // Prepare test documents, which are unlabeled.
    -    Dataset test = sqlContext.createDataFrame(Arrays.asList(
    +    Dataset test = spark.createDataFrame(Arrays.asList(
           new JavaDocument(4L, "spark i j k"),
           new JavaDocument(5L, "l m n"),
           new JavaDocument(6L, "mapreduce spark"),
    @@ -117,6 +115,6 @@ public static void main(String[] args) {
         }
         // $example off$
     
    -    sc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
    index 4994f8f9fa857..0f96293f0348b 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
    @@ -17,8 +17,6 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.SparkContext;
     // $example on$
     import org.apache.spark.ml.evaluation.RegressionEvaluator;
     import org.apache.spark.ml.param.ParamMap;
    @@ -29,7 +27,7 @@
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
     // $example off$
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     
     /**
      * Java example demonstrating model selection using TrainValidationSplit.
    @@ -44,13 +42,13 @@
      */
     public class JavaModelSelectionViaTrainValidationSplitExample {
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf()
    -      .setAppName("JavaModelSelectionViaTrainValidationSplitExample");
    -    SparkContext sc = new SparkContext(conf);
    -    SQLContext jsql = new SQLContext(sc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaModelSelectionViaTrainValidationSplitExample")
    +      .getOrCreate();
     
         // $example on$
    -    Dataset data = jsql.read().format("libsvm")
    +    Dataset data = spark.read().format("libsvm")
           .load("data/mllib/sample_linear_regression_data.txt");
     
         // Prepare training and test data.
    @@ -87,6 +85,6 @@ public static void main(String[] args) {
           .show();
         // $example off$
     
    -    sc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
    index 0ca528d8cd079..0f1d9c26345bd 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java
    @@ -18,11 +18,9 @@
     package org.apache.spark.examples.ml;
     
     // $example on$
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
     import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
     import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
    @@ -34,14 +32,15 @@
     public class JavaMultilayerPerceptronClassifierExample {
     
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JavaMultilayerPerceptronClassifierExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext jsql = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaMultilayerPerceptronClassifierExample")
    +      .getOrCreate();
     
         // $example on$
         // Load training data
         String path = "data/mllib/sample_multiclass_classification_data.txt";
    -    Dataset dataFrame = jsql.read().format("libsvm").load(path);
    +    Dataset dataFrame = spark.read().format("libsvm").load(path);
         // Split the data into train and test
         Dataset[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
         Dataset train = splits[0];
    @@ -58,14 +57,14 @@ public static void main(String[] args) {
           .setMaxIter(100);
         // train the model
         MultilayerPerceptronClassificationModel model = trainer.fit(train);
    -    // compute precision on the test set
    +    // compute accuracy on the test set
         Dataset result = model.transform(test);
         Dataset predictionAndLabels = result.select("prediction", "label");
         MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
    -      .setMetricName("precision");
    -    System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels));
    +      .setMetricName("accuracy");
    +    System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));
         // $example off$
     
    -    jsc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java
    index 608bd80285655..899815f57c84b 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNGramExample.java
    @@ -17,15 +17,13 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
     import org.apache.spark.sql.Dataset;
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     
     // $example on$
     import java.util.Arrays;
    +import java.util.List;
     
    -import org.apache.spark.api.java.JavaRDD;
     import org.apache.spark.ml.feature.NGram;
     import org.apache.spark.sql.Row;
     import org.apache.spark.sql.RowFactory;
    @@ -37,16 +35,17 @@
     
     public class JavaNGramExample {
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JavaNGramExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext sqlContext = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaNGramExample")
    +      .getOrCreate();
     
         // $example on$
    -    JavaRDD jrdd = jsc.parallelize(Arrays.asList(
    +    List data = Arrays.asList(
           RowFactory.create(0.0, Arrays.asList("Hi", "I", "heard", "about", "Spark")),
           RowFactory.create(1.0, Arrays.asList("I", "wish", "Java", "could", "use", "case", "classes")),
           RowFactory.create(2.0, Arrays.asList("Logistic", "regression", "models", "are", "neat"))
    -    ));
    +    );
     
         StructType schema = new StructType(new StructField[]{
           new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
    @@ -54,7 +53,7 @@ public static void main(String[] args) {
             "words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
         });
     
    -    Dataset wordDataFrame = sqlContext.createDataFrame(jrdd, schema);
    +    Dataset wordDataFrame = spark.createDataFrame(data, schema);
     
         NGram ngramTransformer = new NGram().setInputCol("words").setOutputCol("ngrams");
     
    @@ -66,6 +65,6 @@ public static void main(String[] args) {
           System.out.println();
         }
         // $example off$
    -    jsc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java
    index 41d7ad75b9d45..3226d5d2fab6f 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java
    @@ -17,16 +17,13 @@
     
     package org.apache.spark.examples.ml;
     
    -
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
     // $example on$
     import org.apache.spark.ml.classification.NaiveBayes;
     import org.apache.spark.ml.classification.NaiveBayesModel;
     import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     // $example off$
     
     /**
    @@ -35,13 +32,15 @@
     public class JavaNaiveBayesExample {
     
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JavaNaiveBayesExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext jsql = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaNaiveBayesExample")
    +      .getOrCreate();
     
         // $example on$
         // Load training data
    -    Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
    +    Dataset dataFrame =
    +      spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
         // Split the data into train and test
         Dataset[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
         Dataset train = splits[0];
    @@ -51,14 +50,14 @@ public static void main(String[] args) {
         NaiveBayes nb = new NaiveBayes();
         // train the model
         NaiveBayesModel model = nb.fit(train);
    -    // compute precision on the test set
    +    // compute accuracy on the test set
         Dataset result = model.transform(test);
         Dataset predictionAndLabels = result.select("prediction", "label");
         MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
    -      .setMetricName("precision");
    -    System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels));
    +      .setMetricName("accuracy");
    +    System.out.println("Accuracy = " + evaluator.evaluate(predictionAndLabels));
         // $example off$
     
    -    jsc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java
    index 31cd752136689..abc38f85ea774 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNormalizerExample.java
    @@ -17,9 +17,7 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     
     // $example on$
     import org.apache.spark.ml.feature.Normalizer;
    @@ -29,12 +27,14 @@
     
     public class JavaNormalizerExample {
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JavaNormalizerExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext jsql = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaNormalizerExample")
    +      .getOrCreate();
     
         // $example on$
    -    Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
    +    Dataset dataFrame =
    +      spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
     
         // Normalize each Vector using $L^1$ norm.
         Normalizer normalizer = new Normalizer()
    @@ -50,6 +50,6 @@ public static void main(String[] args) {
           normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
         lInfNormData.show();
         // $example off$
    -    jsc.stop();
    +    spark.stop();
       }
     }
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
    index 882438ca28eb7..a15e5f84a1871 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java
    @@ -17,14 +17,12 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
    -import org.apache.spark.sql.SQLContext;
    +import org.apache.spark.sql.SparkSession;
     
     // $example on$
     import java.util.Arrays;
    +import java.util.List;
     
    -import org.apache.spark.api.java.JavaRDD;
     import org.apache.spark.ml.feature.OneHotEncoder;
     import org.apache.spark.ml.feature.StringIndexer;
     import org.apache.spark.ml.feature.StringIndexerModel;
    @@ -39,26 +37,27 @@
     
     public class JavaOneHotEncoderExample {
       public static void main(String[] args) {
    -    SparkConf conf = new SparkConf().setAppName("JavaOneHotEncoderExample");
    -    JavaSparkContext jsc = new JavaSparkContext(conf);
    -    SQLContext sqlContext = new SQLContext(jsc);
    +    SparkSession spark = SparkSession
    +      .builder()
    +      .appName("JavaOneHotEncoderExample")
    +      .getOrCreate();
     
         // $example on$
    -    JavaRDD jrdd = jsc.parallelize(Arrays.asList(
    +    List data = Arrays.asList(
           RowFactory.create(0, "a"),
           RowFactory.create(1, "b"),
           RowFactory.create(2, "c"),
           RowFactory.create(3, "a"),
           RowFactory.create(4, "a"),
           RowFactory.create(5, "c")
    -    ));
    +    );
     
         StructType schema = new StructType(new StructField[]{
    -      new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
    +      new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
           new StructField("category", DataTypes.StringType, false, Metadata.empty())
         });
     
    -    Dataset df = sqlContext.createDataFrame(jrdd, schema);
    +    Dataset df = spark.createDataFrame(data, schema);
     
         StringIndexerModel indexer = new StringIndexer()
           .setInputCol("category")
    @@ -72,7 +71,7 @@ public static void main(String[] args) {
         Dataset encoded = encoder.transform(indexed);
         encoded.select("id", "categoryVec").show();
         // $example off$
    -    jsc.stop();
    +    spark.stop();
       }
     }
     
    diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
    index 1f13b48bf82ae..c6a083ddc984f 100644
    --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
    +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java
    @@ -17,223 +17,68 @@
     
     package org.apache.spark.examples.ml;
     
    -import org.apache.commons.cli.*;
    -
    -import org.apache.spark.SparkConf;
    -import org.apache.spark.api.java.JavaSparkContext;
     // $example on$
     import org.apache.spark.ml.classification.LogisticRegression;
     import org.apache.spark.ml.classification.OneVsRest;
     import org.apache.spark.ml.classification.OneVsRestModel;
    -import org.apache.spark.ml.util.MetadataUtils;
    -import org.apache.spark.mllib.evaluation.MulticlassMetrics;
    -import org.apache.spark.mllib.linalg.Matrix;
    -import org.apache.spark.mllib.linalg.Vector;
    +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
     import org.apache.spark.sql.Dataset;
     import org.apache.spark.sql.Row;
    -import org.apache.spark.sql.SQLContext;
    -import org.apache.spark.sql.types.StructField;
     // $example off$
    +import org.apache.spark.sql.SparkSession;
    +
     
     /**
    - * An example runner for Multiclass to Binary Reduction with One Vs Rest.
    - * The example uses Logistic Regression as the base classifier. All parameters that
    - * can be specified on the base classifier can be passed in to the runner options.
    + * An example of Multiclass to Binary Reduction with One Vs Rest,
    + * using Logistic Regression as the base classifier.
      * Run with
      * 
    - * bin/run-example ml.JavaOneVsRestExample [options]
    + * bin/run-example ml.JavaOneVsRestExample
      * 
    */ public class JavaOneVsRestExample { - - private static class Params { - String input; - String testInput = null; - Integer maxIter = 100; - double tol = 1E-6; - boolean fitIntercept = true; - Double regParam = null; - Double elasticNetParam = null; - double fracTest = 0.2; - } - public static void main(String[] args) { - // parse the arguments - Params params = parse(args); - SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaOneVsRestExample") + .getOrCreate(); // $example on$ - // configure the base classifier - LogisticRegression classifier = new LogisticRegression() - .setMaxIter(params.maxIter) - .setTol(params.tol) - .setFitIntercept(params.fitIntercept); - - if (params.regParam != null) { - classifier.setRegParam(params.regParam); - } - if (params.elasticNetParam != null) { - classifier.setElasticNetParam(params.elasticNetParam); - } + // load data file. + Dataset inputData = spark.read().format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt"); - // instantiate the One Vs Rest Classifier - OneVsRest ovr = new OneVsRest().setClassifier(classifier); + // generate the train/test split. + Dataset[] tmp = inputData.randomSplit(new double[]{0.8, 0.2}); + Dataset train = tmp[0]; + Dataset test = tmp[1]; - String input = params.input; - Dataset inputData = jsql.read().format("libsvm").load(input); - Dataset train; - Dataset test; + // configure the base classifier. + LogisticRegression classifier = new LogisticRegression() + .setMaxIter(10) + .setTol(1E-6) + .setFitIntercept(true); - // compute the train/ test split: if testInput is not provided use part of input - String testInput = params.testInput; - if (testInput != null) { - train = inputData; - // compute the number of features in the training set. - int numFeatures = inputData.first().getAs(1).size(); - test = jsql.read().format("libsvm").option("numFeatures", - String.valueOf(numFeatures)).load(testInput); - } else { - double f = params.fracTest; - Dataset[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); - train = tmp[0]; - test = tmp[1]; - } + // instantiate the One Vs Rest Classifier. + OneVsRest ovr = new OneVsRest().setClassifier(classifier); - // train the multiclass model - OneVsRestModel ovrModel = ovr.fit(train.cache()); + // train the multiclass model. + OneVsRestModel ovrModel = ovr.fit(train); - // score the model on test data - Dataset predictions = ovrModel.transform(test.cache()) + // score the model on test data. + Dataset predictions = ovrModel.transform(test) .select("prediction", "label"); - // obtain metrics - MulticlassMetrics metrics = new MulticlassMetrics(predictions); - StructField predictionColSchema = predictions.schema().apply("prediction"); - Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get(); + // obtain evaluator. + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("accuracy"); - // compute the false positive rate per label - StringBuilder results = new StringBuilder(); - results.append("label\tfpr\n"); - for (int label = 0; label < numClasses; label++) { - results.append(label); - results.append("\t"); - results.append(metrics.falsePositiveRate((double) label)); - results.append("\n"); - } - - Matrix confusionMatrix = metrics.confusionMatrix(); - // output the Confusion Matrix - System.out.println("Confusion Matrix"); - System.out.println(confusionMatrix); - System.out.println(); - System.out.println(results); + // compute the classification error on test data. + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error : " + (1 - accuracy)); // $example off$ - jsc.stop(); - } - - private static Params parse(String[] args) { - Options options = generateCommandlineOptions(); - CommandLineParser parser = new PosixParser(); - Params params = new Params(); - - try { - CommandLine cmd = parser.parse(options, args); - String value; - if (cmd.hasOption("input")) { - params.input = cmd.getOptionValue("input"); - } - if (cmd.hasOption("maxIter")) { - value = cmd.getOptionValue("maxIter"); - params.maxIter = Integer.parseInt(value); - } - if (cmd.hasOption("tol")) { - value = cmd.getOptionValue("tol"); - params.tol = Double.parseDouble(value); - } - if (cmd.hasOption("fitIntercept")) { - value = cmd.getOptionValue("fitIntercept"); - params.fitIntercept = Boolean.parseBoolean(value); - } - if (cmd.hasOption("regParam")) { - value = cmd.getOptionValue("regParam"); - params.regParam = Double.parseDouble(value); - } - if (cmd.hasOption("elasticNetParam")) { - value = cmd.getOptionValue("elasticNetParam"); - params.elasticNetParam = Double.parseDouble(value); - } - if (cmd.hasOption("testInput")) { - value = cmd.getOptionValue("testInput"); - params.testInput = value; - } - if (cmd.hasOption("fracTest")) { - value = cmd.getOptionValue("fracTest"); - params.fracTest = Double.parseDouble(value); - } - - } catch (ParseException e) { - printHelpAndQuit(options); - } - return params; + spark.stop(); } - @SuppressWarnings("static") - private static Options generateCommandlineOptions() { - Option input = OptionBuilder.withArgName("input") - .hasArg() - .isRequired() - .withDescription("input path to labeled examples. This path must be specified") - .create("input"); - Option testInput = OptionBuilder.withArgName("testInput") - .hasArg() - .withDescription("input path to test examples") - .create("testInput"); - Option fracTest = OptionBuilder.withArgName("testInput") - .hasArg() - .withDescription("fraction of data to hold out for testing." + - " If given option testInput, this option is ignored. default: 0.2") - .create("fracTest"); - Option maxIter = OptionBuilder.withArgName("maxIter") - .hasArg() - .withDescription("maximum number of iterations for Logistic Regression. default:100") - .create("maxIter"); - Option tol = OptionBuilder.withArgName("tol") - .hasArg() - .withDescription("the convergence tolerance of iterations " + - "for Logistic Regression. default: 1E-6") - .create("tol"); - Option fitIntercept = OptionBuilder.withArgName("fitIntercept") - .hasArg() - .withDescription("fit intercept for logistic regression. default true") - .create("fitIntercept"); - Option regParam = OptionBuilder.withArgName( "regParam" ) - .hasArg() - .withDescription("the regularization parameter for Logistic Regression.") - .create("regParam"); - Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" ) - .hasArg() - .withDescription("the ElasticNet mixing parameter for Logistic Regression.") - .create("elasticNetParam"); - - Options options = new Options() - .addOption(input) - .addOption(testInput) - .addOption(fracTest) - .addOption(maxIter) - .addOption(tol) - .addOption(fitIntercept) - .addOption(regParam) - .addOption(elasticNetParam); - - return options; - } - - private static void printHelpAndQuit(Options options) { - HelpFormatter formatter = new HelpFormatter(); - formatter.printHelp("JavaOneVsRestExample", options); - System.exit(-1); - } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java index a792fd7d47cc9..d597a9a2ed0b7 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPCAExample.java @@ -17,18 +17,16 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.PCA; import org.apache.spark.ml.feature.PCAModel; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -39,22 +37,23 @@ public class JavaPCAExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaPCAExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaPCAExample") + .getOrCreate(); // $example on$ - JavaRDD data = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(Vectors.sparse(5, new int[]{1, 3}, new double[]{1.0, 7.0})), RowFactory.create(Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)), RowFactory.create(Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("features", new VectorUDT(), false, Metadata.empty()), }); - Dataset df = jsql.createDataFrame(data, schema); + Dataset df = spark.createDataFrame(data, schema); PCAModel pca = new PCA() .setInputCol("features") @@ -65,7 +64,7 @@ public static void main(String[] args) { Dataset result = pca.transform(df).select("pcaFeatures"); result.show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java index 305420f208b79..9a43189c91463 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPipelineExample.java @@ -19,11 +19,7 @@ // $example on$ import java.util.Arrays; -// $example off$ -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; -// $example on$ import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; @@ -33,20 +29,21 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; // $example off$ -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; /** * Java example for simple text document 'Pipeline'. */ public class JavaPipelineExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaPipelineExample"); - SparkContext sc = new SparkContext(conf); - SQLContext sqlContext = new SQLContext(sc); + SparkSession spark = SparkSession + .builder() + .appName("JavaPipelineExample") + .getOrCreate(); // $example on$ // Prepare training documents, which are labeled. - Dataset training = sqlContext.createDataFrame(Arrays.asList( + Dataset training = spark.createDataFrame(Arrays.asList( new JavaLabeledDocument(0L, "a b c d e spark", 1.0), new JavaLabeledDocument(1L, "b d", 0.0), new JavaLabeledDocument(2L, "spark f g h", 1.0), @@ -71,7 +68,7 @@ public static void main(String[] args) { PipelineModel model = pipeline.fit(training); // Prepare test documents, which are unlabeled. - Dataset test = sqlContext.createDataFrame(Arrays.asList( + Dataset test = spark.createDataFrame(Arrays.asList( new JavaDocument(4L, "spark i j k"), new JavaDocument(5L, "l m n"), new JavaDocument(6L, "mapreduce spark"), @@ -86,6 +83,6 @@ public static void main(String[] args) { } // $example off$ - sc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java index 48fc3c8acb0c0..67180df65c721 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaPolynomialExpansionExample.java @@ -17,18 +17,15 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.PolynomialExpansion; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -39,9 +36,10 @@ public class JavaPolynomialExpansionExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaPolynomialExpansionExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaPolynomialExpansionExample") + .getOrCreate(); // $example on$ PolynomialExpansion polyExpansion = new PolynomialExpansion() @@ -49,17 +47,17 @@ public static void main(String[] args) { .setOutputCol("polyFeatures") .setDegree(3); - JavaRDD data = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(Vectors.dense(-2.0, 2.3)), RowFactory.create(Vectors.dense(0.0, 0.0)), RowFactory.create(Vectors.dense(0.6, -1.1)) - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("features", new VectorUDT(), false, Metadata.empty()), }); - Dataset df = jsql.createDataFrame(data, schema); + Dataset df = spark.createDataFrame(data, schema); Dataset polyDF = polyExpansion.transform(df); List rows = polyDF.select("polyFeatures").takeAsList(3); @@ -67,6 +65,6 @@ public static void main(String[] args) { System.out.println(r.get(0)); } // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java index 7b226fede9968..dd20cac621102 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -17,13 +17,11 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.QuantileDiscretizer; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -36,19 +34,18 @@ public class JavaQuantileDiscretizerExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaQuantileDiscretizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaQuantileDiscretizerExample") + .getOrCreate(); // $example on$ - JavaRDD jrdd = jsc.parallelize( - Arrays.asList( - RowFactory.create(0, 18.0), - RowFactory.create(1, 19.0), - RowFactory.create(2, 8.0), - RowFactory.create(3, 5.0), - RowFactory.create(4, 2.2) - ) + List data = Arrays.asList( + RowFactory.create(0, 18.0), + RowFactory.create(1, 19.0), + RowFactory.create(2, 8.0), + RowFactory.create(3, 5.0), + RowFactory.create(4, 2.2) ); StructType schema = new StructType(new StructField[]{ @@ -56,8 +53,13 @@ public static void main(String[] args) { new StructField("hour", DataTypes.DoubleType, false, Metadata.empty()) }); - Dataset df = sqlContext.createDataFrame(jrdd, schema); - + Dataset df = spark.createDataFrame(data, schema); + // $example off$ + // Output of QuantileDiscretizer for such small datasets can depend on the number of + // partitions. Here we force a single partition to ensure consistent results. + // Note this is not necessary for normal use cases + df = df.repartition(1); + // $example on$ QuantileDiscretizer discretizer = new QuantileDiscretizer() .setInputCol("hour") .setOutputCol("result") @@ -66,6 +68,6 @@ public static void main(String[] args) { Dataset result = discretizer.fit(df).transform(df); result.show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java index 8c453bf80d645..428067e0f7efe 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRFormulaExample.java @@ -17,14 +17,12 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.RFormula; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -37,9 +35,10 @@ public class JavaRFormulaExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaRFormulaExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaRFormulaExample") + .getOrCreate(); // $example on$ StructType schema = createStructType(new StructField[]{ @@ -49,13 +48,13 @@ public static void main(String[] args) { createStructField("clicked", DoubleType, false) }); - JavaRDD rdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(7, "US", 18, 1.0), RowFactory.create(8, "CA", 12, 0.0), RowFactory.create(9, "NZ", 15, 0.0) - )); + ); - Dataset dataset = sqlContext.createDataFrame(rdd, schema); + Dataset dataset = spark.createDataFrame(data, schema); RFormula formula = new RFormula() .setFormula("clicked ~ country + hour") .setFeaturesCol("features") @@ -63,7 +62,7 @@ public static void main(String[] args) { Dataset output = formula.fit(dataset).transform(dataset); output.select("features", "label").show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java index 05c2bc9622e1b..da2633e8860ac 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java @@ -17,8 +17,6 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; // $example on$ import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; @@ -29,19 +27,19 @@ import org.apache.spark.ml.feature.*; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example off$ public class JavaRandomForestClassifierExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaRandomForestClassifierExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaRandomForestClassifierExample") + .getOrCreate(); // $example on$ // Load and parse the data file, converting it to a DataFrame. - Dataset data = - sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + Dataset data = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -90,7 +88,7 @@ public static void main(String[] args) { MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") - .setMetricName("precision"); + .setMetricName("accuracy"); double accuracy = evaluator.evaluate(predictions); System.out.println("Test Error = " + (1.0 - accuracy)); @@ -98,6 +96,6 @@ public static void main(String[] args) { System.out.println("Learned classification forest model:\n" + rfModel.toDebugString()); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java index d366967083a19..a7078453deb8a 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java @@ -17,8 +17,6 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; // $example on$ import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; @@ -30,19 +28,19 @@ import org.apache.spark.ml.regression.RandomForestRegressor; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example off$ public class JavaRandomForestRegressorExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaRandomForestRegressorExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaRandomForestRegressorExample") + .getOrCreate(); // $example on$ // Load and parse the data file, converting it to a DataFrame. - Dataset data = - sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + Dataset data = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -66,7 +64,7 @@ public static void main(String[] args) { Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {featureIndexer, rf}); - // Train model. This also runs the indexer. + // Train model. This also runs the indexer. PipelineModel model = pipeline.fit(trainingData); // Make predictions. @@ -87,6 +85,6 @@ public static void main(String[] args) { System.out.println("Learned regression forest model:\n" + rfModel.toDebugString()); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java index 7e3ca99d7cb93..2a3d62de41ab7 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSQLTransformerExample.java @@ -19,36 +19,34 @@ // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.SQLTransformer; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.*; // $example off$ public class JavaSQLTransformerExample { public static void main(String[] args) { - - SparkConf conf = new SparkConf().setAppName("JavaSQLTransformerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaSQLTransformerExample") + .getOrCreate(); // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(0, 1.0, 3.0), RowFactory.create(2, 2.0, 5.0) - )); + ); StructType schema = new StructType(new StructField [] { new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("v1", DataTypes.DoubleType, false, Metadata.empty()), new StructField("v2", DataTypes.DoubleType, false, Metadata.empty()) }); - Dataset df = sqlContext.createDataFrame(jrdd, schema); + Dataset df = spark.createDataFrame(data, schema); SQLTransformer sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"); @@ -56,6 +54,6 @@ public static void main(String[] args) { sqlTrans.transform(df).show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index cb911ef5ef586..ca80d0d8bba57 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -21,8 +21,6 @@ import com.google.common.collect.Lists; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.classification.LogisticRegression; @@ -30,7 +28,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; /** * A simple example demonstrating ways to specify parameters for Estimators and Transformers. @@ -42,12 +40,13 @@ public class JavaSimpleParamsExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaSimpleParamsExample") + .getOrCreate(); // Prepare training data. - // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans + // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans // into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), @@ -55,9 +54,9 @@ public static void main(String[] args) { new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); Dataset training = - jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class); + spark.createDataFrame(localTraining, LabeledPoint.class); - // Create a LogisticRegression instance. This instance is an Estimator. + // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); // Print out the parameters, documentation, and any default values. System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n"); @@ -66,7 +65,7 @@ public static void main(String[] args) { lr.setMaxIter(10) .setRegParam(0.01); - // Learn a LogisticRegression model. This uses the parameters stored in lr. + // Learn a LogisticRegression model. This uses the parameters stored in lr. LogisticRegressionModel model1 = lr.fit(training); // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). @@ -78,12 +77,12 @@ public static void main(String[] args) { ParamMap paramMap = new ParamMap(); paramMap.put(lr.maxIter().w(20)); // Specify 1 Param. paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter. - double[] thresholds = {0.45, 0.55}; + double[] thresholds = {0.5, 0.5}; paramMap.put(lr.regParam().w(0.1), lr.thresholds().w(thresholds)); // Specify multiple Params. // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); - paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name + paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name. ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -96,7 +95,7 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - Dataset test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); + Dataset test = spark.createDataFrame(localTest, LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegressionModel.transform will only use the 'features' column. @@ -109,6 +108,6 @@ public static void main(String[] args) { + ", prediction=" + r.get(3)); } - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index a18a60f448166..7c24c46d2e287 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -21,8 +21,6 @@ import com.google.common.collect.Lists; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; @@ -31,7 +29,7 @@ import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; /** * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java @@ -44,9 +42,10 @@ public class JavaSimpleTextClassificationPipeline { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaSimpleTextClassificationPipeline") + .getOrCreate(); // Prepare training documents, which are labeled. List localTraining = Lists.newArrayList( @@ -55,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); Dataset training = - jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class); + spark.createDataFrame(localTraining, LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -80,7 +79,7 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "spark hadoop spark"), new Document(7L, "apache hadoop")); - Dataset test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); + Dataset test = spark.createDataFrame(localTest, Document.class); // Make predictions on test documents. Dataset predictions = model.transform(test); @@ -89,6 +88,6 @@ public static void main(String[] args) { + ", prediction=" + r.get(3)); } - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java index e2dd759c0a40c..08ea285a0d53d 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStandardScalerExample.java @@ -17,9 +17,7 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import org.apache.spark.ml.feature.StandardScaler; @@ -30,12 +28,14 @@ public class JavaStandardScalerExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStandardScalerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaStandardScalerExample") + .getOrCreate(); // $example on$ - Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + Dataset dataFrame = + spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); StandardScaler scaler = new StandardScaler() .setInputCol("features") @@ -50,6 +50,6 @@ public static void main(String[] args) { Dataset scaledData = scalerModel.transform(dataFrame); scaledData.show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java index 0ff3782cb3e90..def5994429124 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStopWordsRemoverExample.java @@ -17,14 +17,12 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.StopWordsRemover; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -38,28 +36,29 @@ public class JavaStopWordsRemoverExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStopWordsRemoverExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaStopWordsRemoverExample") + .getOrCreate(); // $example on$ StopWordsRemover remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered"); - JavaRDD rdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) - )); + ); StructType schema = new StructType(new StructField[]{ new StructField( "raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) }); - Dataset dataset = jsql.createDataFrame(rdd, schema); + Dataset dataset = spark.createDataFrame(data, schema); remover.transform(dataset).show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java index ceacbb4fb3f33..7533c1835e325 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaStringIndexerExample.java @@ -17,14 +17,12 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.StringIndexer; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -37,30 +35,31 @@ public class JavaStringIndexerExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaStringIndexerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaStringIndexerExample") + .getOrCreate(); // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(0, "a"), RowFactory.create(1, "b"), RowFactory.create(2, "c"), RowFactory.create(3, "a"), RowFactory.create(4, "a"), RowFactory.create(5, "c") - )); + ); StructType schema = new StructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("category", StringType, false) }); - Dataset df = sqlContext.createDataFrame(jrdd, schema); + Dataset df = spark.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("category") .setOutputCol("categoryIndex"); Dataset indexed = indexer.fit(df).transform(df); indexed.show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java index 107c835f2e01e..800e42c949cbe 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java @@ -19,19 +19,17 @@ // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.IDF; import org.apache.spark.ml.feature.IDFModel; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -40,21 +38,22 @@ public class JavaTfIdfExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaTfIdfExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaTfIdfExample") + .getOrCreate(); // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") - )); + List data = Arrays.asList( + RowFactory.create(0.0, "Hi I heard about Spark"), + RowFactory.create(0.0, "I wish Java could use case classes"), + RowFactory.create(1.0, "Logistic regression models are neat") + ); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - Dataset sentenceData = sqlContext.createDataFrame(jrdd, schema); + Dataset sentenceData = spark.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); Dataset wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; @@ -76,6 +75,6 @@ public static void main(String[] args) { } // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java index 9225fe2262f57..1cc16bb60d172 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTokenizerExample.java @@ -17,14 +17,12 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.sql.Dataset; @@ -38,23 +36,24 @@ public class JavaTokenizerExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaTokenizerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaTokenizerExample") + .getOrCreate(); // $example on$ - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(0, "Hi I heard about Spark"), RowFactory.create(1, "I wish Java could use case classes"), RowFactory.create(2, "Logistic,regression,models,are,neat") - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.IntegerType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - Dataset sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema); + Dataset sentenceDataFrame = spark.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); @@ -70,6 +69,6 @@ public static void main(String[] args) { .setOutputCol("words") .setPattern("\\W"); // alternatively .setPattern("\\w+").setGaps(false); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java index 953ad455b1dcd..9bb0f93d3a6a1 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorAssemblerExample.java @@ -17,17 +17,14 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Arrays; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.VectorAssembler; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -38,9 +35,10 @@ public class JavaVectorAssemblerExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorAssemblerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaVectorAssemblerExample") + .getOrCreate(); // $example on$ StructType schema = createStructType(new StructField[]{ @@ -51,8 +49,7 @@ public static void main(String[] args) { createStructField("clicked", DoubleType, false) }); Row row = RowFactory.create(0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0); - JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); - Dataset dataset = sqlContext.createDataFrame(rdd, schema); + Dataset dataset = spark.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() .setInputCols(new String[]{"hour", "mobile", "userFeatures"}) @@ -61,7 +58,7 @@ public static void main(String[] args) { Dataset output = assembler.transform(dataset); System.out.println(output.select("features", "clicked").first()); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java index b3b5953ee7bbe..dd9d757dd6831 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorIndexerExample.java @@ -17,9 +17,7 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ import java.util.Map; @@ -32,12 +30,13 @@ public class JavaVectorIndexerExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorIndexerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaVectorIndexerExample") + .getOrCreate(); // $example on$ - Dataset data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + Dataset data = spark.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") @@ -57,6 +56,6 @@ public static void main(String[] args) { Dataset indexedData = indexerModel.transform(data); indexedData.show(); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java index 2ae57c3577eff..19b8bc83be6e1 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaVectorSlicerExample.java @@ -17,19 +17,18 @@ package org.apache.spark.examples.ml; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; // $example on$ +import java.util.List; + import com.google.common.collect.Lists; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.attribute.Attribute; import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; import org.apache.spark.ml.feature.VectorSlicer; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; @@ -38,9 +37,10 @@ public class JavaVectorSlicerExample { public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaVectorSlicerExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaVectorSlicerExample") + .getOrCreate(); // $example on$ Attribute[] attrs = new Attribute[]{ @@ -50,13 +50,13 @@ public static void main(String[] args) { }; AttributeGroup group = new AttributeGroup("userFeatures", attrs); - JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + List data = Lists.newArrayList( RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})), RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) - )); + ); Dataset dataset = - jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField())); + spark.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); @@ -68,7 +68,7 @@ public static void main(String[] args) { System.out.println(output.select("userFeatures", "features").first()); // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java index c5bb1eaaa3446..9be6e6353adcf 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaWord2VecExample.java @@ -19,37 +19,35 @@ // $example on$ import java.util.Arrays; +import java.util.List; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.Word2Vec; import org.apache.spark.ml.feature.Word2VecModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.*; // $example off$ public class JavaWord2VecExample { public static void main(String[] args) { - - SparkConf conf = new SparkConf().setAppName("JavaWord2VecExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext sqlContext = new SQLContext(jsc); + SparkSession spark = SparkSession + .builder() + .appName("JavaWord2VecExample") + .getOrCreate(); // $example on$ // Input data: Each row is a bag of words from a sentence or document. - JavaRDD jrdd = jsc.parallelize(Arrays.asList( + List data = Arrays.asList( RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" "))) - )); + ); StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - Dataset documentDF = sqlContext.createDataFrame(jrdd, schema); + Dataset documentDF = spark.createDataFrame(data, schema); // Learn a mapping from words to Vectors. Word2Vec word2Vec = new Word2Vec() @@ -64,6 +62,6 @@ public static void main(String[] args) { } // $example off$ - jsc.stop(); + spark.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java index c6361a3729988..a30b5f1f73eaf 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaIsotonicRegressionExample.java @@ -17,6 +17,7 @@ package org.apache.spark.examples.mllib; // $example on$ + import scala.Tuple2; import scala.Tuple3; import org.apache.spark.api.java.function.Function; @@ -27,6 +28,8 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.mllib.regression.IsotonicRegression; import org.apache.spark.mllib.regression.IsotonicRegressionModel; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; // $example off$ import org.apache.spark.SparkConf; @@ -35,27 +38,29 @@ public static void main(String[] args) { SparkConf sparkConf = new SparkConf().setAppName("JavaIsotonicRegressionExample"); JavaSparkContext jsc = new JavaSparkContext(sparkConf); // $example on$ - JavaRDD data = jsc.textFile("data/mllib/sample_isotonic_regression_data.txt"); + JavaRDD data = MLUtils.loadLibSVMFile( + jsc.sc(), "data/mllib/sample_isotonic_regression_libsvm_data.txt").toJavaRDD(); // Create label, feature, weight tuples from input data with weight set to default value 1.0. JavaRDD> parsedData = data.map( - new Function>() { - public Tuple3 call(String line) { - String[] parts = line.split(","); - return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0); + new Function>() { + public Tuple3 call(LabeledPoint point) { + return new Tuple3<>(new Double(point.label()), + new Double(point.features().apply(0)), 1.0); } } ); // Split data into training (60%) and test (40%) sets. JavaRDD>[] splits = - parsedData.randomSplit(new double[]{0.6, 0.4}, 11L); + parsedData.randomSplit(new double[]{0.6, 0.4}, 11L); JavaRDD> training = splits[0]; JavaRDD> test = splits[1]; // Create isotonic regression model from training data. // Isotonic parameter defaults to true so it is only shown for demonstration - final IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training); + final IsotonicRegressionModel model = + new IsotonicRegression().setIsotonic(true).run(training); // Create tuples of predicted and real labels. JavaPairRDD predictionAndLabel = test.mapToPair( diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java index 9d8e4a90dbc99..7fc371ec0f990 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java @@ -65,8 +65,8 @@ public Tuple2 call(LabeledPoint p) { // Get evaluation metrics. MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); - double precision = metrics.precision(); - System.out.println("Precision = " + precision); + double accuracy = metrics.accuracy(); + System.out.println("Accuracy = " + accuracy); // Save and load model model.save(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel"); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java index 5247c9c748618..e84a3a712df14 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -68,9 +68,7 @@ public Tuple2 call(LabeledPoint p) { System.out.println("Confusion matrix: \n" + confusion); // Overall statistics - System.out.println("Precision = " + metrics.precision()); - System.out.println("Recall = " + metrics.recall()); - System.out.println("F1 Score = " + metrics.fMeasure()); + System.out.println("Accuracy = " + metrics.accuracy()); // Stats by labels for (int i = 0; i < metrics.labels().length; i++) { diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java index 2b17dbb96365e..f4ec04b0c677c 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java @@ -36,9 +36,9 @@ public static void main(String[] args) { SparkConf sparkConf = new SparkConf().setAppName("JavaNaiveBayesExample"); JavaSparkContext jsc = new JavaSparkContext(sparkConf); // $example on$ - String path = "data/mllib/sample_naive_bayes_data.txt"; + String path = "data/mllib/sample_libsvm_data.txt"; JavaRDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD(); - JavaRDD[] tmp = inputData.randomSplit(new double[]{0.6, 0.4}, 12345); + JavaRDD[] tmp = inputData.randomSplit(new double[]{0.6, 0.4}); JavaRDD training = tmp[0]; // training set JavaRDD test = tmp[1]; // test set final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java new file mode 100644 index 0000000000000..52e3b62b79dd2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql; + +// $example on:schema_merging$ +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +// $example off:schema_merging$ + +// $example on:basic_parquet_example$ +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Encoders; +// $example on:schema_merging$ +// $example on:json_dataset$ +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off:json_dataset$ +// $example off:schema_merging$ +// $example off:basic_parquet_example$ +import org.apache.spark.sql.SparkSession; + +public class JavaSQLDataSourceExample { + + // $example on:schema_merging$ + public static class Square implements Serializable { + private int value; + private int square; + + // Getters and setters... + // $example off:schema_merging$ + public int getValue() { + return value; + } + + public void setValue(int value) { + this.value = value; + } + + public int getSquare() { + return square; + } + + public void setSquare(int square) { + this.square = square; + } + // $example on:schema_merging$ + } + // $example off:schema_merging$ + + // $example on:schema_merging$ + public static class Cube implements Serializable { + private int value; + private int cube; + + // Getters and setters... + // $example off:schema_merging$ + public int getValue() { + return value; + } + + public void setValue(int value) { + this.value = value; + } + + public int getCube() { + return cube; + } + + public void setCube(int cube) { + this.cube = cube; + } + // $example on:schema_merging$ + } + // $example off:schema_merging$ + + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("Java Spark SQL data sources example") + .config("spark.some.config.option", "some-value") + .getOrCreate(); + + runBasicDataSourceExample(spark); + runBasicParquetExample(spark); + runParquetSchemaMergingExample(spark); + runJsonDatasetExample(spark); + runJdbcDatasetExample(spark); + + spark.stop(); + } + + private static void runBasicDataSourceExample(SparkSession spark) { + // $example on:generic_load_save_functions$ + Dataset usersDF = spark.read().load("examples/src/main/resources/users.parquet"); + usersDF.select("name", "favorite_color").write().save("namesAndFavColors.parquet"); + // $example off:generic_load_save_functions$ + // $example on:manual_load_options$ + Dataset peopleDF = + spark.read().format("json").load("examples/src/main/resources/people.json"); + peopleDF.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); + // $example off:manual_load_options$ + // $example on:direct_sql$ + Dataset sqlDF = + spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); + // $example off:direct_sql$ + } + + private static void runBasicParquetExample(SparkSession spark) { + // $example on:basic_parquet_example$ + Dataset peopleDF = spark.read().json("examples/src/main/resources/people.json"); + + // DataFrames can be saved as Parquet files, maintaining the schema information + peopleDF.write().parquet("people.parquet"); + + // Read in the Parquet file created above. + // Parquet files are self-describing so the schema is preserved + // The result of loading a parquet file is also a DataFrame + Dataset parquetFileDF = spark.read().parquet("people.parquet"); + + // Parquet files can also be used to create a temporary view and then used in SQL statements + parquetFileDF.createOrReplaceTempView("parquetFile"); + Dataset namesDF = spark.sql("SELECT name FROM parquetFile WHERE age BETWEEN 13 AND 19"); + Dataset namesDS = namesDF.map(new MapFunction() { + public String call(Row row) { + return "Name: " + row.getString(0); + } + }, Encoders.STRING()); + namesDS.show(); + // +------------+ + // | value| + // +------------+ + // |Name: Justin| + // +------------+ + // $example off:basic_parquet_example$ + } + + private static void runParquetSchemaMergingExample(SparkSession spark) { + // $example on:schema_merging$ + List squares = new ArrayList<>(); + for (int value = 1; value <= 5; value++) { + Square square = new Square(); + square.setValue(value); + square.setSquare(value * value); + squares.add(square); + } + + // Create a simple DataFrame, store into a partition directory + Dataset squaresDF = spark.createDataFrame(squares, Square.class); + squaresDF.write().parquet("data/test_table/key=1"); + + List cubes = new ArrayList<>(); + for (int value = 6; value <= 10; value++) { + Cube cube = new Cube(); + cube.setValue(value); + cube.setCube(value * value * value); + cubes.add(cube); + } + + // Create another DataFrame in a new partition directory, + // adding a new column and dropping an existing column + Dataset cubesDF = spark.createDataFrame(cubes, Cube.class); + cubesDF.write().parquet("data/test_table/key=2"); + + // Read the partitioned table + Dataset mergedDF = spark.read().option("mergeSchema", true).parquet("data/test_table"); + mergedDF.printSchema(); + + // The final schema consists of all 3 columns in the Parquet files together + // with the partitioning column appeared in the partition directory paths + // root + // |-- value: int (nullable = true) + // |-- square: int (nullable = true) + // |-- cube: int (nullable = true) + // |-- key: int (nullable = true) + // $example off:schema_merging$ + } + + private static void runJsonDatasetExample(SparkSession spark) { + // $example on:json_dataset$ + // A JSON dataset is pointed to by path. + // The path can be either a single text file or a directory storing text files + Dataset people = spark.read().json("examples/src/main/resources/people.json"); + + // The inferred schema can be visualized using the printSchema() method + people.printSchema(); + // root + // |-- age: long (nullable = true) + // |-- name: string (nullable = true) + + // Creates a temporary view using the DataFrame + people.createOrReplaceTempView("people"); + + // SQL statements can be run by using the sql methods provided by spark + Dataset namesDF = spark.sql("SELECT name FROM people WHERE age BETWEEN 13 AND 19"); + namesDF.show(); + // +------+ + // | name| + // +------+ + // |Justin| + // +------+ + // $example off:json_dataset$ + } + + private static void runJdbcDatasetExample(SparkSession spark) { + // $example on:jdbc_dataset$ + Dataset jdbcDF = spark.read() + .format("jdbc") + .option("url", "jdbc:postgresql:dbserver") + .option("dbtable", "schema.tablename") + .option("user", "username") + .option("password", "password") + .load(); + // $example off:jdbc_dataset$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java deleted file mode 100644 index 354a5306ed45f..0000000000000 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.sql; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; - -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - -public class JavaSparkSQL { - public static class Person implements Serializable { - private String name; - private int age; - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public int getAge() { - return age; - } - - public void setAge(int age) { - this.age = age; - } - } - - public static void main(String[] args) throws Exception { - SparkConf sparkConf = new SparkConf().setAppName("JavaSparkSQL"); - JavaSparkContext ctx = new JavaSparkContext(sparkConf); - SQLContext sqlContext = new SQLContext(ctx); - - System.out.println("=== Data source: RDD ==="); - // Load a text file and convert each line to a Java Bean. - JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map( - new Function() { - @Override - public Person call(String line) { - String[] parts = line.split(","); - - Person person = new Person(); - person.setName(parts[0]); - person.setAge(Integer.parseInt(parts[1].trim())); - - return person; - } - }); - - // Apply a schema to an RDD of Java Beans and register it as a table. - Dataset schemaPeople = sqlContext.createDataFrame(people, Person.class); - schemaPeople.registerTempTable("people"); - - // SQL can be run over RDDs that have been registered as tables. - Dataset teenagers = - sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - - // The results of SQL queries are DataFrames and support all the normal RDD operations. - // The columns of a row in the result can be accessed by ordinal. - List teenagerNames = teenagers.toJavaRDD().map(new Function() { - @Override - public String call(Row row) { - return "Name: " + row.getString(0); - } - }).collect(); - for (String name: teenagerNames) { - System.out.println(name); - } - - System.out.println("=== Data source: Parquet File ==="); - // DataFrames can be saved as parquet files, maintaining the schema information. - schemaPeople.write().parquet("people.parquet"); - - // Read in the parquet file created above. - // Parquet files are self-describing so the schema is preserved. - // The result of loading a parquet file is also a DataFrame. - Dataset parquetFile = sqlContext.read().parquet("people.parquet"); - - //Parquet files can also be registered as tables and then used in SQL statements. - parquetFile.registerTempTable("parquetFile"); - Dataset teenagers2 = - sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); - teenagerNames = teenagers2.toJavaRDD().map(new Function() { - @Override - public String call(Row row) { - return "Name: " + row.getString(0); - } - }).collect(); - for (String name: teenagerNames) { - System.out.println(name); - } - - System.out.println("=== Data source: JSON Dataset ==="); - // A JSON dataset is pointed by path. - // The path can be either a single text file or a directory storing text files. - String path = "examples/src/main/resources/people.json"; - // Create a DataFrame from the file(s) pointed by path - Dataset peopleFromJsonFile = sqlContext.read().json(path); - - // Because the schema of a JSON dataset is automatically inferred, to write queries, - // it is better to take a look at what is the schema. - peopleFromJsonFile.printSchema(); - // The schema of people is ... - // root - // |-- age: IntegerType - // |-- name: StringType - - // Register this DataFrame as a table. - peopleFromJsonFile.registerTempTable("people"); - - // SQL statements can be run by using the sql methods provided by sqlContext. - Dataset teenagers3 = - sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - - // The results of SQL queries are DataFrame and support all the normal RDD operations. - // The columns of a row in the result can be accessed by ordinal. - teenagerNames = teenagers3.toJavaRDD().map(new Function() { - @Override - public String call(Row row) { return "Name: " + row.getString(0); } - }).collect(); - for (String name: teenagerNames) { - System.out.println(name); - } - - // Alternatively, a DataFrame can be created for a JSON dataset represented by - // a RDD[String] storing one JSON object per string. - List jsonData = Arrays.asList( - "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); - JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - Dataset peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); - - // Take a look at the schema of this new DataFrame. - peopleFromJsonRDD.printSchema(); - // The schema of anotherPeople is ... - // root - // |-- address: StructType - // | |-- city: StringType - // | |-- state: StringType - // |-- name: StringType - - peopleFromJsonRDD.registerTempTable("people2"); - - Dataset peopleWithCity = sqlContext.sql("SELECT name, address.city FROM people2"); - List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { - @Override - public String call(Row row) { - return "Name: " + row.getString(0) + ", City: " + row.getString(1); - } - }).collect(); - for (String name: nameAndCity) { - System.out.println(name); - } - - ctx.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java new file mode 100644 index 0000000000000..cff9032f52b5a --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQLExample.java @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql; + +// $example on:programmatic_schema$ +import java.util.ArrayList; +import java.util.List; +// $example off:programmatic_schema$ +// $example on:create_ds$ +import java.util.Arrays; +import java.util.Collections; +import java.io.Serializable; +// $example off:create_ds$ + +// $example on:schema_inferring$ +// $example on:programmatic_schema$ +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +// $example off:programmatic_schema$ +// $example on:create_ds$ +import org.apache.spark.api.java.function.MapFunction; +// $example on:create_df$ +// $example on:run_sql$ +// $example on:programmatic_schema$ +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +// $example off:programmatic_schema$ +// $example off:create_df$ +// $example off:run_sql$ +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +// $example off:create_ds$ +// $example off:schema_inferring$ +import org.apache.spark.sql.RowFactory; +// $example on:init_session$ +import org.apache.spark.sql.SparkSession; +// $example off:init_session$ +// $example on:programmatic_schema$ +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +// $example off:programmatic_schema$ + +// $example on:untyped_ops$ +// col("...") is preferable to df.col("...") +import static org.apache.spark.sql.functions.col; +// $example off:untyped_ops$ + +public class JavaSparkSQLExample { + // $example on:create_ds$ + public static class Person implements Serializable { + private String name; + private int age; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getAge() { + return age; + } + + public void setAge(int age) { + this.age = age; + } + } + // $example off:create_ds$ + + public static void main(String[] args) { + // $example on:init_session$ + SparkSession spark = SparkSession + .builder() + .appName("Java Spark SQL basic example") + .config("spark.some.config.option", "some-value") + .getOrCreate(); + // $example off:init_session$ + + runBasicDataFrameExample(spark); + runDatasetCreationExample(spark); + runInferSchemaExample(spark); + runProgrammaticSchemaExample(spark); + + spark.stop(); + } + + private static void runBasicDataFrameExample(SparkSession spark) { + // $example on:create_df$ + Dataset df = spark.read().json("examples/src/main/resources/people.json"); + + // Displays the content of the DataFrame to stdout + df.show(); + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:create_df$ + + // $example on:untyped_ops$ + // Print the schema in a tree format + df.printSchema(); + // root + // |-- age: long (nullable = true) + // |-- name: string (nullable = true) + + // Select only the "name" column + df.select("name").show(); + // +-------+ + // | name| + // +-------+ + // |Michael| + // | Andy| + // | Justin| + // +-------+ + + // Select everybody, but increment the age by 1 + df.select(col("name"), col("age").plus(1)).show(); + // +-------+---------+ + // | name|(age + 1)| + // +-------+---------+ + // |Michael| null| + // | Andy| 31| + // | Justin| 20| + // +-------+---------+ + + // Select people older than 21 + df.filter(col("age").gt(21)).show(); + // +---+----+ + // |age|name| + // +---+----+ + // | 30|Andy| + // +---+----+ + + // Count people by age + df.groupBy("age").count().show(); + // +----+-----+ + // | age|count| + // +----+-----+ + // | 19| 1| + // |null| 1| + // | 30| 1| + // +----+-----+ + // $example off:untyped_ops$ + + // $example on:run_sql$ + // Register the DataFrame as a SQL temporary view + df.createOrReplaceTempView("people"); + + Dataset sqlDF = spark.sql("SELECT * FROM people"); + sqlDF.show(); + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:run_sql$ + } + + private static void runDatasetCreationExample(SparkSession spark) { + // $example on:create_ds$ + // Create an instance of a Bean class + Person person = new Person(); + person.setName("Andy"); + person.setAge(32); + + // Encoders are created for Java beans + Encoder personEncoder = Encoders.bean(Person.class); + Dataset javaBeanDS = spark.createDataset( + Collections.singletonList(person), + personEncoder + ); + javaBeanDS.show(); + // +---+----+ + // |age|name| + // +---+----+ + // | 32|Andy| + // +---+----+ + + // Encoders for most common types are provided in class Encoders + Encoder integerEncoder = Encoders.INT(); + Dataset primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder); + Dataset transformedDS = primitiveDS.map(new MapFunction() { + @Override + public Integer call(Integer value) throws Exception { + return value + 1; + } + }, integerEncoder); + transformedDS.collect(); // Returns [2, 3, 4] + + // DataFrames can be converted to a Dataset by providing a class. Mapping based on name + String path = "examples/src/main/resources/people.json"; + Dataset peopleDS = spark.read().json(path).as(personEncoder); + peopleDS.show(); + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:create_ds$ + } + + private static void runInferSchemaExample(SparkSession spark) { + // $example on:schema_inferring$ + // Create an RDD of Person objects from a text file + JavaRDD peopleRDD = spark.read() + .textFile("examples/src/main/resources/people.txt") + .javaRDD() + .map(new Function() { + @Override + public Person call(String line) throws Exception { + String[] parts = line.split(","); + Person person = new Person(); + person.setName(parts[0]); + person.setAge(Integer.parseInt(parts[1].trim())); + return person; + } + }); + + // Apply a schema to an RDD of JavaBeans to get a DataFrame + Dataset peopleDF = spark.createDataFrame(peopleRDD, Person.class); + // Register the DataFrame as a temporary view + peopleDF.createOrReplaceTempView("people"); + + // SQL statements can be run by using the sql methods provided by spark + Dataset teenagersDF = spark.sql("SELECT name FROM people WHERE age BETWEEN 13 AND 19"); + + // The columns of a row in the result can be accessed by field index + Encoder stringEncoder = Encoders.STRING(); + Dataset teenagerNamesByIndexDF = teenagersDF.map(new MapFunction() { + @Override + public String call(Row row) throws Exception { + return "Name: " + row.getString(0); + } + }, stringEncoder); + teenagerNamesByIndexDF.show(); + // +------------+ + // | value| + // +------------+ + // |Name: Justin| + // +------------+ + + // or by field name + Dataset teenagerNamesByFieldDF = teenagersDF.map(new MapFunction() { + @Override + public String call(Row row) throws Exception { + return "Name: " + row.getAs("name"); + } + }, stringEncoder); + teenagerNamesByFieldDF.show(); + // +------------+ + // | value| + // +------------+ + // |Name: Justin| + // +------------+ + // $example off:schema_inferring$ + } + + private static void runProgrammaticSchemaExample(SparkSession spark) { + // $example on:programmatic_schema$ + // Create an RDD + JavaRDD peopleRDD = spark.sparkContext() + .textFile("examples/src/main/resources/people.txt", 1) + .toJavaRDD(); + + // The schema is encoded in a string + String schemaString = "name age"; + + // Generate the schema based on the string of schema + List fields = new ArrayList<>(); + for (String fieldName : schemaString.split(" ")) { + StructField field = DataTypes.createStructField(fieldName, DataTypes.StringType, true); + fields.add(field); + } + StructType schema = DataTypes.createStructType(fields); + + // Convert records of the RDD (people) to Rows + JavaRDD rowRDD = peopleRDD.map(new Function() { + @Override + public Row call(String record) throws Exception { + String[] attributes = record.split(","); + return RowFactory.create(attributes[0], attributes[1].trim()); + } + }); + + // Apply the schema to the RDD + Dataset peopleDataFrame = spark.createDataFrame(rowRDD, schema); + + // Creates a temporary view using the DataFrame + peopleDataFrame.createOrReplaceTempView("people"); + + // SQL can be run over a temporary view created using DataFrames + Dataset results = spark.sql("SELECT name FROM people"); + + // The results of SQL queries are DataFrames and support all the normal RDD operations + // The columns of a row in the result can be accessed by field index or by field name + Dataset namesDS = results.map(new MapFunction() { + @Override + public String call(Row row) throws Exception { + return "Name: " + row.getString(0); + } + }, Encoders.STRING()); + namesDS.show(); + // +-------------+ + // | value| + // +-------------+ + // |Name: Michael| + // | Name: Andy| + // | Name: Justin| + // +-------------+ + // $example off:programmatic_schema$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java new file mode 100644 index 0000000000000..052153c9e9736 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/hive/JavaSparkHiveExample.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql.hive; + +// $example on:spark_hive$ +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +// $example off:spark_hive$ + +public class JavaSparkHiveExample { + + // $example on:spark_hive$ + public static class Record implements Serializable { + private int key; + private String value; + + public int getKey() { + return key; + } + + public void setKey(int key) { + this.key = key; + } + + public String getValue() { + return value; + } + + public void setValue(String value) { + this.value = value; + } + } + // $example off:spark_hive$ + + public static void main(String[] args) { + // $example on:spark_hive$ + // warehouseLocation points to the default location for managed databases and tables + String warehouseLocation = "spark-warehouse"; + SparkSession spark = SparkSession + .builder() + .appName("Java Spark Hive Example") + .config("spark.sql.warehouse.dir", warehouseLocation) + .enableHiveSupport() + .getOrCreate(); + + spark.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); + spark.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); + + // Queries are expressed in HiveQL + spark.sql("SELECT * FROM src").show(); + // +---+-------+ + // |key| value| + // +---+-------+ + // |238|val_238| + // | 86| val_86| + // |311|val_311| + // ... + + // Aggregation queries are also supported. + spark.sql("SELECT COUNT(*) FROM src").show(); + // +--------+ + // |count(1)| + // +--------+ + // | 500 | + // +--------+ + + // The results of SQL queries are themselves DataFrames and support all normal functions. + Dataset sqlDF = spark.sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key"); + + // The items in DaraFrames are of type Row, which lets you to access each column by ordinal. + Dataset stringsDS = sqlDF.map(new MapFunction() { + @Override + public String call(Row row) throws Exception { + return "Key: " + row.get(0) + ", Value: " + row.get(1); + } + }, Encoders.STRING()); + stringsDS.show(); + // +--------------------+ + // | value| + // +--------------------+ + // |Key: 0, Value: val_0| + // |Key: 0, Value: val_0| + // |Key: 0, Value: val_0| + // ... + + // You can also use DataFrames to create temporary views within a SparkSession. + List records = new ArrayList<>(); + for (int key = 1; key < 100; key++) { + Record record = new Record(); + record.setKey(key); + record.setValue("val_" + key); + records.add(record); + } + Dataset recordsDF = spark.createDataFrame(records, Record.class); + recordsDF.createOrReplaceTempView("records"); + + // Queries can then join DataFrames data with data stored in Hive. + spark.sql("SELECT * FROM records r JOIN src s ON r.key = s.key").show(); + // +---+------+---+------+ + // |key| value|key| value| + // +---+------+---+------+ + // | 2| val_2| 2| val_2| + // | 2| val_2| 2| val_2| + // | 4| val_4| 4| val_4| + // ... + // $example off:spark_hive$ + + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java new file mode 100644 index 0000000000000..346d2182c70b0 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql.streaming; + +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.sql.*; +import org.apache.spark.sql.streaming.StreamingQuery; + +import java.util.Arrays; +import java.util.Iterator; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + * + * Usage: JavaStructuredNetworkWordCount + * and describe the TCP server that Structured Streaming + * would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.JavaStructuredNetworkWordCount + * localhost 9999` + */ +public final class JavaStructuredNetworkWordCount { + + public static void main(String[] args) throws Exception { + if (args.length < 2) { + System.err.println("Usage: JavaStructuredNetworkWordCount "); + System.exit(1); + } + + String host = args[0]; + int port = Integer.parseInt(args[1]); + + SparkSession spark = SparkSession + .builder() + .appName("JavaStructuredNetworkWordCount") + .getOrCreate(); + + // Create DataFrame representing the stream of input lines from connection to host:port + Dataset lines = spark + .readStream() + .format("socket") + .option("host", host) + .option("port", port) + .load().as(Encoders.STRING()); + + // Split the lines into words + Dataset words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterator call(String x) { + return Arrays.asList(x.split(" ")).iterator(); + } + }, Encoders.STRING()); + + // Generate running word count + Dataset wordCounts = words.groupBy("value").count(); + + // Start running the query that prints the running counts to the console + StreamingQuery query = wordCounts.writeStream() + .outputMode("complete") + .format("console") + .start(); + + query.awaitTermination(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java new file mode 100644 index 0000000000000..557d36cff30d7 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCountWindowed.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql.streaming; + +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.sql.*; +import org.apache.spark.sql.functions; +import org.apache.spark.sql.streaming.StreamingQuery; +import scala.Tuple2; + +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network over a + * sliding window of configurable duration. Each line from the network is tagged + * with a timestamp that is used to determine the windows into which it falls. + * + * Usage: JavaStructuredNetworkWordCountWindowed + * [] + * and describe the TCP server that Structured Streaming + * would connect to receive data. + * gives the size of window, specified as integer number of seconds + * gives the amount of time successive windows are offset from one another, + * given in the same units as above. should be less than or equal to + * . If the two are equal, successive windows have no overlap. If + * is not provided, it defaults to . + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.JavaStructuredNetworkWordCountWindowed + * localhost 9999 []` + * + * One recommended , pair is 10, 5 + */ +public final class JavaStructuredNetworkWordCountWindowed { + + public static void main(String[] args) throws Exception { + if (args.length < 3) { + System.err.println("Usage: JavaStructuredNetworkWordCountWindowed " + + " []"); + System.exit(1); + } + + String host = args[0]; + int port = Integer.parseInt(args[1]); + int windowSize = Integer.parseInt(args[2]); + int slideSize = (args.length == 3) ? windowSize : Integer.parseInt(args[3]); + if (slideSize > windowSize) { + System.err.println(" must be less than or equal to "); + } + String windowDuration = windowSize + " seconds"; + String slideDuration = slideSize + " seconds"; + + SparkSession spark = SparkSession + .builder() + .appName("JavaStructuredNetworkWordCountWindowed") + .getOrCreate(); + + // Create DataFrame representing the stream of input lines from connection to host:port + Dataset> lines = spark + .readStream() + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load().as(Encoders.tuple(Encoders.STRING(), Encoders.TIMESTAMP())); + + // Split the lines into words, retaining timestamps + Dataset words = lines.flatMap( + new FlatMapFunction, Tuple2>() { + @Override + public Iterator> call(Tuple2 t) { + List> result = new ArrayList<>(); + for (String word : t._1.split(" ")) { + result.add(new Tuple2<>(word, t._2)); + } + return result.iterator(); + } + }, + Encoders.tuple(Encoders.STRING(), Encoders.TIMESTAMP()) + ).toDF("word", "timestamp"); + + // Group the data by window and word and compute the count of each group + Dataset windowedCounts = words.groupBy( + functions.window(words.col("timestamp"), windowDuration, slideDuration), + words.col("word") + ).count().orderBy("window"); + + // Start running the query that prints the windowed word counts to the console + StreamingQuery query = windowedCounts.writeStream() + .outputMode("complete") + .format("console") + .option("truncate", "false") + .start(); + + query.awaitTermination(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 1cba565b38c2a..e20b94d5b03f2 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -70,7 +70,7 @@ public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf().setAppName("JavaCustomReceiver"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(1000)); - // Create a input stream with the custom receiver on target ip:port and count the + // Create an input stream with the custom receiver on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') JavaReceiverInputDStream lines = ssc.receiverStream( new JavaCustomReceiver(args[0], Integer.parseInt(args[1]))); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index 05631494484be..acbc34524328b 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -29,7 +29,6 @@ import com.google.common.io.Files; -import org.apache.spark.Accumulator; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -41,6 +40,7 @@ import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.util.LongAccumulator; /** * Use this singleton to get or register a Broadcast variable. @@ -67,13 +67,13 @@ public static Broadcast> getInstance(JavaSparkContext jsc) { */ class JavaDroppedWordsCounter { - private static volatile Accumulator instance = null; + private static volatile LongAccumulator instance = null; - public static Accumulator getInstance(JavaSparkContext jsc) { + public static LongAccumulator getInstance(JavaSparkContext jsc) { if (instance == null) { synchronized (JavaDroppedWordsCounter.class) { if (instance == null) { - instance = jsc.accumulator(0, "WordsInBlacklistCounter"); + instance = jsc.sc().longAccumulator("WordsInBlacklistCounter"); } } } @@ -158,7 +158,7 @@ public void call(JavaPairRDD rdd, Time time) throws IOException final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); // Get or register the droppedWordsCounter Accumulator - final Accumulator droppedWordsCounter = + final LongAccumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); // Use blacklist to drop words and use droppedWordsCounter to count them String counts = rdd.filter(new Function, Boolean>() { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 7aa8862761d2b..b8e9e125ba596 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -22,14 +22,13 @@ import java.util.regex.Pattern; import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.VoidFunction2; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.api.java.StorageLevels; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.Time; @@ -82,7 +81,7 @@ public Iterator call(String x) { words.foreachRDD(new VoidFunction2, Time>() { @Override public void call(JavaRDD rdd, Time time) { - SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); + SparkSession spark = JavaSparkSessionSingleton.getInstance(rdd.context().getConf()); // Convert JavaRDD[String] to JavaRDD[bean class] to DataFrame JavaRDD rowRDD = rdd.map(new Function() { @@ -93,14 +92,14 @@ public JavaRecord call(String word) { return record; } }); - Dataset wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRecord.class); + Dataset wordsDataFrame = spark.createDataFrame(rowRDD, JavaRecord.class); - // Register as table - wordsDataFrame.registerTempTable("words"); + // Creates a temporary view using the DataFrame + wordsDataFrame.createOrReplaceTempView("words"); // Do word count on table using SQL and print it Dataset wordCountsDataFrame = - sqlContext.sql("select word, count(*) as total from words group by word"); + spark.sql("select word, count(*) as total from words group by word"); System.out.println("========= " + time + "========="); wordCountsDataFrame.show(); } @@ -111,12 +110,15 @@ public JavaRecord call(String word) { } } -/** Lazily instantiated singleton instance of SQLContext */ -class JavaSQLContextSingleton { - private static transient SQLContext instance = null; - public static SQLContext getInstance(SparkContext sparkContext) { +/** Lazily instantiated singleton instance of SparkSession */ +class JavaSparkSessionSingleton { + private static transient SparkSession instance = null; + public static SparkSession getInstance(SparkConf sparkConf) { if (instance == null) { - instance = new SQLContext(sparkContext); + instance = SparkSession + .builder() + .config(sparkConf) + .getOrCreate(); } return instance; } diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 205ca02962bee..80290e7de9b06 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -17,7 +17,7 @@ """ This is an example implementation of ALS for learning how to use Spark. Please refer to -ALS in pyspark.mllib.recommendation for more conventional use. +pyspark.ml.recommendation.ALS for more conventional use. This example requires numpy (http://www.numpy.org/) """ @@ -28,7 +28,7 @@ import numpy as np from numpy.random import rand from numpy import matrix -from pyspark import SparkContext +from pyspark.sql import SparkSession LAMBDA = 0.01 # regularization np.random.seed(42) @@ -59,10 +59,16 @@ def update(i, vec, mat, ratings): """ print("""WARN: This is a naive implementation of ALS and is given as an - example. Please use the ALS method found in pyspark.mllib.recommendation for more + example. Please use pyspark.ml.recommendation.ALS for more conventional use.""", file=sys.stderr) - sc = SparkContext(appName="PythonALS") + spark = SparkSession\ + .builder\ + .appName("PythonALS")\ + .getOrCreate() + + sc = spark.sparkContext + M = int(sys.argv[1]) if len(sys.argv) > 1 else 100 U = int(sys.argv[2]) if len(sys.argv) > 2 else 500 F = int(sys.argv[3]) if len(sys.argv) > 3 else 10 @@ -99,4 +105,4 @@ def update(i, vec, mat, ratings): print("Iteration %d:" % i) print("\nRMSE: %5.4f\n" % error) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index da368ac628a49..4422f9e7a9589 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -19,8 +19,8 @@ import sys -from pyspark import SparkContext from functools import reduce +from pyspark.sql import SparkSession """ Read data file users.avro in local Spark distro: @@ -64,7 +64,13 @@ exit(-1) path = sys.argv[1] - sc = SparkContext(appName="AvroKeyInputFormat") + + spark = SparkSession\ + .builder\ + .appName("AvroKeyInputFormat")\ + .getOrCreate() + + sc = spark.sparkContext conf = None if len(sys.argv) == 3: @@ -82,4 +88,4 @@ for k in output: print(k) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 0ea7cfb7025a0..92e0a3ae2ee60 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -17,8 +17,8 @@ """ The K-means algorithm written from scratch against PySpark. In practice, -one may prefer to use the KMeans algorithm in MLlib, as shown in -examples/src/main/python/mllib/kmeans.py. +one may prefer to use the KMeans algorithm in ML, as shown in +examples/src/main/python/ml/kmeans_example.py. This example requires NumPy (http://www.numpy.org/). """ @@ -27,7 +27,7 @@ import sys import numpy as np -from pyspark import SparkContext +from pyspark.sql import SparkSession def parseVector(line): @@ -52,11 +52,15 @@ def closestPoint(p, centers): exit(-1) print("""WARN: This is a naive implementation of KMeans Clustering and is given - as an example! Please refer to examples/src/main/python/mllib/kmeans.py for an example on - how to use MLlib's KMeans implementation.""", file=sys.stderr) + as an example! Please refer to examples/src/main/python/ml/kmeans_example.py for an + example on how to use ML's KMeans implementation.""", file=sys.stderr) - sc = SparkContext(appName="PythonKMeans") - lines = sc.textFile(sys.argv[1]) + spark = SparkSession\ + .builder\ + .appName("PythonKMeans")\ + .getOrCreate() + + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) data = lines.map(parseVector).cache() K = int(sys.argv[2]) convergeDist = float(sys.argv[3]) @@ -79,4 +83,4 @@ def closestPoint(p, centers): print("Final centers: " + str(kPoints)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index b318b7d87bfdc..01c938454b108 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -20,14 +20,14 @@ to act on batches of input data using efficient matrix operations. In practice, one may prefer to use the LogisticRegression algorithm in -MLlib, as shown in examples/src/main/python/mllib/logistic_regression.py. +ML, as shown in examples/src/main/python/ml/logistic_regression_with_elastic_net.py. """ from __future__ import print_function import sys import numpy as np -from pyspark import SparkContext +from pyspark.sql import SparkSession D = 10 # Number of dimensions @@ -51,11 +51,17 @@ def readPointBatch(iterator): exit(-1) print("""WARN: This is a naive implementation of Logistic Regression and is - given as an example! Please refer to examples/src/main/python/mllib/logistic_regression.py - to see how MLlib's implementation is used.""", file=sys.stderr) + given as an example! + Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py + to see how ML's implementation is used.""", file=sys.stderr) - sc = SparkContext(appName="PythonLR") - points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache() + spark = SparkSession\ + .builder\ + .appName("PythonLR")\ + .getOrCreate() + + points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\ + .mapPartitions(readPointBatch).cache() iterations = int(sys.argv[2]) # Initialize w to a random value @@ -79,4 +85,4 @@ def add(x, y): print("Final w: " + str(w)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py index 0ee01fd8258df..060f0171ffdb5 100644 --- a/examples/src/main/python/ml/aft_survival_regression.py +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -17,19 +17,26 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.regression import AFTSurvivalRegression -from pyspark.mllib.linalg import Vectors +from pyspark.ml.linalg import Vectors # $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating aft survival regression. +Run with: + bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py +""" if __name__ == "__main__": - sc = SparkContext(appName="AFTSurvivalRegressionExample") - sqlContext = SQLContext(sc) + spark = SparkSession \ + .builder \ + .appName("PythonAFTSurvivalRegressionExample") \ + .getOrCreate() # $example on$ - training = sqlContext.createDataFrame([ + training = spark.createDataFrame([ (1.218, 1.0, Vectors.dense(1.560, -0.605)), (2.949, 0.0, Vectors.dense(0.346, 2.158)), (3.627, 0.0, Vectors.dense(1.380, 0.231)), @@ -48,4 +55,4 @@ model.transform(training).show(truncate=False) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 922173308c6aa..1a979ff5b5be2 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -17,8 +17,11 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext +import sys +if sys.version >= '3': + long = int + +from pyspark.sql import SparkSession # $example on$ from pyspark.ml.evaluation import RegressionEvaluator @@ -27,15 +30,17 @@ # $example off$ if __name__ == "__main__": - sc = SparkContext(appName="ALSExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("ALSExample")\ + .getOrCreate() # $example on$ - lines = sc.textFile("data/mllib/als/sample_movielens_ratings.txt") - parts = lines.map(lambda l: l.split("::")) + lines = spark.read.text("data/mllib/als/sample_movielens_ratings.txt").rdd + parts = lines.map(lambda row: row.value.split("::")) ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]), rating=float(p[2]), timestamp=long(p[3]))) - ratings = sqlContext.createDataFrame(ratingsRDD) + ratings = spark.createDataFrame(ratingsRDD) (training, test) = ratings.randomSplit([0.8, 0.2]) # Build the recommendation model using ALS on the training data @@ -43,13 +48,10 @@ model = als.fit(training) # Evaluate the model by computing the RMSE on the test data - rawPredictions = model.transform(test) - predictions = rawPredictions\ - .withColumn("rating", rawPredictions.rating.cast("double"))\ - .withColumn("prediction", rawPredictions.prediction.cast("double")) - evaluator =\ - RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction") + predictions = model.transform(test) + evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", + predictionCol="prediction") rmse = evaluator.evaluate(predictions) print("Root-mean-square error = " + str(rmse)) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/binarizer_example.py b/examples/src/main/python/ml/binarizer_example.py index 317cfa638a5a9..4224a27dbef0c 100644 --- a/examples/src/main/python/ml/binarizer_example.py +++ b/examples/src/main/python/ml/binarizer_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import Binarizer # $example off$ if __name__ == "__main__": - sc = SparkContext(appName="BinarizerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("BinarizerExample")\ + .getOrCreate() # $example on$ - continuousDataFrame = sqlContext.createDataFrame([ + continuousDataFrame = spark.createDataFrame([ (0, 0.1), (1, 0.8), (2, 0.2) @@ -40,4 +41,4 @@ print(binarized_feature) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/bisecting_k_means_example.py b/examples/src/main/python/ml/bisecting_k_means_example.py index e6f6bfd7e84ed..ee0399ac5eb20 100644 --- a/examples/src/main/python/ml/bisecting_k_means_example.py +++ b/examples/src/main/python/ml/bisecting_k_means_example.py @@ -17,41 +17,40 @@ from __future__ import print_function -from pyspark import SparkContext # $example on$ -from pyspark.ml.clustering import BisectingKMeans, BisectingKMeansModel -from pyspark.mllib.linalg import VectorUDT, _convert_to_vector, Vectors -from pyspark.mllib.linalg import Vectors -from pyspark.sql.types import Row +from pyspark.ml.clustering import BisectingKMeans # $example off$ -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession """ -A simple example demonstrating a bisecting k-means clustering. +An example demonstrating bisecting k-means clustering. +Run with: + bin/spark-submit examples/src/main/python/ml/bisecting_k_means_example.py """ if __name__ == "__main__": - - sc = SparkContext(appName="PythonBisectingKMeansExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("PythonBisectingKMeansExample")\ + .getOrCreate() # $example on$ - data = sc.textFile("data/mllib/kmeans_data.txt") - parsed = data.map(lambda l: Row(features=Vectors.dense([float(x) for x in l.split(' ')]))) - training = sqlContext.createDataFrame(parsed) - - kmeans = BisectingKMeans().setK(2).setSeed(1).setFeaturesCol("features") + # Loads data. + dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") - model = kmeans.fit(training) + # Trains a bisecting k-means model. + bkm = BisectingKMeans().setK(2).setSeed(1) + model = bkm.fit(dataset) - # Evaluate clustering - cost = model.computeCost(training) - print("Bisecting K-means Cost = " + str(cost)) + # Evaluate clustering. + cost = model.computeCost(dataset) + print("Within Set Sum of Squared Errors = " + str(cost)) - centers = model.clusterCenters() + # Shows the result. print("Cluster Centers: ") + centers = model.clusterCenters() for center in centers: print(center) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/bucketizer_example.py b/examples/src/main/python/ml/bucketizer_example.py index 4304255f350db..8177e560ddef1 100644 --- a/examples/src/main/python/ml/bucketizer_example.py +++ b/examples/src/main/python/ml/bucketizer_example.py @@ -17,21 +17,22 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import Bucketizer # $example off$ if __name__ == "__main__": - sc = SparkContext(appName="BucketizerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("BucketizerExample")\ + .getOrCreate() # $example on$ splits = [-float("inf"), -0.5, 0.0, 0.5, float("inf")] data = [(-0.5,), (-0.3,), (0.0,), (0.2,)] - dataFrame = sqlContext.createDataFrame(data, ["features"]) + dataFrame = spark.createDataFrame(data, ["features"]) bucketizer = Bucketizer(splits=splits, inputCol="features", outputCol="bucketedFeatures") @@ -40,4 +41,4 @@ bucketedData.show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/chisq_selector_example.py b/examples/src/main/python/ml/chisq_selector_example.py index 997a504735360..5e19ef1624c7e 100644 --- a/examples/src/main/python/ml/chisq_selector_example.py +++ b/examples/src/main/python/ml/chisq_selector_example.py @@ -17,19 +17,20 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import ChiSqSelector -from pyspark.mllib.linalg import Vectors +from pyspark.ml.linalg import Vectors # $example off$ if __name__ == "__main__": - sc = SparkContext(appName="ChiSqSelectorExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("ChiSqSelectorExample")\ + .getOrCreate() # $example on$ - df = sqlContext.createDataFrame([ + df = spark.createDataFrame([ (7, Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0,), (8, Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0,), (9, Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0,)], ["id", "features", "clicked"]) @@ -41,4 +42,4 @@ result.show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/count_vectorizer_example.py b/examples/src/main/python/ml/count_vectorizer_example.py index e839f645f70b5..38cfac82fbe20 100644 --- a/examples/src/main/python/ml/count_vectorizer_example.py +++ b/examples/src/main/python/ml/count_vectorizer_example.py @@ -17,19 +17,20 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession # $example on$ from pyspark.ml.feature import CountVectorizer # $example off$ if __name__ == "__main__": - sc = SparkContext(appName="CountVectorizerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("CountVectorizerExample")\ + .getOrCreate() # $example on$ # Input data: Each row is a bag of words with a ID. - df = sqlContext.createDataFrame([ + df = spark.createDataFrame([ (0, "a b c".split(" ")), (1, "a b b c a".split(" ")) ], ["id", "words"]) @@ -41,4 +42,4 @@ result.show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py index 5f0ef20218c4a..a41df6cf946fb 100644 --- a/examples/src/main/python/ml/cross_validator.py +++ b/examples/src/main/python/ml/cross_validator.py @@ -17,15 +17,14 @@ from __future__ import print_function -from pyspark import SparkContext # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import BinaryClassificationEvaluator from pyspark.ml.feature import HashingTF, Tokenizer from pyspark.ml.tuning import CrossValidator, ParamGridBuilder -from pyspark.sql import Row, SQLContext # $example off$ +from pyspark.sql import Row, SparkSession """ A simple example demonstrating model selection using CrossValidator. @@ -36,25 +35,26 @@ """ if __name__ == "__main__": - sc = SparkContext(appName="CrossValidatorExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("CrossValidatorExample")\ + .getOrCreate() # $example on$ # Prepare training documents, which are labeled. - LabeledDocument = Row("id", "text", "label") - training = sc.parallelize([(0, "a b c d e spark", 1.0), - (1, "b d", 0.0), - (2, "spark f g h", 1.0), - (3, "hadoop mapreduce", 0.0), - (4, "b spark who", 1.0), - (5, "g d a y", 0.0), - (6, "spark fly", 1.0), - (7, "was mapreduce", 0.0), - (8, "e spark program", 1.0), - (9, "a e c l", 0.0), - (10, "spark compile", 1.0), - (11, "hadoop software", 0.0) - ]) \ - .map(lambda x: LabeledDocument(*x)).toDF() + training = spark.createDataFrame([ + (0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0), + (4, "b spark who", 1.0), + (5, "g d a y", 0.0), + (6, "spark fly", 1.0), + (7, "was mapreduce", 0.0), + (8, "e spark program", 1.0), + (9, "a e c l", 0.0), + (10, "spark compile", 1.0), + (11, "hadoop software", 0.0) + ], ["id", "text", "label"]) # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. tokenizer = Tokenizer(inputCol="text", outputCol="words") @@ -82,12 +82,12 @@ cvModel = crossval.fit(training) # Prepare test documents, which are unlabeled. - Document = Row("id", "text") - test = sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) \ - .map(lambda x: Document(*x)).toDF() + test = spark.createDataFrame([ + (4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop") + ], ["id", "text"]) # Make predictions on test documents. cvModel uses the best model found (lrModel). prediction = cvModel.transform(test) @@ -96,4 +96,4 @@ print(row) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/dataframe_example.py b/examples/src/main/python/ml/dataframe_example.py index d2644ca335654..c1818d72fe467 100644 --- a/examples/src/main/python/ml/dataframe_example.py +++ b/examples/src/main/python/ml/dataframe_example.py @@ -26,16 +26,18 @@ import tempfile import shutil -from pyspark import SparkContext -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession from pyspark.mllib.stat import Statistics +from pyspark.mllib.util import MLUtils if __name__ == "__main__": if len(sys.argv) > 2: print("Usage: dataframe_example.py ", file=sys.stderr) exit(-1) - sc = SparkContext(appName="DataFrameExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("DataFrameExample")\ + .getOrCreate() if len(sys.argv) == 2: input = sys.argv[1] else: @@ -43,7 +45,7 @@ # Load input data print("Loading LIBSVM file with UDT from " + input + ".") - df = sqlContext.read.format("libsvm").load(input).cache() + df = spark.read.format("libsvm").load(input).cache() print("Schema from LIBSVM:") df.printSchema() print("Loaded training data as a DataFrame with " + @@ -54,7 +56,8 @@ labelSummary.show() # Convert features column to an RDD of vectors. - features = df.select("features").map(lambda r: r.features) + features = MLUtils.convertVectorColumnsFromML(df, "features") \ + .select("features").rdd.map(lambda r: r.features) summary = Statistics.colStats(features) print("Selected features column with average values:\n" + str(summary.mean())) @@ -67,9 +70,9 @@ # Load the records back. print("Loading Parquet file with UDT from " + tempdir) - newDF = sqlContext.read.parquet(tempdir) + newDF = spark.read.parquet(tempdir) print("Schema from Parquet:") newDF.printSchema() shutil.rmtree(tempdir) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/dct_example.py b/examples/src/main/python/ml/dct_example.py index 264d47f404cb1..a4f25df784886 100644 --- a/examples/src/main/python/ml/dct_example.py +++ b/examples/src/main/python/ml/dct_example.py @@ -17,19 +17,20 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import DCT -from pyspark.mllib.linalg import Vectors +from pyspark.ml.linalg import Vectors # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="DCTExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("DCTExample")\ + .getOrCreate() # $example on$ - df = sqlContext.createDataFrame([ + df = spark.createDataFrame([ (Vectors.dense([0.0, 1.0, -2.0, 3.0]),), (Vectors.dense([-1.0, 2.0, 4.0, -7.0]),), (Vectors.dense([14.0, -2.0, -5.0, 1.0]),)], ["features"]) @@ -42,4 +43,4 @@ print(dcts) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/decision_tree_classification_example.py b/examples/src/main/python/ml/decision_tree_classification_example.py index 86bdc65392bbb..708f1af6cc6eb 100644 --- a/examples/src/main/python/ml/decision_tree_classification_example.py +++ b/examples/src/main/python/ml/decision_tree_classification_example.py @@ -21,20 +21,22 @@ from __future__ import print_function # $example on$ -from pyspark import SparkContext, SQLContext from pyspark.ml import Pipeline from pyspark.ml.classification import DecisionTreeClassifier from pyspark.ml.feature import StringIndexer, VectorIndexer from pyspark.ml.evaluation import MulticlassClassificationEvaluator # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="decision_tree_classification_example") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("decision_tree_classification_example")\ + .getOrCreate() # $example on$ # Load the data stored in LIBSVM format as a DataFrame. - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Index labels, adding metadata to the label column. # Fit on whole dataset to include all labels in index. @@ -64,7 +66,7 @@ # Select (prediction, true label) and compute test error evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy") accuracy = evaluator.evaluate(predictions) print("Test Error = %g " % (1.0 - accuracy)) @@ -72,3 +74,5 @@ # summary only print(treeModel) # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/decision_tree_regression_example.py b/examples/src/main/python/ml/decision_tree_regression_example.py index 8e20d5d8572a5..58d7ad921d8e0 100644 --- a/examples/src/main/python/ml/decision_tree_regression_example.py +++ b/examples/src/main/python/ml/decision_tree_regression_example.py @@ -20,21 +20,23 @@ """ from __future__ import print_function -from pyspark import SparkContext, SQLContext # $example on$ from pyspark.ml import Pipeline from pyspark.ml.regression import DecisionTreeRegressor from pyspark.ml.feature import VectorIndexer from pyspark.ml.evaluation import RegressionEvaluator # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="decision_tree_classification_example") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("DecisionTreeRegressionExample")\ + .getOrCreate() # $example on$ # Load the data stored in LIBSVM format as a DataFrame. - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Automatically identify categorical features, and index them. # We specify maxCategories so features with > 4 distinct values are treated as continuous. @@ -69,3 +71,5 @@ # summary only print(treeModel) # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py index c85cb0d89543c..590053998bccc 100644 --- a/examples/src/main/python/ml/elementwise_product_example.py +++ b/examples/src/main/python/ml/elementwise_product_example.py @@ -17,23 +17,26 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import ElementwiseProduct -from pyspark.mllib.linalg import Vectors +from pyspark.ml.linalg import Vectors # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="ElementwiseProductExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("ElementwiseProductExample")\ + .getOrCreate() # $example on$ + # Create some vector data; also works for sparse vectors data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)] - df = sqlContext.createDataFrame(data, ["vector"]) + df = spark.createDataFrame(data, ["vector"]) transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]), inputCol="vector", outputCol="transformedVector") + # Batch transform the vectors to create new column: transformer.transform(df).show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/estimator_transformer_param_example.py b/examples/src/main/python/ml/estimator_transformer_param_example.py index 9a8993dac4f65..3bd3fd30f8e57 100644 --- a/examples/src/main/python/ml/estimator_transformer_param_example.py +++ b/examples/src/main/python/ml/estimator_transformer_param_example.py @@ -18,20 +18,22 @@ """ Estimator Transformer Param Example. """ -from pyspark import SparkContext, SQLContext + # $example on$ -from pyspark.mllib.linalg import Vectors +from pyspark.ml.linalg import Vectors from pyspark.ml.classification import LogisticRegression # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - - sc = SparkContext(appName="EstimatorTransformerParamExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("EstimatorTransformerParamExample")\ + .getOrCreate() # $example on$ # Prepare training data from a list of (label, features) tuples. - training = sqlContext.createDataFrame([ + training = spark.createDataFrame([ (1.0, Vectors.dense([0.0, 1.1, 0.1])), (0.0, Vectors.dense([2.0, 1.0, -1.0])), (0.0, Vectors.dense([2.0, 1.3, 1.0])), @@ -69,7 +71,7 @@ print model2.extractParamMap() # Prepare test data - test = sqlContext.createDataFrame([ + test = spark.createDataFrame([ (1.0, Vectors.dense([-1.0, 1.5, 1.3])), (0.0, Vectors.dense([3.0, 2.0, -0.1])), (1.0, Vectors.dense([0.0, 2.2, -1.5]))], ["label", "features"]) @@ -84,4 +86,4 @@ print row # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/gaussian_mixture_example.py b/examples/src/main/python/ml/gaussian_mixture_example.py new file mode 100644 index 0000000000000..2ca13d68f6890 --- /dev/null +++ b/examples/src/main/python/ml/gaussian_mixture_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.clustering import GaussianMixture +# $example off$ +from pyspark.sql import SparkSession + +""" +A simple example demonstrating Gaussian Mixture Model (GMM). +Run with: + bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py +""" + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("PythonGuassianMixtureExample")\ + .getOrCreate() + + # $example on$ + # loads data + dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") + + gmm = GaussianMixture().setK(2) + model = gmm.fit(dataset) + + print("Gaussians: ") + model.gaussiansDF.show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/generalized_linear_regression_example.py b/examples/src/main/python/ml/generalized_linear_regression_example.py new file mode 100644 index 0000000000000..796752a60f3ab --- /dev/null +++ b/examples/src/main/python/ml/generalized_linear_regression_example.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on$ +from pyspark.ml.regression import GeneralizedLinearRegression +# $example off$ + +""" +An example demonstrating generalized linear regression. +Run with: + bin/spark-submit examples/src/main/python/ml/generalized_linear_regression_example.py +""" + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("GeneralizedLinearRegressionExample")\ + .getOrCreate() + + # $example on$ + # Load training data + dataset = spark.read.format("libsvm")\ + .load("data/mllib/sample_linear_regression_data.txt") + + glr = GeneralizedLinearRegression(family="gaussian", link="identity", maxIter=10, regParam=0.3) + + # Fit the model + model = glr.fit(dataset) + + # Print the coefficients and intercept for generalized linear regression model + print("Coefficients: " + str(model.coefficients)) + print("Intercept: " + str(model.intercept)) + + # Summarize the model over the training set and print out some metrics + summary = model.summary + print("Coefficient Standard Errors: " + str(summary.coefficientStandardErrors)) + print("T Values: " + str(summary.tValues)) + print("P Values: " + str(summary.pValues)) + print("Dispersion: " + str(summary.dispersion)) + print("Null Deviance: " + str(summary.nullDeviance)) + print("Residual Degree Of Freedom Null: " + str(summary.residualDegreeOfFreedomNull)) + print("Deviance: " + str(summary.deviance)) + print("Residual Degree Of Freedom: " + str(summary.residualDegreeOfFreedom)) + print("AIC: " + str(summary.aic)) + print("Deviance Residuals: ") + summary.residuals().show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py index f7e842f4b303a..6c2d7e7b810df 100644 --- a/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py +++ b/examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py @@ -20,21 +20,23 @@ """ from __future__ import print_function -from pyspark import SparkContext, SQLContext # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import GBTClassifier from pyspark.ml.feature import StringIndexer, VectorIndexer from pyspark.ml.evaluation import MulticlassClassificationEvaluator # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="gradient_boosted_tree_classifier_example") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("gradient_boosted_tree_classifier_example")\ + .getOrCreate() # $example on$ # Load and parse the data file, converting it to a DataFrame. - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Index labels, adding metadata to the label column. # Fit on whole dataset to include all labels in index. @@ -64,7 +66,7 @@ # Select (prediction, true label) and compute test error evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy") accuracy = evaluator.evaluate(predictions) print("Test Error = %g" % (1.0 - accuracy)) @@ -72,4 +74,4 @@ print(gbtModel) # summary only # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py index f8b4de651c768..5dd2272748d70 100644 --- a/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py +++ b/examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py @@ -20,21 +20,23 @@ """ from __future__ import print_function -from pyspark import SparkContext, SQLContext # $example on$ from pyspark.ml import Pipeline from pyspark.ml.regression import GBTRegressor from pyspark.ml.feature import VectorIndexer from pyspark.ml.evaluation import RegressionEvaluator # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="gradient_boosted_tree_regressor_example") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("gradient_boosted_tree_regressor_example")\ + .getOrCreate() # $example on$ # Load and parse the data file, converting it to a DataFrame. - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Automatically identify categorical features, and index them. # Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -69,4 +71,4 @@ print(gbtModel) # summary only # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/index_to_string_example.py b/examples/src/main/python/ml/index_to_string_example.py index fb0ba2950bbd6..523caac00c18a 100644 --- a/examples/src/main/python/ml/index_to_string_example.py +++ b/examples/src/main/python/ml/index_to_string_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext # $example on$ from pyspark.ml.feature import IndexToString, StringIndexer # $example off$ -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="IndexToStringExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("IndexToStringExample")\ + .getOrCreate() # $example on$ - df = sqlContext.createDataFrame( + df = spark.createDataFrame( [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], ["id", "category"]) @@ -42,4 +43,4 @@ converted.select("id", "originalCategory").show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/isotonic_regression_example.py b/examples/src/main/python/ml/isotonic_regression_example.py new file mode 100644 index 0000000000000..1e61bd8eff143 --- /dev/null +++ b/examples/src/main/python/ml/isotonic_regression_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Isotonic Regression Example. +""" +from __future__ import print_function + +# $example on$ +from pyspark.ml.regression import IsotonicRegression, IsotonicRegressionModel +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating isotonic regression. +Run with: + bin/spark-submit examples/src/main/python/ml/isotonic_regression_example.py +""" +if __name__ == "__main__": + + spark = SparkSession\ + .builder\ + .appName("PythonIsotonicRegressionExample")\ + .getOrCreate() + + # $example on$ + # Loads data. + dataset = spark.read.format("libsvm")\ + .load("data/mllib/sample_isotonic_regression_libsvm_data.txt") + + # Trains an isotonic regression model. + model = IsotonicRegression().fit(dataset) + print("Boundaries in increasing order: " + str(model.boundaries)) + print("Predictions associated with the boundaries: " + str(model.predictions)) + + # Makes predictions. + model.transform(dataset).show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py index fa57a4d3ada1b..4b8b7291f9188 100644 --- a/examples/src/main/python/ml/kmeans_example.py +++ b/examples/src/main/python/ml/kmeans_example.py @@ -17,54 +17,45 @@ from __future__ import print_function -import sys +# $example on$ +from pyspark.ml.clustering import KMeans +# $example off$ -import numpy as np -from pyspark import SparkContext -from pyspark.ml.clustering import KMeans, KMeansModel -from pyspark.mllib.linalg import VectorUDT, _convert_to_vector -from pyspark.sql import SQLContext -from pyspark.sql.types import Row, StructField, StructType +from pyspark.sql import SparkSession """ -A simple example demonstrating a k-means clustering. +An example demonstrating k-means clustering. Run with: - bin/spark-submit examples/src/main/python/ml/kmeans_example.py + bin/spark-submit examples/src/main/python/ml/kmeans_example.py This example requires NumPy (http://www.numpy.org/). """ -def parseVector(line): - array = np.array([float(x) for x in line.split(' ')]) - return _convert_to_vector(array) - - if __name__ == "__main__": - FEATURES_COL = "features" + spark = SparkSession\ + .builder\ + .appName("PythonKMeansExample")\ + .getOrCreate() - if len(sys.argv) != 3: - print("Usage: kmeans_example.py ", file=sys.stderr) - exit(-1) - path = sys.argv[1] - k = sys.argv[2] + # $example on$ + # Loads data. + dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") - sc = SparkContext(appName="PythonKMeansExample") - sqlContext = SQLContext(sc) + # Trains a k-means model. + kmeans = KMeans().setK(2).setSeed(1) + model = kmeans.fit(dataset) - lines = sc.textFile(path) - data = lines.map(parseVector) - row_rdd = data.map(lambda x: Row(x)) - schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)]) - df = sqlContext.createDataFrame(row_rdd, schema) + # Evaluate clustering by computing Within Set Sum of Squared Errors. + wssse = model.computeCost(dataset) + print("Within Set Sum of Squared Errors = " + str(wssse)) - kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL) - model = kmeans.fit(df) + # Shows the result. centers = model.clusterCenters() - print("Cluster Centers: ") for center in centers: print(center) + # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py new file mode 100644 index 0000000000000..5ce810fccc6fb --- /dev/null +++ b/examples/src/main/python/ml/lda_example.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from __future__ import print_function + +# $example on$ +from pyspark.ml.clustering import LDA +# $example off$ +from pyspark.sql import SparkSession + + +""" +An example demonstrating LDA. +Run with: + bin/spark-submit examples/src/main/python/ml/lda_example.py +""" + + +if __name__ == "__main__": + # Creates a SparkSession + spark = SparkSession \ + .builder \ + .appName("LDAExample") \ + .getOrCreate() + + # $example on$ + # Loads data. + dataset = spark.read.format("libsvm").load("data/mllib/sample_lda_libsvm_data.txt") + + # Trains a LDA model. + lda = LDA(k=10, maxIter=10) + model = lda.fit(dataset) + + ll = model.logLikelihood(dataset) + lp = model.logPerplexity(dataset) + print("The lower bound on the log likelihood of the entire corpus: " + str(ll)) + print("The upper bound bound on perplexity: " + str(lp)) + + # Describe topics. + topics = model.describeTopics(3) + print("The topics described by their top-weighted terms:") + topics.show(truncate=False) + + # Shows the result + transformed = model.transform(dataset) + transformed.show(truncate=False) + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/linear_regression_with_elastic_net.py b/examples/src/main/python/ml/linear_regression_with_elastic_net.py index a4cd40cf26726..620ab5b87e594 100644 --- a/examples/src/main/python/ml/linear_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/linear_regression_with_elastic_net.py @@ -17,19 +17,20 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.regression import LinearRegression # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="LinearRegressionWithElasticNet") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("LinearRegressionWithElasticNet")\ + .getOrCreate() # $example on$ # Load training data - training = sqlContext.read.format("libsvm")\ + training = spark.read.format("libsvm")\ .load("data/mllib/sample_linear_regression_data.txt") lr = LinearRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) @@ -42,4 +43,4 @@ print("Intercept: " + str(lrModel.intercept)) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py index b0b1d27e13bb0..33d0689f75cd5 100644 --- a/examples/src/main/python/ml/logistic_regression_with_elastic_net.py +++ b/examples/src/main/python/ml/logistic_regression_with_elastic_net.py @@ -17,19 +17,20 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.classification import LogisticRegression # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="LogisticRegressionWithElasticNet") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("LogisticRegressionWithElasticNet")\ + .getOrCreate() # $example on$ # Load training data - training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) @@ -41,4 +42,4 @@ print("Intercept: " + str(lrModel.intercept)) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/max_abs_scaler_example.py b/examples/src/main/python/ml/max_abs_scaler_example.py index d9b69eef1cd84..ab91198b083d1 100644 --- a/examples/src/main/python/ml/max_abs_scaler_example.py +++ b/examples/src/main/python/ml/max_abs_scaler_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import MaxAbsScaler # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="MaxAbsScalerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("MaxAbsScalerExample")\ + .getOrCreate() # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") scaler = MaxAbsScaler(inputCol="features", outputCol="scaledFeatures") @@ -40,4 +41,4 @@ scaledData.show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/min_max_scaler_example.py b/examples/src/main/python/ml/min_max_scaler_example.py index 2f8e4ade468b9..e3e7bc205b1ec 100644 --- a/examples/src/main/python/ml/min_max_scaler_example.py +++ b/examples/src/main/python/ml/min_max_scaler_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import MinMaxScaler # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="MinMaxScalerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("MinMaxScalerExample")\ + .getOrCreate() # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") scaler = MinMaxScaler(inputCol="features", outputCol="scaledFeatures") @@ -40,4 +41,4 @@ scaledData.show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/multilayer_perceptron_classification.py b/examples/src/main/python/ml/multilayer_perceptron_classification.py index f84588f547fff..aa33bef5a3ddd 100644 --- a/examples/src/main/python/ml/multilayer_perceptron_classification.py +++ b/examples/src/main/python/ml/multilayer_perceptron_classification.py @@ -17,21 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.classification import MultilayerPerceptronClassifier from pyspark.ml.evaluation import MulticlassClassificationEvaluator # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - - sc = SparkContext(appName="multilayer_perceptron_classification_example") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder.appName("multilayer_perceptron_classification_example").getOrCreate() # $example on$ # Load training data - data = sqlContext.read.format("libsvm")\ + data = spark.read.format("libsvm")\ .load("data/mllib/sample_multiclass_classification_data.txt") # Split the data into train and test splits = data.randomSplit([0.6, 0.4], 1234) @@ -45,11 +43,11 @@ trainer = MultilayerPerceptronClassifier(maxIter=100, layers=layers, blockSize=128, seed=1234) # train the model model = trainer.fit(train) - # compute precision on the test set + # compute accuracy on the test set result = model.transform(test) predictionAndLabels = result.select("prediction", "label") - evaluator = MulticlassClassificationEvaluator(metricName="precision") - print("Precision:" + str(evaluator.evaluate(predictionAndLabels))) + evaluator = MulticlassClassificationEvaluator(metricName="accuracy") + print("Accuracy: " + str(evaluator.evaluate(predictionAndLabels))) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/n_gram_example.py b/examples/src/main/python/ml/n_gram_example.py index f2d85f53e7219..9ac07f2c8ee20 100644 --- a/examples/src/main/python/ml/n_gram_example.py +++ b/examples/src/main/python/ml/n_gram_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import NGram # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="NGramExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("NGramExample")\ + .getOrCreate() # $example on$ - wordDataFrame = sqlContext.createDataFrame([ + wordDataFrame = spark.createDataFrame([ (0, ["Hi", "I", "heard", "about", "Spark"]), (1, ["I", "wish", "Java", "could", "use", "case", "classes"]), (2, ["Logistic", "regression", "models", "are", "neat"]) @@ -39,4 +40,4 @@ print(ngrams_label) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/naive_bayes_example.py b/examples/src/main/python/ml/naive_bayes_example.py index db8fbea9bf9b1..8bc32222fe32b 100644 --- a/examples/src/main/python/ml/naive_bayes_example.py +++ b/examples/src/main/python/ml/naive_bayes_example.py @@ -17,21 +17,21 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.classification import NaiveBayes from pyspark.ml.evaluation import MulticlassClassificationEvaluator # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - - sc = SparkContext(appName="naive_bayes_example") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("naive_bayes_example")\ + .getOrCreate() # $example on$ # Load training data - data = sqlContext.read.format("libsvm") \ + data = spark.read.format("libsvm") \ .load("data/mllib/sample_libsvm_data.txt") # Split the data into train and test splits = data.randomSplit([0.6, 0.4], 1234) @@ -43,11 +43,11 @@ # train the model model = nb.fit(train) - # compute precision on the test set + # compute accuracy on the test set result = model.transform(test) predictionAndLabels = result.select("prediction", "label") - evaluator = MulticlassClassificationEvaluator(metricName="precision") - print("Precision:" + str(evaluator.evaluate(predictionAndLabels))) + evaluator = MulticlassClassificationEvaluator(metricName="accuracy") + print("Accuracy: " + str(evaluator.evaluate(predictionAndLabels))) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/normalizer_example.py b/examples/src/main/python/ml/normalizer_example.py index d490221474c24..19012f51f4023 100644 --- a/examples/src/main/python/ml/normalizer_example.py +++ b/examples/src/main/python/ml/normalizer_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import Normalizer # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="NormalizerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("NormalizerExample")\ + .getOrCreate() # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Normalize each Vector using $L^1$ norm. normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0) @@ -40,4 +41,4 @@ lInfNormData.show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/one_vs_rest_example.py b/examples/src/main/python/ml/one_vs_rest_example.py new file mode 100644 index 0000000000000..b82087bebae8a --- /dev/null +++ b/examples/src/main/python/ml/one_vs_rest_example.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.classification import LogisticRegression, OneVsRest +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ +from pyspark.sql import SparkSession + +""" +An example of Multiclass to Binary Reduction with One Vs Rest, +using Logistic Regression as the base classifier. +Run with: + bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py +""" + + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("PythonOneVsRestExample") \ + .getOrCreate() + + # $example on$ + # load data file. + inputData = spark.read.format("libsvm") \ + .load("data/mllib/sample_multiclass_classification_data.txt") + + # generate the train/test split. + (train, test) = inputData.randomSplit([0.8, 0.2]) + + # instantiate the base classifier. + lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True) + + # instantiate the One Vs Rest Classifier. + ovr = OneVsRest(classifier=lr) + + # train the multiclass model. + ovrModel = ovr.fit(train) + + # score the model on test data. + predictions = ovrModel.transform(test) + + # obtain evaluator. + evaluator = MulticlassClassificationEvaluator(metricName="accuracy") + + # compute the classification error on test data. + accuracy = evaluator.evaluate(predictions) + print("Test Error : " + str(1 - accuracy)) + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/onehot_encoder_example.py b/examples/src/main/python/ml/onehot_encoder_example.py index 0f94c26638d35..b9fceef68e703 100644 --- a/examples/src/main/python/ml/onehot_encoder_example.py +++ b/examples/src/main/python/ml/onehot_encoder_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import OneHotEncoder, StringIndexer # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="OneHotEncoderExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("OneHotEncoderExample")\ + .getOrCreate() # $example on$ - df = sqlContext.createDataFrame([ + df = spark.createDataFrame([ (0, "a"), (1, "b"), (2, "c"), @@ -45,4 +46,4 @@ encoded.select("id", "categoryVec").show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/pca_example.py b/examples/src/main/python/ml/pca_example.py index a17181f1b8a51..414629ff88bf9 100644 --- a/examples/src/main/python/ml/pca_example.py +++ b/examples/src/main/python/ml/pca_example.py @@ -17,26 +17,27 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import PCA -from pyspark.mllib.linalg import Vectors +from pyspark.ml.linalg import Vectors # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="PCAExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("PCAExample")\ + .getOrCreate() # $example on$ data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] - df = sqlContext.createDataFrame(data, ["features"]) + df = spark.createDataFrame(data, ["features"]) pca = PCA(k=3, inputCol="features", outputCol="pcaFeatures") model = pca.fit(df) result = model.transform(df).select("pcaFeatures") result.show(truncate=False) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/pipeline_example.py b/examples/src/main/python/ml/pipeline_example.py index 3288568f0c287..bd10cfd7a252b 100644 --- a/examples/src/main/python/ml/pipeline_example.py +++ b/examples/src/main/python/ml/pipeline_example.py @@ -18,21 +18,23 @@ """ Pipeline Example. """ -from pyspark import SparkContext, SQLContext + # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import HashingTF, Tokenizer # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - - sc = SparkContext(appName="PipelineExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("PipelineExample")\ + .getOrCreate() # $example on$ # Prepare training documents from a list of (id, text, label) tuples. - training = sqlContext.createDataFrame([ + training = spark.createDataFrame([ (0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), @@ -48,7 +50,7 @@ model = pipeline.fit(training) # Prepare test documents, which are unlabeled (id, text) tuples. - test = sqlContext.createDataFrame([ + test = spark.createDataFrame([ (4L, "spark i j k"), (5L, "l m n"), (6L, "mapreduce spark"), @@ -61,4 +63,4 @@ print(row) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py index 89f5cbe8f2f41..b46c1ba2f4391 100644 --- a/examples/src/main/python/ml/polynomial_expansion_example.py +++ b/examples/src/main/python/ml/polynomial_expansion_example.py @@ -17,27 +17,28 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import PolynomialExpansion -from pyspark.mllib.linalg import Vectors +from pyspark.ml.linalg import Vectors # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="PolynomialExpansionExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("PolynomialExpansionExample")\ + .getOrCreate() # $example on$ - df = sqlContext\ + df = spark\ .createDataFrame([(Vectors.dense([-2.0, 2.3]),), (Vectors.dense([0.0, 0.0]),), (Vectors.dense([0.6, -1.1]),)], ["features"]) - px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures") + px = PolynomialExpansion(degree=3, inputCol="features", outputCol="polyFeatures") polyDF = px.transform(df) for expanded in polyDF.select("polyFeatures").take(3): print(expanded) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/quantile_discretizer_example.py b/examples/src/main/python/ml/quantile_discretizer_example.py new file mode 100644 index 0000000000000..6f422f840ad28 --- /dev/null +++ b/examples/src/main/python/ml/quantile_discretizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.feature import QuantileDiscretizer +# $example off$ +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + spark = SparkSession.builder.appName("QuantileDiscretizerExample").getOrCreate() + + # $example on$ + data = [(0, 18.0,), (1, 19.0,), (2, 8.0,), (3, 5.0,), (4, 2.2,)] + df = spark.createDataFrame(data, ["id", "hour"]) + # $example off$ + # Output of QuantileDiscretizer for such small datasets can depend on the number of + # partitions. Here we force a single partition to ensure consistent results. + # Note this is not necessary for normal use cases + df = df.repartition(1) + # $example on$ + discretizer = QuantileDiscretizer(numBuckets=3, inputCol="hour", outputCol="result") + + result = discretizer.fit(df).transform(df) + result.show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/ml/random_forest_classifier_example.py b/examples/src/main/python/ml/random_forest_classifier_example.py index c3570438c51d9..eb9ded9af555e 100644 --- a/examples/src/main/python/ml/random_forest_classifier_example.py +++ b/examples/src/main/python/ml/random_forest_classifier_example.py @@ -20,21 +20,23 @@ """ from __future__ import print_function -from pyspark import SparkContext, SQLContext # $example on$ from pyspark.ml import Pipeline from pyspark.ml.classification import RandomForestClassifier from pyspark.ml.feature import StringIndexer, VectorIndexer from pyspark.ml.evaluation import MulticlassClassificationEvaluator # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="random_forest_classifier_example") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("random_forest_classifier_example")\ + .getOrCreate() # $example on$ # Load and parse the data file, converting it to a DataFrame. - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Index labels, adding metadata to the label column. # Fit on whole dataset to include all labels in index. @@ -48,7 +50,7 @@ (trainingData, testData) = data.randomSplit([0.7, 0.3]) # Train a RandomForest model. - rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", numTrees=10) # Chain indexers and forest in a Pipeline pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf]) @@ -64,7 +66,7 @@ # Select (prediction, true label) and compute test error evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy") accuracy = evaluator.evaluate(predictions) print("Test Error = %g" % (1.0 - accuracy)) @@ -72,4 +74,4 @@ print(rfModel) # summary only # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/random_forest_regressor_example.py b/examples/src/main/python/ml/random_forest_regressor_example.py index b77014f379237..3a793737dba89 100644 --- a/examples/src/main/python/ml/random_forest_regressor_example.py +++ b/examples/src/main/python/ml/random_forest_regressor_example.py @@ -20,21 +20,23 @@ """ from __future__ import print_function -from pyspark import SparkContext, SQLContext # $example on$ from pyspark.ml import Pipeline from pyspark.ml.regression import RandomForestRegressor from pyspark.ml.feature import VectorIndexer from pyspark.ml.evaluation import RegressionEvaluator # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="random_forest_regressor_example") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("random_forest_regressor_example")\ + .getOrCreate() # $example on$ # Load and parse the data file, converting it to a DataFrame. - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") # Automatically identify categorical features, and index them. # Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -69,4 +71,4 @@ print(rfModel) # summary only # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/rformula_example.py b/examples/src/main/python/ml/rformula_example.py index b544a14700762..d5df3ce4f5915 100644 --- a/examples/src/main/python/ml/rformula_example.py +++ b/examples/src/main/python/ml/rformula_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import RFormula # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="RFormulaExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("RFormulaExample")\ + .getOrCreate() # $example on$ - dataset = sqlContext.createDataFrame( + dataset = spark.createDataFrame( [(7, "US", 18, 1.0), (8, "CA", 12, 0.0), (9, "NZ", 15, 0.0)], @@ -41,4 +42,4 @@ output.select("features", "label").show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py index 2d6d115d54d02..2f1eaa6f947f0 100644 --- a/examples/src/main/python/ml/simple_params_example.py +++ b/examples/src/main/python/ml/simple_params_example.py @@ -20,11 +20,9 @@ import pprint import sys -from pyspark import SparkContext from pyspark.ml.classification import LogisticRegression -from pyspark.mllib.linalg import DenseVector -from pyspark.mllib.regression import LabeledPoint -from pyspark.sql import SQLContext +from pyspark.ml.linalg import DenseVector +from pyspark.sql import Row, SparkSession """ A simple example demonstrating ways to specify parameters for Estimators and Transformers. @@ -33,21 +31,20 @@ """ if __name__ == "__main__": - if len(sys.argv) > 1: - print("Usage: simple_params_example", file=sys.stderr) - exit(1) - sc = SparkContext(appName="PythonSimpleParamsExample") - sqlContext = SQLContext(sc) + spark = SparkSession \ + .builder \ + .appName("SimpleParamsExample") \ + .getOrCreate() # prepare training data. # We create an RDD of LabeledPoints and convert them into a DataFrame. # A LabeledPoint is an Object with two fields named label and features # and Spark SQL identifies these fields and creates the schema appropriately. - training = sc.parallelize([ - LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), - LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), - LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])), - LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF() + training = spark.createDataFrame([ + Row(label=1.0, features=DenseVector([0.0, 1.1, 0.1])), + Row(label=0.0, features=DenseVector([2.0, 1.0, -1.0])), + Row(label=0.0, features=DenseVector([2.0, 1.3, 1.0])), + Row(label=1.0, features=DenseVector([0.0, 1.2, -0.5]))]) # Create a LogisticRegression instance with maxIter = 10. # This instance is an Estimator. @@ -70,7 +67,7 @@ # We may alternatively specify parameters using a parameter map. # paramMap overrides all lr parameters set earlier. - paramMap = {lr.maxIter: 20, lr.thresholds: [0.45, 0.55], lr.probabilityCol: "myProbability"} + paramMap = {lr.maxIter: 20, lr.thresholds: [0.5, 0.5], lr.probabilityCol: "myProbability"} # Now learn a new model using the new parameters. model2 = lr.fit(training, paramMap) @@ -78,10 +75,10 @@ pprint.pprint(model2.extractParamMap()) # prepare test data. - test = sc.parallelize([ - LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])), - LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])), - LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF() + test = spark.createDataFrame([ + Row(label=1.0, features=DenseVector([-1.0, 1.5, 1.3])), + Row(label=0.0, features=DenseVector([3.0, 2.0, -0.1])), + Row(label=0.0, features=DenseVector([0.0, 2.2, -1.5]))]) # Make predictions on test data using the Transformer.transform() method. # LogisticRegressionModel.transform will only use the 'features' column. @@ -95,4 +92,4 @@ print("features=%s,label=%s -> prob=%s, prediction=%s" % (row.features, row.label, row.myProbability, row.prediction)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index b4f06bf888746..b528b59be9621 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -17,11 +17,10 @@ from __future__ import print_function -from pyspark import SparkContext from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import HashingTF, Tokenizer -from pyspark.sql import Row, SQLContext +from pyspark.sql import Row, SparkSession """ @@ -34,20 +33,22 @@ if __name__ == "__main__": - sc = SparkContext(appName="SimpleTextClassificationPipeline") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("SimpleTextClassificationPipeline")\ + .getOrCreate() # Prepare training documents, which are labeled. - LabeledDocument = Row("id", "text", "label") - training = sc.parallelize([(0, "a b c d e spark", 1.0), - (1, "b d", 0.0), - (2, "spark f g h", 1.0), - (3, "hadoop mapreduce", 0.0)]) \ - .map(lambda x: LabeledDocument(*x)).toDF() + training = spark.createDataFrame([ + (0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0) + ], ["id", "text", "label"]) # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. tokenizer = Tokenizer(inputCol="text", outputCol="words") - hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + hashingTF = HashingTF(numFeatures=1000, inputCol=tokenizer.getOutputCol(), outputCol="features") lr = LogisticRegression(maxIter=10, regParam=0.001) pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) @@ -55,12 +56,12 @@ model = pipeline.fit(training) # Prepare test documents, which are unlabeled. - Document = Row("id", "text") - test = sc.parallelize([(4, "spark i j k"), - (5, "l m n"), - (6, "spark hadoop spark"), - (7, "apache hadoop")]) \ - .map(lambda x: Document(*x)).toDF() + test = spark.createDataFrame([ + (4, "spark i j k"), + (5, "l m n"), + (6, "spark hadoop spark"), + (7, "apache hadoop") + ], ["id", "text"]) # Make predictions on test documents and print columns of interest. prediction = model.transform(test) @@ -68,4 +69,4 @@ for row in selected.collect(): print(row) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/sql_transformer.py b/examples/src/main/python/ml/sql_transformer.py index 9575d728d8159..0bf8f35720c95 100644 --- a/examples/src/main/python/ml/sql_transformer.py +++ b/examples/src/main/python/ml/sql_transformer.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext # $example on$ from pyspark.ml.feature import SQLTransformer # $example off$ -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="SQLTransformerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("SQLTransformerExample")\ + .getOrCreate() # $example on$ - df = sqlContext.createDataFrame([ + df = spark.createDataFrame([ (0, 1.0, 3.0), (2, 2.0, 5.0) ], ["id", "v1", "v2"]) @@ -37,4 +38,4 @@ sqlTrans.transform(df).show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/standard_scaler_example.py b/examples/src/main/python/ml/standard_scaler_example.py index ae7aa85005bcd..c0027480e69b3 100644 --- a/examples/src/main/python/ml/standard_scaler_example.py +++ b/examples/src/main/python/ml/standard_scaler_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import StandardScaler # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="StandardScalerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("StandardScalerExample")\ + .getOrCreate() # $example on$ - dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=False) @@ -40,4 +41,4 @@ scaledData.show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/stopwords_remover_example.py b/examples/src/main/python/ml/stopwords_remover_example.py index 01f94af8ca752..395fdeffc5379 100644 --- a/examples/src/main/python/ml/stopwords_remover_example.py +++ b/examples/src/main/python/ml/stopwords_remover_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import StopWordsRemover # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="StopWordsRemoverExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("StopWordsRemoverExample")\ + .getOrCreate() # $example on$ - sentenceData = sqlContext.createDataFrame([ + sentenceData = spark.createDataFrame([ (0, ["I", "saw", "the", "red", "baloon"]), (1, ["Mary", "had", "a", "little", "lamb"]) ], ["label", "raw"]) @@ -37,4 +38,4 @@ remover.transform(sentenceData).show(truncate=False) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/string_indexer_example.py b/examples/src/main/python/ml/string_indexer_example.py index 58a8cb5d56b73..a328e040f5636 100644 --- a/examples/src/main/python/ml/string_indexer_example.py +++ b/examples/src/main/python/ml/string_indexer_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import StringIndexer # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="StringIndexerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("StringIndexerExample")\ + .getOrCreate() # $example on$ - df = sqlContext.createDataFrame( + df = spark.createDataFrame( [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")], ["id", "category"]) indexer = StringIndexer(inputCol="category", outputCol="categoryIndex") @@ -36,4 +37,4 @@ indexed.show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/tf_idf_example.py b/examples/src/main/python/ml/tf_idf_example.py index 141324d458530..fb4ad992fb809 100644 --- a/examples/src/main/python/ml/tf_idf_example.py +++ b/examples/src/main/python/ml/tf_idf_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext # $example on$ from pyspark.ml.feature import HashingTF, IDF, Tokenizer # $example off$ -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="TfIdfExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("TfIdfExample")\ + .getOrCreate() # $example on$ - sentenceData = sqlContext.createDataFrame([ + sentenceData = spark.createDataFrame([ (0, "Hi I heard about Spark"), (0, "I wish Java could use case classes"), (1, "Logistic regression models are neat") @@ -46,4 +47,4 @@ print(features_label) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/tokenizer_example.py b/examples/src/main/python/ml/tokenizer_example.py index ce9b225be5357..e61ec920d2281 100644 --- a/examples/src/main/python/ml/tokenizer_example.py +++ b/examples/src/main/python/ml/tokenizer_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import Tokenizer, RegexTokenizer # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="TokenizerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("TokenizerExample")\ + .getOrCreate() # $example on$ - sentenceDataFrame = sqlContext.createDataFrame([ + sentenceDataFrame = spark.createDataFrame([ (0, "Hi I heard about Spark"), (1, "I wish Java could use case classes"), (2, "Logistic,regression,models,are,neat") @@ -41,4 +42,4 @@ # alternatively, pattern="\\w+", gaps(False) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/train_validation_split.py b/examples/src/main/python/ml/train_validation_split.py index 161a200c61b6d..5f5c52aca8c42 100644 --- a/examples/src/main/python/ml/train_validation_split.py +++ b/examples/src/main/python/ml/train_validation_split.py @@ -15,13 +15,12 @@ # limitations under the License. # -from pyspark import SparkContext # $example on$ from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit -from pyspark.sql import SQLContext # $example off$ +from pyspark.sql import SparkSession """ This example demonstrates applying TrainValidationSplit to split data @@ -32,11 +31,13 @@ """ if __name__ == "__main__": - sc = SparkContext(appName="TrainValidationSplit") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("TrainValidationSplit")\ + .getOrCreate() # $example on$ # Prepare training and test data. - data = sqlContext.read.format("libsvm")\ + data = spark.read.format("libsvm")\ .load("data/mllib/sample_linear_regression_data.txt") train, test = data.randomSplit([0.7, 0.3]) lr = LinearRegression(maxIter=10, regParam=0.1) @@ -65,4 +66,4 @@ for row in prediction.take(5): print(row) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/vector_assembler_example.py b/examples/src/main/python/ml/vector_assembler_example.py index 04f64839f188d..bbfc316ff2d33 100644 --- a/examples/src/main/python/ml/vector_assembler_example.py +++ b/examples/src/main/python/ml/vector_assembler_example.py @@ -17,19 +17,20 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ -from pyspark.mllib.linalg import Vectors +from pyspark.ml.linalg import Vectors from pyspark.ml.feature import VectorAssembler # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="VectorAssemblerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("VectorAssemblerExample")\ + .getOrCreate() # $example on$ - dataset = sqlContext.createDataFrame( + dataset = spark.createDataFrame( [(0, 18, 1.0, Vectors.dense([0.0, 10.0, 0.5]), 1.0)], ["id", "hour", "mobile", "userFeatures", "clicked"]) assembler = VectorAssembler( @@ -39,4 +40,4 @@ print(output.select("features", "clicked").first()) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/vector_indexer_example.py b/examples/src/main/python/ml/vector_indexer_example.py index 146f41c1dd903..9b00e0f84136c 100644 --- a/examples/src/main/python/ml/vector_indexer_example.py +++ b/examples/src/main/python/ml/vector_indexer_example.py @@ -17,18 +17,19 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import VectorIndexer # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="VectorIndexerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("VectorIndexerExample")\ + .getOrCreate() # $example on$ - data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") indexer = VectorIndexer(inputCol="features", outputCol="indexed", maxCategories=10) indexerModel = indexer.fit(data) @@ -37,4 +38,4 @@ indexedData.show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/vector_slicer_example.py b/examples/src/main/python/ml/vector_slicer_example.py index 31a753073c13c..d2f46b190f9a8 100644 --- a/examples/src/main/python/ml/vector_slicer_example.py +++ b/examples/src/main/python/ml/vector_slicer_example.py @@ -17,20 +17,21 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import VectorSlicer -from pyspark.mllib.linalg import Vectors +from pyspark.ml.linalg import Vectors from pyspark.sql.types import Row # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="VectorSlicerExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("VectorSlicerExample")\ + .getOrCreate() # $example on$ - df = sqlContext.createDataFrame([ + df = spark.createDataFrame([ Row(userFeatures=Vectors.sparse(3, {0: -2.0, 1: 2.3}),), Row(userFeatures=Vectors.dense([-2.0, 2.3, 0.0]),)]) @@ -41,4 +42,4 @@ output.select("userFeatures", "features").show() # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/ml/word2vec_example.py b/examples/src/main/python/ml/word2vec_example.py index 53c77feb10145..66500bee152f7 100644 --- a/examples/src/main/python/ml/word2vec_example.py +++ b/examples/src/main/python/ml/word2vec_example.py @@ -17,19 +17,20 @@ from __future__ import print_function -from pyspark import SparkContext -from pyspark.sql import SQLContext # $example on$ from pyspark.ml.feature import Word2Vec # $example off$ +from pyspark.sql import SparkSession if __name__ == "__main__": - sc = SparkContext(appName="Word2VecExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("Word2VecExample")\ + .getOrCreate() # $example on$ # Input data: Each row is a bag of words from a sentence or document. - documentDF = sqlContext.createDataFrame([ + documentDF = spark.createDataFrame([ ("Hi I heard about Spark".split(" "), ), ("I wish Java could use case classes".split(" "), ), ("Logistic regression models are neat".split(" "), ) @@ -42,4 +43,4 @@ print(feature) # $example off$ - sc.stop() + spark.stop() diff --git a/examples/src/main/python/mllib/binary_classification_metrics_example.py b/examples/src/main/python/mllib/binary_classification_metrics_example.py index 4e7ea289b2532..daf000e38dcd0 100644 --- a/examples/src/main/python/mllib/binary_classification_metrics_example.py +++ b/examples/src/main/python/mllib/binary_classification_metrics_example.py @@ -18,20 +18,25 @@ Binary Classification Metrics Example. """ from __future__ import print_function -from pyspark import SparkContext, SQLContext +from pyspark.sql import SparkSession # $example on$ from pyspark.mllib.classification import LogisticRegressionWithLBFGS from pyspark.mllib.evaluation import BinaryClassificationMetrics -from pyspark.mllib.util import MLUtils +from pyspark.mllib.regression import LabeledPoint # $example off$ if __name__ == "__main__": - sc = SparkContext(appName="BinaryClassificationMetricsExample") - sqlContext = SQLContext(sc) + spark = SparkSession\ + .builder\ + .appName("BinaryClassificationMetricsExample")\ + .getOrCreate() + # $example on$ # Several of the methods available in scala are currently missing from pyspark # Load training data in LIBSVM format - data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + data = spark\ + .read.format("libsvm").load("data/mllib/sample_binary_classification_data.txt")\ + .rdd.map(lambda row: LabeledPoint(row[0], row[1])) # Split data into training (60%) and test (40%) training, test = data.randomSplit([0.6, 0.4], seed=11L) @@ -52,3 +57,5 @@ # Area under ROC curve print("Area under ROC = %s" % metrics.areaUnderROC) # $example off$ + + spark.stop() diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py index 69e836fc1d06a..6b46e27ddaaa8 100644 --- a/examples/src/main/python/mllib/gaussian_mixture_model.py +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -20,6 +20,10 @@ """ from __future__ import print_function +import sys +if sys.version >= '3': + long = int + import random import argparse import numpy as np diff --git a/examples/src/main/python/mllib/isotonic_regression_example.py b/examples/src/main/python/mllib/isotonic_regression_example.py index 89dc9f4b6611a..33d618ab48ea9 100644 --- a/examples/src/main/python/mllib/isotonic_regression_example.py +++ b/examples/src/main/python/mllib/isotonic_regression_example.py @@ -23,7 +23,8 @@ from pyspark import SparkContext # $example on$ import math -from pyspark.mllib.regression import IsotonicRegression, IsotonicRegressionModel +from pyspark.mllib.regression import LabeledPoint, IsotonicRegression, IsotonicRegressionModel +from pyspark.mllib.util import MLUtils # $example off$ if __name__ == "__main__": @@ -31,10 +32,14 @@ sc = SparkContext(appName="PythonIsotonicRegressionExample") # $example on$ - data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + # Load and parse the data + def parsePoint(labeledData): + return (labeledData.label, labeledData.features[0], 1.0) + + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_isotonic_regression_libsvm_data.txt") # Create label, feature, weight tuples from input data with weight set to default value 1.0. - parsedData = data.map(lambda line: tuple([float(x) for x in line.split(',')]) + (1.0,)) + parsedData = data.map(parsePoint) # Split data into training (60%) and test (40%) sets. training, test = parsedData.randomSplit([0.6, 0.4], 11) diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py index 35724f7d6a92d..749353b20eb3e 100644 --- a/examples/src/main/python/mllib/naive_bayes_example.py +++ b/examples/src/main/python/mllib/naive_bayes_example.py @@ -29,15 +29,9 @@ from pyspark import SparkContext # $example on$ from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel -from pyspark.mllib.linalg import Vectors -from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils -def parseLine(line): - parts = line.split(',') - label = float(parts[0]) - features = Vectors.dense([float(x) for x in parts[1].split(' ')]) - return LabeledPoint(label, features) # $example off$ if __name__ == "__main__": @@ -45,10 +39,11 @@ def parseLine(line): sc = SparkContext(appName="PythonNaiveBayesExample") # $example on$ - data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine) + # Load and parse the data file. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") # Split data approximately into training (60%) and test (40%) - training, test = data.randomSplit([0.6, 0.4], seed=0) + training, test = data.randomSplit([0.6, 0.4]) # Train a naive Bayes model. model = NaiveBayes.train(training, 1.0) diff --git a/examples/src/main/python/mllib/tf_idf_example.py b/examples/src/main/python/mllib/tf_idf_example.py index c4d53333a95a9..b66412b2334e7 100644 --- a/examples/src/main/python/mllib/tf_idf_example.py +++ b/examples/src/main/python/mllib/tf_idf_example.py @@ -43,7 +43,7 @@ # In such cases, the IDF for these terms is set to 0. # This feature can be used by passing the minDocFreq value to the IDF constructor. idfIgnore = IDF(minDocFreq=2).fit(tf) - tfidfIgnore = idf.transform(tf) + tfidfIgnore = idfIgnore.transform(tf) # $example off$ print("tfidf:") diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 2fdc9773d4eb1..a399a9c37c5d5 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -25,7 +25,7 @@ import sys from operator import add -from pyspark import SparkContext +from pyspark.sql import SparkSession def computeContribs(urls, rank): @@ -51,14 +51,17 @@ def parseNeighbors(urls): file=sys.stderr) # Initialize the spark context. - sc = SparkContext(appName="PythonPageRank") + spark = SparkSession\ + .builder\ + .appName("PythonPageRank")\ + .getOrCreate() # Loads in input file. It should be in format of: # URL neighbor URL # URL neighbor URL # URL neighbor URL # ... - lines = sc.textFile(sys.argv[1], 1) + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) # Loads all URLs from input file and initialize their neighbors. links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache() @@ -79,4 +82,4 @@ def parseNeighbors(urls): for (link, rank) in ranks.collect(): print("%s has rank: %s." % (link, rank)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index e1fd85b082c08..29a1ac274eccf 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -18,7 +18,7 @@ import sys -from pyspark import SparkContext +from pyspark.sql import SparkSession """ Read data file users.parquet in local Spark distro: @@ -47,7 +47,13 @@ exit(-1) path = sys.argv[1] - sc = SparkContext(appName="ParquetInputFormat") + + spark = SparkSession\ + .builder\ + .appName("ParquetInputFormat")\ + .getOrCreate() + + sc = spark.sparkContext parquet_rdd = sc.newAPIHadoopFile( path, @@ -59,4 +65,4 @@ for k in output: print(k) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index 92e5cf45abc8b..e3f0c4aeef1b7 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -20,14 +20,18 @@ from random import random from operator import add -from pyspark import SparkContext +from pyspark.sql import SparkSession if __name__ == "__main__": """ Usage: pi [partitions] """ - sc = SparkContext(appName="PythonPi") + spark = SparkSession\ + .builder\ + .appName("PythonPi")\ + .getOrCreate() + partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2 n = 100000 * partitions @@ -36,7 +40,7 @@ def f(_): y = random() * 2 - 1 return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add) + count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add) print("Pi is roughly %f" % (4.0 * count / n)) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index b6c2916254056..81898cf6d5ce6 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -19,15 +19,20 @@ import sys -from pyspark import SparkContext +from pyspark.sql import SparkSession if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: sort ", file=sys.stderr) exit(-1) - sc = SparkContext(appName="PythonSort") - lines = sc.textFile(sys.argv[1], 1) + + spark = SparkSession\ + .builder\ + .appName("PythonSort")\ + .getOrCreate() + + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) sortedCount = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (int(x), 1)) \ .sortByKey() @@ -37,4 +42,4 @@ for (num, unitcount) in output: print(num) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py deleted file mode 100644 index 2c188759328f2..0000000000000 --- a/examples/src/main/python/sql.py +++ /dev/null @@ -1,80 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import print_function - -import os -import sys - -from pyspark import SparkContext -from pyspark.sql import SQLContext -from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType - - -if __name__ == "__main__": - sc = SparkContext(appName="PythonSQL") - sqlContext = SQLContext(sc) - - # RDD is created from a list of rows - some_rdd = sc.parallelize([Row(name="John", age=19), - Row(name="Smith", age=23), - Row(name="Sarah", age=18)]) - # Infer schema from the first row, create a DataFrame and print the schema - some_df = sqlContext.createDataFrame(some_rdd) - some_df.printSchema() - - # Another RDD is created from a list of tuples - another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)]) - # Schema with two fields - person_name and person_age - schema = StructType([StructField("person_name", StringType(), False), - StructField("person_age", IntegerType(), False)]) - # Create a DataFrame by applying the schema to the RDD and print the schema - another_df = sqlContext.createDataFrame(another_rdd, schema) - another_df.printSchema() - # root - # |-- age: integer (nullable = true) - # |-- name: string (nullable = true) - - # A JSON dataset is pointed to by path. - # The path can be either a single text file or a directory storing text files. - if len(sys.argv) < 2: - path = "file://" + \ - os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") - else: - path = sys.argv[1] - # Create a DataFrame from the file(s) pointed to by path - people = sqlContext.jsonFile(path) - # root - # |-- person_name: string (nullable = false) - # |-- person_age: integer (nullable = false) - - # The inferred schema can be visualized using the printSchema() method. - people.printSchema() - # root - # |-- age: IntegerType - # |-- name: StringType - - # Register this DataFrame as a table. - people.registerAsTable("people") - - # SQL statements can be run by using the sql methods provided by sqlContext - teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") - - for each in teenagers.collect(): - print(each[0]) - - sc.stop() diff --git a/examples/src/main/python/sql/basic.py b/examples/src/main/python/sql/basic.py new file mode 100644 index 0000000000000..fdc017aed97c1 --- /dev/null +++ b/examples/src/main/python/sql/basic.py @@ -0,0 +1,194 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on:init_session$ +from pyspark.sql import SparkSession +# $example off:init_session$ + +# $example on:schema_inferring$ +from pyspark.sql import Row +# $example off:schema_inferring$ + +# $example on:programmatic_schema$ +# Import data types +from pyspark.sql.types import * +# $example off:programmatic_schema$ + +""" +A simple example demonstrating basic Spark SQL features. +Run with: + ./bin/spark-submit examples/src/main/python/sql/basic.py +""" + + +def basic_df_example(spark): + # $example on:create_df$ + # spark is an existing SparkSession + df = spark.read.json("examples/src/main/resources/people.json") + # Displays the content of the DataFrame to stdout + df.show() + # +----+-------+ + # | age| name| + # +----+-------+ + # |null|Michael| + # | 30| Andy| + # | 19| Justin| + # +----+-------+ + # $example off:create_df$ + + # $example on:untyped_ops$ + # spark, df are from the previous example + # Print the schema in a tree format + df.printSchema() + # root + # |-- age: long (nullable = true) + # |-- name: string (nullable = true) + + # Select only the "name" column + df.select("name").show() + # +-------+ + # | name| + # +-------+ + # |Michael| + # | Andy| + # | Justin| + # +-------+ + + # Select everybody, but increment the age by 1 + df.select(df['name'], df['age'] + 1).show() + # +-------+---------+ + # | name|(age + 1)| + # +-------+---------+ + # |Michael| null| + # | Andy| 31| + # | Justin| 20| + # +-------+---------+ + + # Select people older than 21 + df.filter(df['age'] > 21).show() + # +---+----+ + # |age|name| + # +---+----+ + # | 30|Andy| + # +---+----+ + + # Count people by age + df.groupBy("age").count().show() + # +----+-----+ + # | age|count| + # +----+-----+ + # | 19| 1| + # |null| 1| + # | 30| 1| + # +----+-----+ + # $example off:untyped_ops$ + + # $example on:run_sql$ + # Register the DataFrame as a SQL temporary view + df.createOrReplaceTempView("people") + + sqlDF = spark.sql("SELECT * FROM people") + sqlDF.show() + # +----+-------+ + # | age| name| + # +----+-------+ + # |null|Michael| + # | 30| Andy| + # | 19| Justin| + # +----+-------+ + # $example off:run_sql$ + + +def schema_inference_example(spark): + # $example on:schema_inferring$ + sc = spark.sparkContext + + # Load a text file and convert each line to a Row. + lines = sc.textFile("examples/src/main/resources/people.txt") + parts = lines.map(lambda l: l.split(",")) + people = parts.map(lambda p: Row(name=p[0], age=int(p[1]))) + + # Infer the schema, and register the DataFrame as a table. + schemaPeople = spark.createDataFrame(people) + schemaPeople.createOrReplaceTempView("people") + + # SQL can be run over DataFrames that have been registered as a table. + teenagers = spark.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") + + # The results of SQL queries are Dataframe objects. + # rdd returns the content as an :class:`pyspark.RDD` of :class:`Row`. + teenNames = teenagers.rdd.map(lambda p: "Name: " + p.name).collect() + for name in teenNames: + print(name) + # Name: Justin + # $example off:schema_inferring$ + + +def programmatic_schema_example(spark): + # $example on:programmatic_schema$ + sc = spark.sparkContext + + # Load a text file and convert each line to a Row. + lines = sc.textFile("examples/src/main/resources/people.txt") + parts = lines.map(lambda l: l.split(",")) + # Each line is converted to a tuple. + people = parts.map(lambda p: (p[0], p[1].strip())) + + # The schema is encoded in a string. + schemaString = "name age" + + fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split()] + schema = StructType(fields) + + # Apply the schema to the RDD. + schemaPeople = spark.createDataFrame(people, schema) + + # Creates a temporary view using the DataFrame + schemaPeople.createOrReplaceTempView("people") + + # Creates a temporary view using the DataFrame + schemaPeople.createOrReplaceTempView("people") + + # SQL can be run over DataFrames that have been registered as a table. + results = spark.sql("SELECT name FROM people") + + results.show() + # +-------+ + # | name| + # +-------+ + # |Michael| + # | Andy| + # | Justin| + # +-------+ + # $example off:programmatic_schema$ + +if __name__ == "__main__": + # $example on:init_session$ + spark = SparkSession \ + .builder \ + .appName("Python Spark SQL basic example") \ + .config("spark.some.config.option", "some-value") \ + .getOrCreate() + # $example off:init_session$ + + basic_df_example(spark) + schema_inference_example(spark) + programmatic_schema_example(spark) + + spark.stop() diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py new file mode 100644 index 0000000000000..b36c901d2b403 --- /dev/null +++ b/examples/src/main/python/sql/datasource.py @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark.sql import SparkSession +# $example on:schema_merging$ +from pyspark.sql import Row +# $example off:schema_merging$ + +""" +A simple example demonstrating Spark SQL data sources. +Run with: + ./bin/spark-submit examples/src/main/python/sql/datasource.py +""" + + +def basic_datasource_example(spark): + # $example on:generic_load_save_functions$ + df = spark.read.load("examples/src/main/resources/users.parquet") + df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") + # $example off:generic_load_save_functions$ + + # $example on:manual_load_options$ + df = spark.read.load("examples/src/main/resources/people.json", format="json") + df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") + # $example off:manual_load_options$ + + # $example on:direct_sql$ + df = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") + # $example off:direct_sql$ + + +def parquet_example(spark): + # $example on:basic_parquet_example$ + peopleDF = spark.read.json("examples/src/main/resources/people.json") + + # DataFrames can be saved as Parquet files, maintaining the schema information. + peopleDF.write.parquet("people.parquet") + + # Read in the Parquet file created above. + # Parquet files are self-describing so the schema is preserved. + # The result of loading a parquet file is also a DataFrame. + parquetFile = spark.read.parquet("people.parquet") + + # Parquet files can also be used to create a temporary view and then used in SQL statements. + parquetFile.createOrReplaceTempView("parquetFile") + teenagers = spark.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") + teenagers.show() + # +------+ + # | name| + # +------+ + # |Justin| + # +------+ + # $example off:basic_parquet_example$ + + +def parquet_schema_merging_example(spark): + # $example on:schema_merging$ + # spark is from the previous example. + # Create a simple DataFrame, stored into a partition directory + sc = spark.sparkContext + + squaresDF = spark.createDataFrame(sc.parallelize(range(1, 6)) + .map(lambda i: Row(single=i, double=i ** 2))) + squaresDF.write.parquet("data/test_table/key=1") + + # Create another DataFrame in a new partition directory, + # adding a new column and dropping an existing column + cubesDF = spark.createDataFrame(sc.parallelize(range(6, 11)) + .map(lambda i: Row(single=i, triple=i ** 3))) + cubesDF.write.parquet("data/test_table/key=2") + + # Read the partitioned table + mergedDF = spark.read.option("mergeSchema", "true").parquet("data/test_table") + mergedDF.printSchema() + + # The final schema consists of all 3 columns in the Parquet files together + # with the partitioning column appeared in the partition directory paths. + # root + # |-- double: long (nullable = true) + # |-- single: long (nullable = true) + # |-- triple: long (nullable = true) + # |-- key: integer (nullable = true) + # $example off:schema_merging$ + + +def json_dataset_example(spark): + # $example on:json_dataset$ + # spark is from the previous example. + sc = spark.sparkContext + + # A JSON dataset is pointed to by path. + # The path can be either a single text file or a directory storing text files + path = "examples/src/main/resources/people.json" + peopleDF = spark.read.json(path) + + # The inferred schema can be visualized using the printSchema() method + peopleDF.printSchema() + # root + # |-- age: long (nullable = true) + # |-- name: string (nullable = true) + + # Creates a temporary view using the DataFrame + peopleDF.createOrReplaceTempView("people") + + # SQL statements can be run by using the sql methods provided by spark + teenagerNamesDF = spark.sql("SELECT name FROM people WHERE age BETWEEN 13 AND 19") + teenagerNamesDF.show() + # +------+ + # | name| + # +------+ + # |Justin| + # +------+ + + # Alternatively, a DataFrame can be created for a JSON dataset represented by + # an RDD[String] storing one JSON object per string + jsonStrings = ['{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}'] + otherPeopleRDD = sc.parallelize(jsonStrings) + otherPeople = spark.read.json(otherPeopleRDD) + otherPeople.show() + # +---------------+----+ + # | address|name| + # +---------------+----+ + # |[Columbus,Ohio]| Yin| + # +---------------+----+ + # $example off:json_dataset$ + + +def jdbc_dataset_example(spark): + # $example on:jdbc_dataset$ + jdbcDF = spark.read \ + .format("jdbc") \ + .option("url", "jdbc:postgresql:dbserver") \ + .option("dbtable", "schema.tablename") \ + .option("user", "username") \ + .option("password", "password") \ + .load() + # $example off:jdbc_dataset$ + + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("Python Spark SQL data source example") \ + .getOrCreate() + + basic_datasource_example(spark) + parquet_example(spark) + parquet_schema_merging_example(spark) + json_dataset_example(spark) + jdbc_dataset_example(spark) + + spark.stop() diff --git a/examples/src/main/python/sql/hive.py b/examples/src/main/python/sql/hive.py new file mode 100644 index 0000000000000..142213d3cfb1f --- /dev/null +++ b/examples/src/main/python/sql/hive.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on:spark_hive$ +from os.path import expanduser, join + +from pyspark.sql import SparkSession +from pyspark.sql import Row +# $example off:spark_hive$ + +""" +A simple example demonstrating Spark SQL Hive integration. +Run with: + ./bin/spark-submit examples/src/main/python/sql/hive.py +""" + + +if __name__ == "__main__": + # $example on:spark_hive$ + # warehouse_location points to the default location for managed databases and tables + warehouse_location = 'spark-warehouse' + + spark = SparkSession \ + .builder \ + .appName("Python Spark SQL Hive integration example") \ + .config("spark.sql.warehouse.dir", warehouse_location) \ + .enableHiveSupport() \ + .getOrCreate() + + # spark is an existing SparkSession + spark.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + spark.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + + # Queries are expressed in HiveQL + spark.sql("SELECT * FROM src").show() + # +---+-------+ + # |key| value| + # +---+-------+ + # |238|val_238| + # | 86| val_86| + # |311|val_311| + # ... + + # Aggregation queries are also supported. + spark.sql("SELECT COUNT(*) FROM src").show() + # +--------+ + # |count(1)| + # +--------+ + # | 500 | + # +--------+ + + # The results of SQL queries are themselves DataFrames and support all normal functions. + sqlDF = spark.sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") + + # The items in DaraFrames are of type Row, which allows you to access each column by ordinal. + stringsDS = sqlDF.rdd.map(lambda row: "Key: %d, Value: %s" % (row.key, row.value)) + for record in stringsDS.collect(): + print(record) + # Key: 0, Value: val_0 + # Key: 0, Value: val_0 + # Key: 0, Value: val_0 + # ... + + # You can also use DataFrames to create temporary views within a SparkSession. + Record = Row("key", "value") + recordsDF = spark.createDataFrame(map(lambda i: Record(i, "val_" + str(i)), range(1, 101))) + recordsDF.createOrReplaceTempView("records") + + # Queries can then join DataFrame data with data stored in Hive. + spark.sql("SELECT * FROM records r JOIN src s ON r.key = s.key").show() + # +---+------+---+------+ + # |key| value|key| value| + # +---+------+---+------+ + # | 2| val_2| 2| val_2| + # | 4| val_4| 4| val_4| + # | 5| val_5| 5| val_5| + # ... + # $example off:spark_hive$ + + spark.stop() diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py new file mode 100644 index 0000000000000..afde2550587ca --- /dev/null +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network. + Usage: structured_network_wordcount.py + and describe the TCP server that Structured Streaming + would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py + localhost 9999` +""" +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession +from pyspark.sql.functions import explode +from pyspark.sql.functions import split + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: structured_network_wordcount.py ", file=sys.stderr) + exit(-1) + + host = sys.argv[1] + port = int(sys.argv[2]) + + spark = SparkSession\ + .builder\ + .appName("StructuredNetworkWordCount")\ + .getOrCreate() + + # Create DataFrame representing the stream of input lines from connection to host:port + lines = spark\ + .readStream\ + .format('socket')\ + .option('host', host)\ + .option('port', port)\ + .load() + + # Split the lines into words + words = lines.select( + # explode turns each item in an array into a separate row + explode( + split(lines.value, ' ') + ).alias('word') + ) + + # Generate running word count + wordCounts = words.groupBy('word').count() + + # Start running the query that prints the running counts to the console + query = wordCounts\ + .writeStream\ + .outputMode('complete')\ + .format('console')\ + .start() + + query.awaitTermination() diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py new file mode 100644 index 0000000000000..02a7d3363d780 --- /dev/null +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network over a + sliding window of configurable duration. Each line from the network is tagged + with a timestamp that is used to determine the windows into which it falls. + + Usage: structured_network_wordcount_windowed.py + [] + and describe the TCP server that Structured Streaming + would connect to receive data. + gives the size of window, specified as integer number of seconds + gives the amount of time successive windows are offset from one another, + given in the same units as above. should be less than or equal to + . If the two are equal, successive windows have no overlap. If + is not provided, it defaults to . + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit + examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py + localhost 9999 []` + + One recommended , pair is 10, 5 +""" +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession +from pyspark.sql.functions import explode +from pyspark.sql.functions import split +from pyspark.sql.functions import window + +if __name__ == "__main__": + if len(sys.argv) != 5 and len(sys.argv) != 4: + msg = ("Usage: structured_network_wordcount_windowed.py " + " []") + print(msg, file=sys.stderr) + exit(-1) + + host = sys.argv[1] + port = int(sys.argv[2]) + windowSize = int(sys.argv[3]) + slideSize = int(sys.argv[4]) if (len(sys.argv) == 5) else windowSize + if slideSize > windowSize: + print(" must be less than or equal to ", file=sys.stderr) + windowDuration = '{} seconds'.format(windowSize) + slideDuration = '{} seconds'.format(slideSize) + + spark = SparkSession\ + .builder\ + .appName("StructuredNetworkWordCountWindowed")\ + .getOrCreate() + + # Create DataFrame representing the stream of input lines from connection to host:port + lines = spark\ + .readStream\ + .format('socket')\ + .option('host', host)\ + .option('port', port)\ + .option('includeTimestamp', 'true')\ + .load() + + # Split the lines into words, retaining timestamps + # split() splits each line into an array, and explode() turns the array into multiple rows + words = lines.select( + explode(split(lines.value, ' ')).alias('word'), + lines.timestamp + ) + + # Group the data by window and word and compute the count of each group + windowedCounts = words.groupBy( + window(words.timestamp, windowDuration, slideDuration), + words.word + ).count().orderBy('window') + + # Start running the query that prints the windowed word counts to the console + query = windowedCounts\ + .writeStream\ + .outputMode('complete')\ + .format('console')\ + .option('truncate', 'false')\ + .start() + + query.awaitTermination() diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 1ba5e9fb78993..398ac8d2d8f5e 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -33,13 +33,16 @@ from pyspark import SparkContext from pyspark.streaming import StreamingContext -from pyspark.sql import SQLContext, Row +from pyspark.sql import Row, SparkSession -def getSqlContextInstance(sparkContext): - if ('sqlContextSingletonInstance' not in globals()): - globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext) - return globals()['sqlContextSingletonInstance'] +def getSparkSessionInstance(sparkConf): + if ('sparkSessionSingletonInstance' not in globals()): + globals()['sparkSessionSingletonInstance'] = SparkSession\ + .builder\ + .config(conf=sparkConf)\ + .getOrCreate() + return globals()['sparkSessionSingletonInstance'] if __name__ == "__main__": @@ -60,19 +63,19 @@ def process(time, rdd): print("========= %s =========" % str(time)) try: - # Get the singleton instance of SQLContext - sqlContext = getSqlContextInstance(rdd.context) + # Get the singleton instance of SparkSession + spark = getSparkSessionInstance(rdd.context.getConf()) # Convert RDD[String] to RDD[Row] to DataFrame rowRdd = rdd.map(lambda w: Row(word=w)) - wordsDataFrame = sqlContext.createDataFrame(rowRdd) + wordsDataFrame = spark.createDataFrame(rowRdd) - # Register as table - wordsDataFrame.registerTempTable("words") + # Creates a temporary view using the DataFrame. + wordsDataFrame.createOrReplaceTempView("words") # Do word count on table using SQL and print it wordCountsDataFrame = \ - sqlContext.sql("select word, count(*) as total from words group by word") + spark.sql("select word, count(*) as total from words group by word") wordCountsDataFrame.show() except: pass diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 3d61250d8b230..49551d40851cc 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -20,7 +20,7 @@ import sys from random import Random -from pyspark import SparkContext +from pyspark.sql import SparkSession numEdges = 200 numVertices = 100 @@ -41,9 +41,13 @@ def generateGraph(): """ Usage: transitive_closure [partitions] """ - sc = SparkContext(appName="PythonTransitiveClosure") + spark = SparkSession\ + .builder\ + .appName("PythonTransitiveClosure")\ + .getOrCreate() + partitions = int(sys.argv[1]) if len(sys.argv) > 1 else 2 - tc = sc.parallelize(generateGraph(), partitions).cache() + tc = spark.sparkContext.parallelize(generateGraph(), partitions).cache() # Linear transitive closure: each round grows paths by one edge, # by joining the graph's edges with the already-discovered paths. @@ -67,4 +71,4 @@ def generateGraph(): print("TC has %i edges" % tc.count()) - sc.stop() + spark.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index 7c0143607b61d..3d5e44d5b2df1 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -20,15 +20,20 @@ import sys from operator import add -from pyspark import SparkContext +from pyspark.sql import SparkSession if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: wordcount ", file=sys.stderr) exit(-1) - sc = SparkContext(appName="PythonWordCount") - lines = sc.textFile(sys.argv[1], 1) + + spark = SparkSession\ + .builder\ + .appName("PythonWordCount")\ + .getOrCreate() + + lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) counts = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (x, 1)) \ .reduceByKey(add) @@ -36,4 +41,4 @@ for (word, count) in output: print("%s: %i" % (word, count)) - sc.stop() + spark.stop() diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R new file mode 100644 index 0000000000000..de489e1bda2c3 --- /dev/null +++ b/examples/src/main/r/RSparkSQLExample.R @@ -0,0 +1,211 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(SparkR) + +# $example on:init_session$ +sparkR.session(appName = "MyApp", sparkConfig = list(spark.some.config.option = "some-value")) +# $example off:init_session$ + + +# $example on:create_df$ +df <- read.json("examples/src/main/resources/people.json") + +# Displays the content of the DataFrame +head(df) +## age name +## 1 NA Michael +## 2 30 Andy +## 3 19 Justin + +# Another method to print the first few rows and optionally truncate the printing of long values +showDF(df) +## +----+-------+ +## | age| name| +## +----+-------+ +## |null|Michael| +## | 30| Andy| +## | 19| Justin| +## +----+-------+ +## $example off:create_df$ + + +# $example on:untyped_ops$ +# Create the DataFrame +df <- read.json("examples/src/main/resources/people.json") + +# Show the content of the DataFrame +head(df) +## age name +## 1 NA Michael +## 2 30 Andy +## 3 19 Justin + + +# Print the schema in a tree format +printSchema(df) +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Select only the "name" column +head(select(df, "name")) +## name +## 1 Michael +## 2 Andy +## 3 Justin + +# Select everybody, but increment the age by 1 +head(select(df, df$name, df$age + 1)) +## name (age + 1.0) +## 1 Michael NA +## 2 Andy 31 +## 3 Justin 20 + +# Select people older than 21 +head(where(df, df$age > 21)) +## age name +## 1 30 Andy + +# Count people by age +head(count(groupBy(df, "age"))) +## age count +## 1 19 1 +## 2 NA 1 +## 3 30 1 +# $example off:untyped_ops$ + + +# Register this DataFrame as a table. +createOrReplaceTempView(df, "table") +# $example on:run_sql$ +df <- sql("SELECT * FROM table") +# $example off:run_sql$ + + +# $example on:generic_load_save_functions$ +df <- read.df("examples/src/main/resources/users.parquet") +write.df(select(df, "name", "favorite_color"), "namesAndFavColors.parquet") +# $example off:generic_load_save_functions$ + + +# $example on:manual_load_options$ +df <- read.df("examples/src/main/resources/people.json", "json") +namesAndAges <- select(df, "name", "age") +write.df(namesAndAges, "namesAndAges.parquet", "parquet") +# $example off:manual_load_options$ + + +# $example on:direct_sql$ +df <- sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") +# $example off:direct_sql$ + + +# $example on:basic_parquet_example$ +df <- read.df("examples/src/main/resources/people.json", "json") + +# SparkDataFrame can be saved as Parquet files, maintaining the schema information. +write.parquet(df, "people.parquet") + +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# The result of loading a parquet file is also a DataFrame. +parquetFile <- read.parquet("people.parquet") + +# Parquet files can also be used to create a temporary view and then used in SQL statements. +createOrReplaceTempView(parquetFile, "parquetFile") +teenagers <- sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +## 1 Justin + +# We can also run custom R-UDFs on Spark DataFrames. Here we prefix all the names with "Name:" +schema <- structType(structField("name", "string")) +teenNames <- dapply(df, function(p) { cbind(paste("Name:", p$name)) }, schema) +for (teenName in collect(teenNames)$name) { + cat(teenName, "\n") +} +## Name: Michael +## Name: Andy +## Name: Justin +# $example off:basic_parquet_example$ + + +# $example on:schema_merging$ +df1 <- createDataFrame(data.frame(single=c(12, 29), double=c(19, 23))) +df2 <- createDataFrame(data.frame(double=c(19, 23), triple=c(23, 18))) + +# Create a simple DataFrame, stored into a partition directory +write.df(df1, "data/test_table/key=1", "parquet", "overwrite") + +# Create another DataFrame in a new partition directory, +# adding a new column and dropping an existing column +write.df(df2, "data/test_table/key=2", "parquet", "overwrite") + +# Read the partitioned table +df3 <- read.df("data/test_table", "parquet", mergeSchema = "true") +printSchema(df3) +# The final schema consists of all 3 columns in the Parquet files together +# with the partitioning column appeared in the partition directory paths +## root +## |-- single: double (nullable = true) +## |-- double: double (nullable = true) +## |-- triple: double (nullable = true) +## |-- key: integer (nullable = true) +# $example off:schema_merging$ + + +# $example on:json_dataset$ +# A JSON dataset is pointed to by path. +# The path can be either a single text file or a directory storing text files. +path <- "examples/src/main/resources/people.json" +# Create a DataFrame from the file(s) pointed to by path +people <- read.json(path) + +# The inferred schema can be visualized using the printSchema() method. +printSchema(people) +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Register this DataFrame as a table. +createOrReplaceTempView(people, "people") + +# SQL statements can be run by using the sql methods. +teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +## 1 Justin +# $example off:json_dataset$ + + +# $example on:spark_hive$ +# enableHiveSupport defaults to TRUE +sparkR.session(enableHiveSupport = TRUE) +sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- collect(sql("FROM src SELECT key, value")) +# $example off:spark_hive$ + + +# $example on:jdbc_dataset$ +df <- read.jdbc("jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password") +# $example off:jdbc_dataset$ + +# Stop the SparkSession now +sparkR.session.stop() diff --git a/examples/src/main/r/data-manipulation.R b/examples/src/main/r/data-manipulation.R index 594bf49d60158..371335a62e92d 100644 --- a/examples/src/main/r/data-manipulation.R +++ b/examples/src/main/r/data-manipulation.R @@ -17,11 +17,10 @@ # For this example, we shall use the "flights" dataset # The dataset consists of every flight departing Houston in 2011. -# The data set is made up of 227,496 rows x 14 columns. +# The data set is made up of 227,496 rows x 14 columns. # To run this example use -# ./bin/sparkR --packages com.databricks:spark-csv_2.10:1.0.3 -# examples/src/main/r/data-manipulation.R +# ./bin/spark-submit examples/src/main/r/data-manipulation.R # Load SparkR library into your R session library(SparkR) @@ -29,16 +28,13 @@ library(SparkR) args <- commandArgs(trailing = TRUE) if (length(args) != 1) { - print("Usage: data-manipulation.R ") print("The data can be downloaded from: http://s3-us-west-2.amazonaws.com/sparkr-data/flights.csv") q("no") } -## Initialize SparkContext -sc <- sparkR.init(appName = "SparkR-data-manipulation-example") - -## Initialize SQLContext -sqlContext <- sparkRSQL.init(sc) +## Initialize SparkSession +sparkR.session(appName = "SparkR-data-manipulation-example") flightsCsvPath <- args[[1]] @@ -47,13 +43,13 @@ flights_df <- read.csv(flightsCsvPath, header = TRUE) flights_df$date <- as.Date(flights_df$date) ## Filter flights whose destination is San Francisco and write to a local data frame -SFO_df <- flights_df[flights_df$dest == "SFO", ] +SFO_df <- flights_df[flights_df$dest == "SFO", ] # Convert the local data frame into a SparkDataFrame -SFO_DF <- createDataFrame(sqlContext, SFO_df) +SFO_DF <- createDataFrame(SFO_df) # Directly create a SparkDataFrame from the source data -flightsDF <- read.df(sqlContext, flightsCsvPath, source = "com.databricks.spark.csv", header = "true") +flightsDF <- read.df(flightsCsvPath, source = "csv", header = "true") # Print the schema of this SparkDataFrame printSchema(flightsDF) @@ -76,8 +72,8 @@ destDF <- select(flightsDF, "dest", "cancelled") # Using SQL to select columns of data # First, register the flights SparkDataFrame as a table -registerTempTable(flightsDF, "flightsTable") -destDF <- sql(sqlContext, "SELECT dest, cancelled FROM flightsTable") +createOrReplaceTempView(flightsDF, "flightsTable") +destDF <- sql("SELECT dest, cancelled FROM flightsTable") # Use collect to create a local R data frame local_df <- collect(destDF) @@ -103,5 +99,5 @@ if("magrittr" %in% rownames(installed.packages())) { head(dailyDelayDF) } -# Stop the SparkContext now -sparkR.stop() +# Stop the SparkSession now +sparkR.session.stop() diff --git a/examples/src/main/r/dataframe.R b/examples/src/main/r/dataframe.R index 436bac6aaf455..82b85f2f590f6 100644 --- a/examples/src/main/r/dataframe.R +++ b/examples/src/main/r/dataframe.R @@ -17,15 +17,14 @@ library(SparkR) -# Initialize SparkContext and SQLContext -sc <- sparkR.init(appName="SparkR-DataFrame-example") -sqlContext <- sparkRSQL.init(sc) +# Initialize SparkSession +sparkR.session(appName = "SparkR-DataFrame-example") # Create a simple local data.frame localDF <- data.frame(name=c("John", "Smith", "Sarah"), age=c(19, 23, 18)) # Convert local data frame to a SparkDataFrame -df <- createDataFrame(sqlContext, localDF) +df <- createDataFrame(localDF) # Print its schema printSchema(df) @@ -35,20 +34,23 @@ printSchema(df) # Create a DataFrame from a JSON file path <- file.path(Sys.getenv("SPARK_HOME"), "examples/src/main/resources/people.json") -peopleDF <- read.json(sqlContext, path) +peopleDF <- read.json(path) printSchema(peopleDF) +# root +# |-- age: long (nullable = true) +# |-- name: string (nullable = true) # Register this DataFrame as a table. -registerTempTable(peopleDF, "people") +createOrReplaceTempView(peopleDF, "people") -# SQL statements can be run by using the sql methods provided by sqlContext -teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +# SQL statements can be run by using the sql methods +teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") # Call collect to get a local data.frame teenagersLocalDF <- collect(teenagers) -# Print the teenagers in our dataset +# Print the teenagers in our dataset print(teenagersLocalDF) -# Stop the SparkContext now -sparkR.stop() +# Stop the SparkSession now +sparkR.session.stop() diff --git a/examples/src/main/r/ml.R b/examples/src/main/r/ml.R index a0c903939cbbb..a8a1274ac902a 100644 --- a/examples/src/main/r/ml.R +++ b/examples/src/main/r/ml.R @@ -16,39 +16,133 @@ # # To run this example use -# ./bin/sparkR examples/src/main/r/ml.R +# ./bin/spark-submit examples/src/main/r/ml.R # Load SparkR library into your R session library(SparkR) -# Initialize SparkContext and SQLContext -sc <- sparkR.init(appName="SparkR-ML-example") -sqlContext <- sparkRSQL.init(sc) +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-example") -# Train GLM of family 'gaussian' -training1 <- suppressWarnings(createDataFrame(sqlContext, iris)) -test1 <- training1 -model1 <- glm(Sepal_Length ~ Sepal_Width + Species, training1, family = "gaussian") +############################ spark.glm and glm ############################################## +# $example on:glm$ +irisDF <- suppressWarnings(createDataFrame(iris)) +# Fit a generalized linear model of family "gaussian" with spark.glm +gaussianDF <- irisDF +gaussianTestDF <- irisDF +gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") # Model summary -summary(model1) +summary(gaussianGLM) # Prediction -predictions1 <- predict(model1, test1) -head(select(predictions1, "Sepal_Length", "prediction")) +gaussianPredictions <- predict(gaussianGLM, gaussianTestDF) +showDF(gaussianPredictions) -# Train GLM of family 'binomial' -training2 <- filter(training1, training1$Species != "setosa") -test2 <- training2 -model2 <- glm(Species ~ Sepal_Length + Sepal_Width, data = training2, family = "binomial") +# Fit a generalized linear model with glm (R-compliant) +gaussianGLM2 <- glm(Sepal_Length ~ Sepal_Width + Species, gaussianDF, family = "gaussian") +summary(gaussianGLM2) + +# Fit a generalized linear model of family "binomial" with spark.glm +binomialDF <- filter(irisDF, irisDF$Species != "setosa") +binomialTestDF <- binomialDF +binomialGLM <- spark.glm(binomialDF, Species ~ Sepal_Length + Sepal_Width, family = "binomial") + +# Model summary +summary(binomialGLM) + +# Prediction +binomialPredictions <- predict(binomialGLM, binomialTestDF) +showDF(binomialPredictions) +# $example off:glm$ +############################ spark.survreg ############################################## +# $example on:survreg$ +# Use the ovarian dataset available in R survival package +library(survival) + +# Fit an accelerated failure time (AFT) survival regression model with spark.survreg +ovarianDF <- suppressWarnings(createDataFrame(ovarian)) +aftDF <- ovarianDF +aftTestDF <- ovarianDF +aftModel <- spark.survreg(aftDF, Surv(futime, fustat) ~ ecog_ps + rx) # Model summary -summary(model2) +summary(aftModel) + +# Prediction +aftPredictions <- predict(aftModel, aftTestDF) +showDF(aftPredictions) +# $example off:survreg$ +############################ spark.naiveBayes ############################################## +# $example on:naiveBayes$ +# Fit a Bernoulli naive Bayes model with spark.naiveBayes +titanic <- as.data.frame(Titanic) +titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) +nbDF <- titanicDF +nbTestDF <- titanicDF +nbModel <- spark.naiveBayes(nbDF, Survived ~ Class + Sex + Age) + +# Model summary +summary(nbModel) + +# Prediction +nbPredictions <- predict(nbModel, nbTestDF) +showDF(nbPredictions) +# $example off:naiveBayes$ +############################ spark.kmeans ############################################## +# $example on:kmeans$ +# Fit a k-means model with spark.kmeans +irisDF <- suppressWarnings(createDataFrame(iris)) +kmeansDF <- irisDF +kmeansTestDF <- irisDF +kmeansModel <- spark.kmeans(kmeansDF, ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, + k = 3) + +# Model summary +summary(kmeansModel) + +# Get fitted result from the k-means model +showDF(fitted(kmeansModel)) + +# Prediction +kmeansPredictions <- predict(kmeansModel, kmeansTestDF) +showDF(kmeansPredictions) +# $example off:kmeans$ +############################ model read/write ############################################## +# $example on:read_write$ +irisDF <- suppressWarnings(createDataFrame(iris)) +# Fit a generalized linear model of family "gaussian" with spark.glm +gaussianDF <- irisDF +gaussianTestDF <- irisDF +gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") + +# Save and then load a fitted MLlib model +modelPath <- tempfile(pattern = "ml", fileext = ".tmp") +write.ml(gaussianGLM, modelPath) +gaussianGLM2 <- read.ml(modelPath) + +# Check model summary +summary(gaussianGLM2) + +# Check model prediction +gaussianPredictions <- predict(gaussianGLM2, gaussianTestDF) +showDF(gaussianPredictions) + +unlink(modelPath) +# $example off:read_write$ +############################ fit models with spark.lapply ##################################### + +# Perform distributed training of multiple models with spark.lapply +families <- c("gaussian", "poisson") +train <- function(family) { + model <- glm(Sepal.Length ~ Sepal.Width + Species, iris, family = family) + summary(model) +} +model.summaries <- spark.lapply(families, train) + +# Print the summary of each model +print(model.summaries) -# Prediction (Currently the output of prediction for binomial GLM is the indexed label, -# we need to transform back to the original string label later) -predictions2 <- predict(model2, test2) -head(select(predictions2, "Species", "prediction")) -# Stop the SparkContext now -sparkR.stop() +# Stop the SparkSession now +sparkR.session.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index af5a815f6ec76..a68fd0285f567 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -18,7 +18,8 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession /** * Usage: BroadcastTest [slices] [numElem] [blockSize] @@ -28,9 +29,13 @@ object BroadcastTest { val blockSize = if (args.length > 2) args(2) else "4096" - val sparkConf = new SparkConf().setAppName("Broadcast Test") - .set("spark.broadcast.blockSize", blockSize) - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder() + .appName("Broadcast Test") + .config("spark.broadcast.blockSize", blockSize) + .getOrCreate() + + val sc = spark.sparkContext val slices = if (args.length > 0) args(0).toInt else 2 val num = if (args.length > 1) args(1).toInt else 1000000 @@ -48,7 +53,7 @@ object BroadcastTest { println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 7bf023667dcae..3bff7ce736d08 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -22,7 +22,7 @@ import java.io.File import scala.io.Source._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Simple test for reading and writing to a distributed @@ -101,19 +101,19 @@ object DFSReadWriteTest { val fileContents = readFile(localFilePath.toString()) val localWordCount = runLocalWordCount(fileContents) - println("Creating SparkConf") - val conf = new SparkConf().setAppName("DFS Read Write Test") - - println("Creating SparkContext") - val sc = new SparkContext(conf) + println("Creating SparkSession") + val spark = SparkSession + .builder + .appName("DFS Read Write Test") + .getOrCreate() println("Writing local file to DFS") val dfsFilename = dfsDirPath + "/dfs_read_write_test" - val fileRDD = sc.parallelize(fileContents) + val fileRDD = spark.sparkContext.parallelize(fileContents) fileRDD.saveAsTextFile(dfsFilename) println("Reading file from DFS and running Word Count") - val readFileRDD = sc.textFile(dfsFilename) + val readFileRDD = spark.sparkContext.textFile(dfsFilename) val dfsWordCount = readFileRDD .flatMap(_.split(" ")) @@ -124,7 +124,7 @@ object DFSReadWriteTest { .values .sum - sc.stop() + spark.stop() if (localWordCount == dfsWordCount) { println(s"Success! Local Word Count ($localWordCount) " + diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala index d42f63e87052e..45c4953a84be2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala @@ -17,18 +17,21 @@ package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession object ExceptionHandlingTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("ExceptionHandlingTest") - val sc = new SparkContext(sparkConf) - sc.parallelize(0 until sc.defaultParallelism).foreach { i => + val spark = SparkSession + .builder + .appName("ExceptionHandlingTest") + .getOrCreate() + + spark.sparkContext.parallelize(0 until spark.sparkContext.defaultParallelism).foreach { i => if (math.random > 0.75) { throw new Exception("Testing exception handling") } } - sc.stop() + spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 4db229b5dec32..2f2bbb1275438 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -20,24 +20,26 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object GroupByTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("GroupBy Test") - var numMappers = if (args.length > 0) args(0).toInt else 2 - var numKVPairs = if (args.length > 1) args(1).toInt else 1000 - var valSize = if (args.length > 2) args(2).toInt else 1000 - var numReducers = if (args.length > 3) args(3).toInt else numMappers + val spark = SparkSession + .builder + .appName("GroupBy Test") + .getOrCreate() - val sc = new SparkContext(sparkConf) + val numMappers = if (args.length > 0) args(0).toInt else 2 + val numKVPairs = if (args.length > 1) args(1).toInt else 1000 + val valSize = if (args.length > 2) args(2).toInt else 1000 + val numReducers = if (args.length > 3) args(3).toInt else numMappers - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val pairs1 = spark.sparkContext.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random - var arr1 = new Array[(Int, Array[Byte])](numKVPairs) + val arr1 = new Array[(Int, Array[Byte])](numKVPairs) for (i <- 0 until numKVPairs) { val byteArr = new Array[Byte](valSize) ranGen.nextBytes(byteArr) @@ -50,7 +52,7 @@ object GroupByTest { println(pairs1.groupByKey(numReducers).count()) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index 124dc9af6390f..aa8de69839e28 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark._ +import org.apache.spark.sql.SparkSession object HdfsTest { @@ -29,9 +29,11 @@ object HdfsTest { System.err.println("Usage: HdfsTest ") System.exit(1) } - val sparkConf = new SparkConf().setAppName("HdfsTest") - val sc = new SparkContext(sparkConf) - val file = sc.textFile(args(0)) + val spark = SparkSession + .builder + .appName("HdfsTest") + .getOrCreate() + val file = spark.read.text(args(0)).rdd val mapped = file.map(s => s.length).cache() for (iter <- 1 to 10) { val start = System.currentTimeMillis() @@ -39,7 +41,7 @@ object HdfsTest { val end = System.currentTimeMillis() println("Iteration " + iter + " took " + (end-start) + " ms") } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index fa1010195551a..97aefac025e55 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -24,7 +24,7 @@ import org.apache.commons.math3.linear._ * Alternating least squares matrix factorization. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.recommendation.ALS + * please refer to org.apache.spark.ml.recommendation.ALS. */ object LocalALS { @@ -96,7 +96,7 @@ object LocalALS { def showWarning() { System.err.println( """WARN: This is a naive implementation of ALS and is given as an example! - |Please use the ALS method found in org.apache.spark.mllib.recommendation + |Please use org.apache.spark.ml.recommendation.ALS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index bec89f7c3dff0..3d02ce05619ad 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -26,8 +26,7 @@ import breeze.linalg.{DenseVector, Vector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. + * please refer to org.apache.spark.ml.classification.LogisticRegression. */ object LocalFileLR { val D = 10 // Number of dimensions @@ -43,8 +42,7 @@ object LocalFileLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS + |Please use org.apache.spark.ml.classification.LogisticRegression |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index f8961847f3df2..fca585c2a362b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -29,7 +29,7 @@ import breeze.linalg.{squaredDistance, DenseVector, Vector} * K-means clustering. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.clustering.KMeans + * please refer to org.apache.spark.ml.clustering.KMeans. */ object LocalKMeans { val N = 1000 @@ -66,7 +66,7 @@ object LocalKMeans { def showWarning() { System.err.println( """WARN: This is a naive implementation of KMeans Clustering and is given as an example! - |Please use the KMeans method found in org.apache.spark.mllib.clustering + |Please use org.apache.spark.ml.clustering.KMeans |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 0baf6db607ad9..13ccc2ae7c3d8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -26,8 +26,7 @@ import breeze.linalg.{DenseVector, Vector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. + * please refer to org.apache.spark.ml.classification.LogisticRegression. */ object LocalLR { val N = 10000 // Number of data points @@ -50,8 +49,7 @@ object LocalLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS + |Please use org.apache.spark.ml.classification.LogisticRegression |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 3eb0c2772337a..6495a86fcd77c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -18,8 +18,9 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession + /** * Usage: MultiBroadcastTest [slices] [numElem] @@ -27,8 +28,10 @@ import org.apache.spark.rdd.RDD object MultiBroadcastTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("Multi-Broadcast Test") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("Multi-Broadcast Test") + .getOrCreate() val slices = if (args.length > 0) args(0).toInt else 2 val num = if (args.length > 1) args(1).toInt else 1000000 @@ -43,15 +46,15 @@ object MultiBroadcastTest { arr2(i) = i } - val barr1 = sc.broadcast(arr1) - val barr2 = sc.broadcast(arr2) - val observedSizes: RDD[(Int, Int)] = sc.parallelize(1 to 10, slices).map { _ => + val barr1 = spark.sparkContext.broadcast(arr1) + val barr2 = spark.sparkContext.broadcast(arr2) + val observedSizes: RDD[(Int, Int)] = spark.sparkContext.parallelize(1 to 10, slices).map { _ => (barr1.value.length, barr2.value.length) } // Collect the small RDD so we can print the observed sizes locally. observedSizes.collect().foreach(i => println(i)) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index ec07e6323ee9a..8e1a574c92221 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -20,26 +20,27 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] */ object SimpleSkewedGroupByTest { def main(args: Array[String]) { + val spark = SparkSession + .builder + .appName("SimpleSkewedGroupByTest") + .getOrCreate() - val sparkConf = new SparkConf().setAppName("SimpleSkewedGroupByTest") - var numMappers = if (args.length > 0) args(0).toInt else 2 - var numKVPairs = if (args.length > 1) args(1).toInt else 1000 - var valSize = if (args.length > 2) args(2).toInt else 1000 - var numReducers = if (args.length > 3) args(3).toInt else numMappers - var ratio = if (args.length > 4) args(4).toInt else 5.0 + val numMappers = if (args.length > 0) args(0).toInt else 2 + val numKVPairs = if (args.length > 1) args(1).toInt else 1000 + val valSize = if (args.length > 2) args(2).toInt else 1000 + val numReducers = if (args.length > 3) args(3).toInt else numMappers + val ratio = if (args.length > 4) args(4).toInt else 5.0 - val sc = new SparkContext(sparkConf) - - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val pairs1 = spark.sparkContext.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random - var result = new Array[(Int, Array[Byte])](numKVPairs) + val result = new Array[(Int, Array[Byte])](numKVPairs) for (i <- 0 until numKVPairs) { val byteArr = new Array[Byte](valSize) ranGen.nextBytes(byteArr) @@ -64,7 +65,7 @@ object SimpleSkewedGroupByTest { // .map{case (k,v) => (k, v.size)} // .collectAsMap) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 8e4c2b6229755..4d3c34041bc17 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -20,28 +20,30 @@ package org.apache.spark.examples import java.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object SkewedGroupByTest { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("GroupBy Test") - var numMappers = if (args.length > 0) args(0).toInt else 2 - var numKVPairs = if (args.length > 1) args(1).toInt else 1000 - var valSize = if (args.length > 2) args(2).toInt else 1000 - var numReducers = if (args.length > 3) args(3).toInt else numMappers + val spark = SparkSession + .builder + .appName("GroupBy Test") + .getOrCreate() - val sc = new SparkContext(sparkConf) + val numMappers = if (args.length > 0) args(0).toInt else 2 + var numKVPairs = if (args.length > 1) args(1).toInt else 1000 + val valSize = if (args.length > 2) args(2).toInt else 1000 + val numReducers = if (args.length > 3) args(3).toInt else numMappers - val pairs1 = sc.parallelize(0 until numMappers, numMappers).flatMap { p => + val pairs1 = spark.sparkContext.parallelize(0 until numMappers, numMappers).flatMap { p => val ranGen = new Random // map output sizes linearly increase from the 1st to the last numKVPairs = (1.0 * (p + 1) / numMappers * numKVPairs).toInt - var arr1 = new Array[(Int, Array[Byte])](numKVPairs) + val arr1 = new Array[(Int, Array[Byte])](numKVPairs) for (i <- 0 until numKVPairs) { val byteArr = new Array[Byte](valSize) ranGen.nextBytes(byteArr) @@ -54,7 +56,7 @@ object SkewedGroupByTest { println(pairs1.groupByKey(numReducers).count()) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 4263680c6fde3..8a3d08f459783 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -20,13 +20,13 @@ package org.apache.spark.examples import org.apache.commons.math3.linear._ -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** * Alternating least squares matrix factorization. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.recommendation.ALS + * please refer to org.apache.spark.ml.recommendation.ALS. */ object SparkALS { @@ -81,7 +81,7 @@ object SparkALS { def showWarning() { System.err.println( """WARN: This is a naive implementation of ALS and is given as an example! - |Please use the ALS method found in org.apache.spark.mllib.recommendation + |Please use org.apache.spark.ml.recommendation.ALS |for more conventional use. """.stripMargin) } @@ -108,8 +108,12 @@ object SparkALS { println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS") - val sparkConf = new SparkConf().setAppName("SparkALS") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("SparkALS") + .getOrCreate() + + val sc = spark.sparkContext val R = generateR() @@ -135,7 +139,7 @@ object SparkALS { println() } - sc.stop() + spark.stop() } private def randomVector(n: Int): RealVector = diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 7463b868ff19b..05ac6cbcb35bc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -23,16 +23,14 @@ import java.util.Random import scala.math.exp import breeze.linalg.{DenseVector, Vector} -import org.apache.hadoop.conf.Configuration -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. + * please refer to org.apache.spark.ml.classification.LogisticRegression. */ object SparkHdfsLR { val D = 10 // Number of dimensions @@ -54,8 +52,7 @@ object SparkHdfsLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS + |Please use org.apache.spark.ml.classification.LogisticRegression |for more conventional use. """.stripMargin) } @@ -69,11 +66,14 @@ object SparkHdfsLR { showWarning() - val sparkConf = new SparkConf().setAppName("SparkHdfsLR") + val spark = SparkSession + .builder + .appName("SparkHdfsLR") + .getOrCreate() + val inputPath = args(0) - val conf = new Configuration() - val sc = new SparkContext(sparkConf) - val lines = sc.textFile(inputPath) + val lines = spark.read.textFile(inputPath).rdd + val points = lines.map(parsePoint).cache() val ITERATIONS = args(1).toInt @@ -90,7 +90,7 @@ object SparkHdfsLR { } println("Final w: " + w) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index d9f94a42b1a0b..fec3160e9f37b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -20,13 +20,13 @@ package org.apache.spark.examples import breeze.linalg.{squaredDistance, DenseVector, Vector} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * K-means clustering. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.clustering.KMeans + * please refer to org.apache.spark.ml.clustering.KMeans. */ object SparkKMeans { @@ -52,7 +52,7 @@ object SparkKMeans { def showWarning() { System.err.println( """WARN: This is a naive implementation of KMeans Clustering and is given as an example! - |Please use the KMeans method found in org.apache.spark.mllib.clustering + |Please use org.apache.spark.ml.clustering.KMeans |for more conventional use. """.stripMargin) } @@ -66,14 +66,17 @@ object SparkKMeans { showWarning() - val sparkConf = new SparkConf().setAppName("SparkKMeans") - val sc = new SparkContext(sparkConf) - val lines = sc.textFile(args(0)) + val spark = SparkSession + .builder + .appName("SparkKMeans") + .getOrCreate() + + val lines = spark.read.textFile(args(0)).rdd val data = lines.map(parseVector _).cache() val K = args(1).toInt val convergeDist = args(2).toDouble - val kPoints = data.takeSample(withReplacement = false, K, 42).toArray + val kPoints = data.takeSample(withReplacement = false, K, 42) var tempDist = 1.0 while(tempDist > convergeDist) { @@ -97,7 +100,7 @@ object SparkKMeans { println("Final centers:") kPoints.foreach(println) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index acd8656b65a69..afa8f58c96e59 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -24,15 +24,14 @@ import scala.math.exp import breeze.linalg.{DenseVector, Vector} -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** * Logistic regression based classification. * Usage: SparkLR [slices] * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. + * please refer to org.apache.spark.ml.classification.LogisticRegression. */ object SparkLR { val N = 10000 // Number of data points @@ -55,8 +54,7 @@ object SparkLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or - |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS + |Please use org.apache.spark.ml.classification.LogisticRegression |for more conventional use. """.stripMargin) } @@ -65,10 +63,13 @@ object SparkLR { showWarning() - val sparkConf = new SparkConf().setAppName("SparkLR") - val sc = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("SparkLR") + .getOrCreate() + val numSlices = if (args.length > 0) args(0).toInt else 2 - val points = sc.parallelize(generateData, numSlices).cache() + val points = spark.sparkContext.parallelize(generateData, numSlices).cache() // Initialize w to a random value var w = DenseVector.fill(D) {2 * rand.nextDouble - 1} @@ -84,7 +85,7 @@ object SparkLR { println("Final w: " + w) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 2664ddbb87d23..d0b874c48d00a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Computes the PageRank of URLs from an input file. Input file should @@ -50,10 +50,13 @@ object SparkPageRank { showWarning() - val sparkConf = new SparkConf().setAppName("PageRank") + val spark = SparkSession + .builder + .appName("SparkPageRank") + .getOrCreate() + val iters = if (args.length > 1) args(1).toInt else 10 - val ctx = new SparkContext(sparkConf) - val lines = ctx.textFile(args(0), 1) + val lines = spark.read.textFile(args(0)).rdd val links = lines.map{ s => val parts = s.split("\\s+") (parts(0), parts(1)) @@ -71,7 +74,7 @@ object SparkPageRank { val output = ranks.collect() output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + ".")) - ctx.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 818d4f2b81f82..272c1a4fc2f47 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -20,21 +20,23 @@ package org.apache.spark.examples import scala.math.random -import org.apache.spark._ +import org.apache.spark.sql.SparkSession /** Computes an approximation to pi */ object SparkPi { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("Spark Pi") - val spark = new SparkContext(conf) + val spark = SparkSession + .builder + .appName("Spark Pi") + .getOrCreate() val slices = if (args.length > 0) args(0).toInt else 2 val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow - val count = spark.parallelize(1 until n, slices).map { i => + val count = spark.sparkContext.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 if (x*x + y*y < 1) 1 else 0 }.reduce(_ + _) - println("Pi is roughly " + 4.0 * count / n) + println("Pi is roughly " + 4.0 * count / (n - 1)) spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index fc7a1f859f602..558295ab928af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -21,7 +21,7 @@ package org.apache.spark.examples import scala.collection.mutable import scala.util.Random -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SparkSession /** * Transitive closure on a graph. @@ -42,10 +42,12 @@ object SparkTC { } def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("SparkTC") - val spark = new SparkContext(sparkConf) + val spark = SparkSession + .builder + .appName("SparkTC") + .getOrCreate() val slices = if (args.length > 0) args(0).toInt else 2 - var tc = spark.parallelize(generateGraph, slices).cache() + var tc = spark.sparkContext.parallelize(generateGraph, slices).cache() // Linear transitive closure: each round grows paths by one edge, // by joining the graph's edges with the already-discovered paths. diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/AggregateMessagesExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/AggregateMessagesExample.scala new file mode 100644 index 0000000000000..8f8262db374b8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/AggregateMessagesExample.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.graphx + +// $example on$ +import org.apache.spark.graphx.{Graph, VertexRDD} +import org.apache.spark.graphx.util.GraphGenerators +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example use the [`aggregateMessages`][Graph.aggregateMessages] operator to + * compute the average age of the more senior followers of each user + * Run with + * {{{ + * bin/run-example graphx.AggregateMessagesExample + * }}} + */ +object AggregateMessagesExample { + + def main(args: Array[String]): Unit = { + // Creates a SparkSession. + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + val sc = spark.sparkContext + + // $example on$ + // Create a graph with "age" as the vertex property. + // Here we use a random graph for simplicity. + val graph: Graph[Double, Int] = + GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble ) + // Compute the number of older followers and their total age + val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)]( + triplet => { // Map Function + if (triplet.srcAttr > triplet.dstAttr) { + // Send message to destination vertex containing counter and age + triplet.sendToDst(1, triplet.srcAttr) + } + }, + // Add counter and age + (a, b) => (a._1 + b._1, a._2 + b._2) // Reduce Function + ) + // Divide total age by number of older followers to get average age of older followers + val avgAgeOfOlderFollowers: VertexRDD[Double] = + olderFollowers.mapValues( (id, value) => + value match { case (count, totalAge) => totalAge / count } ) + // Display the results + avgAgeOfOlderFollowers.collect.foreach(println(_)) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/ComprehensiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/ComprehensiveExample.scala new file mode 100644 index 0000000000000..6598863bd2ea0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/ComprehensiveExample.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.graphx + +// $example on$ +import org.apache.spark.graphx.GraphLoader +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * Suppose I want to build a graph from some text files, restrict the graph + * to important relationships and users, run page-rank on the sub-graph, and + * then finally return attributes associated with the top users. + * This example do all of this in just a few lines with GraphX. + * + * Run with + * {{{ + * bin/run-example graphx.ComprehensiveExample + * }}} + */ +object ComprehensiveExample { + + def main(args: Array[String]): Unit = { + // Creates a SparkSession. + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + val sc = spark.sparkContext + + // $example on$ + // Load my user data and parse into tuples of user id and attribute list + val users = (sc.textFile("data/graphx/users.txt") + .map(line => line.split(",")).map( parts => (parts.head.toLong, parts.tail) )) + + // Parse the edge data which is already in userId -> userId format + val followerGraph = GraphLoader.edgeListFile(sc, "data/graphx/followers.txt") + + // Attach the user attributes + val graph = followerGraph.outerJoinVertices(users) { + case (uid, deg, Some(attrList)) => attrList + // Some users may not have attributes so we set them as empty + case (uid, deg, None) => Array.empty[String] + } + + // Restrict the graph to users with usernames and names + val subgraph = graph.subgraph(vpred = (vid, attr) => attr.size == 2) + + // Compute the PageRank + val pagerankGraph = subgraph.pageRank(0.001) + + // Get the attributes of the top pagerank users + val userInfoWithPageRank = subgraph.outerJoinVertices(pagerankGraph.vertices) { + case (uid, attrList, Some(pr)) => (pr, attrList.toList) + case (uid, attrList, None) => (0.0, attrList.toList) + } + + println(userInfoWithPageRank.vertices.top(5)(Ordering.by(_._2._1)).mkString("\n")) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/ConnectedComponentsExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/ConnectedComponentsExample.scala new file mode 100644 index 0000000000000..5377ddb3594b0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/ConnectedComponentsExample.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.graphx + +// $example on$ +import org.apache.spark.graphx.GraphLoader +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * A connected components algorithm example. + * The connected components algorithm labels each connected component of the graph + * with the ID of its lowest-numbered vertex. + * For example, in a social network, connected components can approximate clusters. + * GraphX contains an implementation of the algorithm in the + * [`ConnectedComponents` object][ConnectedComponents], + * and we compute the connected components of the example social network dataset. + * + * Run with + * {{{ + * bin/run-example graphx.ConnectedComponentsExample + * }}} + */ +object ConnectedComponentsExample { + def main(args: Array[String]): Unit = { + // Creates a SparkSession. + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + val sc = spark.sparkContext + + // $example on$ + // Load the graph as in the PageRank example + val graph = GraphLoader.edgeListFile(sc, "data/graphx/followers.txt") + // Find the connected components + val cc = graph.connectedComponents().vertices + // Join the connected components with the usernames + val users = sc.textFile("data/graphx/users.txt").map { line => + val fields = line.split(",") + (fields(0).toLong, fields(1)) + } + val ccByUsername = users.join(cc).map { + case (id, (username, cc)) => (username, cc) + } + // Print the result + println(ccByUsername.collect().mkString("\n")) + // $example off$ + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/PageRankExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/PageRankExample.scala new file mode 100644 index 0000000000000..9e9affca07a18 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/PageRankExample.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.graphx + +// $example on$ +import org.apache.spark.graphx.GraphLoader +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * A PageRank example on social network dataset + * Run with + * {{{ + * bin/run-example graphx.PageRankExample + * }}} + */ +object PageRankExample { + def main(args: Array[String]): Unit = { + // Creates a SparkSession. + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + val sc = spark.sparkContext + + // $example on$ + // Load the edges as a graph + val graph = GraphLoader.edgeListFile(sc, "data/graphx/followers.txt") + // Run PageRank + val ranks = graph.pageRank(0.0001).vertices + // Join the ranks with the usernames + val users = sc.textFile("data/graphx/users.txt").map { line => + val fields = line.split(",") + (fields(0).toLong, fields(1)) + } + val ranksByUsername = users.join(ranks).map { + case (id, (username, rank)) => (username, rank) + } + // Print the result + println(ranksByUsername.collect().mkString("\n")) + // $example off$ + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SSSPExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SSSPExample.scala new file mode 100644 index 0000000000000..5e8b19671de7a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SSSPExample.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.graphx + +// $example on$ +import org.apache.spark.graphx.{Graph, VertexId} +import org.apache.spark.graphx.util.GraphGenerators +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example use the Pregel operator to express computation + * such as single source shortest path + * Run with + * {{{ + * bin/run-example graphx.SSSPExample + * }}} + */ +object SSSPExample { + def main(args: Array[String]): Unit = { + // Creates a SparkSession. + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + val sc = spark.sparkContext + + // $example on$ + // A graph with edge attributes containing distances + val graph: Graph[Long, Double] = + GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble) + val sourceId: VertexId = 42 // The ultimate source + // Initialize the graph such that all vertices except the root have distance infinity. + val initialGraph = graph.mapVertices((id, _) => + if (id == sourceId) 0.0 else Double.PositiveInfinity) + val sssp = initialGraph.pregel(Double.PositiveInfinity)( + (id, dist, newDist) => math.min(dist, newDist), // Vertex Program + triplet => { // Send Message + if (triplet.srcAttr + triplet.attr < triplet.dstAttr) { + Iterator((triplet.dstId, triplet.srcAttr + triplet.attr)) + } else { + Iterator.empty + } + }, + (a, b) => math.min(a, b) // Merge Message + ) + println(sssp.vertices.collect.mkString("\n")) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/TriangleCountingExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/TriangleCountingExample.scala new file mode 100644 index 0000000000000..b9bff69086cc1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/TriangleCountingExample.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.graphx + +// $example on$ +import org.apache.spark.graphx.{GraphLoader, PartitionStrategy} +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * A vertex is part of a triangle when it has two adjacent vertices with an edge between them. + * GraphX implements a triangle counting algorithm in the [`TriangleCount` object][TriangleCount] + * that determines the number of triangles passing through each vertex, + * providing a measure of clustering. + * We compute the triangle count of the social network dataset. + * + * Note that `TriangleCount` requires the edges to be in canonical orientation (`srcId < dstId`) + * and the graph to be partitioned using [`Graph.partitionBy`][Graph.partitionBy]. + * + * Run with + * {{{ + * bin/run-example graphx.TriangleCountingExample + * }}} + */ +object TriangleCountingExample { + def main(args: Array[String]): Unit = { + // Creates a SparkSession. + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + val sc = spark.sparkContext + + // $example on$ + // Load the edges in canonical order and partition the graph for triangle count + val graph = GraphLoader.edgeListFile(sc, "data/graphx/followers.txt", true) + .partitionBy(PartitionStrategy.RandomVertexCut) + // Find the triangle count for each vertex + val triCounts = graph.triangleCount().vertices + // Join the triangle counts with the usernames + val users = sc.textFile("data/graphx/users.txt").map { line => + val fields = line.split(",") + (fields(0).toLong, fields(1)) + } + val triCountByUsername = users.join(triCounts).map { case (id, (username, tc)) => + (username, tc) + } + // Print the result + println(triCountByUsername.collect().mkString("\n")) + // $example off$ + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala index 21f58ddf3cfb7..b6d7b369162db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -18,25 +18,29 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.regression.AFTSurvivalRegression -import org.apache.spark.mllib.linalg.Vectors // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession /** - * An example for AFTSurvivalRegression. + * An example demonstrating AFTSurvivalRegression. + * Run with + * {{{ + * bin/run-example ml.AFTSurvivalRegressionExample + * }}} */ object AFTSurvivalRegressionExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("AFTSurvivalRegressionExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("AFTSurvivalRegressionExample") + .getOrCreate() // $example on$ - val training = sqlContext.createDataFrame(Seq( + val training = spark.createDataFrame(Seq( (1.218, 1.0, Vectors.dense(1.560, -0.605)), (2.949, 0.0, Vectors.dense(0.346, 2.158)), (3.627, 0.0, Vectors.dense(1.380, 0.231)), @@ -56,7 +60,7 @@ object AFTSurvivalRegressionExample { model.transform(training).show(false) // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index a79e15c767e1f..bb5d163608494 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -18,39 +18,40 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.recommendation.ALS // $example off$ -import org.apache.spark.sql.SQLContext -// $example on$ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType -// $example off$ +import org.apache.spark.sql.SparkSession +/** + * An example demonstrating ALS. + * Run with + * {{{ + * bin/run-example ml.ALSExample + * }}} + */ object ALSExample { // $example on$ case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long) - object Rating { - def parseRating(str: String): Rating = { - val fields = str.split("::") - assert(fields.size == 4) - Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) - } + def parseRating(str: String): Rating = { + val fields = str.split("::") + assert(fields.size == 4) + Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong) } // $example off$ def main(args: Array[String]) { - val conf = new SparkConf().setAppName("ALSExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession + .builder + .appName("ALSExample") + .getOrCreate() + import spark.implicits._ // $example on$ - val ratings = sc.textFile("data/mllib/als/sample_movielens_ratings.txt") - .map(Rating.parseRating) + val ratings = spark.read.textFile("data/mllib/als/sample_movielens_ratings.txt") + .map(parseRating) .toDF() val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2)) @@ -65,8 +66,6 @@ object ALSExample { // Evaluate the model by computing the RMSE on the test data val predictions = model.transform(test) - .withColumn("rating", col("rating").cast(DoubleType)) - .withColumn("prediction", col("prediction").cast(DoubleType)) val evaluator = new RegressionEvaluator() .setMetricName("rmse") @@ -75,7 +74,8 @@ object ALSExample { val rmse = evaluator.evaluate(predictions) println(s"Root-mean-square error = $rmse") // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala index 2ed8101c133cf..5cd13ad64ca44 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Binarizer // $example off$ -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{SparkSession} object BinarizerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("BinarizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("BinarizerExample") + .getOrCreate() // $example on$ val data = Array((0, 0.1), (1, 0.8), (2, 0.2)) - val dataFrame: DataFrame = sqlContext.createDataFrame(data).toDF("label", "feature") + val dataFrame = spark.createDataFrame(data).toDF("label", "feature") val binarizer: Binarizer = new Binarizer() .setInputCol("feature") @@ -42,7 +42,8 @@ object BinarizerExample { val binarizedFeatures = binarizedDataFrame.select("binarized_feature") binarizedFeatures.collect().foreach(println) // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BisectingKMeansExample.scala new file mode 100644 index 0000000000000..5f8f2c99cbaf4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BisectingKMeansExample.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println + +// $example on$ +import org.apache.spark.ml.clustering.BisectingKMeans +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating bisecting k-means clustering. + * Run with + * {{{ + * bin/run-example ml.BisectingKMeansExample + * }}} + */ +object BisectingKMeansExample { + + def main(args: Array[String]): Unit = { + // Creates a SparkSession + val spark = SparkSession + .builder + .appName("BisectingKMeansExample") + .getOrCreate() + + // $example on$ + // Loads data. + val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") + + // Trains a bisecting k-means model. + val bkm = new BisectingKMeans().setK(2).setSeed(1) + val model = bkm.fit(dataset) + + // Evaluate clustering. + val cost = model.computeCost(dataset) + println(s"Within Set Sum of Squared Errors = $cost") + + // Shows the result. + println("Cluster Centers: ") + val centers = model.clusterCenters + centers.foreach(println) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala index 6f6236a2b0588..38cce34bb5091 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -18,23 +18,23 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Bucketizer // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object BucketizerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("BucketizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("BucketizerExample") + .getOrCreate() // $example on$ val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) val data = Array(-0.5, -0.3, 0.0, 0.2) - val dataFrame = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val dataFrame = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") val bucketizer = new Bucketizer() .setInputCol("features") @@ -45,7 +45,7 @@ object BucketizerExample { val bucketedData = bucketizer.transform(dataFrame) bucketedData.show() // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala index 2be61537e613a..c9394dd9c64b8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala @@ -18,20 +18,19 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.ChiSqSelector -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object ChiSqSelectorExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("ChiSqSelectorExample") - val sc = new SparkContext(conf) - - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession + .builder + .appName("ChiSqSelectorExample") + .getOrCreate() + import spark.implicits._ // $example on$ val data = Seq( @@ -40,7 +39,7 @@ object ChiSqSelectorExample { (9, Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) ) - val df = sc.parallelize(data).toDF("id", "features", "clicked") + val df = spark.createDataset(data).toDF("id", "features", "clicked") val selector = new ChiSqSelector() .setNumTopFeatures(1) @@ -51,7 +50,7 @@ object ChiSqSelectorExample { val result = selector.fit(df).transform(df) result.show() // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala index 7d07fc7dd113a..988d8941a4ce7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object CountVectorizerExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("CounterVectorizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("CountVectorizerExample") + .getOrCreate() // $example on$ - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, Array("a", "b", "c")), (1, Array("a", "b", "b", "c", "a")) )).toDF("id", "words") @@ -51,6 +51,8 @@ object CountVectorizerExample { cvModel.transform(df).select("features").show() // $example off$ + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala index dc26b55a768a7..ddc671752872b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -18,18 +18,18 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.DCT -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object DCTExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("DCTExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("DCTExample") + .getOrCreate() // $example on$ val data = Seq( @@ -37,7 +37,7 @@ object DCTExample { Vectors.dense(-1.0, 2.0, 4.0, -7.0), Vectors.dense(14.0, -2.0, -5.0, 1.0)) - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") val dct = new DCT() .setInputCol("features") @@ -47,7 +47,8 @@ object DCTExample { val dctDf = dct.transform(df) dctDf.select("featuresDCT").show(3) // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 7e608a281203e..38c1c1c1865b0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -20,14 +20,14 @@ package org.apache.spark.examples.ml import java.io.File -import com.google.common.io.Files import scopt.OptionParser -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.util.Utils /** * An example of how to use [[org.apache.spark.sql.DataFrame]] for ML. Run with @@ -62,14 +62,14 @@ object DataFrameExample { } def run(params: Params) { - - val conf = new SparkConf().setAppName(s"DataFrameExample with $params") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName(s"DataFrameExample with $params") + .getOrCreate() // Load input data println(s"Loading LIBSVM file with UDT from ${params.input}.") - val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() + val df: DataFrame = spark.read.format("libsvm").load(params.input).cache() println("Schema from LIBSVM:") df.printSchema() println(s"Loaded training data as a DataFrame with ${df.count()} records.") @@ -81,24 +81,23 @@ object DataFrameExample { // Convert features column to an RDD of vectors. val features = df.select("features").rdd.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( - (summary, feat) => summary.add(feat), + (summary, feat) => summary.add(Vectors.fromML(feat)), (sum1, sum2) => sum1.merge(sum2)) println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") // Save the records in a parquet file. - val tmpDir = Files.createTempDir() - tmpDir.deleteOnExit() + val tmpDir = Utils.createTempDir() val outputDir = new File(tmpDir, "dataframe").toString println(s"Saving to $outputDir as Parquet file.") df.write.parquet(outputDir) // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") - val newDF = sqlContext.read.parquet(outputDir) + val newDF = spark.read.parquet(outputDir) println(s"Schema from Parquet:") newDF.printSchema() - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index 224d8da5f0ec3..bc6d3275933ea 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.DecisionTreeClassificationModel @@ -26,16 +25,17 @@ import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object DecisionTreeClassificationExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("DecisionTreeClassificationExample") + .getOrCreate() // $example on$ // Load the data stored in LIBSVM format as a DataFrame. - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -47,10 +47,10 @@ object DecisionTreeClassificationExample { val featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous. .fit(data) - // Split the data into training and test sets (30% held out for testing) + // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a DecisionTree model. @@ -64,11 +64,11 @@ object DecisionTreeClassificationExample { .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels) - // Chain indexers and tree in a Pipeline + // Chain indexers and tree in a Pipeline. val pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) - // Train model. This also runs the indexers. + // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // Make predictions. @@ -77,17 +77,19 @@ object DecisionTreeClassificationExample { // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(5) - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") - .setMetricName("precision") + .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] println("Learned classification tree model:\n" + treeModel.toDebugString) // $example off$ + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index d2560cc00ba07..de4474555d2d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -23,24 +23,23 @@ import scala.language.reflectiveCalls import scopt.OptionParser -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer} import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.{MulticlassMetrics, RegressionMetrics} -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} /** * An example runner for decision trees. Run with * {{{ * ./bin/run-example ml.DecisionTreeExample [options] * }}} - * Note that Decision Trees can take a large amount of memory. If the run-example command above + * Note that Decision Trees can take a large amount of memory. If the run-example command above * fails, try running via spark-submit and specifying the amount of memory as at least 1g. * For local mode, run * {{{ @@ -87,7 +86,7 @@ object DecisionTreeExample { .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") .action((x, c) => c.copy(minInfoGain = x)) opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing. If given option testInput, " + + .text(s"fraction of data to hold out for testing. If given option testInput, " + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) opt[Boolean]("cacheNodeIds") @@ -106,7 +105,7 @@ object DecisionTreeExample { s"default: ${defaultParams.checkpointInterval}") .action((x, c) => c.copy(checkpointInterval = x)) opt[String]("testInput") - .text(s"input path to test dataset. If given, option fracTest is ignored." + + .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") .action((x, c) => c.copy(testInput = x)) opt[String]("dataFormat") @@ -134,18 +133,18 @@ object DecisionTreeExample { /** Load a dataset from the given path, using the given format */ private[ml] def loadData( - sqlContext: SQLContext, + spark: SparkSession, path: String, format: String, expectedNumFeatures: Option[Int] = None): DataFrame = { - import sqlContext.implicits._ + import spark.implicits._ format match { - case "dense" => MLUtils.loadLabeledPoints(sqlContext.sparkContext, path).toDF() + case "dense" => MLUtils.loadLabeledPoints(spark.sparkContext, path).toDF() case "libsvm" => expectedNumFeatures match { - case Some(numFeatures) => sqlContext.read.option("numFeatures", numFeatures.toString) + case Some(numFeatures) => spark.read.option("numFeatures", numFeatures.toString) .format("libsvm").load(path) - case None => sqlContext.read.format("libsvm").load(path) + case None => spark.read.format("libsvm").load(path) } case _ => throw new IllegalArgumentException(s"Bad data format: $format") } @@ -157,27 +156,28 @@ object DecisionTreeExample { * @param dataFormat "libsvm" or "dense" * @param testInput Path to test dataset. * @param algo Classification or Regression - * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. * @return (training dataset, test dataset) */ private[ml] def loadDatasets( - sc: SparkContext, input: String, dataFormat: String, testInput: String, algo: String, fracTest: Double): (DataFrame, DataFrame) = { - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .getOrCreate() // Load training data - val origExamples: DataFrame = loadData(sqlContext, input, dataFormat) + val origExamples: DataFrame = loadData(spark, input, dataFormat) // Load or create test set val dataframes: Array[DataFrame] = if (testInput != "") { // Load testInput. val numFeatures = origExamples.first().getAs[Vector](1).size val origTestExamples: DataFrame = - loadData(sqlContext, testInput, dataFormat, Some(numFeatures)) + loadData(spark, testInput, dataFormat, Some(numFeatures)) Array(origExamples, origTestExamples) } else { // Split input into training, test. @@ -198,18 +198,21 @@ object DecisionTreeExample { } def run(params: Params) { - val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params") - val sc = new SparkContext(conf) - params.checkpointDir.foreach(sc.setCheckpointDir) + val spark = SparkSession + .builder + .appName(s"DecisionTreeExample with $params") + .getOrCreate() + + params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) val algo = params.algo.toLowerCase println(s"DecisionTreeExample with parameters:\n$params") // Load training and test data and cache it. val (training: DataFrame, test: DataFrame) = - loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest) + loadDatasets(params.input, params.dataFormat, params.testInput, algo, params.fracTest) - // Set up Pipeline + // Set up Pipeline. val stages = new mutable.ArrayBuffer[PipelineStage]() // (1) For classification, re-index classes. val labelColName = if (algo == "classification") "indexedLabel" else "label" @@ -226,7 +229,7 @@ object DecisionTreeExample { .setOutputCol("indexedFeatures") .setMaxCategories(10) stages += featuresIndexer - // (3) Learn Decision Tree + // (3) Learn Decision Tree. val dt = algo match { case "classification" => new DecisionTreeClassifier() @@ -253,13 +256,13 @@ object DecisionTreeExample { stages += dt val pipeline = new Pipeline().setStages(stages.toArray) - // Fit the Pipeline + // Fit the Pipeline. val startTime = System.nanoTime() val pipelineModel = pipeline.fit(training) val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") - // Get the trained Decision Tree from the fitted PipelineModel + // Get the trained Decision Tree from the fitted PipelineModel. algo match { case "classification" => val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel] @@ -278,7 +281,7 @@ object DecisionTreeExample { case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } - // Evaluate model on training, test data + // Evaluate model on training, test data. algo match { case "classification" => println("Training data results:") @@ -294,11 +297,11 @@ object DecisionTreeExample { throw new IllegalArgumentException("Algo ${params.algo} not supported.") } - sc.stop() + spark.stop() } /** - * Evaluate the given ClassificationModel on data. Print the results. + * Evaluate the given ClassificationModel on data. Print the results. * @param model Must fit ClassificationModel abstraction * @param data DataFrame with "prediction" and labelColName columns * @param labelColName Name of the labelCol parameter for the model @@ -312,18 +315,18 @@ object DecisionTreeExample { val fullPredictions = model.transform(data).cache() val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) - // Print number of classes for reference + // Print number of classes for reference. val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { case Some(n) => n case None => throw new RuntimeException( "Unknown failure when indexing labels for classification.") } - val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision + val accuracy = new MulticlassMetrics(predictions.zip(labels)).accuracy println(s" Accuracy ($numClasses classes): $accuracy") } /** - * Evaluate the given RegressionModel on data. Print the results. + * Evaluate the given RegressionModel on data. Print the results. * @param model Must fit RegressionModel abstraction * @param data DataFrame with "prediction" and labelColName columns * @param labelColName Name of the labelCol parameter for the model diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ad32e5635a3ea..ee61200ad1d0c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.evaluation.RegressionEvaluator @@ -26,17 +25,18 @@ import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.regression.DecisionTreeRegressor // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object DecisionTreeRegressionExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("DecisionTreeRegressionExample") + .getOrCreate() // $example on$ // Load the data stored in LIBSVM format as a DataFrame. - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Automatically identify categorical features, and index them. // Here, we treat features with > 4 distinct values as continuous. @@ -46,7 +46,7 @@ object DecisionTreeRegressionExample { .setMaxCategories(4) .fit(data) - // Split the data into training and test sets (30% held out for testing) + // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a DecisionTree model. @@ -54,11 +54,11 @@ object DecisionTreeRegressionExample { .setLabelCol("label") .setFeaturesCol("indexedFeatures") - // Chain indexer and tree in a Pipeline + // Chain indexer and tree in a Pipeline. val pipeline = new Pipeline() .setStages(Array(featureIndexer, dt)) - // Train model. This also runs the indexer. + // Train model. This also runs the indexer. val model = pipeline.fit(trainingData) // Make predictions. @@ -67,7 +67,7 @@ object DecisionTreeRegressionExample { // Select example rows to display. predictions.select("prediction", "label", "features").show(5) - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. val evaluator = new RegressionEvaluator() .setLabelCol("label") .setPredictionCol("prediction") @@ -78,6 +78,8 @@ object DecisionTreeRegressionExample { val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] println("Learned regression tree model:\n" + treeModel.toDebugString) // $example off$ + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 8d127f9b35420..d94d837d10e96 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} +import org.apache.spark.sql.{Dataset, Row, SparkSession} /** * A simple example demonstrating how to write your own learning algorithm using Estimator, @@ -38,19 +37,20 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} object DeveloperApiExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("DeveloperApiExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession + .builder + .appName("DeveloperApiExample") + .getOrCreate() + import spark.implicits._ // Prepare training data. - val training = sc.parallelize(Seq( + val training = spark.createDataFrame(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) - // Create a LogisticRegression instance. This instance is an Estimator. + // Create a LogisticRegression instance. This instance is an Estimator. val lr = new MyLogisticRegression() // Print out the parameters, documentation, and any default values. println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") @@ -58,17 +58,17 @@ object DeveloperApiExample { // We may set parameters using setter methods. lr.setMaxIter(10) - // Learn a LogisticRegression model. This uses the parameters stored in lr. + // Learn a LogisticRegression model. This uses the parameters stored in lr. val model = lr.fit(training.toDF()) // Prepare test data. - val test = sc.parallelize(Seq( + val test = spark.createDataFrame(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) // Make predictions on test data. - val sumPredictions: Double = model.transform(test.toDF()) + val sumPredictions: Double = model.transform(test) .select("features", "label", "prediction") .collect() .map { case Row(features: Vector, label: Double, prediction: Double) => @@ -77,14 +77,14 @@ object DeveloperApiExample { assert(sumPredictions == 0.0, "MyLogisticRegression predicted something other than 0, even though all coefficients are 0!") - sc.stop() + spark.stop() } } /** * Example of defining a parameter trait for a user-defined type of [[Classifier]]. * - * NOTE: This is private since it is an example. In practice, you may not want it to be private. + * NOTE: This is private since it is an example. In practice, you may not want it to be private. */ private trait MyLogisticRegressionParams extends ClassifierParams { @@ -96,7 +96,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams { * - def getMyParamName * - def setMyParamName * Here, we have a trait to be mixed in with the Estimator and Model (MyLogisticRegression - * and MyLogisticRegressionModel). We place the setter (setMaxIter) method in the Estimator + * and MyLogisticRegressionModel). We place the setter (setMaxIter) method in the Estimator * class since the maxIter parameter is only used during training (not in the Model). */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") @@ -106,7 +106,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams { /** * Example of defining a type of [[Classifier]]. * - * NOTE: This is private since it is an example. In practice, you may not want it to be private. + * NOTE: This is private since it is an example. In practice, you may not want it to be private. */ private class MyLogisticRegression(override val uid: String) extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel] @@ -138,7 +138,7 @@ private class MyLogisticRegression(override val uid: String) /** * Example of defining a type of [[ClassificationModel]]. * - * NOTE: This is private since it is an example. In practice, you may not want it to be private. + * NOTE: This is private since it is an example. In practice, you may not want it to be private. */ private class MyLogisticRegressionModel( override val uid: String, @@ -169,7 +169,7 @@ private class MyLogisticRegressionModel( Vectors.dense(-margin, margin) } - /** Number of classes the label can take. 2 indicates binary classification. */ + /** Number of classes the label can take. 2 indicates binary classification. */ override val numClasses: Int = 2 /** Number of features the model was trained on. */ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala index 629d322c4357f..c0ffc01934b6f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala @@ -18,22 +18,22 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.ElementwiseProduct -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object ElementwiseProductExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("ElementwiseProductExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("ElementwiseProductExample") + .getOrCreate() // $example on$ // Create some vector data; also works for sparse vectors - val dataFrame = sqlContext.createDataFrame(Seq( + val dataFrame = spark.createDataFrame(Seq( ("a", Vectors.dense(1.0, 2.0, 3.0)), ("b", Vectors.dense(4.0, 5.0, 6.0)))).toDF("id", "vector") @@ -46,7 +46,8 @@ object ElementwiseProductExample { // Batch transform the vectors to create new column: transformer.transform(dataFrame).show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala index 65e3c365abb3f..f18d86e1a6921 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -18,32 +18,32 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.sql.Row // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object EstimatorTransformerParamExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("EstimatorTransformerParamExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("EstimatorTransformerParamExample") + .getOrCreate() // $example on$ // Prepare training data from a list of (label, features) tuples. - val training = sqlContext.createDataFrame(Seq( + val training = spark.createDataFrame(Seq( (1.0, Vectors.dense(0.0, 1.1, 0.1)), (0.0, Vectors.dense(2.0, 1.0, -1.0)), (0.0, Vectors.dense(2.0, 1.3, 1.0)), (1.0, Vectors.dense(0.0, 1.2, -0.5)) )).toDF("label", "features") - // Create a LogisticRegression instance. This instance is an Estimator. + // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") @@ -52,7 +52,7 @@ object EstimatorTransformerParamExample { lr.setMaxIter(10) .setRegParam(0.01) - // Learn a LogisticRegression model. This uses the parameters stored in lr. + // Learn a LogisticRegression model. This uses the parameters stored in lr. val model1 = lr.fit(training) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). @@ -63,11 +63,11 @@ object EstimatorTransformerParamExample { // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) - .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + .put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. .put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. - val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name + val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name. val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. @@ -76,7 +76,7 @@ object EstimatorTransformerParamExample { println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) // Prepare test data. - val test = sqlContext.createDataFrame(Seq( + val test = spark.createDataFrame(Seq( (1.0, Vectors.dense(-1.0, 1.5, 1.3)), (0.0, Vectors.dense(3.0, 2.0, -0.1)), (1.0, Vectors.dense(0.0, 2.2, -1.5)) @@ -94,7 +94,7 @@ object EstimatorTransformerParamExample { } // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala index 6b0be0f34e196..a4274ae95405e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -23,13 +23,12 @@ import scala.language.reflectiveCalls import scopt.OptionParser -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.{Pipeline, PipelineStage} import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, SparkSession} /** @@ -37,7 +36,7 @@ import org.apache.spark.sql.DataFrame * {{{ * ./bin/run-example ml.GBTExample [options] * }}} - * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * Decision Trees and ensembles can take a large amount of memory. If the run-example command * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. * For local mode, run * {{{ @@ -88,7 +87,7 @@ object GBTExample { .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}") .action((x, c) => c.copy(maxIter = x)) opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing. If given option testInput, " + + .text(s"fraction of data to hold out for testing. If given option testInput, " + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) opt[Boolean]("cacheNodeIds") @@ -109,7 +108,7 @@ object GBTExample { s"default: ${defaultParams.checkpointInterval}") .action((x, c) => c.copy(checkpointInterval = x)) opt[String]("testInput") - .text(s"input path to test dataset. If given, option fracTest is ignored." + + .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") .action((x, c) => c.copy(testInput = x)) opt[String]("dataFormat") @@ -136,15 +135,18 @@ object GBTExample { } def run(params: Params) { - val conf = new SparkConf().setAppName(s"GBTExample with $params") - val sc = new SparkContext(conf) - params.checkpointDir.foreach(sc.setCheckpointDir) + val spark = SparkSession + .builder + .appName(s"GBTExample with $params") + .getOrCreate() + + params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) val algo = params.algo.toLowerCase println(s"GBTExample with parameters:\n$params") // Load training and test data and cache it. - val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(params.input, params.dataFormat, params.testInput, algo, params.fracTest) // Set up Pipeline @@ -164,7 +166,7 @@ object GBTExample { .setOutputCol("indexedFeatures") .setMaxCategories(10) stages += featuresIndexer - // (3) Learn GBT + // (3) Learn GBT. val dt = algo match { case "classification" => new GBTClassifier() @@ -193,13 +195,13 @@ object GBTExample { stages += dt val pipeline = new Pipeline().setStages(stages.toArray) - // Fit the Pipeline + // Fit the Pipeline. val startTime = System.nanoTime() val pipelineModel = pipeline.fit(training) val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") - // Get the trained GBT from the fitted PipelineModel + // Get the trained GBT from the fitted PipelineModel. algo match { case "classification" => val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel] @@ -218,7 +220,7 @@ object GBTExample { case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } - // Evaluate model on training, test data + // Evaluate model on training, test data. algo match { case "classification" => println("Training data results:") @@ -234,7 +236,7 @@ object GBTExample { throw new IllegalArgumentException("Algo ${params.algo} not supported.") } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala new file mode 100644 index 0000000000000..2c2bf421bc5d3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println + +// $example on$ +import org.apache.spark.ml.clustering.GaussianMixture +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating Gaussian Mixture Model (GMM). + * Run with + * {{{ + * bin/run-example ml.GaussianMixtureExample + * }}} + */ +object GaussianMixtureExample { + def main(args: Array[String]): Unit = { + // Creates a SparkSession + val spark = SparkSession.builder.appName(s"${this.getClass.getSimpleName}").getOrCreate() + + // $example on$ + // Loads data + val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") + + // Trains Gaussian Mixture Model + val gmm = new GaussianMixture() + .setK(2) + val model = gmm.fit(dataset) + + // output parameters of mixture model model + for (i <- 0 until model.getK) { + println("weight=%f\nmu=%s\nsigma=\n%s\n" format + (model.weights(i), model.gaussians(i).mean, model.gaussians(i).cov)) + } + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala new file mode 100644 index 0000000000000..1b86d7cad0b38 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.regression.GeneralizedLinearRegression +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating generalized linear regression. + * Run with + * {{{ + * bin/run-example ml.GeneralizedLinearRegressionExample + * }}} + */ + +object GeneralizedLinearRegressionExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("GeneralizedLinearRegressionExample") + .getOrCreate() + + // $example on$ + // Load training data + val dataset = spark.read.format("libsvm") + .load("data/mllib/sample_linear_regression_data.txt") + + val glr = new GeneralizedLinearRegression() + .setFamily("gaussian") + .setLink("identity") + .setMaxIter(10) + .setRegParam(0.3) + + // Fit the model + val model = glr.fit(dataset) + + // Print the coefficients and intercept for generalized linear regression model + println(s"Coefficients: ${model.coefficients}") + println(s"Intercept: ${model.intercept}") + + // Summarize the model over the training set and print out some metrics + val summary = model.summary + println(s"Coefficient Standard Errors: ${summary.coefficientStandardErrors.mkString(",")}") + println(s"T Values: ${summary.tValues.mkString(",")}") + println(s"P Values: ${summary.pValues.mkString(",")}") + println(s"Dispersion: ${summary.dispersion}") + println(s"Null Deviance: ${summary.nullDeviance}") + println(s"Residual Degree Of Freedom Null: ${summary.residualDegreeOfFreedomNull}") + println(s"Deviance: ${summary.deviance}") + println(s"Residual Degree Of Freedom: ${summary.residualDegreeOfFreedom}") + println(s"AIC: ${summary.aic}") + println("Deviance Residuals: ") + summary.residuals().show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index cd62a803820cf..9a39acfbf37e5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -18,24 +18,24 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object GradientBoostedTreeClassifierExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("GradientBoostedTreeClassifierExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("GradientBoostedTreeClassifierExample") + .getOrCreate() // $example on$ // Load and parse the data file, converting it to a DataFrame. - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -51,7 +51,7 @@ object GradientBoostedTreeClassifierExample { .setMaxCategories(4) .fit(data) - // Split the data into training and test sets (30% held out for testing) + // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a GBT model. @@ -66,11 +66,11 @@ object GradientBoostedTreeClassifierExample { .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels) - // Chain indexers and GBT in a Pipeline + // Chain indexers and GBT in a Pipeline. val pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) - // Train model. This also runs the indexers. + // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // Make predictions. @@ -79,11 +79,11 @@ object GradientBoostedTreeClassifierExample { // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(5) - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") - .setMetricName("precision") + .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) @@ -91,7 +91,7 @@ object GradientBoostedTreeClassifierExample { println("Learned classification GBT model:\n" + gbtModel.toDebugString) // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index b8cf9629bbdab..e53aab7f326d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -18,24 +18,24 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object GradientBoostedTreeRegressorExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("GradientBoostedTreeRegressorExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("GradientBoostedTreeRegressorExample") + .getOrCreate() // $example on$ // Load and parse the data file, converting it to a DataFrame. - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -45,7 +45,7 @@ object GradientBoostedTreeRegressorExample { .setMaxCategories(4) .fit(data) - // Split the data into training and test sets (30% held out for testing) + // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a GBT model. @@ -54,11 +54,11 @@ object GradientBoostedTreeRegressorExample { .setFeaturesCol("indexedFeatures") .setMaxIter(10) - // Chain indexer and GBT in a Pipeline + // Chain indexer and GBT in a Pipeline. val pipeline = new Pipeline() .setStages(Array(featureIndexer, gbt)) - // Train model. This also runs the indexer. + // Train model. This also runs the indexer. val model = pipeline.fit(trainingData) // Make predictions. @@ -67,7 +67,7 @@ object GradientBoostedTreeRegressorExample { // Select example rows to display. predictions.select("prediction", "label", "features").show(5) - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. val evaluator = new RegressionEvaluator() .setLabelCol("label") .setPredictionCol("prediction") @@ -79,7 +79,7 @@ object GradientBoostedTreeRegressorExample { println("Learned regression GBT model:\n" + gbtModel.toDebugString) // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala index 4cea09ba12656..950733831c3d5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala @@ -18,21 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{IndexToString, StringIndexer} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object IndexToStringExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("IndexToStringExample") - val sc = new SparkContext(conf) - - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession + .builder + .appName("IndexToStringExample") + .getOrCreate() // $example on$ - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), @@ -54,7 +53,8 @@ object IndexToStringExample { val converted = converter.transform(indexed) converted.select("id", "originalCategory").show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IsotonicRegressionExample.scala new file mode 100644 index 0000000000000..7c5d3f23411f0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IsotonicRegressionExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.regression.IsotonicRegression +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating Isotonic Regression. + * Run with + * {{{ + * bin/run-example ml.IsotonicRegressionExample + * }}} + */ +object IsotonicRegressionExample { + + def main(args: Array[String]): Unit = { + + // Creates a SparkSession. + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + + // $example on$ + // Loads data. + val dataset = spark.read.format("libsvm") + .load("data/mllib/sample_isotonic_regression_libsvm_data.txt") + + // Trains an isotonic regression model. + val ir = new IsotonicRegression() + val model = ir.fit(dataset) + + println(s"Boundaries in increasing order: ${model.boundaries}") + println(s"Predictions associated with the boundaries: ${model.predictions}") + + // Makes predictions. + model.transform(dataset).show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala index 7af011571f76e..2341b36db2400 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -19,15 +19,13 @@ package org.apache.spark.examples.ml // scalastyle:off println -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.clustering.KMeans -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.{DataFrame, SQLContext} // $example off$ +import org.apache.spark.sql.SparkSession /** - * An example demonstrating a k-means clustering. + * An example demonstrating k-means clustering. * Run with * {{{ * bin/run-example ml.KMeansExample @@ -36,35 +34,30 @@ import org.apache.spark.sql.{DataFrame, SQLContext} object KMeansExample { def main(args: Array[String]): Unit = { - // Creates a Spark context and a SQL context - val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + // Creates a SparkSession. + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() // $example on$ - // Crates a DataFrame - val dataset: DataFrame = sqlContext.createDataFrame(Seq( - (1, Vectors.dense(0.0, 0.0, 0.0)), - (2, Vectors.dense(0.1, 0.1, 0.1)), - (3, Vectors.dense(0.2, 0.2, 0.2)), - (4, Vectors.dense(9.0, 9.0, 9.0)), - (5, Vectors.dense(9.1, 9.1, 9.1)), - (6, Vectors.dense(9.2, 9.2, 9.2)) - )).toDF("id", "features") + // Loads data. + val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") - // Trains a k-means model - val kmeans = new KMeans() - .setK(2) - .setFeaturesCol("features") - .setPredictionCol("prediction") + // Trains a k-means model. + val kmeans = new KMeans().setK(2).setSeed(1L) val model = kmeans.fit(dataset) - // Shows the result - println("Final Centers: ") + // Evaluate clustering by computing Within Set Sum of Squared Errors. + val WSSSE = model.computeCost(dataset) + println(s"Within Set Sum of Squared Errors = $WSSSE") + + // Shows the result. + println("Cluster Centers: ") model.clusterCenters.foreach(println) // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala index f9ddac77090ec..22b3b0e3ad9c1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -18,60 +18,51 @@ package org.apache.spark.examples.ml // scalastyle:off println -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.clustering.LDA -import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} // $example off$ +import org.apache.spark.sql.SparkSession /** - * An example demonstrating a LDA of ML pipeline. + * An example demonstrating LDA. * Run with * {{{ * bin/run-example ml.LDAExample * }}} */ object LDAExample { - - final val FEATURES_COL = "features" - def main(args: Array[String]): Unit = { - - val input = "data/mllib/sample_lda_data.txt" - // Creates a Spark context and a SQL context - val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + // Creates a SparkSession + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() // $example on$ - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) + // Loads data. + val dataset = spark.read.format("libsvm") + .load("data/mllib/sample_lda_libsvm_data.txt") - // Trains a LDA model - val lda = new LDA() - .setK(10) - .setMaxIter(10) - .setFeaturesCol(FEATURES_COL) + // Trains a LDA model. + val lda = new LDA().setK(10).setMaxIter(10) val model = lda.fit(dataset) - val transformed = model.transform(dataset) val ll = model.logLikelihood(dataset) val lp = model.logPerplexity(dataset) + println(s"The lower bound on the log likelihood of the entire corpus: $ll") + println(s"The upper bound bound on perplexity: $lp") - // describeTopics + // Describe topics. val topics = model.describeTopics(3) - - // Shows the result + println("The topics described by their top-weighted terms:") topics.show(false) - transformed.show(false) + // Shows the result. + val transformed = model.transform(dataset) + transformed.show(false) // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala index 25be87811da90..de96fb2979ad1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -22,10 +22,9 @@ import scala.language.reflectiveCalls import scopt.OptionParser -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, SparkSession} /** * An example runner for linear regression with elastic-net (mixing L1/L2) regularization. @@ -74,11 +73,11 @@ object LinearRegressionExample { s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") .action((x, c) => c.copy(tol = x)) opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing. If given option testInput, " + + .text(s"fraction of data to hold out for testing. If given option testInput, " + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) opt[String]("testInput") - .text(s"input path to test dataset. If given, option fracTest is ignored." + + .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") .action((x, c) => c.copy(testInput = x)) opt[String]("dataFormat") @@ -105,13 +104,15 @@ object LinearRegressionExample { } def run(params: Params) { - val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params") - val sc = new SparkContext(conf) + val spark = SparkSession + .builder + .appName(s"LinearRegressionExample with $params") + .getOrCreate() println(s"LinearRegressionExample with parameters:\n$params") // Load training and test data and cache it. - val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(params.input, params.dataFormat, params.testInput, "regression", params.fracTest) val lir = new LinearRegression() @@ -136,7 +137,7 @@ object LinearRegressionExample { println("Test data results:") DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label") - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala index f68aef708201c..94cf2866238b9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -18,22 +18,22 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.regression.LinearRegression // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object LinearRegressionWithElasticNetExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("LinearRegressionWithElasticNetExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("LinearRegressionWithElasticNetExample") + .getOrCreate() // $example on$ // Load training data - val training = sqlContext.read.format("libsvm") + val training = spark.read.format("libsvm") .load("data/mllib/sample_linear_regression_data.txt") val lr = new LinearRegression() @@ -56,7 +56,7 @@ object LinearRegressionWithElasticNetExample { println(s"r2: ${trainingSummary.r2}") // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala index a380c90662a50..c2a87e1ddfd55 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -23,12 +23,11 @@ import scala.language.reflectiveCalls import scopt.OptionParser -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.{Pipeline, PipelineStage} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.feature.StringIndexer -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, SparkSession} /** * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization. @@ -81,11 +80,11 @@ object LogisticRegressionExample { s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") .action((x, c) => c.copy(tol = x)) opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing. If given option testInput, " + + .text(s"fraction of data to hold out for testing. If given option testInput, " + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) opt[String]("testInput") - .text(s"input path to test dataset. If given, option fracTest is ignored." + + .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") .action((x, c) => c.copy(testInput = x)) opt[String]("dataFormat") @@ -112,16 +111,18 @@ object LogisticRegressionExample { } def run(params: Params) { - val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params") - val sc = new SparkContext(conf) + val spark = SparkSession + .builder + .appName(s"LogisticRegressionExample with $params") + .getOrCreate() println(s"LogisticRegressionExample with parameters:\n$params") // Load training and test data and cache it. - val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(params.input, params.dataFormat, params.testInput, "classification", params.fracTest) - // Set up Pipeline + // Set up Pipeline. val stages = new mutable.ArrayBuffer[PipelineStage]() val labelIndexer = new StringIndexer() @@ -141,7 +142,7 @@ object LogisticRegressionExample { stages += lor val pipeline = new Pipeline().setStages(stages.toArray) - // Fit the Pipeline + // Fit the Pipeline. val startTime = System.nanoTime() val pipelineModel = pipeline.fit(training) val elapsedTime = (System.nanoTime() - startTime) / 1e9 @@ -156,7 +157,7 @@ object LogisticRegressionExample { println("Test data results:") DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel") - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala index 89c5edf1ace9c..cd8775c942162 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -18,23 +18,23 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.max object LogisticRegressionSummaryExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("LogisticRegressionSummaryExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession + .builder + .appName("LogisticRegressionSummaryExample") + .getOrCreate() + import spark.implicits._ // Load training data - val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val lr = new LogisticRegression() .setMaxIter(10) @@ -71,7 +71,7 @@ object LogisticRegressionSummaryExample { lrModel.setThreshold(bestThreshold) // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala index 6e27571f1dc16..616263b8e9f48 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala @@ -18,22 +18,22 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.LogisticRegression // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object LogisticRegressionWithElasticNetExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("LogisticRegressionWithElasticNetExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("LogisticRegressionWithElasticNetExample") + .getOrCreate() // $example on$ // Load training data - val training = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val lr = new LogisticRegression() .setMaxIter(10) @@ -47,7 +47,7 @@ object LogisticRegressionWithElasticNetExample { println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MaxAbsScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MaxAbsScalerExample.scala index aafb5efd698e4..572adce657081 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MaxAbsScalerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MaxAbsScalerExample.scala @@ -15,23 +15,22 @@ * limitations under the License. */ -// scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.MaxAbsScaler // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object MaxAbsScalerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("MaxAbsScalerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("MaxAbsScalerExample") + .getOrCreate() // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val scaler = new MaxAbsScaler() .setInputCol("features") .setOutputCol("scaledFeatures") @@ -43,7 +42,7 @@ object MaxAbsScalerExample { val scaledData = scalerModel.transform(dataFrame) scaledData.show() // $example off$ - sc.stop() + + spark.stop() } } -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala index 9a03f69f5af03..d728019a621d4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.MinMaxScaler // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object MinMaxScalerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("MinMaxScalerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("MinMaxScalerExample") + .getOrCreate() // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val scaler = new MinMaxScaler() .setInputCol("features") @@ -44,7 +44,8 @@ object MinMaxScalerExample { val scaledData = scalerModel.transform(dataFrame) scaledData.show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala index d1441b5497a86..c1ff9ef521706 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala @@ -18,17 +18,16 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.Row // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession /** * A simple example demonstrating model selection using CrossValidator. @@ -42,13 +41,14 @@ import org.apache.spark.sql.SQLContext object ModelSelectionViaCrossValidationExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("ModelSelectionViaCrossValidationExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("ModelSelectionViaCrossValidationExample") + .getOrCreate() // $example on$ // Prepare training data from a list of (id, text, label) tuples. - val training = sqlContext.createDataFrame(Seq( + val training = spark.createDataFrame(Seq( (0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), @@ -98,7 +98,7 @@ object ModelSelectionViaCrossValidationExample { val cvModel = cv.fit(training) // Prepare test documents, which are unlabeled (id, text) tuples. - val test = sqlContext.createDataFrame(Seq( + val test = spark.createDataFrame(Seq( (4L, "spark i j k"), (5L, "l m n"), (6L, "mapreduce spark"), @@ -114,7 +114,7 @@ object ModelSelectionViaCrossValidationExample { } // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala index fcad17a817580..75fef2922adbd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala @@ -17,13 +17,12 @@ package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession /** * A simple example demonstrating model selection using TrainValidationSplit. @@ -36,13 +35,14 @@ import org.apache.spark.sql.SQLContext object ModelSelectionViaTrainValidationSplitExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("ModelSelectionViaTrainValidationSplitExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("ModelSelectionViaTrainValidationSplitExample") + .getOrCreate() // $example on$ // Prepare training and test data. - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) val lr = new LinearRegression() @@ -75,6 +75,6 @@ object ModelSelectionViaTrainValidationSplitExample { .show() // $example off$ - sc.stop() + spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index d7d1e82f6f849..e8a9b32da9664 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.MultilayerPerceptronClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession /** * An example for Multilayer Perceptron Classification. @@ -31,13 +30,14 @@ import org.apache.spark.sql.SQLContext object MultilayerPerceptronClassifierExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("MultilayerPerceptronClassifierExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("MultilayerPerceptronClassifierExample") + .getOrCreate() // $example on$ // Load the data stored in LIBSVM format as a DataFrame. - val data = sqlContext.read.format("libsvm") + val data = spark.read.format("libsvm") .load("data/mllib/sample_multiclass_classification_data.txt") // Split the data into train and test val splits = data.randomSplit(Array(0.6, 0.4), seed = 1234L) @@ -55,15 +55,15 @@ object MultilayerPerceptronClassifierExample { .setMaxIter(100) // train the model val model = trainer.fit(train) - // compute precision on the test set + // compute accuracy on the test set val result = model.transform(test) val predictionAndLabels = result.select("prediction", "label") val evaluator = new MulticlassClassificationEvaluator() - .setMetricName("precision") - println("Precision:" + evaluator.evaluate(predictionAndLabels)) + .setMetricName("accuracy") + println("Accuracy: " + evaluator.evaluate(predictionAndLabels)) // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala index 77b913aaa3fa0..e0b52e7a367fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.NGram // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object NGramExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("NGramExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("NGramExample") + .getOrCreate() // $example on$ - val wordDataFrame = sqlContext.createDataFrame(Seq( + val wordDataFrame = spark.createDataFrame(Seq( (0, Array("Hi", "I", "heard", "about", "Spark")), (1, Array("I", "wish", "Java", "could", "use", "case", "classes")), (2, Array("Logistic", "regression", "models", "are", "neat")) @@ -41,7 +41,8 @@ object NGramExample { val ngramDataFrame = ngram.transform(wordDataFrame) ngramDataFrame.take(3).map(_.getAs[Stream[String]]("ngrams").toList).foreach(println) // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala index 5ea1270c9781c..7089a4bc87aaa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -18,24 +18,24 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ -import org.apache.spark.ml.classification.{NaiveBayes} +import org.apache.spark.ml.classification.NaiveBayes import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object NaiveBayesExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("NaiveBayesExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("NaiveBayesExample") + .getOrCreate() // $example on$ // Load the data stored in LIBSVM format as a DataFrame. - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Split the data into training and test sets (30% held out for testing) - val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1234L) // Train a NaiveBayes model. val model = new NaiveBayes() @@ -49,10 +49,12 @@ object NaiveBayesExample { val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("label") .setPredictionCol("prediction") - .setMetricName("precision") - val precision = evaluator.evaluate(predictions) - println("Precision:" + precision) + .setMetricName("accuracy") + val accuracy = evaluator.evaluate(predictions) + println("Accuracy: " + accuracy) // $example off$ + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala index 6b33c16c74037..75ba33a7e7fc1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Normalizer // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object NormalizerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("NormalizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("NormalizerExample") + .getOrCreate() // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Normalize each Vector using $L^1$ norm. val normalizer = new Normalizer() @@ -46,7 +46,8 @@ object NormalizerExample { val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> Double.PositiveInfinity) lInfNormData.show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala index cb9fe65a85e86..4aa649b1332c6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object OneHotEncoderExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("OneHotEncoderExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("OneHotEncoderExample") + .getOrCreate() // $example on$ - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), @@ -52,7 +52,8 @@ object OneHotEncoderExample { val encoded = encoder.transform(indexed) encoded.select("id", "categoryVec").show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index 0b5d31c0ff90d..acde110683950 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -18,171 +18,62 @@ // scalastyle:off println package org.apache.spark.examples.ml -import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} - -import scopt.OptionParser - -import org.apache.spark.{SparkConf, SparkContext} // $example on$ -import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} -import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession /** - * An example runner for Multiclass to Binary Reduction with One Vs Rest. - * The example uses Logistic Regression as the base classifier. All parameters that - * can be specified on the base classifier can be passed in to the runner options. + * An example of Multiclass to Binary Reduction with One Vs Rest, + * using Logistic Regression as the base classifier. * Run with * {{{ - * ./bin/run-example ml.OneVsRestExample [options] - * }}} - * For local mode, run - * {{{ - * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g - * [examples JAR path] [options] + * ./bin/run-example ml.OneVsRestExample * }}} - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object OneVsRestExample { - - case class Params private[ml] ( - input: String = null, - testInput: Option[String] = None, - maxIter: Int = 100, - tol: Double = 1E-6, - fitIntercept: Boolean = true, - regParam: Option[Double] = None, - elasticNetParam: Option[Double] = None, - fracTest: Double = 0.2) extends AbstractParams[Params] +object OneVsRestExample { def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("OneVsRest Example") { - head("OneVsRest Example: multiclass to binary reduction using OneVsRest") - opt[String]("input") - .text("input path to labeled examples. This path must be specified") - .required() - .action((x, c) => c.copy(input = x)) - opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing. If given option testInput, " + - s"this option is ignored. default: ${defaultParams.fracTest}") - .action((x, c) => c.copy(fracTest = x)) - opt[String]("testInput") - .text("input path to test dataset. If given, option fracTest is ignored") - .action((x, c) => c.copy(testInput = Some(x))) - opt[Int]("maxIter") - .text(s"maximum number of iterations for Logistic Regression." + - s" default: ${defaultParams.maxIter}") - .action((x, c) => c.copy(maxIter = x)) - opt[Double]("tol") - .text(s"the convergence tolerance of iterations for Logistic Regression." + - s" default: ${defaultParams.tol}") - .action((x, c) => c.copy(tol = x)) - opt[Boolean]("fitIntercept") - .text(s"fit intercept for Logistic Regression." + - s" default: ${defaultParams.fitIntercept}") - .action((x, c) => c.copy(fitIntercept = x)) - opt[Double]("regParam") - .text(s"the regularization parameter for Logistic Regression.") - .action((x, c) => c.copy(regParam = Some(x))) - opt[Double]("elasticNetParam") - .text(s"the ElasticNet mixing parameter for Logistic Regression.") - .action((x, c) => c.copy(elasticNetParam = Some(x))) - checkConfig { params => - if (params.fracTest < 0 || params.fracTest >= 1) { - failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") - } else { - success - } - } - } - parser.parse(args, defaultParams).map { params => - run(params) - }.getOrElse { - sys.exit(1) - } - } - - private def run(params: Params) { - val conf = new SparkConf().setAppName(s"OneVsRestExample with $params") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName(s"OneVsRestExample") + .getOrCreate() // $example on$ - val inputData = sqlContext.read.format("libsvm").load(params.input) - // compute the train/test split: if testInput is not provided use part of input. - val data = params.testInput match { - case Some(t) => - // compute the number of features in the training set. - val numFeatures = inputData.first().getAs[Vector](1).size - val testData = sqlContext.read.option("numFeatures", numFeatures.toString) - .format("libsvm").load(t) - Array[DataFrame](inputData, testData) - case None => - val f = params.fracTest - inputData.randomSplit(Array(1 - f, f), seed = 12345) - } - val Array(train, test) = data.map(_.cache()) + // load data file. + val inputData = spark.read.format("libsvm") + .load("data/mllib/sample_multiclass_classification_data.txt") + + // generate the train/test split. + val Array(train, test) = inputData.randomSplit(Array(0.8, 0.2)) // instantiate the base classifier val classifier = new LogisticRegression() - .setMaxIter(params.maxIter) - .setTol(params.tol) - .setFitIntercept(params.fitIntercept) - - // Set regParam, elasticNetParam if specified in params - params.regParam.foreach(classifier.setRegParam) - params.elasticNetParam.foreach(classifier.setElasticNetParam) + .setMaxIter(10) + .setTol(1E-6) + .setFitIntercept(true) // instantiate the One Vs Rest Classifier. - - val ovr = new OneVsRest() - ovr.setClassifier(classifier) + val ovr = new OneVsRest().setClassifier(classifier) // train the multiclass model. - val (trainingDuration, ovrModel) = time(ovr.fit(train)) + val ovrModel = ovr.fit(train) // score the model on test data. - val (predictionDuration, predictions) = time(ovrModel.transform(test)) - - // evaluate the model - val predictionsAndLabels = predictions.select("prediction", "label") - .rdd.map(row => (row.getDouble(0), row.getDouble(1))) - - val metrics = new MulticlassMetrics(predictionsAndLabels) - - val confusionMatrix = metrics.confusionMatrix + val predictions = ovrModel.transform(test) - // compute the false positive rate per label - val predictionColSchema = predictions.schema("prediction") - val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get - val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble))) + // obtain evaluator. + val evaluator = new MulticlassClassificationEvaluator() + .setMetricName("accuracy") - println(s" Training Time ${trainingDuration} sec\n") - - println(s" Prediction Time ${predictionDuration} sec\n") - - println(s" Confusion Matrix\n ${confusionMatrix.toString}\n") - - println("label\tfpr") - - println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + // compute the classification error on test data. + val accuracy = evaluator.evaluate(predictions) + println(s"Test Error : ${1 - accuracy}") // $example off$ - sc.stop() + spark.stop() } - private def time[R](block: => R): (Long, R) = { - val t0 = System.nanoTime() - val result = block // call-by-name - val t1 = System.nanoTime() - (NANO.toSeconds(t1 - t0), result) - } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala index 535652ec6c793..dca96eea2ba4e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -18,18 +18,18 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.PCA -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object PCAExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("PCAExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("PCAExample") + .getOrCreate() // $example on$ val data = Array( @@ -37,7 +37,7 @@ object PCAExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) ) - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") val pca = new PCA() .setInputCol("features") .setOutputCol("pcaFeatures") @@ -47,7 +47,8 @@ object PCAExample { val result = pcaDF.select("pcaFeatures") result.show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala index 6c29063626bac..b16692b1fa36f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala @@ -18,26 +18,26 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.Row // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object PipelineExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("PipelineExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("PipelineExample") + .getOrCreate() // $example on$ // Prepare training documents from a list of (id, text, label) tuples. - val training = sqlContext.createDataFrame(Seq( + val training = spark.createDataFrame(Seq( (0L, "a b c d e spark", 1.0), (1L, "b d", 0.0), (2L, "spark f g h", 1.0), @@ -71,7 +71,7 @@ object PipelineExample { val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") // Prepare test documents, which are unlabeled (id, text) tuples. - val test = sqlContext.createDataFrame(Seq( + val test = spark.createDataFrame(Seq( (4L, "spark i j k"), (5L, "l m n"), (6L, "mapreduce spark"), @@ -87,7 +87,7 @@ object PipelineExample { } // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala index 3014008ea0ce4..54d2e6b36d149 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -18,18 +18,18 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.PolynomialExpansion -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object PolynomialExpansionExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("PolynomialExpansionExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("PolynomialExpansionExample") + .getOrCreate() // $example on$ val data = Array( @@ -37,7 +37,7 @@ object PolynomialExpansionExample { Vectors.dense(0.0, 0.0), Vectors.dense(0.6, -1.1) ) - val df = sqlContext.createDataFrame(data.map(Tuple1.apply)).toDF("features") + val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("features") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") .setOutputCol("polyFeatures") @@ -45,7 +45,8 @@ object PolynomialExpansionExample { val polyDF = polynomialExpansion.transform(df) polyDF.select("polyFeatures").take(3).foreach(println) // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index e64e673a485ed..2f7e217b8fe2d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -15,26 +15,30 @@ * limitations under the License. */ -// scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.QuantileDiscretizer // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object QuantileDiscretizerExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("QuantileDiscretizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession + .builder + .appName("QuantileDiscretizerExample") + .getOrCreate() + import spark.implicits._ // $example on$ val data = Array((0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2)) - val df = sc.parallelize(data).toDF("id", "hour") - + var df = spark.createDataFrame(data).toDF("id", "hour") + // $example off$ + // Output of QuantileDiscretizer for such small datasets can depend on the number of + // partitions. Here we force a single partition to ensure consistent results. + // Note this is not necessary for normal use cases + .repartition(1) + // $example on$ val discretizer = new QuantileDiscretizer() .setInputCol("hour") .setOutputCol("result") @@ -43,7 +47,7 @@ object QuantileDiscretizerExample { val result = discretizer.fit(df).transform(df) result.show() // $example off$ - sc.stop() + + spark.stop() } } -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala index bec831d51c581..9ea4920146448 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.RFormula // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object RFormulaExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("RFormulaExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("RFormulaExample") + .getOrCreate() // $example on$ - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( (7, "US", 18, 1.0), (8, "CA", 12, 0.0), (9, "NZ", 15, 0.0) @@ -43,7 +43,8 @@ object RFormulaExample { val output = formula.fit(dataset).transform(dataset) output.select("features", "label").show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index 6c9b52cf259e6..5eafda8ce4285 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -18,24 +18,24 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object RandomForestClassifierExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("RandomForestClassifierExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("RandomForestClassifierExample") + .getOrCreate() // $example on$ // Load and parse the data file, converting it to a DataFrame. - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Index labels, adding metadata to the label column. // Fit on whole dataset to include all labels in index. @@ -51,7 +51,7 @@ object RandomForestClassifierExample { .setMaxCategories(4) .fit(data) - // Split the data into training and test sets (30% held out for testing) + // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a RandomForest model. @@ -66,11 +66,11 @@ object RandomForestClassifierExample { .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels) - // Chain indexers and forest in a Pipeline + // Chain indexers and forest in a Pipeline. val pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) - // Train model. This also runs the indexers. + // Train model. This also runs the indexers. val model = pipeline.fit(trainingData) // Make predictions. @@ -79,11 +79,11 @@ object RandomForestClassifierExample { // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(5) - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") - .setMetricName("precision") + .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) @@ -91,7 +91,7 @@ object RandomForestClassifierExample { println("Learned classification forest model:\n" + rfModel.toDebugString) // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala index 7a00d99dfe53d..2419dc49cd51e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -23,13 +23,12 @@ import scala.language.reflectiveCalls import scopt.OptionParser -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.{Pipeline, PipelineStage} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, SparkSession} /** @@ -37,7 +36,7 @@ import org.apache.spark.sql.DataFrame * {{{ * ./bin/run-example ml.RandomForestExample [options] * }}} - * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * Decision Trees and ensembles can take a large amount of memory. If the run-example command * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. * For local mode, run * {{{ @@ -94,7 +93,7 @@ object RandomForestExample { s" default: ${defaultParams.numTrees}") .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing. If given option testInput, " + + .text(s"fraction of data to hold out for testing. If given option testInput, " + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) opt[Boolean]("cacheNodeIds") @@ -115,7 +114,7 @@ object RandomForestExample { s"default: ${defaultParams.checkpointInterval}") .action((x, c) => c.copy(checkpointInterval = x)) opt[String]("testInput") - .text(s"input path to test dataset. If given, option fracTest is ignored." + + .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") .action((x, c) => c.copy(testInput = x)) opt[String]("dataFormat") @@ -142,18 +141,21 @@ object RandomForestExample { } def run(params: Params) { - val conf = new SparkConf().setAppName(s"RandomForestExample with $params") - val sc = new SparkContext(conf) - params.checkpointDir.foreach(sc.setCheckpointDir) + val spark = SparkSession + .builder + .appName(s"RandomForestExample with $params") + .getOrCreate() + + params.checkpointDir.foreach(spark.sparkContext.setCheckpointDir) val algo = params.algo.toLowerCase println(s"RandomForestExample with parameters:\n$params") // Load training and test data and cache it. - val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(params.input, params.dataFormat, params.testInput, algo, params.fracTest) - // Set up Pipeline + // Set up Pipeline. val stages = new mutable.ArrayBuffer[PipelineStage]() // (1) For classification, re-index classes. val labelColName = if (algo == "classification") "indexedLabel" else "label" @@ -170,7 +172,7 @@ object RandomForestExample { .setOutputCol("indexedFeatures") .setMaxCategories(10) stages += featuresIndexer - // (3) Learn Random Forest + // (3) Learn Random Forest. val dt = algo match { case "classification" => new RandomForestClassifier() @@ -201,13 +203,13 @@ object RandomForestExample { stages += dt val pipeline = new Pipeline().setStages(stages.toArray) - // Fit the Pipeline + // Fit the Pipeline. val startTime = System.nanoTime() val pipelineModel = pipeline.fit(training) val elapsedTime = (System.nanoTime() - startTime) / 1e9 println(s"Training time: $elapsedTime seconds") - // Get the trained Random Forest from the fitted PipelineModel + // Get the trained Random Forest from the fitted PipelineModel. algo match { case "classification" => val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel] @@ -226,7 +228,7 @@ object RandomForestExample { case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } - // Evaluate model on training, test data + // Evaluate model on training, test data. algo match { case "classification" => println("Training data results:") @@ -242,7 +244,7 @@ object RandomForestExample { throw new IllegalArgumentException("Algo ${params.algo} not supported.") } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index 4d2db017f346f..9a0a001c26ef5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -18,24 +18,24 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object RandomForestRegressorExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("RandomForestRegressorExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("RandomForestRegressorExample") + .getOrCreate() // $example on$ // Load and parse the data file, converting it to a DataFrame. - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") // Automatically identify categorical features, and index them. // Set maxCategories so features with > 4 distinct values are treated as continuous. @@ -45,7 +45,7 @@ object RandomForestRegressorExample { .setMaxCategories(4) .fit(data) - // Split the data into training and test sets (30% held out for testing) + // Split the data into training and test sets (30% held out for testing). val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) // Train a RandomForest model. @@ -53,11 +53,11 @@ object RandomForestRegressorExample { .setLabelCol("label") .setFeaturesCol("indexedFeatures") - // Chain indexer and forest in a Pipeline + // Chain indexer and forest in a Pipeline. val pipeline = new Pipeline() .setStages(Array(featureIndexer, rf)) - // Train model. This also runs the indexer. + // Train model. This also runs the indexer. val model = pipeline.fit(trainingData) // Make predictions. @@ -66,7 +66,7 @@ object RandomForestRegressorExample { // Select example rows to display. predictions.select("prediction", "label", "features").show(5) - // Select (prediction, true label) and compute test error + // Select (prediction, true label) and compute test error. val evaluator = new RegressionEvaluator() .setLabelCol("label") .setPredictionCol("prediction") @@ -78,7 +78,7 @@ object RandomForestRegressorExample { println("Learned regression forest model:\n" + rfModel.toDebugString) // $example off$ - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala index 202925acadff2..bb4587b82cb37 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.SQLTransformer // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object SQLTransformerExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("SQLTransformerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("SQLTransformerExample") + .getOrCreate() // $example on$ - val df = sqlContext.createDataFrame( + val df = spark.createDataFrame( Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( @@ -39,6 +39,8 @@ object SQLTransformerExample { sqlTrans.transform(df).show() // $example off$ + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index f4d1fe57856a1..29f1f509608a7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} /** * A simple example demonstrating ways to specify parameters for Estimators and Transformers. @@ -35,21 +34,22 @@ import org.apache.spark.sql.{Row, SQLContext} object SimpleParamsExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("SimpleParamsExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession + .builder + .appName("SimpleParamsExample") + .getOrCreate() + import spark.implicits._ // Prepare training data. - // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes + // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes // into DataFrames, where it uses the case class metadata to infer the schema. - val training = sc.parallelize(Seq( + val training = spark.createDataFrame(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) - // Create a LogisticRegression instance. This instance is an Estimator. + // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") @@ -58,8 +58,8 @@ object SimpleParamsExample { lr.setMaxIter(10) .setRegParam(0.01) - // Learn a LogisticRegression model. This uses the parameters stored in lr. - val model1 = lr.fit(training.toDF()) + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model1 = lr.fit(training) // Since model1 is a Model (i.e., a Transformer produced by an Estimator), // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this @@ -69,8 +69,8 @@ object SimpleParamsExample { // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. val paramMap = ParamMap(lr.maxIter -> 20) - paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. - paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.45, 0.55)) // Specify multiple Params. + paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter. + paramMap.put(lr.regParam -> 0.1, lr.thresholds -> Array(0.5, 0.5)) // Specify multiple Params. // One can also combine ParamMaps. val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name @@ -82,7 +82,7 @@ object SimpleParamsExample { println("Model 2 was fit using parameters: " + model2.parent.extractParamMap()) // Prepare test data. - val test = sc.parallelize(Seq( + val test = spark.createDataFrame(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) @@ -91,14 +91,14 @@ object SimpleParamsExample { // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. - model2.transform(test.toDF()) + model2.transform(test) .select("features", "label", "myProbability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => println(s"($features, $label) -> prob=$prob, prediction=$prediction") } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 960280137cbf9..0b2a058bb61aa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,12 +20,11 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.{Row, SparkSession} @BeanInfo case class LabeledDocument(id: Long, text: String, label: Double) @@ -43,13 +42,14 @@ case class Document(id: Long, text: String) object SimpleTextClassificationPipeline { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession + .builder + .appName("SimpleTextClassificationPipeline") + .getOrCreate() + import spark.implicits._ // Prepare training documents, which are labeled. - val training = sc.parallelize(Seq( + val training = spark.createDataFrame(Seq( LabeledDocument(0L, "a b c d e spark", 1.0), LabeledDocument(1L, "b d", 0.0), LabeledDocument(2L, "spark f g h", 1.0), @@ -73,7 +73,7 @@ object SimpleTextClassificationPipeline { val model = pipeline.fit(training.toDF()) // Prepare test documents, which are unlabeled. - val test = sc.parallelize(Seq( + val test = spark.createDataFrame(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), Document(6L, "spark hadoop spark"), @@ -87,7 +87,7 @@ object SimpleTextClassificationPipeline { println(s"($id, $text) --> prob=$prob, prediction=$prediction") } - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala index e3439677e78d6..4d668e8ab9670 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.StandardScaler // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object StandardScalerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StandardScalerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("StandardScalerExample") + .getOrCreate() // $example on$ - val dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val dataFrame = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val scaler = new StandardScaler() .setInputCol("features") @@ -46,7 +46,8 @@ object StandardScalerExample { val scaledData = scalerModel.transform(dataFrame) scaledData.show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala index 8199be12c155b..fb1a43e962cd5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -18,31 +18,32 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.StopWordsRemover // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object StopWordsRemoverExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StopWordsRemoverExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("StopWordsRemoverExample") + .getOrCreate() // $example on$ val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (0, Seq("I", "saw", "the", "red", "baloon")), (1, Seq("Mary", "had", "a", "little", "lamb")) )).toDF("id", "raw") remover.transform(dataSet).show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala index 3f0e870c8dc6b..63f273e87a209 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.StringIndexer // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object StringIndexerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("StringIndexerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("StringIndexerExample") + .getOrCreate() // $example on$ - val df = sqlContext.createDataFrame( + val df = spark.createDataFrame( Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) ).toDF("id", "category") @@ -42,7 +42,8 @@ object StringIndexerExample { val indexed = indexer.fit(df).transform(df) indexed.show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala index 396f073e6b322..33b5daec59783 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala @@ -18,21 +18,21 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object TfIdfExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("TfIdfExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("TfIdfExample") + .getOrCreate() // $example on$ - val sentenceData = sqlContext.createDataFrame(Seq( + val sentenceData = spark.createDataFrame(Seq( (0, "Hi I heard about Spark"), (0, "I wish Java could use case classes"), (1, "Logistic regression models are neat") @@ -50,6 +50,8 @@ object TfIdfExample { val rescaledData = idfModel.transform(featurizedData) rescaledData.select("features", "label").take(3).foreach(println) // $example off$ + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala index c667728d6326d..1c70dc700b91c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object TokenizerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("TokenizerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("TokenizerExample") + .getOrCreate() // $example on$ - val sentenceDataFrame = sqlContext.createDataFrame(Seq( + val sentenceDataFrame = spark.createDataFrame(Seq( (0, "Hi I heard about Spark"), (1, "I wish Java could use case classes"), (2, "Logistic,regression,models,are,neat") @@ -48,7 +48,8 @@ object TokenizerExample { val regexTokenized = regexTokenizer.transform(sentenceDataFrame) regexTokenized.select("words", "label").take(3).foreach(println) // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala new file mode 100644 index 0000000000000..13c72f88cc83b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.DoubleParam +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.sql.functions.col +// $example off$ +import org.apache.spark.sql.SparkSession +// $example on$ +import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.util.Utils +// $example off$ + +/** + * An example demonstrating creating a custom [[org.apache.spark.ml.Transformer]] using + * the [[UnaryTransformer]] abstraction. + * + * Run with + * {{{ + * bin/run-example ml.UnaryTransformerExample + * }}} + */ +object UnaryTransformerExample { + + // $example on$ + /** + * Simple Transformer which adds a constant value to input Doubles. + * + * [[UnaryTransformer]] can be used to create a stage usable within Pipelines. + * It defines parameters for specifying input and output columns: + * [[UnaryTransformer.inputCol]] and [[UnaryTransformer.outputCol]]. + * It can optionally handle schema validation. + * + * [[DefaultParamsWritable]] provides a default implementation for persisting instances + * of this Transformer. + */ + class MyTransformer(override val uid: String) + extends UnaryTransformer[Double, Double, MyTransformer] with DefaultParamsWritable { + + final val shift: DoubleParam = new DoubleParam(this, "shift", "Value added to input") + + def getShift: Double = $(shift) + + def setShift(value: Double): this.type = set(shift, value) + + def this() = this(Identifiable.randomUID("myT")) + + override protected def createTransformFunc: Double => Double = (input: Double) => { + input + $(shift) + } + + override protected def outputDataType: DataType = DataTypes.DoubleType + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType == DataTypes.DoubleType, s"Bad input type: $inputType. Requires Double.") + } + } + + /** + * Companion object for our simple Transformer. + * + * [[DefaultParamsReadable]] provides a default implementation for loading instances + * of this Transformer which were persisted using [[DefaultParamsWritable]]. + */ + object MyTransformer extends DefaultParamsReadable[MyTransformer] + // $example off$ + + def main(args: Array[String]) { + val spark = SparkSession + .builder() + .appName("UnaryTransformerExample") + .getOrCreate() + + // $example on$ + val myTransformer = new MyTransformer() + .setShift(0.5) + .setInputCol("input") + .setOutputCol("output") + + // Create data, transform, and display it. + val data = spark.range(0, 5).toDF("input") + .select(col("input").cast("double").as("input")) + val result = myTransformer.transform(data) + result.show() + + // Save and load the Transformer. + val tmpDir = Utils.createTempDir() + val dirName = tmpDir.getCanonicalPath + myTransformer.write.overwrite().save(dirName) + val sameTransformer = MyTransformer.load(dirName) + + // Transform the data to show the results are identical. + val sameResult = sameTransformer.transform(data) + sameResult.show() + + Utils.deleteRecursively(tmpDir) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala index 768a8c0690477..8910470c1cf7a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -18,21 +18,21 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object VectorAssemblerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorAssemblerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("VectorAssemblerExample") + .getOrCreate() // $example on$ - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( Seq((0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5), 1.0)) ).toDF("id", "hour", "mobile", "userFeatures", "clicked") @@ -43,7 +43,8 @@ object VectorAssemblerExample { val output = assembler.transform(dataset) println(output.select("features", "clicked").first()) // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index 3bef37ba360b9..afa761aee0b98 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -18,20 +18,20 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.VectorIndexer // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object VectorIndexerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorIndexerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("VectorIndexerExample") + .getOrCreate() // $example on$ - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") val indexer = new VectorIndexer() .setInputCol("features") @@ -48,7 +48,8 @@ object VectorIndexerExample { val indexedData = indexerModel.transform(data) indexedData.show() // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala index 01377d80e7e5c..85dd5c27766c2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -18,31 +18,32 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ +import java.util.Arrays + import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.feature.VectorSlicer -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.Row import org.apache.spark.sql.types.StructType // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object VectorSlicerExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("VectorSlicerExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("VectorSlicerExample") + .getOrCreate() // $example on$ - val data = Array(Row(Vectors.dense(-2.0, 2.3, 0.0))) + val data = Arrays.asList(Row(Vectors.dense(-2.0, 2.3, 0.0))) val defaultAttr = NumericAttribute.defaultAttr val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) - val dataRDD = sc.parallelize(data) - val dataset = sqlContext.createDataFrame(dataRDD, StructType(Array(attrGroup.toStructField()))) + val dataset = spark.createDataFrame(data, StructType(Array(attrGroup.toStructField()))) val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") @@ -52,7 +53,8 @@ object VectorSlicerExample { val output = slicer.transform(dataset) println(output.select("userFeatures", "features").first()) // $example off$ - sc.stop() + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala index e77aa59ba32b2..9ac5623607296 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala @@ -18,21 +18,21 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Word2Vec // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object Word2VecExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("Word2Vec example") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("Word2Vec example") + .getOrCreate() // $example on$ // Input data: Each row is a bag of words from a sentence or document. - val documentDF = sqlContext.createDataFrame(Seq( + val documentDF = spark.createDataFrame(Seq( "Hi I heard about Spark".split(" "), "I wish Java could use case classes".split(" "), "Logistic regression models are neat".split(" ") @@ -48,6 +48,8 @@ object Word2VecExample { val result = model.transform(documentDF) result.select("result").take(3).foreach(println) // $example off$ + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index ee811d3aa1015..a85aa2cac9e1b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -295,11 +295,10 @@ object DecisionTreeRunner { } if (params.algo == Classification) { val trainAccuracy = - new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) - .precision + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))).accuracy println(s"Train accuracy = $trainAccuracy") val testAccuracy = - new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).accuracy println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { @@ -322,11 +321,10 @@ object DecisionTreeRunner { println(model) // Print model summary. } val trainAccuracy = - new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) - .precision + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))).accuracy println(s"Train accuracy = $trainAccuracy") val testAccuracy = - new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).accuracy println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index b0144ef533133..90e4687c1f444 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -120,11 +120,10 @@ object GradientBoostedTreesRunner { println(model) // Print model summary. } val trainAccuracy = - new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) - .precision + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))).accuracy println(s"Train accuracy = $trainAccuracy") val testAccuracy = - new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).accuracy println(s"Test accuracy = $testAccuracy") } else if (params.algo == "Regression") { val startTime = System.nanoTime() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index c4336639d7c0b..e5dea129c113d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -21,6 +21,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} +import org.apache.spark.mllib.util.MLUtils // $example off$ object IsotonicRegressionExample { @@ -30,12 +31,12 @@ object IsotonicRegressionExample { val conf = new SparkConf().setAppName("IsotonicRegressionExample") val sc = new SparkContext(conf) // $example on$ - val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") + val data = MLUtils.loadLibSVMFile(sc, + "data/mllib/sample_isotonic_regression_libsvm_data.txt").cache() // Create label, feature, weight tuples from input data with weight set to default value 1.0. - val parsedData = data.map { line => - val parts = line.split(',').map(_.toDouble) - (parts(0), parts(1), 1.0) + val parsedData = data.map { labeledPoint => + (labeledPoint.label, labeledPoint.features(0), 1.0) } // Split data into training (60%) and test (40%) sets. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index e89d555884dd0..ef67841f0cbee 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -24,10 +24,11 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} +import org.apache.spark.ml.linalg.{Vector => MLVector} import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} /** * An example Latent Dirichlet Allocation (LDA) app. Run with @@ -189,8 +190,11 @@ object LDAExample { vocabSize: Int, stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession + .builder + .sparkContext(sc) + .getOrCreate() + import spark.implicits._ // Get dataset of document texts // One document per line in each text file. If the input consists of many small files, @@ -222,7 +226,7 @@ object LDAExample { val documents = model.transform(df) .select("features") .rdd - .map { case Row(features: Vector) => features } + .map { case Row(features: MLVector) => Vectors.fromML(features) } .zipWithIndex() .map(_.swap) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index f87611f5d4613..a70203028c858 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.util.MLUtils * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt`. * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ +@deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0") object LinearRegression { object RegType extends Enumeration { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala index 669868787e8f0..d399618094487 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.LinearRegressionWithSGD // $example off$ +@deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0") object LinearRegressionWithSGDExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala index 632a2d537e5bc..31ba740ad4af0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LogisticRegressionWithLBFGSExample.scala @@ -54,8 +54,8 @@ object LogisticRegressionWithLBFGSExample { // Get evaluation metrics. val metrics = new MulticlassMetrics(predictionAndLabels) - val precision = metrics.precision - println("Precision = " + precision) + val accuracy = metrics.accuracy + println(s"Accuracy = $accuracy") // Save and load model model.save(sc, "target/tmp/scalaLogisticRegressionWithLBFGSModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala index 4f925ede24d82..12394c867e973 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala @@ -59,13 +59,9 @@ object MulticlassMetricsExample { println(metrics.confusionMatrix) // Overall Statistics - val precision = metrics.precision - val recall = metrics.recall // same as true positive rate - val f1Score = metrics.fMeasure + val accuracy = metrics.accuracy println("Summary Statistics") - println(s"Precision = $precision") - println(s"Recall = $recall") - println(s"F1 Score = $f1Score") + println(s"Accuracy = $accuracy") // Precision by label val labels = metrics.labels diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala index 0187ad603a654..b321d8e127aa5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala @@ -21,8 +21,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils // $example off$ object NaiveBayesExample { @@ -31,16 +30,11 @@ object NaiveBayesExample { val conf = new SparkConf().setAppName("NaiveBayesExample") val sc = new SparkContext(conf) // $example on$ - val data = sc.textFile("data/mllib/sample_naive_bayes_data.txt") - val parsedData = data.map { line => - val parts = line.split(',') - LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) - } + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Split data into training (60%) and test (40%). - val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) - val training = splits(0) - val test = splits(1) + val Array(training, test) = data.randomSplit(Array(0.6, 0.4)) val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala index f7a813695304f..eb36697d94ba1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD} // $example off$ +@deprecated("Deprecated since LinearRegressionWithSGD is deprecated. Use ml.feature.PCA", "2.0.0") object PCAExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala index fdb01b86dd787..d514891da78fc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -18,22 +18,22 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.evaluation.{RankingMetrics, RegressionMetrics} import org.apache.spark.mllib.recommendation.{ALS, Rating} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession object RankingMetricsExample { def main(args: Array[String]) { - val conf = new SparkConf().setAppName("RankingMetricsExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession + .builder + .appName("RankingMetricsExample") + .getOrCreate() + import spark.implicits._ // $example on$ // Read in the ratings data - val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val ratings = spark.read.textFile("data/mllib/sample_movielens_data.txt").rdd.map { line => val fields = line.split("::") Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) }.cache() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala index add634c957b40..76cfb804e18f3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -18,22 +18,27 @@ package org.apache.spark.examples.mllib -import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.regression.LinearRegressionWithSGD -import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD} // $example off$ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession +@deprecated("Use ml.regression.LinearRegression and the resulting model summary for metrics", + "2.0.0") object RegressionMetricsExample { def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("RegressionMetricsExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession + .builder + .appName("RegressionMetricsExample") + .getOrCreate() // $example on$ // Load the data - val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + val data = spark + .read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") + .rdd.map(row => LabeledPoint(row.getDouble(0), row.get(1).asInstanceOf[Vector])) + .cache() // Build the model val numIterations = 100 @@ -61,6 +66,8 @@ object RegressionMetricsExample { // Explained variance println(s"Explained variance = ${metrics.explainedVariance}") // $example off$ + + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 8ce4427c53997..deaa9f252b9b0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -18,8 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.sql -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SaveMode, SparkSession} +import org.apache.spark.sql.SaveMode +// $example on:init_session$ +import org.apache.spark.sql.SparkSession +// $example off:init_session$ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -27,17 +29,21 @@ case class Record(key: Int, value: String) object RDDRelation { def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("RDDRelation") - val sc = new SparkContext(sparkConf) - val spark = new SparkSession(sc) + // $example on:init_session$ + val spark = SparkSession + .builder + .appName("Spark Examples") + .config("spark.some.config.option", "some-value") + .getOrCreate() // Importing the SparkSession gives access to all the SQL functions and implicit conversions. import spark.implicits._ + // $example off:init_session$ - val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF() - // Any RDD containing case classes can be registered as a table. The schema of the table is - // automatically inferred using scala reflection. - df.registerTempTable("records") + val df = spark.createDataFrame((1 to 100).map(i => Record(i, s"val_$i"))) + // Any RDD containing case classes can be used to create a temporary view. The schema of the + // view is automatically inferred using scala reflection. + df.createOrReplaceTempView("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") @@ -66,11 +72,11 @@ object RDDRelation { // Queries can be run using the DSL on parquet files just like the original RDD. parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) - // These files can also be registered as tables. - parquetFile.registerTempTable("parquetFile") + // These files can also be used to create a temporary view. + parquetFile.createOrReplaceTempView("parquetFile") spark.sql("SELECT * FROM parquetFile").collect().foreach(println) - sc.stop() + spark.stop() } } // scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala new file mode 100644 index 0000000000000..dc3915a4882b0 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql + +import org.apache.spark.sql.SparkSession + +object SQLDataSourceExample { + + case class Person(name: String, age: Long) + + def main(args: Array[String]) { + val spark = SparkSession + .builder() + .appName("Spark SQL data sources example") + .config("spark.some.config.option", "some-value") + .getOrCreate() + + runBasicDataSourceExample(spark) + runBasicParquetExample(spark) + runParquetSchemaMergingExample(spark) + runJsonDatasetExample(spark) + runJdbcDatasetExample(spark) + + spark.stop() + } + + private def runBasicDataSourceExample(spark: SparkSession): Unit = { + // $example on:generic_load_save_functions$ + val usersDF = spark.read.load("examples/src/main/resources/users.parquet") + usersDF.select("name", "favorite_color").write.save("namesAndFavColors.parquet") + // $example off:generic_load_save_functions$ + // $example on:manual_load_options$ + val peopleDF = spark.read.format("json").load("examples/src/main/resources/people.json") + peopleDF.select("name", "age").write.format("parquet").save("namesAndAges.parquet") + // $example off:manual_load_options$ + // $example on:direct_sql$ + val sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") + // $example off:direct_sql$ + } + + private def runBasicParquetExample(spark: SparkSession): Unit = { + // $example on:basic_parquet_example$ + // Encoders for most common types are automatically provided by importing spark.implicits._ + import spark.implicits._ + + val peopleDF = spark.read.json("examples/src/main/resources/people.json") + + // DataFrames can be saved as Parquet files, maintaining the schema information + peopleDF.write.parquet("people.parquet") + + // Read in the parquet file created above + // Parquet files are self-describing so the schema is preserved + // The result of loading a Parquet file is also a DataFrame + val parquetFileDF = spark.read.parquet("people.parquet") + + // Parquet files can also be used to create a temporary view and then used in SQL statements + parquetFileDF.createOrReplaceTempView("parquetFile") + val namesDF = spark.sql("SELECT name FROM parquetFile WHERE age BETWEEN 13 AND 19") + namesDF.map(attributes => "Name: " + attributes(0)).show() + // +------------+ + // | value| + // +------------+ + // |Name: Justin| + // +------------+ + // $example off:basic_parquet_example$ + } + + private def runParquetSchemaMergingExample(spark: SparkSession): Unit = { + // $example on:schema_merging$ + // This is used to implicitly convert an RDD to a DataFrame. + import spark.implicits._ + + // Create a simple DataFrame, store into a partition directory + val squaresDF = spark.sparkContext.makeRDD(1 to 5).map(i => (i, i * i)).toDF("value", "square") + squaresDF.write.parquet("data/test_table/key=1") + + // Create another DataFrame in a new partition directory, + // adding a new column and dropping an existing column + val cubesDF = spark.sparkContext.makeRDD(6 to 10).map(i => (i, i * i * i)).toDF("value", "cube") + cubesDF.write.parquet("data/test_table/key=2") + + // Read the partitioned table + val mergedDF = spark.read.option("mergeSchema", "true").parquet("data/test_table") + mergedDF.printSchema() + + // The final schema consists of all 3 columns in the Parquet files together + // with the partitioning column appeared in the partition directory paths + // root + // |-- value: int (nullable = true) + // |-- square: int (nullable = true) + // |-- cube: int (nullable = true) + // |-- key: int (nullable = true) + // $example off:schema_merging$ + } + + private def runJsonDatasetExample(spark: SparkSession): Unit = { + // $example on:json_dataset$ + // A JSON dataset is pointed to by path. + // The path can be either a single text file or a directory storing text files + val path = "examples/src/main/resources/people.json" + val peopleDF = spark.read.json(path) + + // The inferred schema can be visualized using the printSchema() method + peopleDF.printSchema() + // root + // |-- age: long (nullable = true) + // |-- name: string (nullable = true) + + // Creates a temporary view using the DataFrame + peopleDF.createOrReplaceTempView("people") + + // SQL statements can be run by using the sql methods provided by spark + val teenagerNamesDF = spark.sql("SELECT name FROM people WHERE age BETWEEN 13 AND 19") + teenagerNamesDF.show() + // +------+ + // | name| + // +------+ + // |Justin| + // +------+ + + // Alternatively, a DataFrame can be created for a JSON dataset represented by + // an RDD[String] storing one JSON object per string + val otherPeopleRDD = spark.sparkContext.makeRDD( + """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) + val otherPeople = spark.read.json(otherPeopleRDD) + otherPeople.show() + // +---------------+----+ + // | address|name| + // +---------------+----+ + // |[Columbus,Ohio]| Yin| + // +---------------+----+ + // $example off:json_dataset$ + } + + private def runJdbcDatasetExample(spark: SparkSession): Unit = { + // $example on:jdbc_dataset$ + val jdbcDF = spark.read + .format("jdbc") + .option("url", "jdbc:postgresql:dbserver") + .option("dbtable", "schema.tablename") + .option("user", "username") + .option("password", "password") + .load() + // $example off:jdbc_dataset$ + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala new file mode 100644 index 0000000000000..129b81d5fbbf3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkSQLExample.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql + +// $example on:schema_inferring$ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.Encoder +// $example off:schema_inferring$ +import org.apache.spark.sql.Row +// $example on:init_session$ +import org.apache.spark.sql.SparkSession +// $example off:init_session$ +// $example on:programmatic_schema$ +// $example on:data_types$ +import org.apache.spark.sql.types._ +// $example off:data_types$ +// $example off:programmatic_schema$ + +object SparkSQLExample { + + // $example on:create_ds$ + // Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit, + // you can use custom classes that implement the Product interface + case class Person(name: String, age: Long) + // $example off:create_ds$ + + def main(args: Array[String]) { + // $example on:init_session$ + val spark = SparkSession + .builder() + .appName("Spark SQL basic example") + .config("spark.some.config.option", "some-value") + .getOrCreate() + + // For implicit conversions like converting RDDs to DataFrames + import spark.implicits._ + // $example off:init_session$ + + runBasicDataFrameExample(spark) + runDatasetCreationExample(spark) + runInferSchemaExample(spark) + runProgrammaticSchemaExample(spark) + + spark.stop() + } + + private def runBasicDataFrameExample(spark: SparkSession): Unit = { + // $example on:create_df$ + val df = spark.read.json("examples/src/main/resources/people.json") + + // Displays the content of the DataFrame to stdout + df.show() + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:create_df$ + + // $example on:untyped_ops$ + // This import is needed to use the $-notation + import spark.implicits._ + // Print the schema in a tree format + df.printSchema() + // root + // |-- age: long (nullable = true) + // |-- name: string (nullable = true) + + // Select only the "name" column + df.select("name").show() + // +-------+ + // | name| + // +-------+ + // |Michael| + // | Andy| + // | Justin| + // +-------+ + + // Select everybody, but increment the age by 1 + df.select($"name", $"age" + 1).show() + // +-------+---------+ + // | name|(age + 1)| + // +-------+---------+ + // |Michael| null| + // | Andy| 31| + // | Justin| 20| + // +-------+---------+ + + // Select people older than 21 + df.filter($"age" > 21).show() + // +---+----+ + // |age|name| + // +---+----+ + // | 30|Andy| + // +---+----+ + + // Count people by age + df.groupBy("age").count().show() + // +----+-----+ + // | age|count| + // +----+-----+ + // | 19| 1| + // |null| 1| + // | 30| 1| + // +----+-----+ + // $example off:untyped_ops$ + + // $example on:run_sql$ + // Register the DataFrame as a SQL temporary view + df.createOrReplaceTempView("people") + + val sqlDF = spark.sql("SELECT * FROM people") + sqlDF.show() + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:run_sql$ + } + + private def runDatasetCreationExample(spark: SparkSession): Unit = { + import spark.implicits._ + // $example on:create_ds$ + // Encoders are created for case classes + val caseClassDS = Seq(Person("Andy", 32)).toDS() + caseClassDS.show() + // +----+---+ + // |name|age| + // +----+---+ + // |Andy| 32| + // +----+---+ + + // Encoders for most common types are automatically provided by importing spark.implicits._ + val primitiveDS = Seq(1, 2, 3).toDS() + primitiveDS.map(_ + 1).collect() // Returns: Array(2, 3, 4) + + // DataFrames can be converted to a Dataset by providing a class. Mapping will be done by name + val path = "examples/src/main/resources/people.json" + val peopleDS = spark.read.json(path).as[Person] + peopleDS.show() + // +----+-------+ + // | age| name| + // +----+-------+ + // |null|Michael| + // | 30| Andy| + // | 19| Justin| + // +----+-------+ + // $example off:create_ds$ + } + + private def runInferSchemaExample(spark: SparkSession): Unit = { + // $example on:schema_inferring$ + // For implicit conversions from RDDs to DataFrames + import spark.implicits._ + + // Create an RDD of Person objects from a text file, convert it to a Dataframe + val peopleDF = spark.sparkContext + .textFile("examples/src/main/resources/people.txt") + .map(_.split(",")) + .map(attributes => Person(attributes(0), attributes(1).trim.toInt)) + .toDF() + // Register the DataFrame as a temporary view + peopleDF.createOrReplaceTempView("people") + + // SQL statements can be run by using the sql methods provided by Spark + val teenagersDF = spark.sql("SELECT name, age FROM people WHERE age BETWEEN 13 AND 19") + + // The columns of a row in the result can be accessed by field index + teenagersDF.map(teenager => "Name: " + teenager(0)).show() + // +------------+ + // | value| + // +------------+ + // |Name: Justin| + // +------------+ + + // or by field name + teenagersDF.map(teenager => "Name: " + teenager.getAs[String]("name")).show() + // +------------+ + // | value| + // +------------+ + // |Name: Justin| + // +------------+ + + // No pre-defined encoders for Dataset[Map[K,V]], define explicitly + implicit val mapEncoder = org.apache.spark.sql.Encoders.kryo[Map[String, Any]] + // Primitive types and case classes can be also defined as + // implicit val stringIntMapEncoder: Encoder[Map[String, Any]] = ExpressionEncoder() + + // row.getValuesMap[T] retrieves multiple columns at once into a Map[String, T] + teenagersDF.map(teenager => teenager.getValuesMap[Any](List("name", "age"))).collect() + // Array(Map("name" -> "Justin", "age" -> 19)) + // $example off:schema_inferring$ + } + + private def runProgrammaticSchemaExample(spark: SparkSession): Unit = { + import spark.implicits._ + // $example on:programmatic_schema$ + // Create an RDD + val peopleRDD = spark.sparkContext.textFile("examples/src/main/resources/people.txt") + + // The schema is encoded in a string + val schemaString = "name age" + + // Generate the schema based on the string of schema + val fields = schemaString.split(" ") + .map(fieldName => StructField(fieldName, StringType, nullable = true)) + val schema = StructType(fields) + + // Convert records of the RDD (people) to Rows + val rowRDD = peopleRDD + .map(_.split(",")) + .map(attributes => Row(attributes(0), attributes(1).trim)) + + // Apply the schema to the RDD + val peopleDF = spark.createDataFrame(rowRDD, schema) + + // Creates a temporary view using the DataFrame + peopleDF.createOrReplaceTempView("people") + + // SQL can be run over a temporary view created using DataFrames + val results = spark.sql("SELECT name FROM people") + + // The results of SQL queries are DataFrames and support all the normal RDD operations + // The columns of a row in the result can be accessed by field index or by field name + results.map(attributes => "Name: " + attributes(0)).show() + // +-------------+ + // | value| + // +-------------+ + // |Name: Michael| + // | Name: Andy| + // | Name: Justin| + // +-------------+ + // $example off:programmatic_schema$ + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala deleted file mode 100644 index ff33091621c14..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.sql.hive - -import java.io.File - -import com.google.common.io.{ByteStreams, Files} - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql._ - -object HiveFromSpark { - case class Record(key: Int, value: String) - - // Copy kv1.txt file from classpath to temporary directory - val kv1Stream = HiveFromSpark.getClass.getResourceAsStream("/kv1.txt") - val kv1File = File.createTempFile("kv1", "txt") - kv1File.deleteOnExit() - ByteStreams.copy(kv1Stream, Files.newOutputStreamSupplier(kv1File)) - - def main(args: Array[String]) { - val sparkConf = new SparkConf().setAppName("HiveFromSpark") - val sc = new SparkContext(sparkConf) - - // A hive context adds support for finding tables in the MetaStore and writing queries - // using HiveQL. Users who do not have an existing Hive deployment can still create a - // HiveContext. When not configured by the hive-site.xml, the context automatically - // creates metastore_db and warehouse in the current directory. - val sparkSession = SparkSession.withHiveSupport(sc) - import sparkSession.implicits._ - import sparkSession.sql - - sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - sql(s"LOAD DATA LOCAL INPATH '${kv1File.getAbsolutePath}' INTO TABLE src") - - // Queries are expressed in HiveQL - println("Result of 'SELECT *': ") - sql("SELECT * FROM src").collect().foreach(println) - - // Aggregation queries are also supported. - val count = sql("SELECT COUNT(*) FROM src").collect().head.getLong(0) - println(s"COUNT(*): $count") - - // The results of SQL queries are themselves RDDs and support all normal RDD functions. The - // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") - - println("Result of RDD.map:") - val rddAsStrings = rddFromSql.rdd.map { - case Row(key: Int, value: String) => s"Key: $key, Value: $value" - } - - // You can also register RDDs as temporary tables within a HiveContext. - val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.toDF().registerTempTable("records") - - // Queries can then join RDD data with data stored in Hive. - println("Result of SELECT *:") - sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) - - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala new file mode 100644 index 0000000000000..ded18dacf1fe3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/SparkHiveExample.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.sql.hive + +// $example on:spark_hive$ +import org.apache.spark.sql.Row +import org.apache.spark.sql.SparkSession +// $example off:spark_hive$ + +object SparkHiveExample { + + // $example on:spark_hive$ + case class Record(key: Int, value: String) + // $example off:spark_hive$ + + def main(args: Array[String]) { + // When working with Hive, one must instantiate `SparkSession` with Hive support, including + // connectivity to a persistent Hive metastore, support for Hive serdes, and Hive user-defined + // functions. Users who do not have an existing Hive deployment can still enable Hive support. + // When not configured by the hive-site.xml, the context automatically creates `metastore_db` + // in the current directory and creates a directory configured by `spark.sql.warehouse.dir`, + // which defaults to the directory `spark-warehouse` in the current directory that the spark + // application is started. + + // $example on:spark_hive$ + // warehouseLocation points to the default location for managed databases and tables + val warehouseLocation = "spark-warehouse" + + val spark = SparkSession + .builder() + .appName("Spark Hive Example") + .config("spark.sql.warehouse.dir", warehouseLocation) + .enableHiveSupport() + .getOrCreate() + + import spark.implicits._ + import spark.sql + + sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + + // Queries are expressed in HiveQL + sql("SELECT * FROM src").show() + // +---+-------+ + // |key| value| + // +---+-------+ + // |238|val_238| + // | 86| val_86| + // |311|val_311| + // ... + + // Aggregation queries are also supported. + sql("SELECT COUNT(*) FROM src").show() + // +--------+ + // |count(1)| + // +--------+ + // | 500 | + // +--------+ + + // The results of SQL queries are themselves DataFrames and support all normal functions. + val sqlDF = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") + + // The items in DaraFrames are of type Row, which allows you to access each column by ordinal. + val stringsDS = sqlDF.map { + case Row(key: Int, value: String) => s"Key: $key, Value: $value" + } + stringsDS.show() + // +--------------------+ + // | value| + // +--------------------+ + // |Key: 0, Value: val_0| + // |Key: 0, Value: val_0| + // |Key: 0, Value: val_0| + // ... + + // You can also use DataFrames to create temporary views within a SparkSession. + val recordsDF = spark.createDataFrame((1 to 100).map(i => Record(i, s"val_$i"))) + recordsDF.createOrReplaceTempView("records") + + // Queries can then join DataFrame data with data stored in Hive. + sql("SELECT * FROM records r JOIN src s ON r.key = s.key").show() + // +---+------+---+------+ + // |key| value|key| value| + // +---+------+---+------+ + // | 2| val_2| 2| val_2| + // | 4| val_4| 4| val_4| + // | 5| val_5| 5| val_5| + // ... + // $example off:spark_hive$ + + spark.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala new file mode 100644 index 0000000000000..364bff227bc55 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.sql.streaming + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.SparkSession + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network. + * + * Usage: StructuredNetworkWordCount + * and describe the TCP server that Structured Streaming + * would connect to receive data. + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.StructuredNetworkWordCount + * localhost 9999` + */ +object StructuredNetworkWordCount { + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: StructuredNetworkWordCount ") + System.exit(1) + } + + val host = args(0) + val port = args(1).toInt + + val spark = SparkSession + .builder + .appName("StructuredNetworkWordCount") + .getOrCreate() + + import spark.implicits._ + + // Create DataFrame representing the stream of input lines from connection to host:port + val lines = spark.readStream + .format("socket") + .option("host", host) + .option("port", port) + .load().as[String] + + // Split the lines into words + val words = lines.flatMap(_.split(" ")) + + // Generate running word count + val wordCounts = words.groupBy("value").count() + + // Start running the query that prints the running counts to the console + val query = wordCounts.writeStream + .outputMode("complete") + .format("console") + .start() + + query.awaitTermination() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala new file mode 100644 index 0000000000000..333b0a9d24f40 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCountWindowed.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.sql.streaming + +import java.sql.Timestamp + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions._ + +/** + * Counts words in UTF8 encoded, '\n' delimited text received from the network over a + * sliding window of configurable duration. Each line from the network is tagged + * with a timestamp that is used to determine the windows into which it falls. + * + * Usage: StructuredNetworkWordCountWindowed + * [] + * and describe the TCP server that Structured Streaming + * would connect to receive data. + * gives the size of window, specified as integer number of seconds + * gives the amount of time successive windows are offset from one another, + * given in the same units as above. should be less than or equal to + * . If the two are equal, successive windows have no overlap. If + * is not provided, it defaults to . + * + * To run this on your local machine, you need to first run a Netcat server + * `$ nc -lk 9999` + * and then run the example + * `$ bin/run-example sql.streaming.StructuredNetworkWordCountWindowed + * localhost 9999 []` + * + * One recommended , pair is 10, 5 + */ +object StructuredNetworkWordCountWindowed { + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: StructuredNetworkWordCountWindowed " + + " []") + System.exit(1) + } + + val host = args(0) + val port = args(1).toInt + val windowSize = args(2).toInt + val slideSize = if (args.length == 3) windowSize else args(3).toInt + if (slideSize > windowSize) { + System.err.println(" must be less than or equal to ") + } + val windowDuration = s"$windowSize seconds" + val slideDuration = s"$slideSize seconds" + + val spark = SparkSession + .builder + .appName("StructuredNetworkWordCountWindowed") + .getOrCreate() + + import spark.implicits._ + + // Create DataFrame representing the stream of input lines from connection to host:port + val lines = spark.readStream + .format("socket") + .option("host", host) + .option("port", port) + .option("includeTimestamp", true) + .load().as[(String, Timestamp)] + + // Split the lines into words, retaining timestamps + val words = lines.flatMap(line => + line._1.split(" ").map(word => (word, line._2)) + ).toDF("word", "timestamp") + + // Group the data by window and word and compute the count of each group + val windowedCounts = words.groupBy( + window($"timestamp", windowDuration, slideDuration), $"word" + ).count().orderBy("window") + + // Start running the query that prints the windowed word counts to the console + val query = windowedCounts.writeStream + .outputMode("complete") + .format("console") + .option("truncate", "false") + .start() + + query.awaitTermination() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 1d144db9864bd..43044d01b1204 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.receiver.Receiver /** - * Custom Receiver that receives data over a socket. Received bytes is interpreted as + * Custom Receiver that receives data over a socket. Received bytes are interpreted as * text and \n delimited lines are considered as records. They are then counted and printed. * * To run this on your local machine, you need to first run a Netcat server @@ -50,7 +50,7 @@ object CustomReceiver { val sparkConf = new SparkConf().setAppName("CustomReceiver") val ssc = new StreamingContext(sparkConf, Seconds(1)) - // Create a input stream with the custom receiver on target ip:port and count the + // Create an input stream with the custom receiver on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') val lines = ssc.receiverStream(new CustomReceiver(args(0), args(1).toInt)) val words = lines.flatMap(_.split(" ")) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala index 5455aed22085d..19bacd449787b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/QueueStream.scala @@ -43,7 +43,7 @@ object QueueStream { reducedStream.print() ssc.start() - // Create and push some RDDs into + // Create and push some RDDs into rddQueue for (i <- 1 to 30) { rddQueue.synchronized { rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 1bcd85e1d533f..49c0427321133 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -23,11 +23,11 @@ import java.nio.charset.Charset import com.google.common.io.Files -import org.apache.spark.{Accumulator, SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Seconds, StreamingContext, Time} -import org.apache.spark.util.IntParam +import org.apache.spark.util.{IntParam, LongAccumulator} /** * Use this singleton to get or register a Broadcast variable. @@ -54,13 +54,13 @@ object WordBlacklist { */ object DroppedWordsCounter { - @volatile private var instance: Accumulator[Long] = null + @volatile private var instance: LongAccumulator = null - def getInstance(sc: SparkContext): Accumulator[Long] = { + def getInstance(sc: SparkContext): LongAccumulator = { if (instance == null) { synchronized { if (instance == null) { - instance = sc.accumulator(0L, "WordsInBlacklistCounter") + instance = sc.longAccumulator("WordsInBlacklistCounter") } } } @@ -124,7 +124,7 @@ object RecoverableNetworkWordCount { // Use blacklist to drop words and use droppedWordsCounter to count them val counts = rdd.filter { case (word, count) => if (blacklist.value.contains(word)) { - droppedWordsCounter += count + droppedWordsCounter.add(count) false } else { true diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 918e124065e4c..787bbec73b28f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -19,9 +19,8 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf -import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext, Time} @@ -60,19 +59,19 @@ object SqlNetworkWordCount { // Convert RDDs of the words DStream to DataFrame and run SQL query words.foreachRDD { (rdd: RDD[String], time: Time) => - // Get the singleton instance of SQLContext - val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) - import sqlContext.implicits._ + // Get the singleton instance of SparkSession + val spark = SparkSessionSingleton.getInstance(rdd.sparkContext.getConf) + import spark.implicits._ // Convert RDD[String] to RDD[case class] to DataFrame val wordsDataFrame = rdd.map(w => Record(w)).toDF() - // Register as table - wordsDataFrame.registerTempTable("words") + // Creates a temporary view using the DataFrame + wordsDataFrame.createOrReplaceTempView("words") // Do word count on table using SQL and print it val wordCountsDataFrame = - sqlContext.sql("select word, count(*) as total from words group by word") + spark.sql("select word, count(*) as total from words group by word") println(s"========= $time =========") wordCountsDataFrame.show() } @@ -87,14 +86,17 @@ object SqlNetworkWordCount { case class Record(word: String) -/** Lazily instantiated singleton instance of SQLContext */ -object SQLContextSingleton { +/** Lazily instantiated singleton instance of SparkSession */ +object SparkSessionSingleton { - @transient private var instance: SQLContext = _ + @transient private var instance: SparkSession = _ - def getInstance(sparkContext: SparkContext): SQLContext = { + def getInstance(sparkConf: SparkConf): SparkSession = { if (instance == null) { - instance = new SQLContext(sparkContext) + instance = SparkSession + .builder + .config(sparkConf) + .getOrCreate() } instance } diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 53a24f3e06e08..3ef4c928cb5fd 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -49,49 +49,16 @@ com.spotify docker-client - shaded test - - - - com.fasterxml.jackson.jaxrs - jackson-jaxrs-json-provider - - - com.fasterxml.jackson.datatype - jackson-datatype-guava - - - com.fasterxml.jackson.core - jackson-databind - - - org.glassfish.jersey.core - jersey-client - - - org.glassfish.jersey.connectors - jersey-apache-connector - - - org.glassfish.jersey.media - jersey-media-json-jackson - - org.apache.httpcomponents httpclient - 4.5 test org.apache.httpcomponents httpcore - 4.4.1 test @@ -128,7 +95,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} ${project.version} test @@ -154,43 +121,6 @@ test - - - com.sun.jersey - jersey-server - 1.19 - test - - - com.sun.jersey - jersey-core - 1.19 - test - - - com.sun.jersey - jersey-servlet - 1.19 - test - - - com.sun.jersey - jersey-json - 1.19 - test - - - stax - stax-api - - - - - + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.3-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kafka-0-10-assembly_2.11 + jar + Spark Integration for Kafka 0.10 Assembly + http://spark.apache.org/ + + + streaming-kafka-0-10-assembly + + + + + org.apache.spark + spark-streaming-kafka-0-10_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + commons-codec + commons-codec + provided + + + commons-lang + commons-lang + provided + + + com.google.protobuf + protobuf-java + provided + + + net.jpountz.lz4 + lz4 + provided + + + org.apache.hadoop + hadoop-client + provided + + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + provided + + + org.apache.curator + curator-recipes + provided + + + org.apache.zookeeper + zookeeper + provided + + + log4j + log4j + provided + + + net.java.dev.jets3t + jets3t + provided + + + org.scala-lang + scala-library + provided + + + org.slf4j + slf4j-api + provided + + + org.slf4j + slf4j-log4j12 + provided + + + org.xerial.snappy + snappy-java + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml new file mode 100644 index 0000000000000..3278ee35ec23b --- /dev/null +++ b/external/kafka-0-10-sql/pom.xml @@ -0,0 +1,96 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.3-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-sql-kafka-0-10_2.11 + + sql-kafka-0-10 + + jar + Kafka 0.10 Source for Structured Streaming + http://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.kafka + kafka-clients + 0.10.0.1 + + + org.apache.kafka + kafka_${scala.binary.version} + 0.10.0.1 + test + + + net.sf.jopt-simple + jopt-simple + 3.2 + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..2f9e9fc0396d5 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.spark.sql.kafka010.KafkaSourceProvider diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala new file mode 100644 index 0000000000000..3b5a96534f9b6 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.{SparkEnv, SparkException, TaskContext} +import org.apache.spark.internal.Logging + + +/** + * Consumer of single topicpartition, intended for cached reuse. + * Underlying consumer is not threadsafe, so neither is this, + * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + */ +private[kafka010] case class CachedKafkaConsumer private( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object]) extends Logging { + + private val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + private val consumer = { + val c = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + val tps = new ju.ArrayList[TopicPartition]() + tps.add(topicPartition) + c.assign(tps) + c + } + + /** Iterator to the already fetch data */ + private var fetchedData = ju.Collections.emptyIterator[ConsumerRecord[Array[Byte], Array[Byte]]] + private var nextOffsetInFetchedData = -2L + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[Array[Byte], Array[Byte]] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") + if (offset != nextOffsetInFetchedData) { + logInfo(s"Initial fetch for $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + + if (!fetchedData.hasNext()) { poll(pollTimeoutMs) } + assert(fetchedData.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset " + + s"after polling for $pollTimeoutMs") + var record = fetchedData.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + assert(fetchedData.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset " + + s"after polling for $pollTimeoutMs") + record = fetchedData.next() + assert(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset") + } + + nextOffsetInFetchedData = offset + 1 + record + } + + private def close(): Unit = consumer.close() + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $groupId $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(pollTimeoutMs: Long): Unit = { + val p = consumer.poll(pollTimeoutMs) + val r = p.records(topicPartition) + logDebug(s"Polled $groupId ${p.partitions()} ${r.size}") + fetchedData = r.iterator + } +} + +private[kafka010] object CachedKafkaConsumer extends Logging { + + private case class CacheKey(groupId: String, topicPartition: TopicPartition) + + private lazy val cache = { + val conf = SparkEnv.get.conf + val capacity = conf.getInt("spark.sql.kafkaConsumerCache.capacity", 64) + new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer](capacity, 0.75f, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer]): Boolean = { + if (this.size > capacity) { + logWarning(s"KafkaConsumer cache hitting max capacity of $capacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case e: SparkException => + logError(s"Error closing earliest Kafka consumer for ${entry.getKey}", e) + } + true + } else { + false + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + */ + def getOrCreate( + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val topicPartition = new TopicPartition(topic, partition) + val key = CacheKey(groupId, topicPartition) + + // If this is reattempt at running the task, then invalidate cache and start with + // a new consumer + if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { + cache.remove(key) + new CachedKafkaConsumer(topicPartition, kafkaParams) + } else { + if (!cache.containsKey(key)) { + cache.put(key, new CachedKafkaConsumer(topicPartition, kafkaParams)) + } + cache.get(key) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala new file mode 100644 index 0000000000000..40d568a12c25d --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.io.Writer + +import scala.collection.mutable.HashMap +import scala.util.control.NonFatal + +import org.apache.kafka.common.TopicPartition +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +/** + * Utilities for converting Kafka related objects to and from json. + */ +private object JsonUtils { + private implicit val formats = Serialization.formats(NoTypeHints) + + /** + * Read TopicPartitions from json string + */ + def partitions(str: String): Array[TopicPartition] = { + try { + Serialization.read[Map[String, Seq[Int]]](str).flatMap { case (topic, parts) => + parts.map { part => + new TopicPartition(topic, part) + } + }.toArray + } catch { + case NonFatal(x) => + throw new IllegalArgumentException( + s"""Expected e.g. {"topicA":[0,1],"topicB":[0,1]}, got $str""") + } + } + + /** + * Write TopicPartitions as json string + */ + def partitions(partitions: Iterable[TopicPartition]): String = { + val result = new HashMap[String, List[Int]] + partitions.foreach { tp => + val parts: List[Int] = result.getOrElse(tp.topic, Nil) + result += tp.topic -> (tp.partition::parts) + } + Serialization.write(result) + } + + /** + * Read per-TopicPartition offsets from json string + */ + def partitionOffsets(str: String): Map[TopicPartition, Long] = { + try { + Serialization.read[Map[String, Map[Int, Long]]](str).flatMap { case (topic, partOffsets) => + partOffsets.map { case (part, offset) => + new TopicPartition(topic, part) -> offset + } + }.toMap + } catch { + case NonFatal(x) => + throw new IllegalArgumentException( + s"""Expected e.g. {"topicA":{"0":23,"1":-1},"topicB":{"0":-2}}, got $str""") + } + } + + /** + * Write per-TopicPartition offsets as json string + */ + def partitionOffsets(partitionOffsets: Map[TopicPartition, Long]): String = { + val result = new HashMap[String, HashMap[Int, Long]]() + partitionOffsets.foreach { case (tp, off) => + val parts = result.getOrElse(tp.topic, new HashMap[Int, Long]) + parts += tp.partition -> off + result += tp.topic -> parts + } + Serialization.write(result) + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala new file mode 100644 index 0000000000000..61cba737d148a --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -0,0 +1,521 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import org.apache.kafka.clients.consumer.{Consumer, KafkaConsumer, OffsetOutOfRangeException} +import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.UninterruptibleThread + +/** + * A [[Source]] that uses Kafka's own [[KafkaConsumer]] API to reads data from Kafka. The design + * for this source is as follows. + * + * - The [[KafkaSourceOffset]] is the custom [[Offset]] defined for this source that contains + * a map of TopicPartition -> offset. Note that this offset is 1 + (available offset). For + * example if the last record in a Kafka topic "t", partition 2 is offset 5, then + * KafkaSourceOffset will contain TopicPartition("t", 2) -> 6. This is done keep it consistent + * with the semantics of `KafkaConsumer.position()`. + * + * - The [[ConsumerStrategy]] class defines which Kafka topics and partitions should be read + * by this source. These strategies directly correspond to the different consumption options + * in . This class is designed to return a configured [[KafkaConsumer]] that is used by the + * [[KafkaSource]] to query for the offsets. See the docs on + * [[org.apache.spark.sql.kafka010.KafkaSource.ConsumerStrategy]] for more details. + * + * - The [[KafkaSource]] written to do the following. + * + * - As soon as the source is created, the pre-configured KafkaConsumer returned by the + * [[ConsumerStrategy]] is used to query the initial offsets that this source should + * start reading from. This used to create the first batch. + * + * - `getOffset()` uses the KafkaConsumer to query the latest available offsets, which are + * returned as a [[KafkaSourceOffset]]. + * + * - `getBatch()` returns a DF that reads from the 'start offset' until the 'end offset' in + * for each partition. The end offset is excluded to be consistent with the semantics of + * [[KafkaSourceOffset]] and `KafkaConsumer.position()`. + * + * - The DF returned is based on [[KafkaSourceRDD]] which is constructed such that the + * data from Kafka topic + partition is consistently read by the same executors across + * batches, and cached KafkaConsumers in the executors can be reused efficiently. See the + * docs on [[KafkaSourceRDD]] for more details. + * + * Zero data lost is not guaranteed when topics are deleted. If zero data lost is critical, the user + * must make sure all messages in a topic have been processed when deleting a topic. + * + * There is a known issue caused by KAFKA-1894: the query using KafkaSource maybe cannot be stopped. + * To avoid this issue, you should make sure stopping the query before stopping the Kafka brokers + * and not use wrong broker addresses. + */ +private[kafka010] case class KafkaSource( + sqlContext: SQLContext, + consumerStrategy: ConsumerStrategy, + executorKafkaParams: ju.Map[String, Object], + sourceOptions: Map[String, String], + metadataPath: String, + startingOffsets: StartingOffsets, + failOnDataLoss: Boolean) + extends Source with Logging { + + private val sc = sqlContext.sparkContext + + private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + + private val maxOffsetFetchAttempts = + sourceOptions.getOrElse("fetchOffset.numRetries", "3").toInt + + private val offsetFetchAttemptIntervalMs = + sourceOptions.getOrElse("fetchOffset.retryIntervalMs", "10").toLong + + private val maxOffsetsPerTrigger = + sourceOptions.get("maxOffsetsPerTrigger").map(_.toLong) + + /** + * A KafkaConsumer used in the driver to query the latest Kafka offsets. This only queries the + * offsets and never commits them. + */ + private val consumer = consumerStrategy.createConsumer() + + /** + * Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only + * called in StreamExecutionThread. Otherwise, interrupting a thread while running + * `KafkaConsumer.poll` may hang forever (KAFKA-1894). + */ + private lazy val initialPartitionOffsets = { + val metadataLog = new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) + metadataLog.get(0).getOrElse { + val offsets = startingOffsets match { + case EarliestOffsets => KafkaSourceOffset(fetchEarliestOffsets()) + case LatestOffsets => KafkaSourceOffset(fetchLatestOffsets()) + case SpecificOffsets(p) => KafkaSourceOffset(fetchSpecificStartingOffsets(p)) + } + metadataLog.add(0, offsets) + logInfo(s"Initial offsets: $offsets") + offsets + }.partitionToOffsets + } + + private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None + + override def schema: StructType = KafkaSource.kafkaSchema + + /** Returns the maximum available offset for this source. */ + override def getOffset: Option[Offset] = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + val latest = fetchLatestOffsets() + val offsets = maxOffsetsPerTrigger match { + case None => + latest + case Some(limit) if currentPartitionOffsets.isEmpty => + rateLimit(limit, initialPartitionOffsets, latest) + case Some(limit) => + rateLimit(limit, currentPartitionOffsets.get, latest) + } + + currentPartitionOffsets = Some(offsets) + logDebug(s"GetOffset: ${offsets.toSeq.map(_.toString).sorted}") + Some(KafkaSourceOffset(offsets)) + } + + /** Proportionally distribute limit number of offsets among topicpartitions */ + private def rateLimit( + limit: Long, + from: Map[TopicPartition, Long], + until: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + val fromNew = fetchNewPartitionEarliestOffsets(until.keySet.diff(from.keySet).toSeq) + val sizes = until.flatMap { + case (tp, end) => + // If begin isn't defined, something's wrong, but let alert logic in getBatch handle it + from.get(tp).orElse(fromNew.get(tp)).flatMap { begin => + val size = end - begin + logDebug(s"rateLimit $tp size is $size") + if (size > 0) Some(tp -> size) else None + } + } + val total = sizes.values.sum.toDouble + if (total < 1) { + until + } else { + until.map { + case (tp, end) => + tp -> sizes.get(tp).map { size => + val begin = from.get(tp).getOrElse(fromNew(tp)) + val prorate = limit * (size / total) + logDebug(s"rateLimit $tp prorated amount is $prorate") + // Don't completely starve small topicpartitions + val off = begin + (if (prorate < 1) Math.ceil(prorate) else Math.floor(prorate)).toLong + logDebug(s"rateLimit $tp new offset is $off") + // Paranoia, make sure not to return an offset that's past end + Math.min(end, off) + }.getOrElse(end) + } + } + } + + /** + * Returns the data that is between the offsets + * [`start.get.partitionToOffsets`, `end.partitionToOffsets`), i.e. end.partitionToOffsets is + * exclusive. + */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + // Make sure initialPartitionOffsets is initialized + initialPartitionOffsets + + logInfo(s"GetBatch called with start = $start, end = $end") + val untilPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(end) + val fromPartitionOffsets = start match { + case Some(prevBatchEndOffset) => + KafkaSourceOffset.getPartitionOffsets(prevBatchEndOffset) + case None => + initialPartitionOffsets + } + + // Find the new partitions, and get their earliest offsets + val newPartitions = untilPartitionOffsets.keySet.diff(fromPartitionOffsets.keySet) + val newPartitionOffsets = fetchNewPartitionEarliestOffsets(newPartitions.toSeq) + if (newPartitionOffsets.keySet != newPartitions) { + // We cannot get from offsets for some partitions. It means they got deleted. + val deletedPartitions = newPartitions.diff(newPartitionOffsets.keySet) + reportDataLoss( + s"Cannot find earliest offsets of ${deletedPartitions}. Some data may have been missed") + } + logInfo(s"Partitions added: $newPartitionOffsets") + newPartitionOffsets.filter(_._2 != 0).foreach { case (p, o) => + reportDataLoss( + s"Added partition $p starts from $o instead of 0. Some data may have been missed") + } + + val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet) + if (deletedPartitions.nonEmpty) { + reportDataLoss(s"$deletedPartitions are gone. Some data may have been missed") + } + + // Use the until partitions to calculate offset ranges to ignore partitions that have + // been deleted + val topicPartitions = untilPartitionOffsets.keySet.filter { tp => + // Ignore partitions that we don't know the from offsets. + newPartitionOffsets.contains(tp) || fromPartitionOffsets.contains(tp) + }.toSeq + logDebug("TopicPartitions: " + topicPartitions.mkString(", ")) + + val sortedExecutors = getSortedExecutorList(sc) + val numExecutors = sortedExecutors.length + logDebug("Sorted executors: " + sortedExecutors.mkString(", ")) + + // Calculate offset ranges + val offsetRanges = topicPartitions.map { tp => + val fromOffset = fromPartitionOffsets.get(tp).getOrElse { + newPartitionOffsets.getOrElse(tp, { + // This should not happen since newPartitionOffsets contains all partitions not in + // fromPartitionOffsets + throw new IllegalStateException(s"$tp doesn't have a from offset") + }) + } + val untilOffset = untilPartitionOffsets(tp) + val preferredLoc = if (numExecutors > 0) { + // This allows cached KafkaConsumers in the executors to be re-used to read the same + // partition in every batch. + Some(sortedExecutors(floorMod(tp.hashCode, numExecutors))) + } else None + KafkaSourceRDDOffsetRange(tp, fromOffset, untilOffset, preferredLoc) + }.filter { range => + if (range.untilOffset < range.fromOffset) { + reportDataLoss(s"Partition ${range.topicPartition}'s offset was changed from " + + s"${range.fromOffset} to ${range.untilOffset}, some data may have been missed") + false + } else { + true + } + }.toArray + + // Create a RDD that reads from Kafka and get the (key, value) pair as byte arrays. + val rdd = new KafkaSourceRDD( + sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr => + Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) + } + + logInfo("GetBatch generating RDD of offset range: " + + offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) + + // On recovery, getBatch will get called before getOffset + if (currentPartitionOffsets.isEmpty) { + currentPartitionOffsets = Some(untilPartitionOffsets) + } + + sqlContext.createDataFrame(rdd, schema) + } + + /** Stop this source and free any resources it has allocated. */ + override def stop(): Unit = synchronized { + consumer.close() + } + + override def toString(): String = s"KafkaSource[$consumerStrategy]" + + /** + * Set consumer position to specified offsets, making sure all assignments are set. + */ + private def fetchSpecificStartingOffsets( + partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + val result = withRetriesWithoutInterrupt { + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + assert(partitions.asScala == partitionOffsets.keySet, + "If startingOffsets contains specific offsets, you must specify all TopicPartitions.\n" + + "Use -1 for latest, -2 for earliest, if you don't care.\n" + + s"Specified: ${partitionOffsets.keySet} Assigned: ${partitions.asScala}") + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to $partitionOffsets") + + partitionOffsets.foreach { + case (tp, -1) => consumer.seekToEnd(ju.Arrays.asList(tp)) + case (tp, -2) => consumer.seekToBeginning(ju.Arrays.asList(tp)) + case (tp, off) => consumer.seek(tp, off) + } + partitionOffsets.map { + case (tp, _) => tp -> consumer.position(tp) + } + } + partitionOffsets.foreach { + case (tp, off) if off != -1 && off != -2 => + if (result(tp) != off) { + reportDataLoss( + s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}") + } + case _ => + // no real way to check that beginning or end is reasonable + } + result + } + + /** + * Fetch the earliest offsets of partitions. + */ + private def fetchEarliestOffsets(): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the beginning") + + consumer.seekToBeginning(partitions) + val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got earliest offsets for partition : $partitionOffsets") + partitionOffsets + } + + /** + * Fetch the latest offset of partitions. + */ + private def fetchLatestOffsets(): Map[TopicPartition, Long] = withRetriesWithoutInterrupt { + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + logDebug(s"Partitions assigned to consumer: $partitions. Seeking to the end.") + + consumer.seekToEnd(partitions) + val partitionOffsets = partitions.asScala.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got latest offsets for partition : $partitionOffsets") + partitionOffsets + } + + /** + * Fetch the earliest offsets for newly discovered partitions. The return result may not contain + * some partitions if they are deleted. + */ + private def fetchNewPartitionEarliestOffsets( + newPartitions: Seq[TopicPartition]): Map[TopicPartition, Long] = + if (newPartitions.isEmpty) { + Map.empty[TopicPartition, Long] + } else { + withRetriesWithoutInterrupt { + // Poll to get the latest assigned partitions + consumer.poll(0) + val partitions = consumer.assignment() + consumer.pause(partitions) + logDebug(s"\tPartitions assigned to consumer: $partitions") + + // Get the earliest offset of each partition + consumer.seekToBeginning(partitions) + val partitionOffsets = newPartitions.filter { p => + // When deleting topics happen at the same time, some partitions may not be in + // `partitions`. So we need to ignore them + partitions.contains(p) + }.map(p => p -> consumer.position(p)).toMap + logDebug(s"Got earliest offsets for new partitions: $partitionOffsets") + partitionOffsets + } + } + + /** + * Helper function that does multiple retries on the a body of code that returns offsets. + * Retries are needed to handle transient failures. For e.g. race conditions between getting + * assignment and getting position while topics/partitions are deleted can cause NPEs. + * + * This method also makes sure `body` won't be interrupted to workaround a potential issue in + * `KafkaConsumer.poll`. (KAFKA-1894) + */ + private def withRetriesWithoutInterrupt( + body: => Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + // Make sure `KafkaConsumer.poll` won't be interrupted (KAFKA-1894) + assert(Thread.currentThread().isInstanceOf[StreamExecutionThread]) + + synchronized { + var result: Option[Map[TopicPartition, Long]] = None + var attempt = 1 + var lastException: Throwable = null + while (result.isEmpty && attempt <= maxOffsetFetchAttempts + && !Thread.currentThread().isInterrupted) { + Thread.currentThread match { + case ut: UninterruptibleThread => + // "KafkaConsumer.poll" may hang forever if the thread is interrupted (E.g., the query + // is stopped)(KAFKA-1894). Hence, we just make sure we don't interrupt it. + // + // If the broker addresses are wrong, or Kafka cluster is down, "KafkaConsumer.poll" may + // hang forever as well. This cannot be resolved in KafkaSource until Kafka fixes the + // issue. + ut.runUninterruptibly { + try { + result = Some(body) + } catch { + case x: OffsetOutOfRangeException => + reportDataLoss(x.getMessage) + case NonFatal(e) => + lastException = e + logWarning(s"Error in attempt $attempt getting Kafka offsets: ", e) + attempt += 1 + Thread.sleep(offsetFetchAttemptIntervalMs) + } + } + case _ => + throw new IllegalStateException( + "Kafka APIs must be executed on a o.a.spark.util.UninterruptibleThread") + } + } + if (Thread.interrupted()) { + throw new InterruptedException() + } + if (result.isEmpty) { + assert(attempt > maxOffsetFetchAttempts) + assert(lastException != null) + throw lastException + } + result.get + } + } + + /** + * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`. + * Otherwise, just log a warning. + */ + private def reportDataLoss(message: String): Unit = { + if (failOnDataLoss) { + throw new IllegalStateException(message + + ". Set the source option 'failOnDataLoss' to 'false' if you want to ignore these checks.") + } else { + logWarning(message) + } + } +} + + +/** Companion object for the [[KafkaSource]]. */ +private[kafka010] object KafkaSource { + + def kafkaSchema: StructType = StructType(Seq( + StructField("key", BinaryType), + StructField("value", BinaryType), + StructField("topic", StringType), + StructField("partition", IntegerType), + StructField("offset", LongType), + StructField("timestamp", LongType), + StructField("timestampType", IntegerType) + )) + + sealed trait ConsumerStrategy { + def createConsumer(): Consumer[Array[Byte], Array[Byte]] + } + + case class AssignStrategy(partitions: Array[TopicPartition], kafkaParams: ju.Map[String, Object]) + extends ConsumerStrategy { + override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + consumer.assign(ju.Arrays.asList(partitions: _*)) + consumer + } + + override def toString: String = s"Assign[${partitions.mkString(", ")}]" + } + + case class SubscribeStrategy(topics: Seq[String], kafkaParams: ju.Map[String, Object]) + extends ConsumerStrategy { + override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + consumer.subscribe(topics.asJava) + consumer + } + + override def toString: String = s"Subscribe[${topics.mkString(", ")}]" + } + + case class SubscribePatternStrategy( + topicPattern: String, kafkaParams: ju.Map[String, Object]) + extends ConsumerStrategy { + override def createConsumer(): Consumer[Array[Byte], Array[Byte]] = { + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](kafkaParams) + consumer.subscribe( + ju.regex.Pattern.compile(topicPattern), + new NoOpConsumerRebalanceListener()) + consumer + } + + override def toString: String = s"SubscribePattern[$topicPattern]" + } + + private def getSortedExecutorList(sc: SparkContext): Array[String] = { + val bm = sc.env.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compare) + .map(_.toString) + } + + private def compare(a: ExecutorCacheTaskLocation, b: ExecutorCacheTaskLocation): Boolean = { + if (a.host == b.host) { a.executorId > b.executorId } else { a.host > b.host } + } + + private def floorMod(a: Long, b: Int): Int = ((a % b).toInt + b) % b +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala new file mode 100644 index 0000000000000..b5ade982515f0 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.sql.execution.streaming.Offset + +/** + * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and + * their offsets. + */ +private[kafka010] +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { + override def toString(): String = { + partitionToOffsets.toSeq.sortBy(_._1.toString).mkString("[", ", ", "]") + } +} + +/** Companion object of the [[KafkaSourceOffset]] */ +private[kafka010] object KafkaSourceOffset { + + def getPartitionOffsets(offset: Offset): Map[TopicPartition, Long] = { + offset match { + case o: KafkaSourceOffset => o.partitionToOffsets + case _ => + throw new IllegalArgumentException( + s"Invalid conversion from offset of ${offset.getClass} to KafkaSourceOffset") + } + } + + /** + * Returns [[KafkaSourceOffset]] from a variable sequence of (topic, partitionId, offset) + * tuples. + */ + def apply(offsetTuples: (String, Int, Long)*): KafkaSourceOffset = { + KafkaSourceOffset(offsetTuples.map { case(t, p, o) => (new TopicPartition(t, p), o) }.toMap) + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala new file mode 100644 index 0000000000000..585ced875caa7 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} +import java.util.UUID + +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.serialization.ByteArrayDeserializer + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types.StructType + +/** + * The provider class for the [[KafkaSource]]. This provider is designed such that it throws + * IllegalArgumentException when the Kafka Dataset is created, so that it can catch + * missing options even before the query is started. + */ +private[kafka010] class KafkaSourceProvider extends StreamSourceProvider + with DataSourceRegister with Logging { + + import KafkaSourceProvider._ + + /** + * Returns the name and schema of the source. In addition, it also verifies whether the options + * are correct and sufficient to create the [[KafkaSource]] when the query is started. + */ + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one") + validateOptions(parameters) + ("kafka", KafkaSource.kafkaSchema) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + validateOptions(parameters) + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val specifiedKafkaParams = + parameters + .keySet + .filter(_.toLowerCase.startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + + val deserClassName = classOf[ByteArrayDeserializer].getName + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}" + + val startingOffsets = + caseInsensitiveParams.get(STARTING_OFFSETS_OPTION_KEY).map(_.trim.toLowerCase) match { + case Some("latest") => LatestOffsets + case Some("earliest") => EarliestOffsets + case Some(json) => SpecificOffsets(JsonUtils.partitionOffsets(json)) + case None => LatestOffsets + } + + val kafkaParamsForStrategy = + ConfigUpdater("source", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // So that consumers in Kafka source do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-driver") + + // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial + // offsets by itself instead of counting on KafkaConsumer. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + + // So that consumers in the driver does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // So that the driver does not pull too much data + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + val kafkaParamsForExecutors = + ConfigUpdater("executor", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Make sure executors do only what the driver tells them. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + + // So that consumers in executors do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") + + // So that consumers in executors does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + case ("assign", value) => + AssignStrategy( + JsonUtils.partitions(value), + kafkaParamsForStrategy) + case ("subscribe", value) => + SubscribeStrategy( + value.split(",").map(_.trim()).filter(_.nonEmpty), + kafkaParamsForStrategy) + case ("subscribepattern", value) => + SubscribePatternStrategy( + value.trim(), + kafkaParamsForStrategy) + case _ => + // Should never reach here as we are already matching on + // matched strategy names + throw new IllegalArgumentException("Unknown option") + } + + val failOnDataLoss = + caseInsensitiveParams.getOrElse(FAIL_ON_DATA_LOSS_OPTION_KEY, "true").toBoolean + + new KafkaSource( + sqlContext, + strategy, + kafkaParamsForExecutors, + parameters, + metadataPath, + startingOffsets, + failOnDataLoss) + } + + private def validateOptions(parameters: Map[String, String]): Unit = { + + // Validate source options + + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + val specifiedStrategies = + caseInsensitiveParams.filter { case (k, _) => STRATEGY_OPTION_KEYS.contains(k) }.toSeq + if (specifiedStrategies.isEmpty) { + throw new IllegalArgumentException( + "One of the following options must be specified for Kafka source: " + + STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.") + } else if (specifiedStrategies.size > 1) { + throw new IllegalArgumentException( + "Only one of the following options can be specified for Kafka source: " + + STRATEGY_OPTION_KEYS.mkString(", ") + ". See the docs for more details.") + } + + val strategy = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { + case ("assign", value) => + if (!value.trim.startsWith("{")) { + throw new IllegalArgumentException( + "No topicpartitions to assign as specified value for option " + + s"'assign' is '$value'") + } + + case ("subscribe", value) => + val topics = value.split(",").map(_.trim).filter(_.nonEmpty) + if (topics.isEmpty) { + throw new IllegalArgumentException( + "No topics to subscribe to as specified value for option " + + s"'subscribe' is '$value'") + } + case ("subscribepattern", value) => + val pattern = caseInsensitiveParams("subscribepattern").trim() + if (pattern.isEmpty) { + throw new IllegalArgumentException( + "Pattern to subscribe is empty as specified value for option " + + s"'subscribePattern' is '$value'") + } + case _ => + // Should never reach here as we are already matching on + // matched strategy names + throw new IllegalArgumentException("Unknown option") + } + + // Validate user-specified Kafka options + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.GROUP_ID_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ConsumerConfig.GROUP_ID_CONFIG}' is not supported as " + + s"user-specified consumer groups is not used to track offsets.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}")) { + throw new IllegalArgumentException( + s""" + |Kafka option '${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG}' is not supported. + |Instead set the source option '$STARTING_OFFSETS_OPTION_KEY' to 'earliest' or 'latest' + |to specify where to start. Structured Streaming manages which offsets are consumed + |internally, rather than relying on the kafkaConsumer to do it. This will ensure that no + |data is missed when when new topics/partitions are dynamically subscribed. Note that + |'$STARTING_OFFSETS_OPTION_KEY' only applies when a new Streaming query is started, and + |that resuming will always pick up from where the query left off. See the docs for more + |details. + """.stripMargin) + } + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame operations " + + "to explicitly deserialize the keys.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are deserialized as byte arrays with ByteArrayDeserializer. Use DataFrame " + + "operations to explicitly deserialize the values.") + } + + val otherUnsupportedConfigs = Seq( + ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, // committing correctly requires new APIs in Source + ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG) // interceptors can modify payload, so not safe + + otherUnsupportedConfigs.foreach { c => + if (caseInsensitiveParams.contains(s"kafka.$c")) { + throw new IllegalArgumentException(s"Kafka option '$c' is not supported") + } + } + + if (!caseInsensitiveParams.contains(s"kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}")) { + throw new IllegalArgumentException( + s"Option 'kafka.${ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG}' must be specified for " + + s"configuring Kafka consumer") + } + } + + override def shortName(): String = "kafka" + + /** Class to conveniently update Kafka config params, while logging the changes */ + private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { + private val map = new ju.HashMap[String, Object](kafkaParams.asJava) + + def set(key: String, value: Object): this.type = { + map.put(key, value) + logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.get(key).getOrElse("")}") + this + } + + def setIfUnset(key: String, value: Object): ConfigUpdater = { + if (!map.containsKey(key)) { + map.put(key, value) + logInfo(s"$module: Set $key to $value") + } + this + } + + def build(): ju.Map[String, Object] = map + } +} + +private[kafka010] object KafkaSourceProvider { + private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign") + private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" + private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala new file mode 100644 index 0000000000000..802dd040aed93 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import scala.collection.mutable.ArrayBuffer + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** Offset range that one partition of the KafkaSourceRDD has to read */ +private[kafka010] case class KafkaSourceRDDOffsetRange( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long, + preferredLoc: Option[String]) { + def topic: String = topicPartition.topic + def partition: Int = topicPartition.partition + def size: Long = untilOffset - fromOffset +} + + +/** Partition of the KafkaSourceRDD */ +private[kafka010] case class KafkaSourceRDDPartition( + index: Int, offsetRange: KafkaSourceRDDOffsetRange) extends Partition + + +/** + * An RDD that reads data from Kafka based on offset ranges across multiple partitions. + * Additionally, it allows preferred locations to be set for each topic + partition, so that + * the [[KafkaSource]] can ensure the same executor always reads the same topic + partition + * and cached KafkaConsuemrs (see [[CachedKafkaConsumer]] can be used read data efficiently. + * + * @param sc the [[SparkContext]] + * @param executorKafkaParams Kafka configuration for creating KafkaConsumer on the executors + * @param offsetRanges Offset ranges that define the Kafka data belonging to this RDD + */ +private[kafka010] class KafkaSourceRDD( + sc: SparkContext, + executorKafkaParams: ju.Map[String, Object], + offsetRanges: Seq[KafkaSourceRDDOffsetRange], + pollTimeoutMs: Long) + extends RDD[ConsumerRecord[Array[Byte], Array[Byte]]](sc, Nil) { + + override def persist(newLevel: StorageLevel): this.type = { + logError("Kafka ConsumerRecord is not serializable. " + + "Use .map to extract fields before calling .persist or .window") + super.persist(newLevel) + } + + override def getPartitions: Array[Partition] = { + offsetRanges.zipWithIndex.map { case (o, i) => new KafkaSourceRDDPartition(i, o) }.toArray + } + + override def count(): Long = offsetRanges.map(_.size).sum + + override def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[ConsumerRecord[Array[Byte], Array[Byte]]] = { + val nonEmptyPartitions = + this.partitions.map(_.asInstanceOf[KafkaSourceRDDPartition]).filter(_.offsetRange.size > 0) + + if (num < 1 || nonEmptyPartitions.isEmpty) { + return new Array[ConsumerRecord[Array[Byte], Array[Byte]]](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.offsetRange.size) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[ConsumerRecord[Array[Byte], Array[Byte]]] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]) => + it.take(parts(tc.partitionId)).toArray, parts.keys.toArray + ) + res.foreach(buf ++= _) + buf.toArray + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + val part = split.asInstanceOf[KafkaSourceRDDPartition] + part.offsetRange.preferredLoc.map(Seq(_)).getOrElse(Seq.empty) + } + + override def compute( + thePart: Partition, + context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { + val range = thePart.asInstanceOf[KafkaSourceRDDPartition].offsetRange + assert( + range.fromOffset <= range.untilOffset, + s"Beginning offset ${range.fromOffset} is after the ending offset ${range.untilOffset} " + + s"for topic ${range.topic} partition ${range.partition}. " + + "You either provided an invalid fromOffset, or the Kafka topic has been damaged") + if (range.fromOffset == range.untilOffset) { + logInfo(s"Beginning offset ${range.fromOffset} is the same as ending offset " + + s"skipping ${range.topic} ${range.partition}") + Iterator.empty + + } else { + + val consumer = CachedKafkaConsumer.getOrCreate( + range.topic, range.partition, executorKafkaParams) + var requestOffset = range.fromOffset + + logDebug(s"Creating iterator for $range") + + new Iterator[ConsumerRecord[Array[Byte], Array[Byte]]]() { + override def hasNext(): Boolean = requestOffset < range.untilOffset + override def next(): ConsumerRecord[Array[Byte], Array[Byte]] = { + assert(hasNext(), "Can't call next() once untilOffset has been reached") + val r = consumer.get(requestOffset, pollTimeoutMs) + requestOffset += 1 + r + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/StartingOffsets.scala similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/StartingOffsets.scala index 2479e67e369ec..83959e597171a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SourceStatus.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/StartingOffsets.scala @@ -15,20 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.kafka010 -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.{Offset, Source} +import org.apache.kafka.common.TopicPartition -/** - * :: Experimental :: - * Status and metrics of a streaming [[Source]]. - * - * @param description Description of the source corresponding to this status - * @param offset Current offset of the source, if known - * @since 2.0.0 +/* + * Values that can be specified for config startingOffsets */ -@Experimental -class SourceStatus private[sql] ( - val description: String, - val offset: Option[Offset]) +private[kafka010] sealed trait StartingOffsets + +private[kafka010] case object EarliestOffsets extends StartingOffsets + +private[kafka010] case object LatestOffsets extends StartingOffsets + +private[kafka010] case class SpecificOffsets( + partitionOffsets: Map[TopicPartition, Long]) extends StartingOffsets diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java similarity index 79% rename from core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java index f6e46ae9a481a..596f775c56dbc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetchException.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package-info.java @@ -15,10 +15,7 @@ * limitations under the License. */ -package org.apache.spark.storage - -import org.apache.spark.SparkException - -private[spark] -case class BlockFetchException(messages: String, throwable: Throwable) - extends SparkException(messages, throwable) +/** + * Structured Streaming Data Source for Kafka 0.10 + */ +package org.apache.spark.sql.kafka010; diff --git a/external/kafka-0-10-sql/src/test/resources/log4j.properties b/external/kafka-0-10-sql/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..75e3b53a093f6 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/resources/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark-project.jetty=WARN + diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/JsonUtilsSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/JsonUtilsSuite.scala new file mode 100644 index 0000000000000..54b980049d1a2 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/JsonUtilsSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkFunSuite + +class JsonUtilsSuite extends SparkFunSuite { + + test("parsing partitions") { + val parsed = JsonUtils.partitions("""{"topicA":[0,1],"topicB":[4,6]}""") + val expected = Array( + new TopicPartition("topicA", 0), + new TopicPartition("topicA", 1), + new TopicPartition("topicB", 4), + new TopicPartition("topicB", 6) + ) + assert(parsed.toSeq === expected.toSeq) + } + + test("parsing partitionOffsets") { + val parsed = JsonUtils.partitionOffsets( + """{"topicA":{"0":23,"1":-1},"topicB":{"0":-2}}""") + + assert(parsed(new TopicPartition("topicA", 0)) === 23) + assert(parsed(new TopicPartition("topicA", 1)) === -1) + assert(parsed(new TopicPartition("topicB", 0)) === -2) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala new file mode 100644 index 0000000000000..7056a41b1751e --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.spark.sql.streaming.OffsetSuite + +class KafkaSourceOffsetSuite extends OffsetSuite { + + compare( + one = KafkaSourceOffset(("t", 0, 1L)), + two = KafkaSourceOffset(("t", 0, 2L))) + + compare( + one = KafkaSourceOffset(("t", 0, 1L), ("t", 1, 0L)), + two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L))) + + compare( + one = KafkaSourceOffset(("t", 0, 1L), ("T", 0, 0L)), + two = KafkaSourceOffset(("t", 0, 2L), ("T", 0, 1L))) + + compare( + one = KafkaSourceOffset(("t", 0, 1L)), + two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L))) +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala new file mode 100644 index 0000000000000..ed4cc75920e8e --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -0,0 +1,604 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import scala.util.Random + +import org.apache.kafka.clients.producer.RecordMetadata +import org.apache.kafka.common.TopicPartition +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.streaming.{ ProcessingTime, StreamTest } +import org.apache.spark.sql.test.SharedSQLContext + +abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { + + protected var testUtils: KafkaTestUtils = _ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + super.afterAll() + } + } + + protected def makeSureGetOffsetCalled = AssertOnQuery { q => + // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure + // its "getOffset" is called before pushing any data. Otherwise, because of the race contion, + // we don't know which data should be fetched when `startingOffsets` is latest. + q.processAllAvailable() + true + } + + /** + * Add data to Kafka. + * + * `topicAction` can be used to run actions for each topic before inserting data. + */ + case class AddKafkaData(topics: Set[String], data: Int*) + (implicit ensureDataInMultiplePartition: Boolean = false, + concurrent: Boolean = false, + message: String = "", + topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { + + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + if (query.get.isActive) { + // Make sure no Spark job is running when deleting a topic + query.get.processAllAvailable() + } + + val existingTopics = testUtils.getAllTopicsAndPartitionSize().toMap + val newTopics = topics.diff(existingTopics.keySet) + for (newTopic <- newTopics) { + topicAction(newTopic, None) + } + for (existingTopicPartitions <- existingTopics) { + topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2)) + } + + // Read all topics again in case some topics are delete. + val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys + require( + query.nonEmpty, + "Cannot add data when there is no query for finding the active kafka source") + + val sources = query.get.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] => + source.asInstanceOf[KafkaSource] + } + if (sources.isEmpty) { + throw new Exception( + "Could not find Kafka source in the StreamExecution logical plan to add data to") + } else if (sources.size > 1) { + throw new Exception( + "Could not select the Kafka source in the StreamExecution logical plan as there" + + "are multiple Kafka sources:\n\t" + sources.mkString("\n\t")) + } + val kafkaSource = sources.head + val topic = topics.toSeq(Random.nextInt(topics.size)) + val sentMetadata = testUtils.sendMessages(topic, data.map { _.toString }.toArray) + + def metadataToStr(m: (String, RecordMetadata)): String = { + s"Sent ${m._1} to partition ${m._2.partition()}, offset ${m._2.offset()}" + } + // Verify that the test data gets inserted into multiple partitions + if (ensureDataInMultiplePartition) { + require( + sentMetadata.groupBy(_._2.partition).size > 1, + s"Added data does not test multiple partitions: ${sentMetadata.map(metadataToStr)}") + } + + val offset = KafkaSourceOffset(testUtils.getLatestOffsets(topics)) + logInfo(s"Added data, expected offset $offset") + (kafkaSource, offset) + } + + override def toString: String = + s"AddKafkaData(topics = $topics, data = $data, message = $message)" + } +} + + +class KafkaSourceSuite extends KafkaSourceTest { + + import testImplicits._ + + private val topicId = new AtomicInteger(0) + + test("maxOffsetsPerTrigger") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 3) + testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0)) + testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1)) + testUtils.sendMessages(topic, Array("1"), Some(2)) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("maxOffsetsPerTrigger", 10) + .option("subscribe", topic) + .option("startingOffsets", "earliest") + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + val clock = new StreamManualClock + + val waitUntilBatchProcessed = AssertOnQuery { q => + eventually(Timeout(streamingTimeout)) { + if (!q.exception.isDefined) { + assert(clock.isStreamWaitingAt(clock.getTimeMillis())) + } + } + if (q.exception.isDefined) { + throw q.exception.get + } + true + } + + testStream(mapped)( + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + // 1 from smallest, 1 from middle, 8 from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116 + ), + StopStream, + StartStream(ProcessingTime(100), clock), + waitUntilBatchProcessed, + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125 + ), + AdvanceManualClock(100), + waitUntilBatchProcessed, + // smallest now empty, 1 more from middle, 9 more from biggest + CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107, + 11, 108, 109, 110, 111, 112, 113, 114, 115, 116, + 12, 117, 118, 119, 120, 121, 122, 123, 124, 125, + 13, 126, 127, 128, 129, 130, 131, 132, 133, 134 + ) + ) + } + + test("cannot stop Kafka stream") { + val topic = newTopic() + testUtils.createTopic(newTopic(), partitions = 5) + testUtils.sendMessages(topic, (101 to 105).map { _.toString }.toArray) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"topic-.*") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + StopStream + ) + } + + test("assign from latest offsets") { + val topic = newTopic() + testFromLatestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4)) + } + + test("assign from earliest offsets") { + val topic = newTopic() + testFromEarliestOffsets(topic, false, "assign" -> assignString(topic, 0 to 4)) + } + + test("assign from specific offsets") { + val topic = newTopic() + testFromSpecificOffsets(topic, "assign" -> assignString(topic, 0 to 4)) + } + + test("subscribing topic by name from latest offsets") { + val topic = newTopic() + testFromLatestOffsets(topic, true, "subscribe" -> topic) + } + + test("subscribing topic by name from earliest offsets") { + val topic = newTopic() + testFromEarliestOffsets(topic, true, "subscribe" -> topic) + } + + test("subscribing topic by name from specific offsets") { + val topic = newTopic() + testFromSpecificOffsets(topic, "subscribe" -> topic) + } + + test("subscribing topic by pattern from latest offsets") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromLatestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*") + } + + test("subscribing topic by pattern from earliest offsets") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromEarliestOffsets(topic, true, "subscribePattern" -> s"$topicPrefix-.*") + } + + test("subscribing topic by pattern from specific offsets") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-suffix" + testFromSpecificOffsets(topic, "subscribePattern" -> s"$topicPrefix-.*") + } + + test("subscribing topic by pattern with topic deletions") { + val topicPrefix = newTopic() + val topic = topicPrefix + "-seems" + val topic2 = topicPrefix + "-bad" + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", s"$topicPrefix-.*") + .option("failOnDataLoss", "false") + + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + Assert { + testUtils.deleteTopic(topic) + testUtils.createTopic(topic2, partitions = 5) + true + }, + AddKafkaData(Set(topic2), 4, 5, 6), + CheckAnswer(2, 3, 4, 5, 6, 7) + ) + } + + test("bad source options") { + def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { + val ex = intercept[IllegalArgumentException] { + val reader = spark + .readStream + .format("kafka") + options.foreach { case (k, v) => reader.option(k, v) } + reader.load() + } + expectedMsgs.foreach { m => + assert(ex.getMessage.toLowerCase.contains(m.toLowerCase)) + } + } + + // No strategy specified + testBadOptions()("options must be specified", "subscribe", "subscribePattern") + + // Multiple strategies specified + testBadOptions("subscribe" -> "t", "subscribePattern" -> "t.*")( + "only one", "options can be specified") + + testBadOptions("subscribe" -> "t", "assign" -> """{"a":[0]}""")( + "only one", "options can be specified") + + testBadOptions("assign" -> "")("no topicpartitions to assign") + testBadOptions("subscribe" -> "")("no topics to subscribe") + testBadOptions("subscribePattern" -> "")("pattern to subscribe is empty") + } + + test("unsupported kafka configs") { + def testUnsupportedConfig(key: String, value: String = "someValue"): Unit = { + val ex = intercept[IllegalArgumentException] { + val reader = spark + .readStream + .format("kafka") + .option("subscribe", "topic") + .option("kafka.bootstrap.servers", "somehost") + .option(s"$key", value) + reader.load() + } + assert(ex.getMessage.toLowerCase.contains("not supported")) + } + + testUnsupportedConfig("kafka.group.id") + testUnsupportedConfig("kafka.auto.offset.reset") + testUnsupportedConfig("kafka.enable.auto.commit") + testUnsupportedConfig("kafka.interceptor.classes") + testUnsupportedConfig("kafka.key.deserializer") + testUnsupportedConfig("kafka.value.deserializer") + + testUnsupportedConfig("kafka.auto.offset.reset", "none") + testUnsupportedConfig("kafka.auto.offset.reset", "someValue") + testUnsupportedConfig("kafka.auto.offset.reset", "earliest") + testUnsupportedConfig("kafka.auto.offset.reset", "latest") + } + + test("input row metrics") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val kafka = spark + .readStream + .format("kafka") + .option("subscribe", topic) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + AssertOnLastQueryStatus { status => + assert(status.triggerDetails.get("numRows.input.total").toInt > 0) + assert(status.sourceStatuses(0).processingRate > 0.0) + } + ) + } + + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + + private def assignString(topic: String, partitions: Iterable[Int]): String = { + JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p))) + } + + private def testFromSpecificOffsets(topic: String, options: (String, String)*): Unit = { + val partitionOffsets = Map( + new TopicPartition(topic, 0) -> -2L, + new TopicPartition(topic, 1) -> -1L, + new TopicPartition(topic, 2) -> 0L, + new TopicPartition(topic, 3) -> 1L, + new TopicPartition(topic, 4) -> 2L + ) + val startingOffsets = JsonUtils.partitionOffsets(partitionOffsets) + + testUtils.createTopic(topic, partitions = 5) + // part 0 starts at earliest, these should all be seen + testUtils.sendMessages(topic, Array(-20, -21, -22).map(_.toString), Some(0)) + // part 1 starts at latest, these should all be skipped + testUtils.sendMessages(topic, Array(-10, -11, -12).map(_.toString), Some(1)) + // part 2 starts at 0, these should all be seen + testUtils.sendMessages(topic, Array(0, 1, 2).map(_.toString), Some(2)) + // part 3 starts at 1, first should be skipped + testUtils.sendMessages(topic, Array(10, 11, 12).map(_.toString), Some(3)) + // part 4 starts at 2, first and second should be skipped + testUtils.sendMessages(topic, Array(20, 21, 22).map(_.toString), Some(4)) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("startingOffsets", startingOffsets) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + options.foreach { case (k, v) => reader.option(k, v) } + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt) + + testStream(mapped)( + makeSureGetOffsetCalled, + CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), + StopStream, + StartStream(), + CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22), // Should get the data back on recovery + AddKafkaData(Set(topic), 30, 31, 32, 33, 34)(ensureDataInMultiplePartition = true), + CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22, 30, 31, 32, 33, 34), + StopStream + ) + } + + private def testFromLatestOffsets( + topic: String, + addPartitions: Boolean, + options: (String, String)*): Unit = { + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("-1")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("startingOffsets", s"latest") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + options.foreach { case (k, v) => reader.option(k, v) } + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(2, 3, 4), + StopStream, + StartStream(), + CheckAnswer(2, 3, 4), // Should get the data back on recovery + StopStream, + AddKafkaData(Set(topic), 4, 5, 6), // Add data when stream is stopped + StartStream(), + CheckAnswer(2, 3, 4, 5, 6, 7), // Should get the added data + AddKafkaData(Set(topic), 7, 8), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), + AssertOnQuery("Add partitions") { query: StreamExecution => + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } + true + }, + AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) + ) + } + + private def testFromEarliestOffsets( + topic: String, + addPartitions: Boolean, + options: (String, String)*): Unit = { + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, (1 to 3).map { _.toString }.toArray) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark.readStream + reader + .format(classOf[KafkaSourceProvider].getCanonicalName.stripSuffix("$")) + .option("startingOffsets", s"earliest") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + options.foreach { case (k, v) => reader.option(k, v) } + val kafka = reader.load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + val mapped = kafka.map(kv => kv._2.toInt + 1) + + testStream(mapped)( + AddKafkaData(Set(topic), 4, 5, 6), // Add data when stream is stopped + CheckAnswer(2, 3, 4, 5, 6, 7), + StopStream, + StartStream(), + CheckAnswer(2, 3, 4, 5, 6, 7), + StopStream, + AddKafkaData(Set(topic), 7, 8), + StartStream(), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9), + AssertOnQuery("Add partitions") { query: StreamExecution => + if (addPartitions) { + testUtils.addPartitions(topic, 10) + } + true + }, + AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16), + CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) + ) + } +} + + +class KafkaSourceStressSuite extends KafkaSourceTest { + + import testImplicits._ + + val topicId = new AtomicInteger(1) + + @volatile var topics: Seq[String] = (1 to 5).map(_ => newStressTopic) + + def newStressTopic: String = s"stress${topicId.getAndIncrement()}" + + private def nextInt(start: Int, end: Int): Int = { + start + Random.nextInt(start + end - 1) + } + + test("stress test with multiple topics and partitions") { + topics.foreach { topic => + testUtils.createTopic(topic, partitions = nextInt(1, 6)) + testUtils.sendMessages(topic, (101 to 105).map { _.toString }.toArray) + } + + // Create Kafka source that reads from latest offset + val kafka = + spark.readStream + .format(classOf[KafkaSourceProvider].getCanonicalName.stripSuffix("$")) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("kafka.metadata.max.age.ms", "1") + .option("subscribePattern", "stress.*") + .option("failOnDataLoss", "false") + .load() + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] + + val mapped = kafka.map(kv => kv._2.toInt + 1) + + runStressTest( + mapped, + Seq(makeSureGetOffsetCalled), + (d, running) => { + Random.nextInt(5) match { + case 0 => // Add a new topic + topics = topics ++ Seq(newStressTopic) + AddKafkaData(topics.toSet, d: _*)(message = s"Add topic $newStressTopic", + topicAction = (topic, partition) => { + if (partition.isEmpty) { + testUtils.createTopic(topic, partitions = nextInt(1, 6)) + } + }) + case 1 if running => + // Only delete a topic when the query is running. Otherwise, we may lost data and + // cannot check the correctness. + val deletedTopic = topics(Random.nextInt(topics.size)) + if (deletedTopic != topics.head) { + topics = topics.filterNot(_ == deletedTopic) + } + AddKafkaData(topics.toSet, d: _*)(message = s"Delete topic $deletedTopic", + topicAction = (topic, partition) => { + // Never remove the first topic to make sure we have at least one topic + if (topic == deletedTopic && deletedTopic != topics.head) { + testUtils.deleteTopic(deletedTopic) + } + }) + case 2 => // Add new partitions + AddKafkaData(topics.toSet, d: _*)(message = "Add partitiosn", + topicAction = (topic, partition) => { + testUtils.addPartitions(topic, partition.get + nextInt(1, 6)) + }) + case _ => // Just add new data + AddKafkaData(topics.toSet, d: _*) + } + }, + iterations = 50) + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala new file mode 100644 index 0000000000000..9b24ccdd560e8 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.io.File +import java.lang.{Integer => JInt} +import java.net.InetSocketAddress +import java.util.{Map => JMap, Properties} +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ +import scala.language.postfixOps +import scala.util.Random + +import kafka.admin.AdminUtils +import kafka.api.Request +import kafka.common.TopicAndPartition +import kafka.server.{KafkaConfig, KafkaServer, OffsetCheckpoint} +import kafka.utils.ZkUtils +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * This is a helper class for Kafka test suites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + * + * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. + */ +class KafkaTestUtils extends Logging { + + // Zookeeper related configurations + private val zkHost = "localhost" + private var zkPort: Int = 0 + private val zkConnectionTimeout = 60000 + private val zkSessionTimeout = 6000 + + private var zookeeper: EmbeddedZookeeper = _ + + private var zkUtils: ZkUtils = _ + + // Kafka broker related configurations + private val brokerHost = "localhost" + private var brokerPort = 0 + private var brokerConf: KafkaConfig = _ + + // Kafka broker server + private var server: KafkaServer = _ + + // Kafka producer + private var producer: Producer[String, String] = _ + + // Flag to test whether the system is correctly started + private var zkReady = false + private var brokerReady = false + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } + + def zookeeperClient: ZkUtils = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client") + Option(zkUtils).getOrElse( + throw new IllegalStateException("Zookeeper client is not yet initialized")) + } + + // Set up the Embedded Zookeeper server and get the proper Zookeeper port + private def setupEmbeddedZookeeper(): Unit = { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort + zkUtils = ZkUtils(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, false) + zkReady = true + } + + // Set up the Embedded Kafka server + private def setupEmbeddedKafkaServer(): Unit = { + assert(zkReady, "Zookeeper should be set up beforehand") + + // Kafka broker startup + Utils.startServiceOnPort(brokerPort, port => { + brokerPort = port + brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) + server = new KafkaServer(brokerConf) + server.startup() + brokerPort = server.boundPort() + (server, brokerPort) + }, new SparkConf(), "KafkaBroker") + + brokerReady = true + } + + /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ + def setup(): Unit = { + setupEmbeddedZookeeper() + setupEmbeddedKafkaServer() + } + + /** Teardown the whole servers, including Kafka broker and Zookeeper */ + def teardown(): Unit = { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } + + if (server != null) { + server.shutdown() + server = null + } + + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + if (zkUtils != null) { + zkUtils.close() + zkUtils = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } + + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + AdminUtils.createTopic(zkUtils, topic, partitions, 1) + // wait until metadata is propagated + (0 until partitions).foreach { p => + waitUntilMetadataIsPropagated(topic, p) + } + } + + def getAllTopicsAndPartitionSize(): Seq[(String, Int)] = { + zkUtils.getPartitionsForTopics(zkUtils.getAllTopics()).mapValues(_.size).toSeq + } + + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String): Unit = { + createTopic(topic, 1) + } + + /** Delete a Kafka topic and wait until it is propagated to the whole cluster */ + def deleteTopic(topic: String): Unit = { + val partitions = zkUtils.getPartitionsForTopics(Seq(topic))(topic).size + AdminUtils.deleteTopic(zkUtils, topic) + verifyTopicDeletion(zkUtils, topic, partitions, List(this.server)) + } + + /** Add new paritions to a Kafka topic */ + def addPartitions(topic: String, partitions: Int): Unit = { + AdminUtils.addPartitions(zkUtils, topic, partitions) + // wait until metadata is propagated + (0 until partitions).foreach { p => + waitUntilMetadataIsPropagated(topic, p) + } + } + + /** Java-friendly function for sending messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { + sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) + } + + /** Send the messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + /** Send the array of messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[String]): Seq[(String, RecordMetadata)] = { + sendMessages(topic, messages, None) + } + + /** Send the array of messages to the Kafka broker using specified partition */ + def sendMessages( + topic: String, + messages: Array[String], + partition: Option[Int]): Seq[(String, RecordMetadata)] = { + producer = new KafkaProducer[String, String](producerConfiguration) + val offsets = try { + messages.map { m => + val record = partition match { + case Some(p) => new ProducerRecord[String, String](topic, p, null, m) + case None => new ProducerRecord[String, String](topic, m) + } + val metadata = + producer.send(record).get(10, TimeUnit.SECONDS) + logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}") + (m, metadata) + } + } finally { + if (producer != null) { + producer.close() + producer = null + } + } + offsets + } + + def getLatestOffsets(topics: Set[String]): Map[TopicPartition, Long] = { + val kc = new KafkaConsumer[String, String](consumerConfiguration) + logInfo("Created consumer to get latest offsets") + kc.subscribe(topics.asJavaCollection) + kc.poll(0) + val partitions = kc.assignment() + kc.pause(partitions) + kc.seekToEnd(partitions) + val offsets = partitions.asScala.map(p => p -> kc.position(p)).toMap + kc.close() + logInfo("Closed consumer to get latest offsets") + offsets + } + + private def brokerConfiguration: Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("advertised.host.name", "localhost") + props.put("port", brokerPort.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkAddress) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props.put("delete.topic.enable", "true") + props + } + + private def producerConfiguration: Properties = { + val props = new Properties() + props.put("bootstrap.servers", brokerAddress) + props.put("value.serializer", classOf[StringSerializer].getName) + props.put("key.serializer", classOf[StringSerializer].getName) + // wait for all in-sync replicas to ack sends + props.put("acks", "all") + props + } + + private def consumerConfiguration: Properties = { + val props = new Properties() + props.put("bootstrap.servers", brokerAddress) + props.put("group.id", "group-KafkaTestUtils-" + Random.nextInt) + props.put("value.deserializer", classOf[StringDeserializer].getName) + props.put("key.deserializer", classOf[StringDeserializer].getName) + props.put("enable.auto.commit", "false") + props + } + + private def verifyTopicDeletion( + zkUtils: ZkUtils, + topic: String, + numPartitions: Int, + servers: Seq[KafkaServer]) { + import ZkUtils._ + val topicAndPartitions = (0 until numPartitions).map(TopicAndPartition(topic, _)) + def isDeleted(): Boolean = { + // wait until admin path for delete topic is deleted, signaling completion of topic deletion + val deletePath = !zkUtils.pathExists(getDeleteTopicPath(topic)) + val topicPath = !zkUtils.pathExists(getTopicPath(topic)) + // ensure that the topic-partition has been deleted from all brokers' replica managers + val replicaManager = servers.forall(server => topicAndPartitions.forall(tp => + server.replicaManager.getPartition(tp.topic, tp.partition) == None)) + // ensure that logs from all replicas are deleted if delete topic is marked successful + val logManager = servers.forall(server => topicAndPartitions.forall(tp => + server.getLogManager().getLog(tp).isEmpty)) + // ensure that topic is removed from all cleaner offsets + val cleaner = servers.forall(server => topicAndPartitions.forall { tp => + val checkpoints = server.getLogManager().logDirs.map { logDir => + new OffsetCheckpoint(new File(logDir, "cleaner-offset-checkpoint")).read() + } + checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) + }) + deletePath && topicPath && replicaManager && logManager && cleaner + } + eventually(timeout(10.seconds)) { + assert(isDeleted, s"$topic not deleted after timeout") + } + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { + case Some(partitionState) => + val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr + + zkUtils.getLeaderForPartition(topic, partition).isDefined && + Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && + leaderAndInSyncReplicas.isr.size >= 1 + + case _ => + false + } + eventually(timeout(10.seconds)) { + assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout") + } + } + + private class EmbeddedZookeeper(val zkConnect: String) { + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + val actualPort = factory.getLocalPort + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } +} + diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml new file mode 100644 index 0000000000000..0cde10930272d --- /dev/null +++ b/external/kafka-0-10/pom.xml @@ -0,0 +1,98 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.3-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kafka-0-10_2.11 + + streaming-kafka-0-10 + + jar + Spark Integration for Kafka 0.10 + http://spark.apache.org/ + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.kafka + kafka_${scala.binary.version} + 0.10.0.1 + + + com.sun.jmx + jmxri + + + com.sun.jdmk + jmxtools + + + net.sf.jopt-simple + jopt-simple + + + org.slf4j + slf4j-simple + + + org.apache.zookeeper + zookeeper + + + + + net.sf.jopt-simple + jopt-simple + 3.2 + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala new file mode 100644 index 0000000000000..fa3ea6131a507 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{ util => ju } + +import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } +import org.apache.kafka.common.{ KafkaException, TopicPartition } + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging + + +/** + * Consumer of single topicpartition, intended for cached reuse. + * Underlying consumer is not threadsafe, so neither is this, + * but processing the same topicpartition and group id in multiple threads is usually bad anyway. + */ +private[kafka010] +class CachedKafkaConsumer[K, V] private( + val groupId: String, + val topic: String, + val partition: Int, + val kafkaParams: ju.Map[String, Object]) extends Logging { + + assert(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), + "groupId used for cache key must match the groupId in kafkaParams") + + val topicPartition = new TopicPartition(topic, partition) + + protected val consumer = { + val c = new KafkaConsumer[K, V](kafkaParams) + val tps = new ju.ArrayList[TopicPartition]() + tps.add(topicPartition) + c.assign(tps) + c + } + + // TODO if the buffer was kept around as a random-access structure, + // could possibly optimize re-calculating of an RDD in the same batch + protected var buffer = ju.Collections.emptyList[ConsumerRecord[K, V]]().iterator + protected var nextOffset = -2L + + def close(): Unit = consumer.close() + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { + logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset") + if (offset != nextOffset) { + logInfo(s"Initial fetch for $groupId $topic $partition $offset") + seek(offset) + poll(timeout) + } + + if (!buffer.hasNext()) { poll(timeout) } + assert(buffer.hasNext(), + s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") + var record = buffer.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topic $partition $offset") + seek(offset) + poll(timeout) + assert(buffer.hasNext(), + s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") + record = buffer.next() + assert(record.offset == offset, + s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset") + } + + nextOffset = offset + 1 + record + } + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(timeout: Long): Unit = { + val p = consumer.poll(timeout) + val r = p.records(topicPartition) + logDebug(s"Polled ${p.partitions()} ${r.size}") + buffer = r.iterator + } + +} + +private[kafka010] +object CachedKafkaConsumer extends Logging { + + private case class CacheKey(groupId: String, topic: String, partition: Int) + + // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap + private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null + + /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */ + def init( + initialCapacity: Int, + maxCapacity: Int, + loadFactor: Float): Unit = CachedKafkaConsumer.synchronized { + if (null == cache) { + logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") + cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]]( + initialCapacity, loadFactor, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = { + if (this.size > maxCapacity) { + try { + entry.getValue.consumer.close() + } catch { + case x: KafkaException => + logError("Error closing oldest Kafka consumer", x) + } + true + } else { + false + } + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + */ + def get[K, V]( + groupId: String, + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = + CachedKafkaConsumer.synchronized { + val k = CacheKey(groupId, topic, partition) + val v = cache.get(k) + if (null == v) { + logInfo(s"Cache miss for $k") + logDebug(cache.keySet.toString) + val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) + cache.put(k, c) + c + } else { + // any given topicpartition should have a consistent key and value type + v.asInstanceOf[CachedKafkaConsumer[K, V]] + } + } + + /** + * Get a fresh new instance, unassociated with the global cache. + * Caller is responsible for closing + */ + def getUncached[K, V]( + groupId: String, + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = + new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) + + /** remove consumer for given groupId, topic, and partition, if it exists */ + def remove(groupId: String, topic: String, partition: Int): Unit = { + val k = CacheKey(groupId, topic, partition) + logInfo(s"Removing $k from cache") + val v = CachedKafkaConsumer.synchronized { + cache.remove(k) + } + if (null != v) { + v.close() + logInfo(s"Removed $k from cache") + } + } +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala new file mode 100644 index 0000000000000..778c06ea16a2b --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -0,0 +1,477 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{ lang => jl, util => ju } + +import scala.collection.JavaConverters._ + +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.annotation.Experimental +import org.apache.spark.internal.Logging + +/** + * :: Experimental :: + * Choice of how to create and configure underlying Kafka Consumers on driver and executors. + * See [[ConsumerStrategies]] to obtain instances. + * Kafka 0.10 consumers can require additional, sometimes complex, setup after object + * instantiation. This interface encapsulates that process, and allows it to be checkpointed. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ +@Experimental +abstract class ConsumerStrategy[K, V] { + /** + * Kafka + * configuration parameters to be used on executors. Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + */ + def executorKafkaParams: ju.Map[String, Object] + + /** + * Must return a fully configured Kafka Consumer, including subscribed or assigned topics. + * See Kafka docs. + * This consumer will be used on the driver to query for offsets only, not messages. + * The consumer must be returned in a state that it is safe to call poll(0) on. + * @param currentOffsets A map from TopicPartition to offset, indicating how far the driver + * has successfully read. Will be empty on initial start, possibly non-empty on restart from + * checkpoint. + */ + def onStart(currentOffsets: ju.Map[TopicPartition, jl.Long]): Consumer[K, V] +} + +/** + * Subscribe to a collection of topics. + * @param topics collection of topics to subscribe + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsets: offsets to begin at on initial startup. If no offset is given for a + * TopicPartition, the committed offset (if applicable) or kafka param + * auto.offset.reset will be used. + */ +private case class Subscribe[K, V]( + topics: ju.Collection[jl.String], + kafkaParams: ju.Map[String, Object], + offsets: ju.Map[TopicPartition, jl.Long] + ) extends ConsumerStrategy[K, V] with Logging { + + def executorKafkaParams: ju.Map[String, Object] = kafkaParams + + def onStart(currentOffsets: ju.Map[TopicPartition, jl.Long]): Consumer[K, V] = { + val consumer = new KafkaConsumer[K, V](kafkaParams) + consumer.subscribe(topics) + val toSeek = if (currentOffsets.isEmpty) { + offsets + } else { + currentOffsets + } + if (!toSeek.isEmpty) { + // work around KAFKA-3370 when reset is none + // poll will throw if no position, i.e. auto offset reset none and no explicit position + // but cant seek to a position before poll, because poll is what gets subscription partitions + // So, poll, suppress the first exception, then seek + val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) + val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + try { + consumer.poll(0) + } catch { + case x: NoOffsetForPartitionException if shouldSuppress => + logWarning("Catching NoOffsetForPartitionException since " + + ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " is none. See KAFKA-3370") + } + toSeek.asScala.foreach { case (topicPartition, offset) => + consumer.seek(topicPartition, offset) + } + // we've called poll, we must pause or next poll may consume messages and set position + consumer.pause(consumer.assignment()) + } + + consumer + } +} + +/** + * Subscribe to all topics matching specified pattern to get dynamically assigned partitions. + * The pattern matching will be done periodically against topics existing at the time of check. + * @param pattern pattern to subscribe to + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsets: offsets to begin at on initial startup. If no offset is given for a + * TopicPartition, the committed offset (if applicable) or kafka param + * auto.offset.reset will be used. + */ +private case class SubscribePattern[K, V]( + pattern: ju.regex.Pattern, + kafkaParams: ju.Map[String, Object], + offsets: ju.Map[TopicPartition, jl.Long] + ) extends ConsumerStrategy[K, V] with Logging { + + def executorKafkaParams: ju.Map[String, Object] = kafkaParams + + def onStart(currentOffsets: ju.Map[TopicPartition, jl.Long]): Consumer[K, V] = { + val consumer = new KafkaConsumer[K, V](kafkaParams) + consumer.subscribe(pattern, new NoOpConsumerRebalanceListener()) + val toSeek = if (currentOffsets.isEmpty) { + offsets + } else { + currentOffsets + } + if (!toSeek.isEmpty) { + // work around KAFKA-3370 when reset is none, see explanation in Subscribe above + val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) + val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE" + try { + consumer.poll(0) + } catch { + case x: NoOffsetForPartitionException if shouldSuppress => + logWarning("Catching NoOffsetForPartitionException since " + + ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " is none. See KAFKA-3370") + } + toSeek.asScala.foreach { case (topicPartition, offset) => + consumer.seek(topicPartition, offset) + } + // we've called poll, we must pause or next poll may consume messages and set position + consumer.pause(consumer.assignment()) + } + + consumer + } +} + +/** + * Assign a fixed collection of TopicPartitions + * @param topicPartitions collection of TopicPartitions to assign + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsets: offsets to begin at on initial startup. If no offset is given for a + * TopicPartition, the committed offset (if applicable) or kafka param + * auto.offset.reset will be used. + */ +private case class Assign[K, V]( + topicPartitions: ju.Collection[TopicPartition], + kafkaParams: ju.Map[String, Object], + offsets: ju.Map[TopicPartition, jl.Long] + ) extends ConsumerStrategy[K, V] { + + def executorKafkaParams: ju.Map[String, Object] = kafkaParams + + def onStart(currentOffsets: ju.Map[TopicPartition, jl.Long]): Consumer[K, V] = { + val consumer = new KafkaConsumer[K, V](kafkaParams) + consumer.assign(topicPartitions) + val toSeek = if (currentOffsets.isEmpty) { + offsets + } else { + currentOffsets + } + if (!toSeek.isEmpty) { + // this doesn't need a KAFKA-3370 workaround, because partitions are known, no poll needed + toSeek.asScala.foreach { case (topicPartition, offset) => + consumer.seek(topicPartition, offset) + } + } + + consumer + } +} + +/** + * :: Experimental :: + * object for obtaining instances of [[ConsumerStrategy]] + */ +@Experimental +object ConsumerStrategies { + /** + * :: Experimental :: + * Subscribe to a collection of topics. + * @param topics collection of topics to subscribe + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsets: offsets to begin at on initial startup. If no offset is given for a + * TopicPartition, the committed offset (if applicable) or kafka param + * auto.offset.reset will be used. + */ + @Experimental + def Subscribe[K, V]( + topics: Iterable[jl.String], + kafkaParams: collection.Map[String, Object], + offsets: collection.Map[TopicPartition, Long]): ConsumerStrategy[K, V] = { + new Subscribe[K, V]( + new ju.ArrayList(topics.asJavaCollection), + new ju.HashMap[String, Object](kafkaParams.asJava), + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + } + + /** + * :: Experimental :: + * Subscribe to a collection of topics. + * @param topics collection of topics to subscribe + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + */ + @Experimental + def Subscribe[K, V]( + topics: Iterable[jl.String], + kafkaParams: collection.Map[String, Object]): ConsumerStrategy[K, V] = { + new Subscribe[K, V]( + new ju.ArrayList(topics.asJavaCollection), + new ju.HashMap[String, Object](kafkaParams.asJava), + ju.Collections.emptyMap[TopicPartition, jl.Long]()) + } + + /** + * :: Experimental :: + * Subscribe to a collection of topics. + * @param topics collection of topics to subscribe + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsets: offsets to begin at on initial startup. If no offset is given for a + * TopicPartition, the committed offset (if applicable) or kafka param + * auto.offset.reset will be used. + */ + @Experimental + def Subscribe[K, V]( + topics: ju.Collection[jl.String], + kafkaParams: ju.Map[String, Object], + offsets: ju.Map[TopicPartition, jl.Long]): ConsumerStrategy[K, V] = { + new Subscribe[K, V](topics, kafkaParams, offsets) + } + + /** + * :: Experimental :: + * Subscribe to a collection of topics. + * @param topics collection of topics to subscribe + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + */ + @Experimental + def Subscribe[K, V]( + topics: ju.Collection[jl.String], + kafkaParams: ju.Map[String, Object]): ConsumerStrategy[K, V] = { + new Subscribe[K, V](topics, kafkaParams, ju.Collections.emptyMap[TopicPartition, jl.Long]()) + } + + /** :: Experimental :: + * Subscribe to all topics matching specified pattern to get dynamically assigned partitions. + * The pattern matching will be done periodically against topics existing at the time of check. + * @param pattern pattern to subscribe to + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsets: offsets to begin at on initial startup. If no offset is given for a + * TopicPartition, the committed offset (if applicable) or kafka param + * auto.offset.reset will be used. + */ + @Experimental + def SubscribePattern[K, V]( + pattern: ju.regex.Pattern, + kafkaParams: collection.Map[String, Object], + offsets: collection.Map[TopicPartition, Long]): ConsumerStrategy[K, V] = { + new SubscribePattern[K, V]( + pattern, + new ju.HashMap[String, Object](kafkaParams.asJava), + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + } + + /** :: Experimental :: + * Subscribe to all topics matching specified pattern to get dynamically assigned partitions. + * The pattern matching will be done periodically against topics existing at the time of check. + * @param pattern pattern to subscribe to + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + */ + @Experimental + def SubscribePattern[K, V]( + pattern: ju.regex.Pattern, + kafkaParams: collection.Map[String, Object]): ConsumerStrategy[K, V] = { + new SubscribePattern[K, V]( + pattern, + new ju.HashMap[String, Object](kafkaParams.asJava), + ju.Collections.emptyMap[TopicPartition, jl.Long]()) + } + + /** :: Experimental :: + * Subscribe to all topics matching specified pattern to get dynamically assigned partitions. + * The pattern matching will be done periodically against topics existing at the time of check. + * @param pattern pattern to subscribe to + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsets: offsets to begin at on initial startup. If no offset is given for a + * TopicPartition, the committed offset (if applicable) or kafka param + * auto.offset.reset will be used. + */ + @Experimental + def SubscribePattern[K, V]( + pattern: ju.regex.Pattern, + kafkaParams: ju.Map[String, Object], + offsets: ju.Map[TopicPartition, jl.Long]): ConsumerStrategy[K, V] = { + new SubscribePattern[K, V](pattern, kafkaParams, offsets) + } + + /** :: Experimental :: + * Subscribe to all topics matching specified pattern to get dynamically assigned partitions. + * The pattern matching will be done periodically against topics existing at the time of check. + * @param pattern pattern to subscribe to + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + */ + @Experimental + def SubscribePattern[K, V]( + pattern: ju.regex.Pattern, + kafkaParams: ju.Map[String, Object]): ConsumerStrategy[K, V] = { + new SubscribePattern[K, V]( + pattern, + kafkaParams, + ju.Collections.emptyMap[TopicPartition, jl.Long]()) + } + + /** + * :: Experimental :: + * Assign a fixed collection of TopicPartitions + * @param topicPartitions collection of TopicPartitions to assign + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsets: offsets to begin at on initial startup. If no offset is given for a + * TopicPartition, the committed offset (if applicable) or kafka param + * auto.offset.reset will be used. + */ + @Experimental + def Assign[K, V]( + topicPartitions: Iterable[TopicPartition], + kafkaParams: collection.Map[String, Object], + offsets: collection.Map[TopicPartition, Long]): ConsumerStrategy[K, V] = { + new Assign[K, V]( + new ju.ArrayList(topicPartitions.asJavaCollection), + new ju.HashMap[String, Object](kafkaParams.asJava), + new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava)) + } + + /** + * :: Experimental :: + * Assign a fixed collection of TopicPartitions + * @param topicPartitions collection of TopicPartitions to assign + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + */ + @Experimental + def Assign[K, V]( + topicPartitions: Iterable[TopicPartition], + kafkaParams: collection.Map[String, Object]): ConsumerStrategy[K, V] = { + new Assign[K, V]( + new ju.ArrayList(topicPartitions.asJavaCollection), + new ju.HashMap[String, Object](kafkaParams.asJava), + ju.Collections.emptyMap[TopicPartition, jl.Long]()) + } + + /** + * :: Experimental :: + * Assign a fixed collection of TopicPartitions + * @param topicPartitions collection of TopicPartitions to assign + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsets: offsets to begin at on initial startup. If no offset is given for a + * TopicPartition, the committed offset (if applicable) or kafka param + * auto.offset.reset will be used. + */ + @Experimental + def Assign[K, V]( + topicPartitions: ju.Collection[TopicPartition], + kafkaParams: ju.Map[String, Object], + offsets: ju.Map[TopicPartition, jl.Long]): ConsumerStrategy[K, V] = { + new Assign[K, V](topicPartitions, kafkaParams, offsets) + } + + /** + * :: Experimental :: + * Assign a fixed collection of TopicPartitions + * @param topicPartitions collection of TopicPartitions to assign + * @param kafkaParams Kafka + * + * configuration parameters to be used on driver. The same params will be used on executors, + * with minor automatic modifications applied. + * Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + */ + @Experimental + def Assign[K, V]( + topicPartitions: ju.Collection[TopicPartition], + kafkaParams: ju.Map[String, Object]): ConsumerStrategy[K, V] = { + new Assign[K, V]( + topicPartitions, + kafkaParams, + ju.Collections.emptyMap[TopicPartition, jl.Long]()) + } + +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala new file mode 100644 index 0000000000000..7e57bb18cbd50 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{ util => ju } +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicReference + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.common.{ PartitionInfo, TopicPartition } + +import org.apache.spark.SparkException +import org.apache.spark.internal.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{StreamingContext, Time} +import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator + +/** + * A DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number + * of messages + * per second that each '''partition''' will accept. + * @param locationStrategy In most cases, pass in [[PreferConsistent]], + * see [[LocationStrategy]] for more details. + * @param executorKafkaParams Kafka + * + * configuration parameters. + * Requires "bootstrap.servers" to be set with Kafka broker(s), + * NOT zookeeper servers, specified in host1:port1,host2:port2 form. + * @param consumerStrategy In most cases, pass in [[Subscribe]], + * see [[ConsumerStrategy]] for more details + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ +private[spark] class DirectKafkaInputDStream[K, V]( + _ssc: StreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V] + ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { + + val executorKafkaParams = { + val ekp = new ju.HashMap[String, Object](consumerStrategy.executorKafkaParams) + KafkaUtils.fixKafkaParams(ekp) + ekp + } + + protected var currentOffsets = Map[TopicPartition, Long]() + + @transient private var kc: Consumer[K, V] = null + def consumer(): Consumer[K, V] = this.synchronized { + if (null == kc) { + kc = consumerStrategy.onStart(currentOffsets.mapValues(l => new java.lang.Long(l)).asJava) + } + kc + } + + override def persist(newLevel: StorageLevel): DStream[ConsumerRecord[K, V]] = { + logError("Kafka ConsumerRecord is not serializable. " + + "Use .map to extract fields before calling .persist or .window") + super.persist(newLevel) + } + + protected def getBrokers = { + val c = consumer + val result = new ju.HashMap[TopicPartition, String]() + val hosts = new ju.HashMap[TopicPartition, String]() + val assignments = c.assignment().iterator() + while (assignments.hasNext()) { + val tp: TopicPartition = assignments.next() + if (null == hosts.get(tp)) { + val infos = c.partitionsFor(tp.topic).iterator() + while (infos.hasNext()) { + val i = infos.next() + hosts.put(new TopicPartition(i.topic(), i.partition()), i.leader.host()) + } + } + result.put(tp, hosts.get(tp)) + } + result + } + + protected def getPreferredHosts: ju.Map[TopicPartition, String] = { + locationStrategy match { + case PreferBrokers => getBrokers + case PreferConsistent => ju.Collections.emptyMap[TopicPartition, String]() + case PreferFixed(hostMap) => hostMap + } + } + + // Keep this consistent with how other streams are named (e.g. "Flume polling stream [2]") + private[streaming] override def name: String = s"Kafka 0.10 direct stream [$id]" + + protected[streaming] override val checkpointData = + new DirectKafkaInputDStreamCheckpointData + + + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + Some(new DirectKafkaRateController(id, + RateEstimator.create(ssc.conf, context.graph.batchDuration))) + } else { + None + } + } + + private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( + "spark.streaming.kafka.maxRatePerPartition", 0) + + protected[streaming] def maxMessagesPerPartition( + offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { + val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + + // calculate a per-partition rate limit based on current lag + val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { + case Some(rate) => + val lagPerPartition = offsets.map { case (tp, offset) => + tp -> Math.max(offset - currentOffsets(tp), 0) + } + val totalLag = lagPerPartition.values.sum + + lagPerPartition.map { case (tp, lag) => + val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + tp -> (if (maxRateLimitPerPartition > 0) { + Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) + } + case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + } + + if (effectiveRateLimitPerPartition.values.sum > 0) { + val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 + Some(effectiveRateLimitPerPartition.map { + case (tp, limit) => tp -> (secsPerBatch * limit).toLong + }) + } else { + None + } + } + + /** + * The concern here is that poll might consume messages despite being paused, + * which would throw off consumer position. Fix position if this happens. + */ + private def paranoidPoll(c: Consumer[K, V]): Unit = { + val msgs = c.poll(0) + if (!msgs.isEmpty) { + // position should be minimum offset per topicpartition + msgs.asScala.foldLeft(Map[TopicPartition, Long]()) { (acc, m) => + val tp = new TopicPartition(m.topic, m.partition) + val off = acc.get(tp).map(o => Math.min(o, m.offset)).getOrElse(m.offset) + acc + (tp -> off) + }.foreach { case (tp, off) => + logInfo(s"poll(0) returned messages, seeking $tp to $off to compensate") + c.seek(tp, off) + } + } + } + + /** + * Returns the latest (highest) available offsets, taking new partitions into account. + */ + protected def latestOffsets(): Map[TopicPartition, Long] = { + val c = consumer + paranoidPoll(c) + val parts = c.assignment().asScala + + // make sure new partitions are reflected in currentOffsets + val newPartitions = parts.diff(currentOffsets.keySet) + // position for new partitions determined by auto.offset.reset if no commit + currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap + // don't want to consume messages, so pause + c.pause(newPartitions.asJava) + // find latest available offsets + c.seekToEnd(currentOffsets.keySet.asJava) + parts.map(tp => tp -> c.position(tp)).toMap + } + + // limits the maximum number of messages per partition + protected def clamp( + offsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { + + maxMessagesPerPartition(offsets).map { mmp => + mmp.map { case (tp, messages) => + val uo = offsets(tp) + tp -> Math.min(currentOffsets(tp) + messages, uo) + } + }.getOrElse(offsets) + } + + override def compute(validTime: Time): Option[KafkaRDD[K, V]] = { + val untilOffsets = clamp(latestOffsets()) + val offsetRanges = untilOffsets.map { case (tp, uo) => + val fo = currentOffsets(tp) + OffsetRange(tp.topic, tp.partition, fo, uo) + } + val rdd = new KafkaRDD[K, V]( + context.sparkContext, executorKafkaParams, offsetRanges.toArray, getPreferredHosts, true) + + // Report the record number and metadata of this batch interval to InputInfoTracker. + val description = offsetRanges.filter { offsetRange => + // Don't display empty ranges. + offsetRange.fromOffset != offsetRange.untilOffset + }.map { offsetRange => + s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + }.mkString("\n") + // Copy offsetRanges to immutable.List to prevent from being modified by the user + val metadata = Map( + "offsets" -> offsetRanges.toList, + StreamInputInfo.METADATA_KEY_DESCRIPTION -> description) + val inputInfo = StreamInputInfo(id, rdd.count, metadata) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + + currentOffsets = untilOffsets + commitAll() + Some(rdd) + } + + override def start(): Unit = { + val c = consumer + paranoidPoll(c) + if (currentOffsets.isEmpty) { + currentOffsets = c.assignment().asScala.map { tp => + tp -> c.position(tp) + }.toMap + } + + // don't actually want to consume any messages, so pause all partitions + c.pause(currentOffsets.keySet.asJava) + } + + override def stop(): Unit = this.synchronized { + if (kc != null) { + kc.close() + } + } + + protected val commitQueue = new ConcurrentLinkedQueue[OffsetRange] + protected val commitCallback = new AtomicReference[OffsetCommitCallback] + + /** + * Queue up offset ranges for commit to Kafka at a future time. Threadsafe. + * @param offsetRanges The maximum untilOffset for a given partition will be used at commit. + */ + def commitAsync(offsetRanges: Array[OffsetRange]): Unit = { + commitAsync(offsetRanges, null) + } + + /** + * Queue up offset ranges for commit to Kafka at a future time. Threadsafe. + * @param offsetRanges The maximum untilOffset for a given partition will be used at commit. + * @param callback Only the most recently provided callback will be used at commit. + */ + def commitAsync(offsetRanges: Array[OffsetRange], callback: OffsetCommitCallback): Unit = { + commitCallback.set(callback) + commitQueue.addAll(ju.Arrays.asList(offsetRanges: _*)) + } + + protected def commitAll(): Unit = { + val m = new ju.HashMap[TopicPartition, OffsetAndMetadata]() + var osr = commitQueue.poll() + while (null != osr) { + val tp = osr.topicPartition + val x = m.get(tp) + val offset = if (null == x) { osr.untilOffset } else { Math.max(x.offset, osr.untilOffset) } + m.put(tp, new OffsetAndMetadata(offset)) + osr = commitQueue.poll() + } + if (!m.isEmpty) { + consumer.commitAsync(m, commitCallback.get) + } + } + + private[streaming] + class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) { + def batchForTime: mutable.HashMap[Time, Array[(String, Int, Long, Long)]] = { + data.asInstanceOf[mutable.HashMap[Time, Array[OffsetRange.OffsetRangeTuple]]] + } + + override def update(time: Time): Unit = { + batchForTime.clear() + generatedRDDs.foreach { kv => + val a = kv._2.asInstanceOf[KafkaRDD[K, V]].offsetRanges.map(_.toTuple).toArray + batchForTime += kv._1 -> a + } + } + + override def cleanup(time: Time): Unit = { } + + override def restore(): Unit = { + batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => + logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") + generatedRDDs += t -> new KafkaRDD[K, V]( + context.sparkContext, + executorKafkaParams, + b.map(OffsetRange(_)), + getPreferredHosts, + // during restore, it's possible same partition will be consumed from multiple + // threads, so dont use cache + false + ) + } + } + } + + /** + * A RateController to retrieve the rate from RateEstimator. + */ + private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = () + } +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala new file mode 100644 index 0000000000000..5b5a9ac48c7ca --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{ util => ju } + +import scala.collection.mutable.ArrayBuffer + +import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord } +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.{Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.ExecutorCacheTaskLocation +import org.apache.spark.storage.StorageLevel + +/** + * A batch-oriented interface for consuming from Kafka. + * Starting and ending offsets are specified in advance, + * so that you can control exactly-once semantics. + * @param kafkaParams Kafka + * + * configuration parameters. Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD + * @param preferredHosts map from TopicPartition to preferred host for processing that partition. + * In most cases, use [[DirectKafkaInputDStream.preferConsistent]] + * Use [[DirectKafkaInputDStream.preferBrokers]] if your executors are on same nodes as brokers. + * @param useConsumerCache whether to use a consumer from a per-jvm cache + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ +private[spark] class KafkaRDD[K, V]( + sc: SparkContext, + val kafkaParams: ju.Map[String, Object], + val offsetRanges: Array[OffsetRange], + val preferredHosts: ju.Map[TopicPartition, String], + useConsumerCache: Boolean +) extends RDD[ConsumerRecord[K, V]](sc, Nil) with Logging with HasOffsetRanges { + + assert("none" == + kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).asInstanceOf[String], + ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + + " must be set to none for executor kafka params, else messages may not match offsetRange") + + assert(false == + kafkaParams.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG).asInstanceOf[Boolean], + ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG + + " must be set to false for executor kafka params, else offsets may commit before processing") + + // TODO is it necessary to have separate configs for initial poll time vs ongoing poll time? + private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", 512) + private val cacheInitialCapacity = + conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16) + private val cacheMaxCapacity = + conf.getInt("spark.streaming.kafka.consumer.cache.maxCapacity", 64) + private val cacheLoadFactor = + conf.getDouble("spark.streaming.kafka.consumer.cache.loadFactor", 0.75).toFloat + + override def persist(newLevel: StorageLevel): this.type = { + logError("Kafka ConsumerRecord is not serializable. " + + "Use .map to extract fields before calling .persist or .window") + super.persist(newLevel) + } + + override def getPartitions: Array[Partition] = { + offsetRanges.zipWithIndex.map { case (o, i) => + new KafkaRDDPartition(i, o.topic, o.partition, o.fromOffset, o.untilOffset) + }.toArray + } + + override def count(): Long = offsetRanges.map(_.count).sum + + override def countApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[ConsumerRecord[K, V]] = { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (num < 1 || nonEmptyPartitions.isEmpty) { + return new Array[ConsumerRecord[K, V]](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[ConsumerRecord[K, V]] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) => + it.take(parts(tc.partitionId)).toArray, parts.keys.toArray + ) + res.foreach(buf ++= _) + buf.toArray + } + + private def executors(): Array[ExecutorCacheTaskLocation] = { + val bm = sparkContext.env.blockManager + bm.master.getPeers(bm.blockManagerId).toArray + .map(x => ExecutorCacheTaskLocation(x.host, x.executorId)) + .sortWith(compareExecutors) + } + + protected[kafka010] def compareExecutors( + a: ExecutorCacheTaskLocation, + b: ExecutorCacheTaskLocation): Boolean = + if (a.host == b.host) { + a.executorId > b.executorId + } else { + a.host > b.host + } + + /** + * Non-negative modulus, from java 8 math + */ + private def floorMod(a: Int, b: Int): Int = ((a % b) + b) % b + + override def getPreferredLocations(thePart: Partition): Seq[String] = { + // The intention is best-effort consistent executor for a given topicpartition, + // so that caching consumers can be effective. + // TODO what about hosts specified by ip vs name + val part = thePart.asInstanceOf[KafkaRDDPartition] + val allExecs = executors() + val tp = part.topicPartition + val prefHost = preferredHosts.get(tp) + val prefExecs = if (null == prefHost) allExecs else allExecs.filter(_.host == prefHost) + val execs = if (prefExecs.isEmpty) allExecs else prefExecs + if (execs.isEmpty) { + Seq() + } else { + // execs is sorted, tp.hashCode depends only on topic and partition, so consistent index + val index = this.floorMod(tp.hashCode, execs.length) + val chosen = execs(index) + Seq(chosen.toString) + } + } + + private def errBeginAfterEnd(part: KafkaRDDPartition): String = + s"Beginning offset ${part.fromOffset} is after the ending offset ${part.untilOffset} " + + s"for topic ${part.topic} partition ${part.partition}. " + + "You either provided an invalid fromOffset, or the Kafka topic has been damaged" + + override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = { + val part = thePart.asInstanceOf[KafkaRDDPartition] + assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part)) + if (part.fromOffset == part.untilOffset) { + logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " + + s"skipping ${part.topic} ${part.partition}") + Iterator.empty + } else { + new KafkaRDDIterator(part, context) + } + } + + /** + * An iterator that fetches messages directly from Kafka for the offsets in partition. + * Uses a cached consumer where possible to take advantage of prefetching + */ + private class KafkaRDDIterator( + part: KafkaRDDPartition, + context: TaskContext) extends Iterator[ConsumerRecord[K, V]] { + + logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " + + s"offsets ${part.fromOffset} -> ${part.untilOffset}") + + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + + context.addTaskCompletionListener{ context => closeIfNeeded() } + + val consumer = if (useConsumerCache) { + CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + if (context.attemptNumber > 1) { + // just in case the prior attempt failures were cache related + CachedKafkaConsumer.remove(groupId, part.topic, part.partition) + } + CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) + } else { + CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + } + + var requestOffset = part.fromOffset + + def closeIfNeeded(): Unit = { + if (!useConsumerCache && consumer != null) { + consumer.close + } + } + + override def hasNext(): Boolean = requestOffset < part.untilOffset + + override def next(): ConsumerRecord[K, V] = { + assert(hasNext(), "Can't call getNext() once untilOffset has been reached") + val r = consumer.get(requestOffset, pollTimeout) + requestOffset += 1 + r + } + } +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDDPartition.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDDPartition.scala new file mode 100644 index 0000000000000..95569b109f30d --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDDPartition.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.Partition + + +/** + * @param topic kafka topic name + * @param partition kafka partition id + * @param fromOffset inclusive starting offset + * @param untilOffset exclusive ending offset + */ +private[kafka010] +class KafkaRDDPartition( + val index: Int, + val topic: String, + val partition: Int, + val fromOffset: Long, + val untilOffset: Long +) extends Partition { + /** Number of messages this partition refers to */ + def count(): Long = untilOffset - fromOffset + + /** Kafka TopicPartition object, for convenience */ + def topicPartition(): TopicPartition = new TopicPartition(topic, partition) + +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala new file mode 100644 index 0000000000000..ecabe1c365b41 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.io.File +import java.lang.{Integer => JInt} +import java.net.InetSocketAddress +import java.util.{Map => JMap, Properties} +import java.util.concurrent.TimeoutException + +import scala.annotation.tailrec +import scala.collection.JavaConverters._ +import scala.language.postfixOps +import scala.util.control.NonFatal + +import kafka.admin.AdminUtils +import kafka.api.Request +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.ZkUtils +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.serialization.StringSerializer +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.streaming.Time +import org.apache.spark.util.Utils + +/** + * This is a helper class for Kafka test suites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + * + * The reason to put Kafka test utility class in src is to test Python related Kafka APIs. + */ +private[kafka010] class KafkaTestUtils extends Logging { + + // Zookeeper related configurations + private val zkHost = "localhost" + private var zkPort: Int = 0 + private val zkConnectionTimeout = 60000 + private val zkSessionTimeout = 6000 + + private var zookeeper: EmbeddedZookeeper = _ + + private var zkUtils: ZkUtils = _ + + // Kafka broker related configurations + private val brokerHost = "localhost" + private var brokerPort = 0 + private var brokerConf: KafkaConfig = _ + + // Kafka broker server + private var server: KafkaServer = _ + + // Kafka producer + private var producer: KafkaProducer[String, String] = _ + + // Flag to test whether the system is correctly started + private var zkReady = false + private var brokerReady = false + + def zkAddress: String = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address") + s"$zkHost:$zkPort" + } + + def brokerAddress: String = { + assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address") + s"$brokerHost:$brokerPort" + } + + def zookeeperClient: ZkUtils = { + assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client") + Option(zkUtils).getOrElse( + throw new IllegalStateException("Zookeeper client is not yet initialized")) + } + + // Set up the Embedded Zookeeper server and get the proper Zookeeper port + private def setupEmbeddedZookeeper(): Unit = { + // Zookeeper server startup + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort + zkUtils = ZkUtils(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, false) + zkReady = true + } + + // Set up the Embedded Kafka server + private def setupEmbeddedKafkaServer(): Unit = { + assert(zkReady, "Zookeeper should be set up beforehand") + + // Kafka broker startup + Utils.startServiceOnPort(brokerPort, port => { + brokerPort = port + brokerConf = new KafkaConfig(brokerConfiguration, doLog = false) + server = new KafkaServer(brokerConf) + server.startup() + brokerPort = server.boundPort() + (server, brokerPort) + }, new SparkConf(), "KafkaBroker") + + brokerReady = true + } + + /** setup the whole embedded servers, including Zookeeper and Kafka brokers */ + def setup(): Unit = { + setupEmbeddedZookeeper() + setupEmbeddedKafkaServer() + } + + /** Teardown the whole servers, including Kafka broker and Zookeeper */ + def teardown(): Unit = { + brokerReady = false + zkReady = false + + if (producer != null) { + producer.close() + producer = null + } + + if (server != null) { + server.shutdown() + server = null + } + + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } + + if (zkUtils != null) { + zkUtils.close() + zkUtils = null + } + + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } + } + + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String, partitions: Int): Unit = { + AdminUtils.createTopic(zkUtils, topic, partitions, 1) + // wait until metadata is propagated + (0 until partitions).foreach { p => + waitUntilMetadataIsPropagated(topic, p) + } + } + + /** Create a Kafka topic and wait until it is propagated to the whole cluster */ + def createTopic(topic: String): Unit = { + createTopic(topic, 1) + } + + /** Java-friendly function for sending messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { + sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) + } + + /** Send the messages to the Kafka broker */ + def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = { + val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray + sendMessages(topic, messages) + } + + /** Send the array of messages to the Kafka broker */ + def sendMessages(topic: String, messages: Array[String]): Unit = { + producer = new KafkaProducer[String, String](producerConfiguration) + messages.foreach { message => + producer.send(new ProducerRecord[String, String](topic, message)) + } + producer.close() + producer = null + } + + private def brokerConfiguration: Properties = { + val props = new Properties() + props.put("broker.id", "0") + props.put("host.name", "localhost") + props.put("port", brokerPort.toString) + props.put("log.dir", Utils.createTempDir().getAbsolutePath) + props.put("zookeeper.connect", zkAddress) + props.put("log.flush.interval.messages", "1") + props.put("replica.socket.timeout.ms", "1500") + props + } + + private def producerConfiguration: Properties = { + val props = new Properties() + props.put("bootstrap.servers", brokerAddress) + props.put("value.serializer", classOf[StringSerializer].getName) + // Key serializer is required. + props.put("key.serializer", classOf[StringSerializer].getName) + // wait for all in-sync replicas to ack sends + props.put("acks", "all") + props + } + + // A simplified version of scalatest eventually, rewritten here to avoid adding extra test + // dependency + def eventually[T](timeout: Time, interval: Time)(func: => T): T = { + def makeAttempt(): Either[Throwable, T] = { + try { + Right(func) + } catch { + case e if NonFatal(e) => Left(e) + } + } + + val startTime = System.currentTimeMillis() + @tailrec + def tryAgain(attempt: Int): T = { + makeAttempt() match { + case Right(result) => result + case Left(e) => + val duration = System.currentTimeMillis() - startTime + if (duration < timeout.milliseconds) { + Thread.sleep(interval.milliseconds) + } else { + throw new TimeoutException(e.getMessage) + } + + tryAgain(attempt + 1) + } + } + + tryAgain(1) + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { + case Some(partitionState) => + val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr + + zkUtils.getLeaderForPartition(topic, partition).isDefined && + Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && + leaderAndInSyncReplicas.isr.size >= 1 + + case _ => + false + } + eventually(Time(10000), Time(100)) { + assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout") + } + } + + private class EmbeddedZookeeper(val zkConnect: String) { + val snapshotDir = Utils.createTempDir() + val logDir = Utils.createTempDir() + + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500) + val (ip, port) = { + val splits = zkConnect.split(":") + (splits(0), splits(1).toInt) + } + val factory = new NIOServerCnxnFactory() + factory.configure(new InetSocketAddress(ip, port), 16) + factory.startup(zookeeper) + + val actualPort = factory.getLocalPort + + def shutdown() { + factory.shutdown() + Utils.deleteRecursively(snapshotDir) + Utils.deleteRecursively(logDir) + } + } +} + diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala new file mode 100644 index 0000000000000..b2190bfa05a3a --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{ util => ju } + +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{ JavaRDD, JavaSparkContext } +import org.apache.spark.api.java.function.{ Function0 => JFunction0 } +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.api.java.{ JavaInputDStream, JavaStreamingContext } +import org.apache.spark.streaming.dstream._ + +/** + * :: Experimental :: + * object for constructing Kafka streams and RDDs + */ +@Experimental +object KafkaUtils extends Logging { + /** + * :: Experimental :: + * Scala constructor for a batch-oriented interface for consuming from Kafka. + * Starting and ending offsets are specified in advance, + * so that you can control exactly-once semantics. + * @param kafkaParams Kafka + * + * configuration parameters. Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createRDD[K, V]( + sc: SparkContext, + kafkaParams: ju.Map[String, Object], + offsetRanges: Array[OffsetRange], + locationStrategy: LocationStrategy + ): RDD[ConsumerRecord[K, V]] = { + val preferredHosts = locationStrategy match { + case PreferBrokers => + throw new AssertionError( + "If you want to prefer brokers, you must provide a mapping using PreferFixed " + + "A single KafkaRDD does not have a driver consumer and cannot look up brokers for you.") + case PreferConsistent => ju.Collections.emptyMap[TopicPartition, String]() + case PreferFixed(hostMap) => hostMap + } + val kp = new ju.HashMap[String, Object](kafkaParams) + fixKafkaParams(kp) + val osr = offsetRanges.clone() + + new KafkaRDD[K, V](sc, kp, osr, preferredHosts, true) + } + + /** + * :: Experimental :: + * Java constructor for a batch-oriented interface for consuming from Kafka. + * Starting and ending offsets are specified in advance, + * so that you can control exactly-once semantics. + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param kafkaParams Kafka + * + * configuration parameters. Requires "bootstrap.servers" to be set + * with Kafka broker(s) specified in host1:port1,host2:port2 form. + * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createRDD[K, V]( + jsc: JavaSparkContext, + kafkaParams: ju.Map[String, Object], + offsetRanges: Array[OffsetRange], + locationStrategy: LocationStrategy + ): JavaRDD[ConsumerRecord[K, V]] = { + + new JavaRDD(createRDD[K, V](jsc.sc, kafkaParams, offsetRanges, locationStrategy)) + } + + /** + * :: Experimental :: + * Scala constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number + * of messages + * per second that each '''partition''' will accept. + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + ssc: StreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V] + ): InputDStream[ConsumerRecord[K, V]] = { + new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy) + } + + /** + * :: Experimental :: + * Java constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + jssc: JavaStreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V] + ): JavaInputDStream[ConsumerRecord[K, V]] = { + new JavaInputDStream( + createDirectStream[K, V]( + jssc.ssc, locationStrategy, consumerStrategy)) + } + + /** + * Tweak kafka params to prevent issues on executors + */ + private[kafka010] def fixKafkaParams(kafkaParams: ju.HashMap[String, Object]): Unit = { + logWarning(s"overriding ${ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG} to false for executor") + kafkaParams.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false: java.lang.Boolean) + + logWarning(s"overriding ${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG} to none for executor") + kafkaParams.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + + // driver and executor should be in different consumer groups + val originalGroupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG) + if (null == originalGroupId) { + logError(s"${ConsumerConfig.GROUP_ID_CONFIG} is null, you should probably set it") + } + val groupId = "spark-executor-" + originalGroupId + logWarning(s"overriding executor ${ConsumerConfig.GROUP_ID_CONFIG} to ${groupId}") + kafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, groupId) + + // possible workaround for KAFKA-3135 + val rbb = kafkaParams.get(ConsumerConfig.RECEIVE_BUFFER_CONFIG) + if (null == rbb || rbb.asInstanceOf[java.lang.Integer] < 65536) { + logWarning(s"overriding ${ConsumerConfig.RECEIVE_BUFFER_CONFIG} to 65536 see KAFKA-3135") + kafkaParams.put(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + } + } +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/LocationStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/LocationStrategy.scala new file mode 100644 index 0000000000000..c9a8a13f51c32 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/LocationStrategy.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{ util => ju } + +import scala.collection.JavaConverters._ + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.annotation.Experimental + + +/** + * :: Experimental :: + * Choice of how to schedule consumers for a given TopicPartition on an executor. + * See [[LocationStrategies]] to obtain instances. + * Kafka 0.10 consumers prefetch messages, so it's important for performance + * to keep cached consumers on appropriate executors, not recreate them for every partition. + * Choice of location is only a preference, not an absolute; partitions may be scheduled elsewhere. + */ +@Experimental +sealed abstract class LocationStrategy + +private case object PreferBrokers extends LocationStrategy + +private case object PreferConsistent extends LocationStrategy + +private case class PreferFixed(hostMap: ju.Map[TopicPartition, String]) extends LocationStrategy + +/** + * :: Experimental :: object to obtain instances of [[LocationStrategy]] + * + */ +@Experimental +object LocationStrategies { + /** + * :: Experimental :: + * Use this only if your executors are on the same nodes as your Kafka brokers. + */ + @Experimental + def PreferBrokers: LocationStrategy = + org.apache.spark.streaming.kafka010.PreferBrokers + + /** + * :: Experimental :: + * Use this in most cases, it will consistently distribute partitions across all executors. + */ + @Experimental + def PreferConsistent: LocationStrategy = + org.apache.spark.streaming.kafka010.PreferConsistent + + /** + * :: Experimental :: + * Use this to place particular TopicPartitions on particular hosts if your load is uneven. + * Any TopicPartition not specified in the map will use a consistent location. + */ + @Experimental + def PreferFixed(hostMap: collection.Map[TopicPartition, String]): LocationStrategy = + new PreferFixed(new ju.HashMap[TopicPartition, String](hostMap.asJava)) + + /** + * :: Experimental :: + * Use this to place particular TopicPartitions on particular hosts if your load is uneven. + * Any TopicPartition not specified in the map will use a consistent location. + */ + @Experimental + def PreferFixed(hostMap: ju.Map[TopicPartition, String]): LocationStrategy = + new PreferFixed(hostMap) +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/OffsetRange.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/OffsetRange.scala new file mode 100644 index 0000000000000..c66d3c9b8d229 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/OffsetRange.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import org.apache.kafka.clients.consumer.OffsetCommitCallback +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.annotation.Experimental + +/** + * Represents any object that has a collection of [[OffsetRange]]s. This can be used to access the + * offset ranges in RDDs generated by the direct Kafka DStream (see + * [[KafkaUtils.createDirectStream]]). + * {{{ + * KafkaUtils.createDirectStream(...).foreachRDD { rdd => + * val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + * ... + * } + * }}} + */ +trait HasOffsetRanges { + def offsetRanges: Array[OffsetRange] +} + +/** + * :: Experimental :: + * Represents any object that can commit a collection of [[OffsetRange]]s. + * The direct Kafka DStream implements this interface (see + * [[KafkaUtils.createDirectStream]]). + * {{{ + * val stream = KafkaUtils.createDirectStream(...) + * ... + * stream.asInstanceOf[CanCommitOffsets].commitAsync(offsets, new OffsetCommitCallback() { + * def onComplete(m: java.util.Map[TopicPartition, OffsetAndMetadata], e: Exception) { + * if (null != e) { + * // error + * } else { + * // success + * } + * } + * }) + * }}} + */ +@Experimental +trait CanCommitOffsets { + /** + * :: Experimental :: + * Queue up offset ranges for commit to Kafka at a future time. Threadsafe. + * This is only needed if you intend to store offsets in Kafka, instead of your own store. + * @param offsetRanges The maximum untilOffset for a given partition will be used at commit. + */ + @Experimental + def commitAsync(offsetRanges: Array[OffsetRange]): Unit + + /** + * :: Experimental :: + * Queue up offset ranges for commit to Kafka at a future time. Threadsafe. + * This is only needed if you intend to store offsets in Kafka, instead of your own store. + * @param offsetRanges The maximum untilOffset for a given partition will be used at commit. + * @param callback Only the most recently provided callback will be used at commit. + */ + @Experimental + def commitAsync(offsetRanges: Array[OffsetRange], callback: OffsetCommitCallback): Unit +} + +/** + * Represents a range of offsets from a single Kafka TopicPartition. Instances of this class + * can be created with `OffsetRange.create()`. + * @param topic Kafka topic name + * @param partition Kafka partition id + * @param fromOffset Inclusive starting offset + * @param untilOffset Exclusive ending offset + */ +final class OffsetRange private( + val topic: String, + val partition: Int, + val fromOffset: Long, + val untilOffset: Long) extends Serializable { + import OffsetRange.OffsetRangeTuple + + /** Kafka TopicPartition object, for convenience */ + def topicPartition(): TopicPartition = new TopicPartition(topic, partition) + + /** Number of messages this OffsetRange refers to */ + def count(): Long = untilOffset - fromOffset + + override def equals(obj: Any): Boolean = obj match { + case that: OffsetRange => + this.topic == that.topic && + this.partition == that.partition && + this.fromOffset == that.fromOffset && + this.untilOffset == that.untilOffset + case _ => false + } + + override def hashCode(): Int = { + toTuple.hashCode() + } + + override def toString(): String = { + s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])" + } + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[streaming] + def toTuple: OffsetRangeTuple = (topic, partition, fromOffset, untilOffset) +} + +/** + * Companion object the provides methods to create instances of [[OffsetRange]]. + */ +object OffsetRange { + def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def create( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicPartition.topic, topicPartition.partition, fromOffset, untilOffset) + + def apply(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange = + new OffsetRange(topic, partition, fromOffset, untilOffset) + + def apply( + topicPartition: TopicPartition, + fromOffset: Long, + untilOffset: Long): OffsetRange = + new OffsetRange(topicPartition.topic, topicPartition.partition, fromOffset, untilOffset) + + /** this is to avoid ClassNotFoundException during checkpoint restore */ + private[kafka010] + type OffsetRangeTuple = (String, Int, Long, Long) + + private[kafka010] + def apply(t: OffsetRangeTuple) = + new OffsetRange(t._1, t._2, t._3, t._4) +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package-info.java b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package-info.java new file mode 100644 index 0000000000000..ebfcf8764a328 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Spark Integration for Kafka 0.10 + */ +package org.apache.spark.streaming.kafka010; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package.scala similarity index 84% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala rename to external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package.scala index a4d387eae3c80..09db6d6062d82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.analysis +package org.apache.spark.streaming -sealed trait OutputMode - -case object Append extends OutputMode -case object Update extends OutputMode +/** + * Spark Integration for Kafka 0.10 + */ +package object kafka010 //scalastyle:ignore diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java new file mode 100644 index 0000000000000..ba57b6beb247d --- /dev/null +++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010; + +import java.io.Serializable; +import java.util.*; +import java.util.regex.Pattern; + +import scala.collection.JavaConverters; + +import org.apache.kafka.common.TopicPartition; + +import org.junit.Assert; +import org.junit.Test; + +public class JavaConsumerStrategySuite implements Serializable { + + @Test + public void testConsumerStrategyConstructors() { + final String topic1 = "topic1"; + final Pattern pat = Pattern.compile("top.*"); + final Collection topics = Arrays.asList(topic1); + final scala.collection.Iterable sTopics = + JavaConverters.collectionAsScalaIterableConverter(topics).asScala(); + final TopicPartition tp1 = new TopicPartition(topic1, 0); + final TopicPartition tp2 = new TopicPartition(topic1, 1); + final Collection parts = Arrays.asList(tp1, tp2); + final scala.collection.Iterable sParts = + JavaConverters.collectionAsScalaIterableConverter(parts).asScala(); + final Map kafkaParams = new HashMap(); + kafkaParams.put("bootstrap.servers", "not used"); + final scala.collection.Map sKafkaParams = + JavaConverters.mapAsScalaMapConverter(kafkaParams).asScala(); + final Map offsets = new HashMap<>(); + offsets.put(tp1, 23L); + final scala.collection.Map sOffsets = + JavaConverters.mapAsScalaMapConverter(offsets).asScala().mapValues( + new scala.runtime.AbstractFunction1() { + @Override + public Object apply(Long x) { + return (Object) x; + } + } + ); + + final ConsumerStrategy sub1 = + ConsumerStrategies.Subscribe(sTopics, sKafkaParams, sOffsets); + final ConsumerStrategy sub2 = + ConsumerStrategies.Subscribe(sTopics, sKafkaParams); + final ConsumerStrategy sub3 = + ConsumerStrategies.Subscribe(topics, kafkaParams, offsets); + final ConsumerStrategy sub4 = + ConsumerStrategies.Subscribe(topics, kafkaParams); + + Assert.assertEquals( + sub1.executorKafkaParams().get("bootstrap.servers"), + sub3.executorKafkaParams().get("bootstrap.servers")); + + final ConsumerStrategy psub1 = + ConsumerStrategies.SubscribePattern(pat, sKafkaParams, sOffsets); + final ConsumerStrategy psub2 = + ConsumerStrategies.SubscribePattern(pat, sKafkaParams); + final ConsumerStrategy psub3 = + ConsumerStrategies.SubscribePattern(pat, kafkaParams, offsets); + final ConsumerStrategy psub4 = + ConsumerStrategies.SubscribePattern(pat, kafkaParams); + + Assert.assertEquals( + psub1.executorKafkaParams().get("bootstrap.servers"), + psub3.executorKafkaParams().get("bootstrap.servers")); + + final ConsumerStrategy asn1 = + ConsumerStrategies.Assign(sParts, sKafkaParams, sOffsets); + final ConsumerStrategy asn2 = + ConsumerStrategies.Assign(sParts, sKafkaParams); + final ConsumerStrategy asn3 = + ConsumerStrategies.Assign(parts, kafkaParams, offsets); + final ConsumerStrategy asn4 = + ConsumerStrategies.Assign(parts, kafkaParams); + + Assert.assertEquals( + asn1.executorKafkaParams().get("bootstrap.servers"), + asn3.executorKafkaParams().get("bootstrap.servers")); + } + +} diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java new file mode 100644 index 0000000000000..dc9c13ba863ff --- /dev/null +++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010; + +import java.io.Serializable; +import java.util.*; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.clients.consumer.ConsumerRecord; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; + +public class JavaDirectKafkaStreamSuite implements Serializable { + private transient JavaStreamingContext ssc = null; + private transient KafkaTestUtils kafkaTestUtils = null; + + @Before + public void setUp() { + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200)); + } + + @After + public void tearDown() { + if (ssc != null) { + ssc.stop(); + ssc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } + } + + @Test + public void testKafkaStream() throws InterruptedException { + final String topic1 = "topic1"; + final String topic2 = "topic2"; + // hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference<>(); + + String[] topic1data = createTopicAndSendData(topic1); + String[] topic2data = createTopicAndSendData(topic2); + + Set sent = new HashSet<>(); + sent.addAll(Arrays.asList(topic1data)); + sent.addAll(Arrays.asList(topic2data)); + + Random random = new Random(); + + final Map kafkaParams = new HashMap<>(); + kafkaParams.put("bootstrap.servers", kafkaTestUtils.brokerAddress()); + kafkaParams.put("key.deserializer", StringDeserializer.class); + kafkaParams.put("value.deserializer", StringDeserializer.class); + kafkaParams.put("auto.offset.reset", "earliest"); + kafkaParams.put("group.id", "java-test-consumer-" + random.nextInt() + + "-" + System.currentTimeMillis()); + + JavaInputDStream> istream1 = KafkaUtils.createDirectStream( + ssc, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(Arrays.asList(topic1), kafkaParams) + ); + + JavaDStream stream1 = istream1.transform( + // Make sure you can get offset ranges from the rdd + new Function>, + JavaRDD>>() { + @Override + public JavaRDD> call( + JavaRDD> rdd + ) { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + Assert.assertEquals(topic1, offsets[0].topic()); + return rdd; + } + } + ).map( + new Function, String>() { + @Override + public String call(ConsumerRecord r) { + return r.value(); + } + } + ); + + final Map kafkaParams2 = new HashMap<>(kafkaParams); + kafkaParams2.put("group.id", "java-test-consumer-" + random.nextInt() + + "-" + System.currentTimeMillis()); + + JavaInputDStream> istream2 = KafkaUtils.createDirectStream( + ssc, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(Arrays.asList(topic2), kafkaParams2) + ); + + JavaDStream stream2 = istream2.transform( + // Make sure you can get offset ranges from the rdd + new Function>, + JavaRDD>>() { + @Override + public JavaRDD> call( + JavaRDD> rdd + ) { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + Assert.assertEquals(topic2, offsets[0].topic()); + return rdd; + } + } + ).map( + new Function, String>() { + @Override + public String call(ConsumerRecord r) { + return r.value(); + } + } + ); + + JavaDStream unifiedStream = stream1.union(stream2); + + final Set result = Collections.synchronizedSet(new HashSet()); + unifiedStream.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaRDD rdd) { + result.addAll(rdd.collect()); + } + } + ); + ssc.start(); + long startTime = System.currentTimeMillis(); + boolean matches = false; + while (!matches && System.currentTimeMillis() - startTime < 20000) { + matches = sent.size() == result.size(); + Thread.sleep(50); + } + Assert.assertEquals(sent, result); + ssc.stop(); + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java new file mode 100644 index 0000000000000..87bfe1514e338 --- /dev/null +++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +public class JavaKafkaRDDSuite implements Serializable { + private transient JavaSparkContext sc = null; + private transient KafkaTestUtils kafkaTestUtils = null; + + @Before + public void setUp() { + kafkaTestUtils = new KafkaTestUtils(); + kafkaTestUtils.setup(); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + sc = new JavaSparkContext(sparkConf); + } + + @After + public void tearDown() { + if (sc != null) { + sc.stop(); + sc = null; + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown(); + kafkaTestUtils = null; + } + } + + @Test + public void testKafkaRDD() throws InterruptedException { + String topic1 = "topic1"; + String topic2 = "topic2"; + + Random random = new Random(); + + createTopicAndSendData(topic1); + createTopicAndSendData(topic2); + + Map kafkaParams = new HashMap<>(); + kafkaParams.put("bootstrap.servers", kafkaTestUtils.brokerAddress()); + kafkaParams.put("key.deserializer", StringDeserializer.class); + kafkaParams.put("value.deserializer", StringDeserializer.class); + kafkaParams.put("group.id", "java-test-consumer-" + random.nextInt() + + "-" + System.currentTimeMillis()); + + OffsetRange[] offsetRanges = { + OffsetRange.create(topic1, 0, 0, 1), + OffsetRange.create(topic2, 0, 0, 1) + }; + + Map leaders = new HashMap<>(); + String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":"); + String broker = hostAndPort[0]; + leaders.put(offsetRanges[0].topicPartition(), broker); + leaders.put(offsetRanges[1].topicPartition(), broker); + + Function, String> handler = + new Function, String>() { + @Override + public String call(ConsumerRecord r) { + return r.value(); + } + }; + + JavaRDD rdd1 = KafkaUtils.createRDD( + sc, + kafkaParams, + offsetRanges, + LocationStrategies.PreferFixed(leaders) + ).map(handler); + + JavaRDD rdd2 = KafkaUtils.createRDD( + sc, + kafkaParams, + offsetRanges, + LocationStrategies.PreferConsistent() + ).map(handler); + + // just making sure the java user apis work; the scala tests handle logic corner cases + long count1 = rdd1.count(); + long count2 = rdd2.count(); + Assert.assertTrue(count1 > 0); + Assert.assertEquals(count1, count2); + } + + private String[] createTopicAndSendData(String topic) { + String[] data = { topic + "-1", topic + "-2", topic + "-3"}; + kafkaTestUtils.createTopic(topic); + kafkaTestUtils.sendMessages(topic, data); + return data; + } +} diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java new file mode 100644 index 0000000000000..41ccb0ebe7bfa --- /dev/null +++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010; + +import java.io.Serializable; +import java.util.*; + +import scala.collection.JavaConverters; + +import org.apache.kafka.common.TopicPartition; + +import org.junit.Assert; +import org.junit.Test; + +public class JavaLocationStrategySuite implements Serializable { + + @Test + public void testLocationStrategyConstructors() { + final String topic1 = "topic1"; + final TopicPartition tp1 = new TopicPartition(topic1, 0); + final TopicPartition tp2 = new TopicPartition(topic1, 1); + final Map hosts = new HashMap<>(); + hosts.put(tp1, "node1"); + hosts.put(tp2, "node2"); + final scala.collection.Map sHosts = + JavaConverters.mapAsScalaMapConverter(hosts).asScala(); + + // make sure constructors can be called from java + final LocationStrategy c1 = LocationStrategies.PreferConsistent(); + final LocationStrategy c2 = LocationStrategies.PreferConsistent(); + Assert.assertSame(c1, c2); + + final LocationStrategy c3 = LocationStrategies.PreferBrokers(); + final LocationStrategy c4 = LocationStrategies.PreferBrokers(); + Assert.assertSame(c3, c4); + + Assert.assertNotSame(c1, c3); + + final LocationStrategy c5 = LocationStrategies.PreferFixed(hosts); + final LocationStrategy c6 = LocationStrategies.PreferFixed(sHosts); + Assert.assertEquals(c5, c6); + } + +} diff --git a/external/kafka-0-10/src/test/resources/log4j.properties b/external/kafka-0-10/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..75e3b53a093f6 --- /dev/null +++ b/external/kafka-0-10/src/test/resources/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.spark-project.jetty=WARN + diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala new file mode 100644 index 0000000000000..860834df7bd7a --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -0,0 +1,696 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.io.File +import java.lang.{ Long => JLong } +import java.util.{ Arrays, HashMap => JHashMap, Map => JMap } +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.JavaConverters._ +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.StringDeserializer +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.util.Utils + +class DirectKafkaStreamSuite + extends SparkFunSuite + with BeforeAndAfter + with BeforeAndAfterAll + with Eventually + with Logging { + val sparkConf = new SparkConf() + .setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + + private var sc: SparkContext = _ + private var ssc: StreamingContext = _ + private var testDir: File = _ + + private var kafkaTestUtils: KafkaTestUtils = _ + + override def beforeAll { + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + } + + override def afterAll { + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + after { + if (ssc != null) { + ssc.stop() + sc = null + } + if (sc != null) { + sc.stop() + } + if (testDir != null) { + Utils.deleteRecursively(testDir) + } + } + + def getKafkaParams(extra: (String, Object)*): JHashMap[String, Object] = { + val kp = new JHashMap[String, Object]() + kp.put("bootstrap.servers", kafkaTestUtils.brokerAddress) + kp.put("key.deserializer", classOf[StringDeserializer]) + kp.put("value.deserializer", classOf[StringDeserializer]) + kp.put("group.id", s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}") + extra.foreach(e => kp.put(e._1, e._2)) + kp + } + + val preferredHosts = LocationStrategies.PreferConsistent + + test("basic stream receiving with multiple topics and smallest starting offset") { + val topics = List("basic1", "basic2", "basic3") + val data = Map("a" -> 7, "b" -> 9) + topics.foreach { t => + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) + } + val offsets = Map(new TopicPartition("basic3", 0) -> 2L) + // one topic is starting 2 messages later + val expectedTotal = (data.values.sum * topics.size) - 2 + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](topics, kafkaParams.asScala, offsets)) + } + val allReceived = new ConcurrentLinkedQueue[(String, String)]() + + // hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + val tf = stream.transform { rdd => + // Get the offset ranges in the RDD + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd.map(r => (r.key, r.value)) + } + + tf.foreachRDD { rdd => + for (o <- offsetRanges) { + logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + val collected = rdd.mapPartitionsWithIndex { (i, iter) => + // For each partition, get size of the range in the partition, + // and the number of items in the partition + val off = offsetRanges(i) + val all = iter.toSeq + val partSize = all.size + val rangeSize = off.untilOffset - off.fromOffset + Iterator((partSize, rangeSize)) + }.collect + + // Verify whether number of elements in each partition + // matches with the corresponding offset range + collected.foreach { case (partSize, rangeSize) => + assert(partSize === rangeSize, "offset ranges are wrong") + } + } + + stream.foreachRDD { rdd => + allReceived.addAll(Arrays.asList(rdd.map(r => (r.key, r.value)).collect(): _*)) + } + ssc.start() + eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + assert(allReceived.size === expectedTotal, + "didn't get expected number of messages, messages:\n" + + allReceived.asScala.mkString("\n")) + } + ssc.stop() + } + + test("pattern based subscription") { + val topics = List("pat1", "pat2", "pat3", "advanced3") + // Should match 3 out of 4 topics + val pat = """pat\d""".r.pattern + val data = Map("a" -> 7, "b" -> 9) + topics.foreach { t => + kafkaTestUtils.createTopic(t) + kafkaTestUtils.sendMessages(t, data) + } + val offsets = Map( + new TopicPartition("pat2", 0) -> 3L, + new TopicPartition("pat3", 0) -> 4L) + // 3 matching topics, two of which start a total of 7 messages later + val expectedTotal = (data.values.sum * 3) - 7 + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.SubscribePattern[String, String](pat, kafkaParams.asScala, offsets)) + } + val allReceived = new ConcurrentLinkedQueue[(String, String)]() + + // hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + val tf = stream.transform { rdd => + // Get the offset ranges in the RDD + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd.map(r => (r.key, r.value)) + } + + tf.foreachRDD { rdd => + for (o <- offsetRanges) { + logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + val collected = rdd.mapPartitionsWithIndex { (i, iter) => + // For each partition, get size of the range in the partition, + // and the number of items in the partition + val off = offsetRanges(i) + val all = iter.toSeq + val partSize = all.size + val rangeSize = off.untilOffset - off.fromOffset + Iterator((partSize, rangeSize)) + }.collect + + // Verify whether number of elements in each partition + // matches with the corresponding offset range + collected.foreach { case (partSize, rangeSize) => + assert(partSize === rangeSize, "offset ranges are wrong") + } + } + + stream.foreachRDD { rdd => + allReceived.addAll(Arrays.asList(rdd.map(r => (r.key, r.value)).collect(): _*)) + } + ssc.start() + eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + assert(allReceived.size === expectedTotal, + "didn't get expected number of messages, messages:\n" + + allReceived.asScala.mkString("\n")) + } + ssc.stop() + } + + + test("receiving from largest starting offset") { + val topic = "latest" + val topicPartition = new TopicPartition(topic, 0) + val data = Map("a" -> 10) + kafkaTestUtils.createTopic(topic) + val kafkaParams = getKafkaParams("auto.offset.reset" -> "latest") + val kc = new KafkaConsumer(kafkaParams) + kc.assign(Arrays.asList(topicPartition)) + def getLatestOffset(): Long = { + kc.seekToEnd(Arrays.asList(topicPartition)) + kc.position(topicPartition) + } + + // Send some initial messages before starting context + kafkaTestUtils.sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() > 3) + } + val offsetBeforeStart = getLatestOffset() + kc.close() + + // Setup context and kafka stream with largest offset + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + val s = new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) + s.consumer.poll(0) + assert( + s.consumer.position(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + s + } + + val collectedData = new ConcurrentLinkedQueue[String]() + stream.map { _.value }.foreachRDD { rdd => + collectedData.addAll(Arrays.asList(rdd.collect(): _*)) + } + ssc.start() + val newData = Map("b" -> 10) + kafkaTestUtils.sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + + test("creating stream by offset") { + val topic = "offset" + val topicPartition = new TopicPartition(topic, 0) + val data = Map("a" -> 10) + kafkaTestUtils.createTopic(topic) + val kafkaParams = getKafkaParams("auto.offset.reset" -> "latest") + val kc = new KafkaConsumer(kafkaParams) + kc.assign(Arrays.asList(topicPartition)) + def getLatestOffset(): Long = { + kc.seekToEnd(Arrays.asList(topicPartition)) + kc.position(topicPartition) + } + + // Send some initial messages before starting context + kafkaTestUtils.sendMessages(topic, data) + eventually(timeout(10 seconds), interval(20 milliseconds)) { + assert(getLatestOffset() >= 10) + } + val offsetBeforeStart = getLatestOffset() + kc.close() + + // Setup context and kafka stream with largest offset + kafkaParams.put("auto.offset.reset", "none") + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val stream = withClue("Error creating direct stream") { + val s = new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Assign[String, String]( + List(topicPartition), + kafkaParams.asScala, + Map(topicPartition -> 11L))) + s.consumer.poll(0) + assert( + s.consumer.position(topicPartition) >= offsetBeforeStart, + "Start offset not from latest" + ) + s + } + + val collectedData = new ConcurrentLinkedQueue[String]() + stream.map(_.value).foreachRDD { rdd => collectedData.addAll(Arrays.asList(rdd.collect(): _*)) } + ssc.start() + val newData = Map("b" -> 10) + kafkaTestUtils.sendMessages(topic, newData) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + collectedData.contains("b") + } + assert(!collectedData.contains("a")) + } + + // Test to verify the offset ranges can be recovered from the checkpoints + test("offset recovery") { + val topic = "recovery" + kafkaTestUtils.createTopic(topic) + testDir = Utils.createTempDir() + + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + + // Send data to Kafka + def sendData(data: Seq[Int]) { + val strings = data.map { _.toString} + kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) + } + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(100)) + val kafkaStream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) + } + val keyedStream = kafkaStream.map { r => "key" -> r.value.toInt } + val stateStream = keyedStream.updateStateByKey { (values: Seq[Int], state: Option[Int]) => + Some(values.sum + state.getOrElse(0)) + } + ssc.checkpoint(testDir.getAbsolutePath) + + // This is ensure all the data is eventually receiving only once + stateStream.foreachRDD { (rdd: RDD[(String, Int)]) => + rdd.collect().headOption.foreach { x => + DirectKafkaStreamSuite.total.set(x._2) + } + } + + ssc.start() + + // Send some data + for (i <- (1 to 10).grouped(4)) { + sendData(i) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(DirectKafkaStreamSuite.total.get === (1 to 10).sum) + } + + ssc.stop() + + // Verify that offset ranges were generated + val offsetRangesBeforeStop = getOffsetRanges(kafkaStream) + assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated") + assert( + offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 }, + "starting offset not zero" + ) + + logInfo("====== RESTARTING ========") + + // Recover context from checkpoints + ssc = new StreamingContext(testDir.getAbsolutePath) + val recoveredStream = + ssc.graph.getInputStreams().head.asInstanceOf[DStream[ConsumerRecord[String, String]]] + + // Verify offset ranges have been recovered + val recoveredOffsetRanges = getOffsetRanges(recoveredStream).map { x => (x._1, x._2.toSet) } + assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered") + val earlierOffsetRanges = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) } + assert( + recoveredOffsetRanges.forall { or => + earlierOffsetRanges.contains((or._1, or._2)) + }, + "Recovered ranges are not the same as the ones generated\n" + + earlierOffsetRanges + "\n" + recoveredOffsetRanges + ) + // Restart context, give more data and verify the total at the end + // If the total is write that means each records has been received only once + ssc.start() + for (i <- (11 to 20).grouped(4)) { + sendData(i) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(DirectKafkaStreamSuite.total.get === (1 to 20).sum) + } + ssc.stop() + } + + // Test to verify the offsets can be recovered from Kafka + test("offset recovery from kafka") { + val topic = "recoveryfromkafka" + kafkaTestUtils.createTopic(topic) + + val kafkaParams = getKafkaParams( + "auto.offset.reset" -> "earliest", + ("enable.auto.commit", false: java.lang.Boolean) + ) + + val collectedData = new ConcurrentLinkedQueue[String]() + val committed = new JHashMap[TopicPartition, OffsetAndMetadata]() + + // Send data to Kafka and wait for it to be received + def sendDataAndWaitForReceive(data: Seq[Int]) { + val strings = data.map { _.toString} + kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(strings.forall { collectedData.contains }) + } + } + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(100)) + withClue("Error creating direct stream") { + val kafkaStream = KafkaUtils.createDirectStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) + kafkaStream.foreachRDD { (rdd: RDD[ConsumerRecord[String, String]], time: Time) => + val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + val data = rdd.map(_.value).collect() + collectedData.addAll(Arrays.asList(data: _*)) + kafkaStream.asInstanceOf[CanCommitOffsets] + .commitAsync(offsets, new OffsetCommitCallback() { + def onComplete(m: JMap[TopicPartition, OffsetAndMetadata], e: Exception) { + if (null != e) { + logError("commit failed", e) + } else { + committed.putAll(m) + } + } + }) + } + } + ssc.start() + // Send some data and wait for them to be received + for (i <- (1 to 10).grouped(4)) { + sendDataAndWaitForReceive(i) + } + ssc.stop() + assert(! committed.isEmpty) + val consumer = new KafkaConsumer[String, String](kafkaParams) + consumer.subscribe(Arrays.asList(topic)) + consumer.poll(0) + committed.asScala.foreach { + case (k, v) => + // commits are async, not exactly once + assert(v.offset > 0) + assert(consumer.position(k) >= v.offset) + } + } + + + test("Direct Kafka stream report input information") { + val topic = "report-test" + val data = Map("a" -> 7, "b" -> 9) + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, data) + + val totalSent = data.values.sum + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + + import DirectKafkaStreamSuite._ + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val collector = new InputInfoCollector + ssc.addStreamingListener(collector) + + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) + } + + val allReceived = new ConcurrentLinkedQueue[(String, String)] + + stream.map(r => (r.key, r.value)) + .foreachRDD { rdd => allReceived.addAll(Arrays.asList(rdd.collect(): _*)) } + ssc.start() + eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + assert(allReceived.size === totalSent, + "didn't get expected number of messages, messages:\n" + + allReceived.asScala.mkString("\n")) + + // Calculate all the record number collected in the StreamingListener. + assert(collector.numRecordsSubmitted.get() === totalSent) + assert(collector.numRecordsStarted.get() === totalSent) + assert(collector.numRecordsCompleted.get() === totalSent) + } + ssc.stop() + } + + test("maxMessagesPerPartition with backpressure disabled") { + val topic = "maxMessagesPerPartition" + val kafkaStream = getDirectKafkaStream(topic, None) + + val input = Map(new TopicPartition(topic, 0) -> 50L, new TopicPartition(topic, 1) -> 50L) + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 10L)) + } + + test("maxMessagesPerPartition with no lag") { + val topic = "maxMessagesPerPartition" + val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100)) + val kafkaStream = getDirectKafkaStream(topic, rateController) + + val input = Map(new TopicPartition(topic, 0) -> 0L, new TopicPartition(topic, 1) -> 0L) + assert(kafkaStream.maxMessagesPerPartition(input).isEmpty) + } + + test("maxMessagesPerPartition respects max rate") { + val topic = "maxMessagesPerPartition" + val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000)) + val kafkaStream = getDirectKafkaStream(topic, rateController) + + val input = Map(new TopicPartition(topic, 0) -> 1000L, new TopicPartition(topic, 1) -> 1000L) + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 10L)) + } + + test("using rate controller") { + val topic = "backpressure" + kafkaTestUtils.createTopic(topic, 1) + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + val executorKafkaParams = new JHashMap[String, Object](kafkaParams) + KafkaUtils.fixKafkaParams(executorKafkaParams) + + val batchIntervalMilliseconds = 500 + val estimator = new ConstantEstimator(100) + val messages = Map("foo" -> 5000) + kafkaTestUtils.sendMessages(topic, messages) + + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, estimator)) + }.map(r => (r.key, r.value)) + } + + val collectedData = new ConcurrentLinkedQueue[Array[String]]() + + // Used for assertion failure messages. + def dataToString: String = + collectedData.asScala.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}") + + // This is to collect the raw data received from Kafka + kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => + val data = rdd.map { _._2 }.collect() + collectedData.add(data) + } + + ssc.start() + + // Try different rate limits. + // Wait for arrays of data to appear matching the rate. + Seq(100, 50, 20).foreach { rate => + collectedData.clear() // Empty this buffer on each pass. + estimator.updateRate(rate) // Set a new rate. + // Expect blocks of data equal to "rate", scaled by the interval length in secs. + val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001) + eventually(timeout(5.seconds), interval(10 milliseconds)) { + // Assert that rate estimator values are used to determine maxMessagesPerPartition. + // Funky "-" in message makes the complete assertion message read better. + assert(collectedData.asScala.exists(_.size == expectedSize), + s" - No arrays of size $expectedSize for rate $rate found in $dataToString") + } + } + + ssc.stop() + } + + /** Get the generated offset ranges from the DirectKafkaStream */ + private def getOffsetRanges[K, V]( + kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = { + kafkaStream.generatedRDDs.mapValues { rdd => + rdd.asInstanceOf[HasOffsetRanges].offsetRanges + }.toSeq.sortBy { _._1 } + } + + private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = { + val batchIntervalMilliseconds = 100 + + val sparkConf = new SparkConf() + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + + val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") + val ekp = new JHashMap[String, Object](kafkaParams) + KafkaUtils.fixKafkaParams(ekp) + + val s = new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + new ConsumerStrategy[String, String] { + def executorKafkaParams = ekp + def onStart(currentOffsets: JMap[TopicPartition, JLong]): Consumer[String, String] = { + val consumer = new KafkaConsumer[String, String](kafkaParams) + val tps = List(new TopicPartition(topic, 0), new TopicPartition(topic, 1)) + consumer.assign(Arrays.asList(tps: _*)) + tps.foreach(tp => consumer.seek(tp, 0)) + consumer + } + } + ) { + override protected[streaming] val rateController = mockRateController + } + // manual start necessary because we arent consuming the stream, just checking its state + s.start() + s + } +} + +object DirectKafkaStreamSuite { + val total = new AtomicLong(-1L) + + class InputInfoCollector extends StreamingListener { + val numRecordsSubmitted = new AtomicLong(0L) + val numRecordsStarted = new AtomicLong(0L) + val numRecordsCompleted = new AtomicLong(0L) + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + numRecordsSubmitted.addAndGet(batchSubmitted.batchInfo.numRecords) + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = { + numRecordsStarted.addAndGet(batchStarted.batchInfo.numRecords) + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + numRecordsCompleted.addAndGet(batchCompleted.batchInfo.numRecords) + } + } +} + +private[streaming] class ConstantEstimator(@volatile private var rate: Long) + extends RateEstimator { + + def updateRate(newRate: Long): Unit = { + rate = newRate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(rate) +} + +private[streaming] class ConstantRateController(id: Int, estimator: RateEstimator, rate: Long) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = () + override def getLatestRate(): Long = rate +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala new file mode 100644 index 0000000000000..be373af0599cc --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{ util => ju } + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.StringDeserializer +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ +import org.apache.spark.scheduler.ExecutorCacheTaskLocation + +class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var kafkaTestUtils: KafkaTestUtils = _ + + private val sparkConf = new SparkConf().setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + private var sc: SparkContext = _ + + override def beforeAll { + sc = new SparkContext(sparkConf) + kafkaTestUtils = new KafkaTestUtils + kafkaTestUtils.setup() + } + + override def afterAll { + if (sc != null) { + sc.stop + sc = null + } + + if (kafkaTestUtils != null) { + kafkaTestUtils.teardown() + kafkaTestUtils = null + } + } + + private def getKafkaParams() = Map[String, Object]( + "bootstrap.servers" -> kafkaTestUtils.brokerAddress, + "key.deserializer" -> classOf[StringDeserializer], + "value.deserializer" -> classOf[StringDeserializer], + "group.id" -> s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}" + ).asJava + + private val preferredHosts = LocationStrategies.PreferConsistent + + test("basic usage") { + val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}" + kafkaTestUtils.createTopic(topic) + val messages = Array("the", "quick", "brown", "fox") + kafkaTestUtils.sendMessages(topic, messages) + + val kafkaParams = getKafkaParams() + + val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) + + val rdd = KafkaUtils.createRDD[String, String](sc, kafkaParams, offsetRanges, preferredHosts) + .map(_.value) + + val received = rdd.collect.toSet + assert(received === messages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === messages.size) + assert(rdd.countApprox(0).getFinalValue.mean === messages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head === messages.head) + assert(rdd.take(messages.size + 10).size === messages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)), preferredHosts) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + val result = KafkaUtils.createRDD[String, String](sc, kafkaParams, badRanges, preferredHosts) + .map(_.value) + .collect() + } + } + + test("iterator boundary conditions") { + // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd + val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + kafkaTestUtils.createTopic(topic) + + val kafkaParams = getKafkaParams() + + // this is the "lots of messages" case + kafkaTestUtils.sendMessages(topic, sent) + var sentCount = sent.values.sum + + val rdd = KafkaUtils.createRDD[String, String](sc, kafkaParams, + Array(OffsetRange(topic, 0, 0, sentCount)), preferredHosts) + + val ranges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + val rangeCount = ranges.map(o => o.untilOffset - o.fromOffset).sum + + assert(rangeCount === sentCount, "offset range didn't include all sent messages") + assert(rdd.map(_.offset).collect.sorted === (0 until sentCount).toArray, + "didn't get all sent messages") + + // this is the "0 messages" case + val rdd2 = KafkaUtils.createRDD[String, String](sc, kafkaParams, + Array(OffsetRange(topic, 0, sentCount, sentCount)), preferredHosts) + + // shouldn't get anything, since message is sent after rdd was defined + val sentOnlyOne = Map("d" -> 1) + + kafkaTestUtils.sendMessages(topic, sentOnlyOne) + + assert(rdd2.map(_.value).collect.size === 0, "got messages when there shouldn't be any") + + // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above + val rdd3 = KafkaUtils.createRDD[String, String](sc, kafkaParams, + Array(OffsetRange(topic, 0, sentCount, sentCount + 1)), preferredHosts) + + // send lots of messages after rdd was defined, they shouldn't show up + kafkaTestUtils.sendMessages(topic, Map("extra" -> 22)) + + assert(rdd3.map(_.value).collect.head === sentOnlyOne.keys.head, + "didn't get exactly one message") + } + + test("executor sorting") { + val kafkaParams = new ju.HashMap[String, Object](getKafkaParams()) + kafkaParams.put("auto.offset.reset", "none") + val rdd = new KafkaRDD[String, String]( + sc, + kafkaParams, + Array(OffsetRange("unused", 0, 1, 2)), + ju.Collections.emptyMap[TopicPartition, String](), + true) + val a3 = ExecutorCacheTaskLocation("a", "3") + val a4 = ExecutorCacheTaskLocation("a", "4") + val b1 = ExecutorCacheTaskLocation("b", "1") + val b2 = ExecutorCacheTaskLocation("b", "2") + + val correct = Array(b2, b1, a4, a3) + + correct.permutations.foreach { p => + assert(p.sortWith(rdd.compareExecutors) === correct) + } + } +} diff --git a/external/kafka-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml similarity index 91% rename from external/kafka-assembly/pom.xml rename to external/kafka-0-8-assembly/pom.xml index 62818f5e8f434..d6edd5a2c107e 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,24 +21,24 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka-assembly_2.11 + spark-streaming-kafka-0-8-assembly_2.11 jar Spark Project External Kafka Assembly http://spark.apache.org/ - streaming-kafka-assembly + streaming-kafka-0-8-assembly org.apache.spark - spark-streaming-kafka_${scala.binary.version} + spark-streaming-kafka-0-8_${scala.binary.version} ${project.version} @@ -65,16 +65,6 @@ protobuf-java provided - - com.sun.jersey - jersey-server - provided - - - com.sun.jersey - jersey-core - provided - net.jpountz.lz4 lz4 diff --git a/external/kafka/pom.xml b/external/kafka-0-8/pom.xml similarity index 92% rename from external/kafka/pom.xml rename to external/kafka-0-8/pom.xml index 68d52e9339b3d..d595c5552c6a9 100644 --- a/external/kafka/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,17 +21,17 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml org.apache.spark - spark-streaming-kafka_2.11 + spark-streaming-kafka-0-8_2.11 - streaming-kafka + streaming-kafka-0-8 jar - Spark Project External Kafka + Spark Integration for Kafka 0.8 http://spark.apache.org/ @@ -88,7 +88,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala similarity index 97% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index d9d4240c056a5..abfd7aad4c5c6 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -35,6 +35,7 @@ import kafka.serializer.StringEncoder import kafka.server.{KafkaConfig, KafkaServer} import kafka.utils.{ZKStringSerializer, ZkUtils} import org.I0Itec.zkclient.ZkClient +import org.apache.commons.lang3.RandomUtils import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.apache.spark.SparkConf @@ -62,7 +63,8 @@ private[kafka] class KafkaTestUtils extends Logging { // Kafka broker related configurations private val brokerHost = "localhost" - private var brokerPort = 9092 + // 0.8.2 server doesn't have a boundPort method, so can't use 0 for a random port + private var brokerPort = RandomUtils.nextInt(1024, 65536) private var brokerConf: KafkaConfig = _ // Kafka broker server @@ -112,7 +114,7 @@ private[kafka] class KafkaTestUtils extends Logging { brokerConf = new KafkaConfig(brokerConfiguration) server = new KafkaServer(brokerConf) server.startup() - (server, port) + (server, brokerPort) }, new SparkConf(), "KafkaBroker") brokerReady = true diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala similarity index 99% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index edaafb912c5c5..b17e198077949 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.kafka import java.io.OutputStream -import java.lang.{Integer => JInt, Long => JLong} +import java.lang.{Integer => JInt, Long => JLong, Number => JNumber} import java.nio.charset.StandardCharsets import java.util.{List => JList, Map => JMap, Set => JSet} @@ -682,7 +682,7 @@ private[kafka] class KafkaUtilsPythonHelper { jssc: JavaStreamingContext, kafkaParams: JMap[String, String], topics: JSet[String], - fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[(Array[Byte], Array[Byte])] = { + fromOffsets: JMap[TopicAndPartition, JNumber]): JavaDStream[(Array[Byte], Array[Byte])] = { val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => (mmd.key, mmd.message) new JavaDStream(createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler)) @@ -692,7 +692,7 @@ private[kafka] class KafkaUtilsPythonHelper { jssc: JavaStreamingContext, kafkaParams: JMap[String, String], topics: JSet[String], - fromOffsets: JMap[TopicAndPartition, JLong]): JavaDStream[Array[Byte]] = { + fromOffsets: JMap[TopicAndPartition, JNumber]): JavaDStream[Array[Byte]] = { val messageHandler = (mmd: MessageAndMetadata[Array[Byte], Array[Byte]]) => new PythonMessageAndMetadata(mmd.topic, mmd.partition, mmd.offset, mmd.key(), mmd.message()) val stream = createDirectStream(jssc, kafkaParams, topics, fromOffsets, messageHandler). @@ -704,7 +704,7 @@ private[kafka] class KafkaUtilsPythonHelper { jssc: JavaStreamingContext, kafkaParams: JMap[String, String], topics: JSet[String], - fromOffsets: JMap[TopicAndPartition, JLong], + fromOffsets: JMap[TopicAndPartition, JNumber], messageHandler: MessageAndMetadata[Array[Byte], Array[Byte]] => V): DStream[V] = { val currentFromOffsets = if (!fromOffsets.isEmpty) { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/package-info.java b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/package-info.java similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/package-info.java rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/package-info.java diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/package.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/package.scala similarity index 100% rename from external/kafka/src/main/scala/org/apache/spark/streaming/kafka/package.scala rename to external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/package.scala diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka-0-8/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java similarity index 100% rename from external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java rename to external/kafka-0-8/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka-0-8/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java similarity index 100% rename from external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java rename to external/kafka-0-8/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka-0-8/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java similarity index 88% rename from external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java rename to external/kafka-0-8/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 868df64e8c944..98fe38e826afb 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka-0-8/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -122,14 +122,23 @@ public void call(JavaPairRDD rdd) { ssc.start(); long startTime = System.currentTimeMillis(); - boolean sizeMatches = false; - while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { - sizeMatches = sent.size() == result.size(); + AssertionError lastError = null; + while (System.currentTimeMillis() - startTime < 20000) { + try { + Assert.assertEquals(sent.size(), result.size()); + for (Map.Entry e : sent.entrySet()) { + Assert.assertEquals(e.getValue().intValue(), result.get(e.getKey()).intValue()); + } + return; + } catch (AssertionError e) { + lastError = e; + } Thread.sleep(200); } - Assert.assertEquals(sent.size(), result.size()); - for (Map.Entry e : sent.entrySet()) { - Assert.assertEquals(e.getValue().intValue(), result.get(e.getKey()).intValue()); + if (lastError != null) { + throw lastError; + } else { + Assert.fail("timeout"); } } } diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka-0-8/src/test/resources/log4j.properties similarity index 100% rename from external/kafka/src/test/resources/log4j.properties rename to external/kafka-0-8/src/test/resources/log4j.properties diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala similarity index 93% rename from external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala rename to external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index f14ff6705fd97..ab1c5055a253f 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -244,12 +244,9 @@ class DirectKafkaStreamSuite ) // Send data to Kafka and wait for it to be received - def sendDataAndWaitForReceive(data: Seq[Int]) { + def sendData(data: Seq[Int]) { val strings = data.map { _.toString} kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap) - eventually(timeout(10 seconds), interval(50 milliseconds)) { - assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains }) - } } // Setup the streaming context @@ -264,31 +261,37 @@ class DirectKafkaStreamSuite } ssc.checkpoint(testDir.getAbsolutePath) - // This is to collect the raw data received from Kafka - kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) => - val data = rdd.map { _._2 }.collect() - DirectKafkaStreamSuite.collectedData.addAll(Arrays.asList(data: _*)) - } - // This is ensure all the data is eventually receiving only once stateStream.foreachRDD { (rdd: RDD[(String, Int)]) => - rdd.collect().headOption.foreach { x => DirectKafkaStreamSuite.total = x._2 } + rdd.collect().headOption.foreach { x => + DirectKafkaStreamSuite.total.set(x._2) + } } ssc.start() - // Send some data and wait for them to be received + // Send some data for (i <- (1 to 10).grouped(4)) { - sendDataAndWaitForReceive(i) + sendData(i) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + assert(DirectKafkaStreamSuite.total.get === (1 to 10).sum) } + ssc.stop() + // Verify that offset ranges were generated - val offsetRangesBeforeStop = getOffsetRanges(kafkaStream) - assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated") + // Since "offsetRangesAfterStop" will be used to compare with "recoveredOffsetRanges", we should + // collect offset ranges after stopping. Otherwise, because new RDDs keep being generated before + // stopping, we may not be able to get the latest RDDs, then "recoveredOffsetRanges" will + // contain something not in "offsetRangesAfterStop". + val offsetRangesAfterStop = getOffsetRanges(kafkaStream) + assert(offsetRangesAfterStop.size >= 1, "No offset ranges generated") assert( - offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 }, + offsetRangesAfterStop.head._2.forall { _.fromOffset === 0 }, "starting offset not zero" ) - ssc.stop() + logInfo("====== RESTARTING ========") // Recover context from checkpoints @@ -296,21 +299,26 @@ class DirectKafkaStreamSuite val recoveredStream = ssc.graph.getInputStreams().head.asInstanceOf[DStream[(String, String)]] // Verify offset ranges have been recovered - val recoveredOffsetRanges = getOffsetRanges(recoveredStream) + val recoveredOffsetRanges = getOffsetRanges(recoveredStream).map { x => (x._1, x._2.toSet) } assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered") - val earlierOffsetRangesAsSets = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) } + val earlierOffsetRanges = offsetRangesAfterStop.map { x => (x._1, x._2.toSet) } assert( recoveredOffsetRanges.forall { or => - earlierOffsetRangesAsSets.contains((or._1, or._2.toSet)) + earlierOffsetRanges.contains((or._1, or._2)) }, - "Recovered ranges are not the same as the ones generated" + "Recovered ranges are not the same as the ones generated\n" + + s"recoveredOffsetRanges: $recoveredOffsetRanges\n" + + s"earlierOffsetRanges: $earlierOffsetRanges" ) // Restart context, give more data and verify the total at the end // If the total is write that means each records has been received only once ssc.start() - sendDataAndWaitForReceive(11 to 20) + for (i <- (11 to 20).grouped(4)) { + sendData(i) + } + eventually(timeout(10 seconds), interval(50 milliseconds)) { - assert(DirectKafkaStreamSuite.total === (1 to 20).sum) + assert(DirectKafkaStreamSuite.total.get === (1 to 20).sum) } ssc.stop() } @@ -480,8 +488,7 @@ class DirectKafkaStreamSuite } object DirectKafkaStreamSuite { - val collectedData = new ConcurrentLinkedQueue[String]() - @volatile var total = -1L + val total = new AtomicLong(-1L) class InputInfoCollector extends StreamingListener { val numRecordsSubmitted = new AtomicLong(0L) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala similarity index 100% rename from external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala rename to external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala similarity index 95% rename from external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala rename to external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 5e539c1d790cc..809699a739962 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -53,13 +53,13 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { } test("basic usage") { - val topic = s"topicbasic-${Random.nextInt}" + val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}" kafkaTestUtils.createTopic(topic) val messages = Array("the", "quick", "brown", "fox") kafkaTestUtils.sendMessages(topic, messages) val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, - "group.id" -> s"test-consumer-${Random.nextInt}") + "group.id" -> s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}") val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) @@ -92,12 +92,12 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { test("iterator boundary conditions") { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd - val topic = s"topicboundary-${Random.nextInt}" + val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) kafkaTestUtils.createTopic(topic) val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, - "group.id" -> s"test-consumer-${Random.nextInt}") + "group.id" -> s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}") val kc = new KafkaCluster(kafkaParams) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala similarity index 100% rename from external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala rename to external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala similarity index 100% rename from external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala rename to external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index d1c38c7ca5d69..8076c899f21a1 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -63,16 +63,26 @@ com.google.protobuf protobuf-java + 2.6.1 + + + + org.glassfish.jersey.core + jersey-client provided - com.sun.jersey - jersey-server + org.glassfish.jersey.core + jersey-common provided - com.sun.jersey - jersey-core + org.glassfish.jersey.core + jersey-server provided @@ -132,6 +142,21 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + org.apache.maven.plugins + maven-install-plugin + + true + + org.apache.maven.plugins maven-shade-plugin @@ -142,6 +167,15 @@ *:* + + + com.google.protobuf + kinesis.protobuf + + com.google.protobuf.** + + + *:* diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index 935155eb5d362..5a59235753ed5 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -77,7 +77,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index bfb92791de3d8..9a0dbb1d4e2f4 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 1813f383cdcba..9dc26601d8e2d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml @@ -72,7 +72,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index 31373a53cf933..43594573cf013 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -27,7 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.util.collection.BitSet /** - * An class containing additional operations for subclasses of VertexPartitionBase that provide + * A class containing additional operations for subclasses of VertexPartitionBase that provide * implicit evidence of membership in the `VertexPartitionBaseOpsConstructor` typeclass (for * example, [[VertexPartition.VertexPartitionOpsConstructor]]). */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 80c6b6838faf5..2b3e5f98c4fe5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -119,7 +119,7 @@ object GraphGenerators extends Logging { * A random graph generator using the R-MAT model, proposed in * "R-MAT: A Recursive Model for Graph Mining" by Chakrabarti et al. * - * See [[http://www.cs.cmu.edu/~christos/PUBLICATIONS/siam04.pdf]]. + * See http://www.cs.cmu.edu/~christos/PUBLICATIONS/siam04.pdf. */ def rmatGraph(sc: SparkContext, requestedNumVertices: Int, numEdges: Int): Graph[Int, Int] = { // let N = requestedNumVertices @@ -209,7 +209,6 @@ object GraphGenerators extends Logging { } } - // TODO(crankshaw) turn result into an enum (or case class for pattern matching} private def pickQuadrant(a: Double, b: Double, c: Double, d: Double): Int = { if (a + b + c + d != 1.0) { throw new IllegalArgumentException("R-MAT probability parameters sum to " + (a + b + c + d) diff --git a/launcher/pom.xml b/launcher/pom.xml index ef731948826ef..8e8610c37247d 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml @@ -65,7 +65,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 91586aad7b709..62a22008d0d5d 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -325,7 +325,7 @@ static void addPermGenSizeOpt(List cmd) { return; } for (String arg : cmd) { - if (arg.startsWith("-XX:MaxPermSize=")) { + if (arg.contains("-XX:MaxPermSize=")) { return; } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index e3413fd6652d8..ae43f563e8b46 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -337,6 +337,10 @@ public void close() throws IOException { } super.close(); if (handle != null) { + if (!handle.getState().isFinal()) { + LOG.log(Level.WARNING, "Lost connection to spark application."); + handle.setState(SparkAppHandle.State.LOST); + } handle.disconnect(); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java index 625d02632114a..0aa7bd197d16f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java @@ -46,7 +46,9 @@ enum State { /** The application finished with a failed status. */ FAILED(true), /** The application was killed. */ - KILLED(true); + KILLED(true), + /** The Spark Submit JVM exited with a unknown status. */ + LOST(true); private final boolean isFinal; diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index a083f05a2a9f7..08873f5811238 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -66,6 +66,13 @@ public class SparkLauncher { /** Logger name to use when launching a child process. */ public static final String CHILD_PROCESS_LOGGER_NAME = "spark.launcher.childProcLoggerName"; + /** + * A special value for the resource that tells Spark to not try to process the app resource as a + * file. This is useful when the class being executed is added to the application using other + * means - for example, by adding jars using the package download feature. + */ + public static final String NO_RESOURCE = "spark-internal"; + /** * Maximum time (in ms) to wait for a child process to connect back to the launcher server * when using @link{#start()}. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 6941ca903cd0a..b3ccc4805f2c5 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -83,13 +83,13 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { static { specialClasses.put("org.apache.spark.repl.Main", "spark-shell"); specialClasses.put("org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver", - "spark-internal"); + SparkLauncher.NO_RESOURCE); specialClasses.put("org.apache.spark.sql.hive.thriftserver.HiveThriftServer2", - "spark-internal"); + SparkLauncher.NO_RESOURCE); } final List sparkArgs; - private final boolean printInfo; + private final boolean isAppResourceReq; private final boolean isExample; /** @@ -101,42 +101,51 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { SparkSubmitCommandBuilder() { this.sparkArgs = new ArrayList<>(); - this.printInfo = false; + this.isAppResourceReq = true; this.isExample = false; } SparkSubmitCommandBuilder(List args) { this.allowsMixedArguments = false; - + this.sparkArgs = new ArrayList<>(); boolean isExample = false; List submitArgs = args; - if (args.size() > 0 && args.get(0).equals(PYSPARK_SHELL)) { - this.allowsMixedArguments = true; - appResource = PYSPARK_SHELL_RESOURCE; - submitArgs = args.subList(1, args.size()); - } else if (args.size() > 0 && args.get(0).equals(SPARKR_SHELL)) { - this.allowsMixedArguments = true; - appResource = SPARKR_SHELL_RESOURCE; - submitArgs = args.subList(1, args.size()); - } else if (args.size() > 0 && args.get(0).equals(RUN_EXAMPLE)) { - isExample = true; - submitArgs = args.subList(1, args.size()); - } - this.sparkArgs = new ArrayList<>(); - this.isExample = isExample; + if (args.size() > 0) { + switch (args.get(0)) { + case PYSPARK_SHELL: + this.allowsMixedArguments = true; + appResource = PYSPARK_SHELL; + submitArgs = args.subList(1, args.size()); + break; + + case SPARKR_SHELL: + this.allowsMixedArguments = true; + appResource = SPARKR_SHELL; + submitArgs = args.subList(1, args.size()); + break; + + case RUN_EXAMPLE: + isExample = true; + submitArgs = args.subList(1, args.size()); + } - OptionParser parser = new OptionParser(); - parser.parse(submitArgs); - this.printInfo = parser.infoRequested; + this.isExample = isExample; + OptionParser parser = new OptionParser(); + parser.parse(submitArgs); + this.isAppResourceReq = parser.isAppResourceReq; + } else { + this.isExample = isExample; + this.isAppResourceReq = false; + } } @Override public List buildCommand(Map env) throws IOException, IllegalArgumentException { - if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printInfo) { + if (PYSPARK_SHELL.equals(appResource) && isAppResourceReq) { return buildPySparkShellCommand(env); - } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printInfo) { + } else if (SPARKR_SHELL.equals(appResource) && isAppResourceReq) { return buildSparkRCommand(env); } else { return buildSparkSubmitCommand(env); @@ -147,6 +156,10 @@ List buildSparkSubmitArgs() { List args = new ArrayList<>(); SparkSubmitOptionParser parser = new SparkSubmitOptionParser(); + if (!allowsMixedArguments && isAppResourceReq) { + checkArgument(appResource != null, "Missing application resource."); + } + if (verbose) { args.add(parser.VERBOSE); } @@ -195,7 +208,7 @@ List buildSparkSubmitArgs() { args.add(join(",", pyFiles)); } - if (!printInfo) { + if (isAppResourceReq) { checkArgument(!isExample || mainClass != null, "Missing example class name."); } if (mainClass != null) { @@ -278,6 +291,7 @@ private List buildPySparkShellCommand(Map env) throws IO // When launching the pyspark shell, the spark-submit arguments should be stored in the // PYSPARK_SUBMIT_ARGS env variable. + appResource = PYSPARK_SHELL_RESOURCE; constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS"); // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script, @@ -301,6 +315,7 @@ private List buildSparkRCommand(Map env) throws IOExcept } // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS // env variable. + appResource = SPARKR_SHELL_RESOURCE; constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS"); // Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up. @@ -373,7 +388,7 @@ private List findExamplesJars() { private class OptionParser extends SparkSubmitOptionParser { - boolean infoRequested = false; + boolean isAppResourceReq = true; @Override protected boolean handle(String opt, String value) { @@ -405,11 +420,15 @@ protected boolean handle(String opt, String value) { allowsMixedArguments = true; appResource = specialClasses.get(value); } + } else if (opt.equals(KILL_SUBMISSION) || opt.equals(STATUS)) { + isAppResourceReq = false; + sparkArgs.add(opt); + sparkArgs.add(value); } else if (opt.equals(HELP) || opt.equals(USAGE_ERROR)) { - infoRequested = true; + isAppResourceReq = false; sparkArgs.add(opt); } else if (opt.equals(VERSION)) { - infoRequested = true; + isAppResourceReq = false; sparkArgs.add(opt); } else { sparkArgs.add(opt); @@ -435,22 +454,19 @@ protected boolean handleUnknown(String opt) { className = EXAMPLE_CLASS_PREFIX + className; } mainClass = className; - appResource = "spark-internal"; + appResource = SparkLauncher.NO_RESOURCE; return false; } else { checkArgument(!opt.startsWith("-"), "Unrecognized option: %s", opt); - sparkArgs.add(opt); + checkState(appResource == null, "Found unrecognized argument but resource is already set."); + appResource = opt; return false; } } @Override protected void handleExtraArgs(List extra) { - if (isExample) { - appArgs.addAll(extra); - } else { - sparkArgs.addAll(extra); - } + appArgs.addAll(extra); } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index 4fafc43ef293b..caeeea5ec6dd5 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -99,12 +99,48 @@ public void testJavaMajorVersion() { assertEquals(10, javaMajorVersion("10")); } - private void testOpt(String opts, List expected) { + @Test + public void testAddPermGenSizeOpt() { + List cmd = new ArrayList<>(); + + if (javaMajorVersion(System.getProperty("java.version")) > 7) { + // Does nothing in Java 8 + addPermGenSizeOpt(cmd); + assertEquals(0, cmd.size()); + cmd.clear(); + + } else { + addPermGenSizeOpt(cmd); + assertEquals(1, cmd.size()); + assertTrue(cmd.get(0).startsWith("-XX:MaxPermSize=")); + cmd.clear(); + + cmd.add("foo"); + addPermGenSizeOpt(cmd); + assertEquals(2, cmd.size()); + assertTrue(cmd.get(1).startsWith("-XX:MaxPermSize=")); + cmd.clear(); + + cmd.add("-XX:MaxPermSize=512m"); + addPermGenSizeOpt(cmd); + assertEquals(1, cmd.size()); + assertEquals("-XX:MaxPermSize=512m", cmd.get(0)); + cmd.clear(); + + cmd.add("'-XX:MaxPermSize=512m'"); + addPermGenSizeOpt(cmd); + assertEquals(1, cmd.size()); + assertEquals("'-XX:MaxPermSize=512m'", cmd.get(0)); + cmd.clear(); + } + } + + private static void testOpt(String opts, List expected) { assertEquals(String.format("test string failed to parse: [[ %s ]]", opts), expected, parseOptionString(opts)); } - private void testInvalidOpt(String opts) { + private static void testInvalidOpt(String opts) { try { parseOptionString(opts); fail("Expected exception for invalid option string."); diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index bfe1fcc87fe35..12f1a0ce2d1b4 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -152,6 +152,37 @@ public void testTimeout() throws Exception { } } + @Test + public void testSparkSubmitVmShutsDown() throws Exception { + ChildProcAppHandle handle = LauncherServer.newAppHandle(); + TestClient client = null; + final Semaphore semaphore = new Semaphore(0); + try { + Socket s = new Socket(InetAddress.getLoopbackAddress(), + LauncherServer.getServerInstance().getPort()); + handle.addListener(new SparkAppHandle.Listener() { + public void stateChanged(SparkAppHandle handle) { + semaphore.release(); + } + public void infoChanged(SparkAppHandle handle) { + semaphore.release(); + } + }); + client = new TestClient(s); + client.send(new Hello(handle.getSecret(), "1.4.0")); + assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); + // Make sure the server matched the client to the handle. + assertNotNull(handle.getConnection()); + close(client); + assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); + assertEquals(SparkAppHandle.State.LOST, handle.getState()); + } finally { + kill(handle); + close(client); + client.clientThread.join(); + } + } + private void kill(SparkAppHandle handle) { if (handle != null) { handle.kill(); diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index c7e8b2e03a9fa..16e5a22401ca8 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -58,6 +58,26 @@ public void testClusterCmdBuilder() throws Exception { testCmdBuilder(false, false); } + @Test + public void testCliHelpAndNoArg() throws Exception { + List helpArgs = Arrays.asList(parser.HELP); + Map env = new HashMap<>(); + List cmd = buildCommand(helpArgs, env); + assertTrue("--help should be contained in the final cmd.", cmd.contains(parser.HELP)); + + List sparkEmptyArgs = Collections.emptyList(); + cmd = buildCommand(sparkEmptyArgs, env); + assertTrue( + "org.apache.spark.deploy.SparkSubmit should be contained in the final cmd of empty input.", + cmd.contains("org.apache.spark.deploy.SparkSubmit")); + } + + @Test + public void testCliKillAndStatus() throws Exception { + testCLIOpts(parser.STATUS); + testCLIOpts(parser.KILL_SUBMISSION); + } + @Test public void testCliParser() throws Exception { List sparkSubmitArgs = Arrays.asList( @@ -72,7 +92,8 @@ public void testCliParser() throws Exception { parser.CONF, "spark.randomOption=foo", parser.CONF, - SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH + "=/driverLibPath"); + SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH + "=/driverLibPath", + SparkLauncher.NO_RESOURCE); Map env = new HashMap<>(); List cmd = buildCommand(sparkSubmitArgs, env); @@ -109,7 +130,8 @@ public void testAlternateSyntaxParsing() throws Exception { List sparkSubmitArgs = Arrays.asList( parser.CLASS + "=org.my.Class", parser.MASTER + "=foo", - parser.DEPLOY_MODE + "=bar"); + parser.DEPLOY_MODE + "=bar", + SparkLauncher.NO_RESOURCE); List cmd = newCommandBuilder(sparkSubmitArgs).buildSparkSubmitArgs(); assertEquals("org.my.Class", findArgValue(cmd, parser.CLASS)); @@ -168,6 +190,11 @@ public void testExamplesRunner() throws Exception { assertEquals("42", cmd.get(cmd.size() - 1)); } + @Test(expected = IllegalArgumentException.class) + public void testMissingAppResource() { + new SparkSubmitCommandBuilder().buildSparkSubmitArgs(); + } + private void testCmdBuilder(boolean isDriver, boolean useDefaultPropertyFile) throws Exception { String deployMode = isDriver ? "client" : "cluster"; @@ -305,4 +332,12 @@ private List buildCommand(List args, Map env) th return newCommandBuilder(args).buildCommand(env); } + private void testCLIOpts(String opt) throws Exception { + List helpArgs = Arrays.asList(opt, "driver-20160531171222-0000"); + Map env = new HashMap<>(); + List cmd = buildCommand(helpArgs, env); + assertTrue(opt + " should be contained in the final cmd.", + cmd.contains(opt)); + } + } diff --git a/licenses/LICENSE-modernizr.txt b/licenses/LICENSE-modernizr.txt new file mode 100644 index 0000000000000..2bf24b9b9f848 --- /dev/null +++ b/licenses/LICENSE-modernizr.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file diff --git a/licenses/LICENSE-postgresql.txt b/licenses/LICENSE-postgresql.txt new file mode 100644 index 0000000000000..515bf9af4d432 --- /dev/null +++ b/licenses/LICENSE-postgresql.txt @@ -0,0 +1,24 @@ +PostgreSQL Database Management System +(formerly known as Postgres, then as Postgres95) + +Portions Copyright (c) 1996-2010, PostgreSQL Global Development Group + +Portions Copyright (c) 1994, The Regents of the University of California + +Permission to use, copy, modify, and distribute this software and its +documentation for any purpose, without fee, and without a written agreement +is hereby granted, provided that the above copyright notice and this +paragraph and the following two paragraphs appear in all copies. + +IN NO EVENT SHALL THE UNIVERSITY OF CALIFORNIA BE LIABLE TO ANY PARTY FOR +DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING +LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS +DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + +THE UNIVERSITY OF CALIFORNIA SPECIFICALLY DISCLAIMS ANY WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS +ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO +PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 68f15dd905028..0c0fbaebe8a81 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml @@ -53,6 +53,10 @@ mockito-core test + + org.apache.spark + spark-tags_${scala.binary.version} + diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 41b0c6c89a647..4ca19f3387f07 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -638,12 +638,16 @@ private[spark] object BLAS extends Serializable { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } else if (xIndices(k) < Acols(i)) { + k += 1 + } else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala index 8204b5af02cff..f1ecc65af1105 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala @@ -24,21 +24,28 @@ import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, Has import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.annotation.Since + /** * Trait for a local matrix. */ +@Since("2.0.0") sealed trait Matrix extends Serializable { /** Number of rows. */ + @Since("2.0.0") def numRows: Int /** Number of columns. */ + @Since("2.0.0") def numCols: Int /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + @Since("2.0.0") val isTransposed: Boolean = false /** Converts to a dense array in column major. */ + @Since("2.0.0") def toArray: Array[Double] = { val newArray = new Array[Double](numRows * numCols) foreachActive { (i, j, v) => @@ -51,18 +58,21 @@ sealed trait Matrix extends Serializable { * Returns an iterator of column vectors. * This operation could be expensive, depending on the underlying storage. */ + @Since("2.0.0") def colIter: Iterator[Vector] /** * Returns an iterator of row vectors. * This operation could be expensive, depending on the underlying storage. */ + @Since("2.0.0") def rowIter: Iterator[Vector] = this.transpose.colIter /** Converts to a breeze matrix. */ - private[ml] def toBreeze: BM[Double] + private[ml] def asBreeze: BM[Double] /** Gets the (i, j)-th element. */ + @Since("2.0.0") def apply(i: Int, j: Int): Double /** Return the index for the (i, j)-th element in the backing array. */ @@ -72,12 +82,15 @@ sealed trait Matrix extends Serializable { private[ml] def update(i: Int, j: Int, v: Double): Unit /** Get a deep copy of the matrix. */ + @Since("2.0.0") def copy: Matrix /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + @Since("2.0.0") def transpose: Matrix /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ + @Since("2.0.0") def multiply(y: DenseMatrix): DenseMatrix = { val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) BLAS.gemm(1.0, this, y, 0.0, C) @@ -85,11 +98,13 @@ sealed trait Matrix extends Serializable { } /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */ + @Since("2.0.0") def multiply(y: DenseVector): DenseVector = { multiply(y.asInstanceOf[Vector]) } /** Convenience method for `Matrix`-`Vector` multiplication. */ + @Since("2.0.0") def multiply(y: Vector): DenseVector = { val output = new DenseVector(new Array[Double](numRows)) BLAS.gemv(1.0, this, y, 0.0, output) @@ -97,10 +112,11 @@ sealed trait Matrix extends Serializable { } /** A human readable representation of the matrix */ - override def toString: String = toBreeze.toString() + override def toString: String = asBreeze.toString() /** A human readable representation of the matrix with maximum lines and width */ - def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) + @Since("2.0.0") + def toString(maxLines: Int, maxLineWidth: Int): String = asBreeze.toString(maxLines, maxLineWidth) /** * Map the values of this matrix using a function. Generates a new matrix. Performs the @@ -129,11 +145,13 @@ sealed trait Matrix extends Serializable { /** * Find the number of non-zero active values. */ + @Since("2.0.0") def numNonzeros: Int /** * Find the number of values stored explicitly. These values can be zero as well. */ + @Since("2.0.0") def numActives: Int } @@ -154,10 +172,11 @@ sealed trait Matrix extends Serializable { * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in * row major. */ -class DenseMatrix ( - val numRows: Int, - val numCols: Int, - val values: Array[Double], +@Since("2.0.0") +class DenseMatrix @Since("2.0.0") ( + @Since("2.0.0") val numRows: Int, + @Since("2.0.0") val numCols: Int, + @Since("2.0.0") val values: Array[Double], override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + @@ -178,11 +197,12 @@ class DenseMatrix ( * @param numCols number of columns * @param values matrix entries in column major */ + @Since("2.0.0") def this(numRows: Int, numCols: Int, values: Array[Double]) = this(numRows, numCols, values, false) override def equals(o: Any): Boolean = o match { - case m: Matrix => toBreeze == m.toBreeze + case m: Matrix => asBreeze == m.asBreeze case _ => false } @@ -190,7 +210,7 @@ class DenseMatrix ( Seq(numRows, numCols, toArray).## } - private[ml] def toBreeze: BM[Double] = { + private[ml] def asBreeze: BM[Double] = { if (!isTransposed) { new BDM[Double](numRows, numCols, values) } else { @@ -266,6 +286,7 @@ class DenseMatrix ( * Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed * set to false. */ + @Since("2.0.0") def toSparse: SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) @@ -307,6 +328,7 @@ class DenseMatrix ( /** * Factory methods for [[org.apache.spark.ml.linalg.DenseMatrix]]. */ +@Since("2.0.0") object DenseMatrix { /** @@ -315,6 +337,7 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of zeros */ + @Since("2.0.0") def zeros(numRows: Int, numCols: Int): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -327,6 +350,7 @@ object DenseMatrix { * @param numCols number of columns of the matrix * @return `DenseMatrix` with size `numRows` x `numCols` and values of ones */ + @Since("2.0.0") def ones(numRows: Int, numCols: Int): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -338,6 +362,7 @@ object DenseMatrix { * @param n number of rows and columns of the matrix * @return `DenseMatrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("2.0.0") def eye(n: Int): DenseMatrix = { val identity = DenseMatrix.zeros(n, n) var i = 0 @@ -355,6 +380,7 @@ object DenseMatrix { * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("2.0.0") def rand(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -368,6 +394,7 @@ object DenseMatrix { * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("2.0.0") def randn(numRows: Int, numCols: Int, rng: Random): DenseMatrix = { require(numRows.toLong * numCols <= Int.MaxValue, s"$numRows x $numCols dense matrix is too large to allocate") @@ -380,6 +407,7 @@ object DenseMatrix { * @return Square `DenseMatrix` with size `values.length` x `values.length` and `values` * on the diagonal */ + @Since("2.0.0") def diag(vector: Vector): DenseMatrix = { val n = vector.size val matrix = DenseMatrix.zeros(n, n) @@ -415,20 +443,24 @@ object DenseMatrix { * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ -class SparseMatrix ( - val numRows: Int, - val numCols: Int, - val colPtrs: Array[Int], - val rowIndices: Array[Int], - val values: Array[Double], +@Since("2.0.0") +class SparseMatrix @Since("2.0.0") ( + @Since("2.0.0") val numRows: Int, + @Since("2.0.0") val numCols: Int, + @Since("2.0.0") val colPtrs: Array[Int], + @Since("2.0.0") val rowIndices: Array[Int], + @Since("2.0.0") val values: Array[Double], override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - // The Or statement is for the case when the matrix is transposed - require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + - "column indices should be the number of columns + 1. Currently, colPointers.length: " + - s"${colPtrs.length}, numCols: $numCols") + if (isTransposed) { + require(colPtrs.length == numRows + 1, + s"Expecting ${numRows + 1} colPtrs when numRows = $numRows but got ${colPtrs.length}") + } else { + require(colPtrs.length == numCols + 1, + s"Expecting ${numCols + 1} colPtrs when numCols = $numCols but got ${colPtrs.length}") + } require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") @@ -451,6 +483,7 @@ class SparseMatrix ( * order for each column * @param values non-zero matrix entries in column major */ + @Since("2.0.0") def this( numRows: Int, numCols: Int, @@ -458,14 +491,14 @@ class SparseMatrix ( rowIndices: Array[Int], values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) - override def hashCode(): Int = toBreeze.hashCode() + override def hashCode(): Int = asBreeze.hashCode() override def equals(o: Any): Boolean = o match { - case m: Matrix => toBreeze == m.toBreeze + case m: Matrix => asBreeze == m.asBreeze case _ => false } - private[ml] def toBreeze: BM[Double] = { + private[ml] def asBreeze: BM[Double] = { if (!isTransposed) { new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) } else { @@ -550,6 +583,7 @@ class SparseMatrix ( * Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed * set to false. */ + @Since("2.0.0") def toDense: DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } @@ -594,6 +628,7 @@ class SparseMatrix ( /** * Factory methods for [[org.apache.spark.ml.linalg.SparseMatrix]]. */ +@Since("2.0.0") object SparseMatrix { /** @@ -605,6 +640,7 @@ object SparseMatrix { * @param entries Array of (i, j, value) tuples * @return The corresponding `SparseMatrix` */ + @Since("2.0.0") def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = { val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1)) val numEntries = sortedEntries.size @@ -653,6 +689,7 @@ object SparseMatrix { * @param n number of rows and columns of the matrix * @return `SparseMatrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("2.0.0") def speye(n: Int): SparseMatrix = { new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0)) } @@ -722,6 +759,7 @@ object SparseMatrix { * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("2.0.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextDouble()) @@ -735,6 +773,7 @@ object SparseMatrix { * @param rng a random number generator * @return `SparseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("2.0.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): SparseMatrix = { val mat = genRandMatrix(numRows, numCols, density, rng) mat.update(i => rng.nextGaussian()) @@ -746,6 +785,7 @@ object SparseMatrix { * @return Square `SparseMatrix` with size `values.length` x `values.length` and non-zero * `values` on the diagonal */ + @Since("2.0.0") def spdiag(vector: Vector): SparseMatrix = { val n = vector.size vector match { @@ -762,6 +802,7 @@ object SparseMatrix { /** * Factory methods for [[org.apache.spark.ml.linalg.Matrix]]. */ +@Since("2.0.0") object Matrices { /** @@ -771,6 +812,7 @@ object Matrices { * @param numCols number of columns * @param values matrix entries in column major */ + @Since("2.0.0") def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix = { new DenseMatrix(numRows, numCols, values) } @@ -784,6 +826,7 @@ object Matrices { * @param rowIndices the row index of the entry * @param values non-zero matrix entries in column major */ + @Since("2.0.0") def sparse( numRows: Int, numCols: Int, @@ -825,6 +868,7 @@ object Matrices { * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of zeros */ + @Since("2.0.0") def zeros(numRows: Int, numCols: Int): Matrix = DenseMatrix.zeros(numRows, numCols) /** @@ -833,6 +877,7 @@ object Matrices { * @param numCols number of columns of the matrix * @return `Matrix` with size `numRows` x `numCols` and values of ones */ + @Since("2.0.0") def ones(numRows: Int, numCols: Int): Matrix = DenseMatrix.ones(numRows, numCols) /** @@ -840,6 +885,7 @@ object Matrices { * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("2.0.0") def eye(n: Int): Matrix = DenseMatrix.eye(n) /** @@ -847,6 +893,7 @@ object Matrices { * @param n number of rows and columns of the matrix * @return `Matrix` with size `n` x `n` and values of ones on the diagonal */ + @Since("2.0.0") def speye(n: Int): Matrix = SparseMatrix.speye(n) /** @@ -856,6 +903,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("2.0.0") def rand(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.rand(numRows, numCols, rng) @@ -867,6 +915,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in U(0, 1) */ + @Since("2.0.0") def sprand(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprand(numRows, numCols, density, rng) @@ -877,6 +926,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("2.0.0") def randn(numRows: Int, numCols: Int, rng: Random): Matrix = DenseMatrix.randn(numRows, numCols, rng) @@ -888,6 +938,7 @@ object Matrices { * @param rng a random number generator * @return `Matrix` with size `numRows` x `numCols` and values in N(0, 1) */ + @Since("2.0.0") def sprandn(numRows: Int, numCols: Int, density: Double, rng: Random): Matrix = SparseMatrix.sprandn(numRows, numCols, density, rng) @@ -897,6 +948,7 @@ object Matrices { * @return Square `Matrix` with size `values.length` x `values.length` and `values` * on the diagonal */ + @Since("2.0.0") def diag(vector: Vector): Matrix = DenseMatrix.diag(vector) /** @@ -906,6 +958,7 @@ object Matrices { * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were horizontally concatenated */ + @Since("2.0.0") def horzcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) @@ -964,6 +1017,7 @@ object Matrices { * @param matrices array of matrices * @return a single `Matrix` composed of the matrices that were vertically concatenated */ + @Since("2.0.0") def vertcat(matrices: Array[Matrix]): Matrix = { if (matrices.isEmpty) { return new DenseMatrix(0, 0, Array[Double]()) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index c0d112d2c53d8..0659324aad1fa 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -25,21 +25,26 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} +import org.apache.spark.annotation.Since + /** * Represents a numeric vector, whose index type is Int and value type is Double. * * Note: Users should not implement this interface. */ +@Since("2.0.0") sealed trait Vector extends Serializable { /** * Size of the vector. */ + @Since("2.0.0") def size: Int /** * Converts the instance to a double array. */ + @Since("2.0.0") def toArray: Array[Double] override def equals(other: Any): Boolean = { @@ -61,7 +66,7 @@ sealed trait Vector extends Serializable { /** * Returns a hash code value for the vector. The hash code is based on its size and its first 128 - * nonzero entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + * nonzero entries, using a hash algorithm similar to `java.util.Arrays.hashCode`. */ override def hashCode(): Int = { // This is a reference implementation. It calls return in foreachActive, which is slow. @@ -87,17 +92,19 @@ sealed trait Vector extends Serializable { /** * Converts the instance to a breeze vector. */ - private[spark] def toBreeze: BV[Double] + private[spark] def asBreeze: BV[Double] /** * Gets the value of the ith element. * @param i index */ - def apply(i: Int): Double = toBreeze(i) + @Since("2.0.0") + def apply(i: Int): Double = asBreeze(i) /** * Makes a deep copy of this vector. */ + @Since("2.0.0") def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } @@ -109,32 +116,38 @@ sealed trait Vector extends Serializable { * the vector with type `Int`, and the second parameter is the corresponding value * with type `Double`. */ + @Since("2.0.0") def foreachActive(f: (Int, Double) => Unit): Unit /** * Number of active entries. An "active entry" is an element which is explicitly stored, * regardless of its value. Note that inactive entries have value 0. */ + @Since("2.0.0") def numActives: Int /** * Number of nonzero elements. This scans all active values and count nonzeros. */ + @Since("2.0.0") def numNonzeros: Int /** * Converts this vector to a sparse vector with all explicit zeros removed. */ + @Since("2.0.0") def toSparse: SparseVector /** * Converts this vector to a dense vector. */ + @Since("2.0.0") def toDense: DenseVector = new DenseVector(this.toArray) /** * Returns a vector in either dense or sparse format, whichever uses less storage. */ + @Since("2.0.0") def compressed: Vector = { val nnz = numNonzeros // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. @@ -149,6 +162,7 @@ sealed trait Vector extends Serializable { * Find the index of a maximal element. Returns the first maximal element in case of a tie. * Returns -1 if vector has length 0. */ + @Since("2.0.0") def argmax: Int } @@ -157,12 +171,14 @@ sealed trait Vector extends Serializable { * We don't use the name `Vector` because Scala imports * [[scala.collection.immutable.Vector]] by default. */ +@Since("2.0.0") object Vectors { /** * Creates a dense vector from its values. */ @varargs + @Since("2.0.0") def dense(firstValue: Double, otherValues: Double*): Vector = new DenseVector((firstValue +: otherValues).toArray) @@ -170,6 +186,7 @@ object Vectors { /** * Creates a dense vector from a double array. */ + @Since("2.0.0") def dense(values: Array[Double]): Vector = new DenseVector(values) /** @@ -179,6 +196,7 @@ object Vectors { * @param indices index array, must be strictly increasing. * @param values value array, must have the same length as indices. */ + @Since("2.0.0") def sparse(size: Int, indices: Array[Int], values: Array[Double]): Vector = new SparseVector(size, indices, values) @@ -188,6 +206,7 @@ object Vectors { * @param size vector size. * @param elements vector elements in (index, value) pairs. */ + @Since("2.0.0") def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { require(size > 0, "The size of the requested sparse vector must be greater than 0.") @@ -209,6 +228,7 @@ object Vectors { * @param size vector size. * @param elements vector elements in (index, value) pairs. */ + @Since("2.0.0") def sparse(size: Int, elements: JavaIterable[(JavaInteger, JavaDouble)]): Vector = { sparse(size, elements.asScala.map { case (i, x) => (i.intValue(), x.doubleValue()) @@ -221,6 +241,7 @@ object Vectors { * @param size vector size * @return a zero vector */ + @Since("2.0.0") def zeros(size: Int): Vector = { new DenseVector(new Array[Double](size)) } @@ -253,6 +274,7 @@ object Vectors { * @param p norm. * @return norm in L^p^ space. */ + @Since("2.0.0") def norm(vector: Vector, p: Double): Double = { require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + s"You specified p=$p.") @@ -305,6 +327,7 @@ object Vectors { * @param v2 second Vector. * @return squared distance between two Vectors. */ + @Since("2.0.0") def sqdist(v1: Vector, v2: Vector): Double = { require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + s"=${v2.size}.") @@ -421,7 +444,8 @@ object Vectors { /** * A dense vector represented by a value array. */ -class DenseVector (val values: Array[Double]) extends Vector { +@Since("2.0.0") +class DenseVector @Since("2.0.0") ( @Since("2.0.0") val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -429,7 +453,7 @@ class DenseVector (val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values - private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) + private[spark] override def asBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int): Double = values(i) @@ -515,23 +539,26 @@ class DenseVector (val values: Array[Double]) extends Vector { } } +@Since("2.0.0") object DenseVector { /** Extracts the value array from a dense vector. */ + @Since("2.0.0") def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) } /** - * A sparse vector represented by an index array and an value array. + * A sparse vector represented by an index array and a value array. * * @param size size of the vector. * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ -class SparseVector ( +@Since("2.0.0") +class SparseVector @Since("2.0.0") ( override val size: Int, - val indices: Array[Int], - val values: Array[Double]) extends Vector { + @Since("2.0.0") val indices: Array[Int], + @Since("2.0.0") val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + @@ -557,7 +584,7 @@ class SparseVector ( new SparseVector(size, indices.clone(), values.clone()) } - private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + private[spark] override def asBreeze: BV[Double] = new BSV[Double](indices, values, size) override def foreachActive(f: (Int, Double) => Unit): Unit = { var i = 0 @@ -693,7 +720,9 @@ class SparseVector ( } } +@Since("2.0.0") object SparseVector { + @Since("2.0.0") def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] = Some((sv.size, sv.indices, sv.values)) } diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala index c62a1eab2016b..0be28677eff31 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.stat.distribution import breeze.linalg.{diag, eigSym, max, DenseMatrix => BDM, DenseVector => BDV, Vector => BV} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.impl.Utils import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors} @@ -32,9 +33,11 @@ import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors} * @param mean The mean vector of the distribution * @param cov The covariance matrix of the distribution */ -class MultivariateGaussian( - val mean: Vector, - val cov: Matrix) extends Serializable { +@Since("2.0.0") +@DeveloperApi +class MultivariateGaussian @Since("2.0.0") ( + @Since("2.0.0") val mean: Vector, + @Since("2.0.0") val cov: Matrix) extends Serializable { require(cov.numCols == cov.numRows, "Covariance matrix must be square") require(mean.size == cov.numCols, "Mean vector length must match covariance matrix size") @@ -44,7 +47,7 @@ class MultivariateGaussian( this(Vectors.fromBreeze(mean), Matrices.fromBreeze(cov)) } - private val breezeMu = mean.toBreeze.toDenseVector + private val breezeMu = mean.asBreeze.toDenseVector /** * Compute distribution dependent constants: @@ -56,15 +59,17 @@ class MultivariateGaussian( /** * Returns density of this multivariate Gaussian at given point, x */ + @Since("2.0.0") def pdf(x: Vector): Double = { - pdf(x.toBreeze) + pdf(x.asBreeze) } /** * Returns the log-density of this multivariate Gaussian at given point, x */ + @Since("2.0.0") def logpdf(x: Vector): Double = { - logpdf(x.toBreeze) + logpdf(x.asBreeze) } /** Returns density of this multivariate Gaussian at given point, x */ @@ -108,7 +113,7 @@ class MultivariateGaussian( * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (BDM[Double], Double) = { - val eigSym.EigSym(d, u) = eigSym(cov.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t + val eigSym.EigSym(d, u) = eigSym(cov.asBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala index 8a9f49792c1cd..6e72a5fff0a91 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala @@ -392,6 +392,23 @@ class BLASSuite extends SparkMLFunSuite { } } + val y17 = new DenseVector(Array(0.0, 0.0)) + val y18 = y17.copy + + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) + .transpose + val sA4 = + new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0)) + val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) + + val expected4 = new DenseVector(Array(5.0, 4.0)) + + gemv(1.0, sA3, sx3, 0.0, y17) + gemv(1.0, sA4, sx3, 0.0, y18) + + assert(y17 ~== expected4 absTol 1e-15) + assert(y18 ~== expected4 absTol 1e-15) + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BreezeMatrixConversionSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BreezeMatrixConversionSuite.scala index 70a21e41bfbd1..f07ed20cf0e77 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BreezeMatrixConversionSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.SparkMLFunSuite class BreezeMatrixConversionSuite extends SparkMLFunSuite { test("dense matrix to breeze") { val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) - val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] + val breeze = mat.asBreeze.asInstanceOf[BDM[Double]] assert(breeze.rows === mat.numRows) assert(breeze.cols === mat.numCols) assert(breeze.data.eq(mat.asInstanceOf[DenseMatrix].values), "should not copy data") @@ -48,7 +48,7 @@ class BreezeMatrixConversionSuite extends SparkMLFunSuite { val colPtrs = Array(0, 2, 4) val rowIndices = Array(1, 2, 1, 2) val mat = Matrices.sparse(3, 2, colPtrs, rowIndices, values) - val breeze = mat.toBreeze.asInstanceOf[BSM[Double]] + val breeze = mat.asBreeze.asInstanceOf[BSM[Double]] assert(breeze.rows === mat.numRows) assert(breeze.cols === mat.numCols) assert(breeze.data.eq(mat.asInstanceOf[SparseMatrix].values), "should not copy data") diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BreezeVectorConversionSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BreezeVectorConversionSuite.scala index 00c9ee79eb175..4c9740b6bca76 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BreezeVectorConversionSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BreezeVectorConversionSuite.scala @@ -33,12 +33,12 @@ class BreezeVectorConversionSuite extends SparkMLFunSuite { test("dense to breeze") { val vec = Vectors.dense(arr) - assert(vec.toBreeze === new BDV[Double](arr)) + assert(vec.asBreeze === new BDV[Double](arr)) } test("sparse to breeze") { val vec = Vectors.sparse(n, indices, values) - assert(vec.toBreeze === new BSV[Double](indices, values, n)) + assert(vec.asBreeze === new BSV[Double](indices, values, n)) } test("dense breeze to vector") { diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala index 5c69c5ed7bdd2..2796fcf2cbc22 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala @@ -61,7 +61,7 @@ class MatricesSuite extends SparkMLFunSuite { (1, 2, 2.0), (2, 2, 2.0), (1, 2, 2.0), (0, 0, 0.0)) val mat2 = SparseMatrix.fromCOO(m, n, entries) - assert(mat.toBreeze === mat2.toBreeze) + assert(mat.asBreeze === mat2.asBreeze) assert(mat2.values.length == 4) } @@ -174,8 +174,8 @@ class MatricesSuite extends SparkMLFunSuite { val spMat2 = deMat1.toSparse val deMat2 = spMat1.toDense - assert(spMat1.toBreeze === spMat2.toBreeze) - assert(deMat1.toBreeze === deMat2.toBreeze) + assert(spMat1.asBreeze === spMat2.asBreeze) + assert(deMat1.asBreeze === deMat2.asBreeze) } test("map, update") { @@ -209,8 +209,8 @@ class MatricesSuite extends SparkMLFunSuite { val sATexpected = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT.toBreeze === dATexpected.toBreeze) - assert(sAT.toBreeze === sATexpected.toBreeze) + assert(dAT.asBreeze === dATexpected.asBreeze) + assert(sAT.asBreeze === sATexpected.asBreeze) assert(dA(1, 0) === dAT(0, 1)) assert(dA(2, 1) === dAT(1, 2)) assert(sA(1, 0) === sAT(0, 1)) @@ -219,8 +219,8 @@ class MatricesSuite extends SparkMLFunSuite { assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") - assert(dAT.toSparse.toBreeze === sATexpected.toBreeze) - assert(sAT.toDense.toBreeze === dATexpected.toBreeze) + assert(dAT.toSparse.asBreeze === sATexpected.asBreeze) + assert(sAT.toDense.asBreeze === dATexpected.asBreeze) } test("foreachActive") { diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 887814b5e731a..614be460a414a 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -230,7 +230,7 @@ class VectorsSuite extends SparkMLFunSuite { val denseVector1 = Vectors.dense(sparseVector1.toArray) val denseVector2 = Vectors.dense(sparseVector2.toArray) - val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze) + val squaredDist = breezeSquaredDistance(sparseVector1.asBreeze, sparseVector2.asBreeze) // SparseVector vs. SparseVector assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) diff --git a/mllib/pom.xml b/mllib/pom.xml index 24d8274e2222f..f7bf11f48ca97 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml @@ -102,21 +102,17 @@ org.jpmml pmml-model - 1.2.7 + 1.2.15 - com.sun.xml.fastinfoset - FastInfoset - - - com.sun.istack - istack-commons-runtime + org.jpmml + pmml-agent org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index f632dd603c449..a865cbe19b184 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1 +1 @@ -org.apache.spark.ml.source.libsvm.DefaultSource +org.apache.spark.ml.source.libsvm.LibSVMFileFormat diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/README b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/README new file mode 100755 index 0000000000000..ec08a5080774d --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/README @@ -0,0 +1,12 @@ +Stopwords Corpus + +This corpus contains lists of stop words for several languages. These +are high-frequency grammatical words which are usually ignored in text +retrieval applications. + +They were obtained from: +http://anoncvs.postgresql.org/cvsweb.cgi/pgsql/src/backend/snowball/stopwords/ + +The English list has been augmented +https://github.com/nltk/nltk_data/issues/22 + diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/danish.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/danish.txt new file mode 100644 index 0000000000000..ea9e2c4abe5b9 --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/danish.txt @@ -0,0 +1,94 @@ +og +i +jeg +det +at +en +den +til +er +som +på +de +med +han +af +for +ikke +der +var +mig +sig +men +et +har +om +vi +min +havde +ham +hun +nu +over +da +fra +du +ud +sin +dem +os +op +man +hans +hvor +eller +hvad +skal +selv +her +alle +vil +blev +kunne +ind +når +være +dog +noget +ville +jo +deres +efter +ned +skulle +denne +end +dette +mit +også +under +have +dig +anden +hende +mine +alt +meget +sit +sine +vor +mod +disse +hvis +din +nogle +hos +blive +mange +ad +bliver +hendes +været +thi +jer +sådan \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/dutch.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/dutch.txt new file mode 100644 index 0000000000000..023cc2c939b2a --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/dutch.txt @@ -0,0 +1,101 @@ +de +en +van +ik +te +dat +die +in +een +hij +het +niet +zijn +is +was +op +aan +met +als +voor +had +er +maar +om +hem +dan +zou +of +wat +mijn +men +dit +zo +door +over +ze +zich +bij +ook +tot +je +mij +uit +der +daar +haar +naar +heb +hoe +heeft +hebben +deze +u +want +nog +zal +me +zij +nu +ge +geen +omdat +iets +worden +toch +al +waren +veel +meer +doen +toen +moet +ben +zonder +kan +hun +dus +alles +onder +ja +eens +hier +wie +werd +altijd +doch +wordt +wezen +kunnen +ons +zelf +tegen +na +reeds +wil +kon +niets +uw +iemand +geweest +andere \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/english.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/english.txt new file mode 100644 index 0000000000000..d075cc0babc3e --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/english.txt @@ -0,0 +1,153 @@ +i +me +my +myself +we +our +ours +ourselves +you +your +yours +yourself +yourselves +he +him +his +himself +she +her +hers +herself +it +its +itself +they +them +their +theirs +themselves +what +which +who +whom +this +that +these +those +am +is +are +was +were +be +been +being +have +has +had +having +do +does +did +doing +a +an +the +and +but +if +or +because +as +until +while +of +at +by +for +with +about +against +between +into +through +during +before +after +above +below +to +from +up +down +in +out +on +off +over +under +again +further +then +once +here +there +when +where +why +how +all +any +both +each +few +more +most +other +some +such +no +nor +not +only +own +same +so +than +too +very +s +t +can +will +just +don +should +now +d +ll +m +o +re +ve +y +ain +aren +couldn +didn +doesn +hadn +hasn +haven +isn +ma +mightn +mustn +needn +shan +shouldn +wasn +weren +won +wouldn diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/finnish.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/finnish.txt new file mode 100644 index 0000000000000..5b0eb10777d0e --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/finnish.txt @@ -0,0 +1,235 @@ +olla +olen +olet +on +olemme +olette +ovat +ole +oli +olisi +olisit +olisin +olisimme +olisitte +olisivat +olit +olin +olimme +olitte +olivat +ollut +olleet +en +et +ei +emme +ette +eivät +minä +minun +minut +minua +minussa +minusta +minuun +minulla +minulta +minulle +sinä +sinun +sinut +sinua +sinussa +sinusta +sinuun +sinulla +sinulta +sinulle +hän +hänen +hänet +häntä +hänessä +hänestä +häneen +hänellä +häneltä +hänelle +me +meidän +meidät +meitä +meissä +meistä +meihin +meillä +meiltä +meille +te +teidän +teidät +teitä +teissä +teistä +teihin +teillä +teiltä +teille +he +heidän +heidät +heitä +heissä +heistä +heihin +heillä +heiltä +heille +tämä +tämän +tätä +tässä +tästä +tähän +tallä +tältä +tälle +tänä +täksi +tuo +tuon +tuotä +tuossa +tuosta +tuohon +tuolla +tuolta +tuolle +tuona +tuoksi +se +sen +sitä +siinä +siitä +siihen +sillä +siltä +sille +sinä +siksi +nämä +näiden +näitä +näissä +näistä +näihin +näillä +näiltä +näille +näinä +näiksi +nuo +noiden +noita +noissa +noista +noihin +noilla +noilta +noille +noina +noiksi +ne +niiden +niitä +niissä +niistä +niihin +niillä +niiltä +niille +niinä +niiksi +kuka +kenen +kenet +ketä +kenessä +kenestä +keneen +kenellä +keneltä +kenelle +kenenä +keneksi +ketkä +keiden +ketkä +keitä +keissä +keistä +keihin +keillä +keiltä +keille +keinä +keiksi +mikä +minkä +minkä +mitä +missä +mistä +mihin +millä +miltä +mille +minä +miksi +mitkä +joka +jonka +jota +jossa +josta +johon +jolla +jolta +jolle +jona +joksi +jotka +joiden +joita +joissa +joista +joihin +joilla +joilta +joille +joina +joiksi +että +ja +jos +koska +kuin +mutta +niin +sekä +sillä +tai +vaan +vai +vaikka +kanssa +mukaan +noin +poikki +yli +kun +niin +nyt +itse \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt new file mode 100644 index 0000000000000..94b8f8f39a3e1 --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/french.txt @@ -0,0 +1,155 @@ +au +aux +avec +ce +ces +dans +de +des +du +elle +en +et +eux +il +je +la +le +leur +lui +ma +mais +me +même +mes +moi +mon +ne +nos +notre +nous +on +ou +par +pas +pour +qu +que +qui +sa +se +ses +son +sur +ta +te +tes +toi +ton +tu +un +une +vos +votre +vous +c +d +j +l +à +m +n +s +t +y +été +étée +étées +étés +étant +étante +étants +étantes +suis +es +est +sommes +êtes +sont +serai +seras +sera +serons +serez +seront +serais +serait +serions +seriez +seraient +étais +était +étions +étiez +étaient +fus +fut +fûmes +fûtes +furent +sois +soit +soyons +soyez +soient +fusse +fusses +fût +fussions +fussiez +fussent +ayant +ayante +ayantes +ayants +eu +eue +eues +eus +ai +as +avons +avez +ont +aurai +auras +aura +aurons +aurez +auront +aurais +aurait +aurions +auriez +auraient +avais +avait +avions +aviez +avaient +eut +eûmes +eûtes +eurent +aie +aies +ait +ayons +ayez +aient +eusse +eusses +eût +eussions +eussiez +eussent \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/german.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/german.txt new file mode 100644 index 0000000000000..7e65190f8ba28 --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/german.txt @@ -0,0 +1,231 @@ +aber +alle +allem +allen +aller +alles +als +also +am +an +ander +andere +anderem +anderen +anderer +anderes +anderm +andern +anderr +anders +auch +auf +aus +bei +bin +bis +bist +da +damit +dann +der +den +des +dem +die +das +daß +derselbe +derselben +denselben +desselben +demselben +dieselbe +dieselben +dasselbe +dazu +dein +deine +deinem +deinen +deiner +deines +denn +derer +dessen +dich +dir +du +dies +diese +diesem +diesen +dieser +dieses +doch +dort +durch +ein +eine +einem +einen +einer +eines +einig +einige +einigem +einigen +einiger +einiges +einmal +er +ihn +ihm +es +etwas +euer +eure +eurem +euren +eurer +eures +für +gegen +gewesen +hab +habe +haben +hat +hatte +hatten +hier +hin +hinter +ich +mich +mir +ihr +ihre +ihrem +ihren +ihrer +ihres +euch +im +in +indem +ins +ist +jede +jedem +jeden +jeder +jedes +jene +jenem +jenen +jener +jenes +jetzt +kann +kein +keine +keinem +keinen +keiner +keines +können +könnte +machen +man +manche +manchem +manchen +mancher +manches +mein +meine +meinem +meinen +meiner +meines +mit +muss +musste +nach +nicht +nichts +noch +nun +nur +ob +oder +ohne +sehr +sein +seine +seinem +seinen +seiner +seines +selbst +sich +sie +ihnen +sind +so +solche +solchem +solchen +solcher +solches +soll +sollte +sondern +sonst +über +um +und +uns +unse +unsem +unsen +unser +unses +unter +viel +vom +von +vor +während +war +waren +warst +was +weg +weil +weiter +welche +welchem +welchen +welcher +welches +wenn +werde +werden +wie +wieder +will +wir +wird +wirst +wo +wollen +wollte +würde +würden +zu +zum +zur +zwar +zwischen \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/hungarian.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/hungarian.txt new file mode 100644 index 0000000000000..8d4543a0965d5 --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/hungarian.txt @@ -0,0 +1,199 @@ +a +ahogy +ahol +aki +akik +akkor +alatt +által +általában +amely +amelyek +amelyekben +amelyeket +amelyet +amelynek +ami +amit +amolyan +amíg +amikor +át +abban +ahhoz +annak +arra +arról +az +azok +azon +azt +azzal +azért +aztán +azután +azonban +bár +be +belül +benne +cikk +cikkek +cikkeket +csak +de +e +eddig +egész +egy +egyes +egyetlen +egyéb +egyik +egyre +ekkor +el +elég +ellen +elõ +elõször +elõtt +elsõ +én +éppen +ebben +ehhez +emilyen +ennek +erre +ez +ezt +ezek +ezen +ezzel +ezért +és +fel +felé +hanem +hiszen +hogy +hogyan +igen +így +illetve +ill. +ill +ilyen +ilyenkor +ison +ismét +itt +jó +jól +jobban +kell +kellett +keresztül +keressünk +ki +kívül +között +közül +legalább +lehet +lehetett +legyen +lenne +lenni +lesz +lett +maga +magát +majd +majd +már +más +másik +meg +még +mellett +mert +mely +melyek +mi +mit +míg +miért +milyen +mikor +minden +mindent +mindenki +mindig +mint +mintha +mivel +most +nagy +nagyobb +nagyon +ne +néha +nekem +neki +nem +néhány +nélkül +nincs +olyan +ott +össze +õ +õk +õket +pedig +persze +rá +s +saját +sem +semmi +sok +sokat +sokkal +számára +szemben +szerint +szinte +talán +tehát +teljes +tovább +továbbá +több +úgy +ugyanis +új +újabb +újra +után +utána +utolsó +vagy +vagyis +valaki +valami +valamint +való +vagyok +van +vannak +volt +voltam +voltak +voltunk +vissza +vele +viszont +volna \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/italian.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/italian.txt new file mode 100644 index 0000000000000..783b2e0cbfcd8 --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/italian.txt @@ -0,0 +1,279 @@ +ad +al +allo +ai +agli +all +agl +alla +alle +con +col +coi +da +dal +dallo +dai +dagli +dall +dagl +dalla +dalle +di +del +dello +dei +degli +dell +degl +della +delle +in +nel +nello +nei +negli +nell +negl +nella +nelle +su +sul +sullo +sui +sugli +sull +sugl +sulla +sulle +per +tra +contro +io +tu +lui +lei +noi +voi +loro +mio +mia +miei +mie +tuo +tua +tuoi +tue +suo +sua +suoi +sue +nostro +nostra +nostri +nostre +vostro +vostra +vostri +vostre +mi +ti +ci +vi +lo +la +li +le +gli +ne +il +un +uno +una +ma +ed +se +perché +anche +come +dov +dove +che +chi +cui +non +più +quale +quanto +quanti +quanta +quante +quello +quelli +quella +quelle +questo +questi +questa +queste +si +tutto +tutti +a +c +e +i +l +o +ho +hai +ha +abbiamo +avete +hanno +abbia +abbiate +abbiano +avrò +avrai +avrà +avremo +avrete +avranno +avrei +avresti +avrebbe +avremmo +avreste +avrebbero +avevo +avevi +aveva +avevamo +avevate +avevano +ebbi +avesti +ebbe +avemmo +aveste +ebbero +avessi +avesse +avessimo +avessero +avendo +avuto +avuta +avuti +avute +sono +sei +è +siamo +siete +sia +siate +siano +sarò +sarai +sarà +saremo +sarete +saranno +sarei +saresti +sarebbe +saremmo +sareste +sarebbero +ero +eri +era +eravamo +eravate +erano +fui +fosti +fu +fummo +foste +furono +fossi +fosse +fossimo +fossero +essendo +faccio +fai +facciamo +fanno +faccia +facciate +facciano +farò +farai +farà +faremo +farete +faranno +farei +faresti +farebbe +faremmo +fareste +farebbero +facevo +facevi +faceva +facevamo +facevate +facevano +feci +facesti +fece +facemmo +faceste +fecero +facessi +facesse +facessimo +facessero +facendo +sto +stai +sta +stiamo +stanno +stia +stiate +stiano +starò +starai +starà +staremo +starete +staranno +starei +staresti +starebbe +staremmo +stareste +starebbero +stavo +stavi +stava +stavamo +stavate +stavano +stetti +stesti +stette +stemmo +steste +stettero +stessi +stesse +stessimo +stessero +stando \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/norwegian.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/norwegian.txt new file mode 100644 index 0000000000000..cb91702c5e9a9 --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/norwegian.txt @@ -0,0 +1,176 @@ +og +i +jeg +det +at +en +et +den +til +er +som +på +de +med +han +av +ikke +ikkje +der +så +var +meg +seg +men +ett +har +om +vi +min +mitt +ha +hadde +hun +nå +over +da +ved +fra +du +ut +sin +dem +oss +opp +man +kan +hans +hvor +eller +hva +skal +selv +sjøl +her +alle +vil +bli +ble +blei +blitt +kunne +inn +når +være +kom +noen +noe +ville +dere +som +deres +kun +ja +etter +ned +skulle +denne +for +deg +si +sine +sitt +mot +å +meget +hvorfor +dette +disse +uten +hvordan +ingen +din +ditt +blir +samme +hvilken +hvilke +sånn +inni +mellom +vår +hver +hvem +vors +hvis +både +bare +enn +fordi +før +mange +også +slik +vært +være +båe +begge +siden +dykk +dykkar +dei +deira +deires +deim +di +då +eg +ein +eit +eitt +elles +honom +hjå +ho +hoe +henne +hennar +hennes +hoss +hossen +ikkje +ingi +inkje +korleis +korso +kva +kvar +kvarhelst +kven +kvi +kvifor +me +medan +mi +mine +mykje +no +nokon +noka +nokor +noko +nokre +si +sia +sidan +so +somt +somme +um +upp +vere +vore +verte +vort +varte +vart \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/portuguese.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/portuguese.txt new file mode 100644 index 0000000000000..98b4fdcdf7a20 --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/portuguese.txt @@ -0,0 +1,203 @@ +de +a +o +que +e +do +da +em +um +para +com +não +uma +os +no +se +na +por +mais +as +dos +como +mas +ao +ele +das +à +seu +sua +ou +quando +muito +nos +já +eu +também +só +pelo +pela +até +isso +ela +entre +depois +sem +mesmo +aos +seus +quem +nas +me +esse +eles +você +essa +num +nem +suas +meu +às +minha +numa +pelos +elas +qual +nós +lhe +deles +essas +esses +pelas +este +dele +tu +te +vocês +vos +lhes +meus +minhas +teu +tua +teus +tuas +nosso +nossa +nossos +nossas +dela +delas +esta +estes +estas +aquele +aquela +aqueles +aquelas +isto +aquilo +estou +está +estamos +estão +estive +esteve +estivemos +estiveram +estava +estávamos +estavam +estivera +estivéramos +esteja +estejamos +estejam +estivesse +estivéssemos +estivessem +estiver +estivermos +estiverem +hei +há +havemos +hão +houve +houvemos +houveram +houvera +houvéramos +haja +hajamos +hajam +houvesse +houvéssemos +houvessem +houver +houvermos +houverem +houverei +houverá +houveremos +houverão +houveria +houveríamos +houveriam +sou +somos +são +era +éramos +eram +fui +foi +fomos +foram +fora +fôramos +seja +sejamos +sejam +fosse +fôssemos +fossem +for +formos +forem +serei +será +seremos +serão +seria +seríamos +seriam +tenho +tem +temos +tém +tinha +tínhamos +tinham +tive +teve +tivemos +tiveram +tivera +tivéramos +tenha +tenhamos +tenham +tivesse +tivéssemos +tivessem +tiver +tivermos +tiverem +terei +terá +teremos +terão +teria +teríamos +teriam \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/russian.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/russian.txt new file mode 100644 index 0000000000000..8a800b74497dd --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/russian.txt @@ -0,0 +1,151 @@ +и +в +во +не +что +он +на +я +с +со +как +а +то +все +она +так +его +но +да +ты +к +у +же +вы +за +бы +по +только +ее +мне +было +вот +от +меня +еще +нет +о +из +ему +теперь +когда +даже +ну +вдруг +ли +если +уже +или +ни +быть +был +него +до +вас +нибудь +опять +уж +вам +ведь +там +потом +себя +ничего +ей +может +они +тут +где +есть +надо +ней +для +мы +тебя +их +чем +была +сам +чтоб +без +будто +чего +раз +тоже +себе +под +будет +ж +тогда +кто +этот +того +потому +этого +какой +совсем +ним +здесь +этом +один +почти +мой +тем +чтобы +нее +сейчас +были +куда +зачем +всех +никогда +можно +при +наконец +два +об +другой +хоть +после +над +больше +тот +через +эти +нас +про +всего +них +какая +много +разве +три +эту +моя +впрочем +хорошо +свою +этой +перед +иногда +лучше +чуть +том +нельзя +такой +им +более +всегда +конечно +всю +между \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/spanish.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/spanish.txt new file mode 100644 index 0000000000000..94f493a8d1e03 --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/spanish.txt @@ -0,0 +1,313 @@ +de +la +que +el +en +y +a +los +del +se +las +por +un +para +con +no +una +su +al +lo +como +más +pero +sus +le +ya +o +este +sí +porque +esta +entre +cuando +muy +sin +sobre +también +me +hasta +hay +donde +quien +desde +todo +nos +durante +todos +uno +les +ni +contra +otros +ese +eso +ante +ellos +e +esto +mí +antes +algunos +qué +unos +yo +otro +otras +otra +él +tanto +esa +estos +mucho +quienes +nada +muchos +cual +poco +ella +estar +estas +algunas +algo +nosotros +mi +mis +tú +te +ti +tu +tus +ellas +nosotras +vosostros +vosostras +os +mío +mía +míos +mías +tuyo +tuya +tuyos +tuyas +suyo +suya +suyos +suyas +nuestro +nuestra +nuestros +nuestras +vuestro +vuestra +vuestros +vuestras +esos +esas +estoy +estás +está +estamos +estáis +están +esté +estés +estemos +estéis +estén +estaré +estarás +estará +estaremos +estaréis +estarán +estaría +estarías +estaríamos +estaríais +estarían +estaba +estabas +estábamos +estabais +estaban +estuve +estuviste +estuvo +estuvimos +estuvisteis +estuvieron +estuviera +estuvieras +estuviéramos +estuvierais +estuvieran +estuviese +estuvieses +estuviésemos +estuvieseis +estuviesen +estando +estado +estada +estados +estadas +estad +he +has +ha +hemos +habéis +han +haya +hayas +hayamos +hayáis +hayan +habré +habrás +habrá +habremos +habréis +habrán +habría +habrías +habríamos +habríais +habrían +había +habías +habíamos +habíais +habían +hube +hubiste +hubo +hubimos +hubisteis +hubieron +hubiera +hubieras +hubiéramos +hubierais +hubieran +hubiese +hubieses +hubiésemos +hubieseis +hubiesen +habiendo +habido +habida +habidos +habidas +soy +eres +es +somos +sois +son +sea +seas +seamos +seáis +sean +seré +serás +será +seremos +seréis +serán +sería +serías +seríamos +seríais +serían +era +eras +éramos +erais +eran +fui +fuiste +fue +fuimos +fuisteis +fueron +fuera +fueras +fuéramos +fuerais +fueran +fuese +fueses +fuésemos +fueseis +fuesen +sintiendo +sentido +sentida +sentidos +sentidas +siente +sentid +tengo +tienes +tiene +tenemos +tenéis +tienen +tenga +tengas +tengamos +tengáis +tengan +tendré +tendrás +tendrá +tendremos +tendréis +tendrán +tendría +tendrías +tendríamos +tendríais +tendrían +tenía +tenías +teníamos +teníais +tenían +tuve +tuviste +tuvo +tuvimos +tuvisteis +tuvieron +tuviera +tuvieras +tuviéramos +tuvierais +tuvieran +tuviese +tuvieses +tuviésemos +tuvieseis +tuviesen +teniendo +tenido +tenida +tenidos +tenidas +tened \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/swedish.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/swedish.txt new file mode 100644 index 0000000000000..9fae31c1858a9 --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/swedish.txt @@ -0,0 +1,114 @@ +och +det +att +i +en +jag +hon +som +han +på +den +med +var +sig +för +så +till +är +men +ett +om +hade +de +av +icke +mig +du +henne +då +sin +nu +har +inte +hans +honom +skulle +hennes +där +min +man +ej +vid +kunde +något +från +ut +när +efter +upp +vi +dem +vara +vad +över +än +dig +kan +sina +här +ha +mot +alla +under +någon +eller +allt +mycket +sedan +ju +denna +själv +detta +åt +utan +varit +hur +ingen +mitt +ni +bli +blev +oss +din +dessa +några +deras +blir +mina +samma +vilken +er +sådan +vår +blivit +dess +inom +mellan +sådant +varför +varje +vilka +ditt +vem +vilket +sitta +sådana +vart +dina +vars +vårt +våra +ert +era +vilkas \ No newline at end of file diff --git a/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/turkish.txt b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/turkish.txt new file mode 100644 index 0000000000000..4e9708d9d2c5e --- /dev/null +++ b/mllib/src/main/resources/org/apache/spark/ml/feature/stopwords/turkish.txt @@ -0,0 +1,53 @@ +acaba +ama +aslında +az +bazı +belki +biri +birkaç +birşey +biz +bu +çok +çünkü +da +daha +de +defa +diye +eğer +en +gibi +hem +hep +hepsi +her +hiç +için +ile +ise +kez +ki +kim +mı +mu +mü +nasıl +ne +neden +nerde +nerede +nereye +niçin +niye +o +sanki +şey +siz +şu +tüm +ve +veya +ya +yani \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index b02aea92b72f9..195a93e086725 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -27,7 +27,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util._ @@ -44,7 +44,10 @@ abstract class PipelineStage extends Params with Logging { /** * :: DeveloperApi :: * - * Derives the output schema from the input schema. + * Check transform validity and derive the output schema from the input schema. + * + * Typical implementation should first conduct verification on schema change and parameter + * validity, including complex parameter interaction checks. */ @DeveloperApi def transformSchema(schema: StructType): StructType @@ -75,19 +78,17 @@ abstract class PipelineStage extends Params with Logging { } /** - * :: Experimental :: * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will * be called on the input dataset to fit a model. Then the model, which is a transformer, will be * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]], * its [[Transformer#transform]] method will be called to produce the dataset for the next stage. - * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and + * The fitted model from a [[Pipeline]] is a [[PipelineModel]], which consists of fitted models and * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. */ @Since("1.2.0") -@Experimental class Pipeline @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends Estimator[PipelineModel] with MLWritable { @@ -211,7 +212,7 @@ object Pipeline extends MLReadable[Pipeline] { } } - /** Methods for [[MLReader]] and [[MLWriter]] shared between [[Pipeline]] and [[PipelineModel]] */ + /** Methods for `MLReader` and `MLWriter` shared between [[Pipeline]] and [[PipelineModel]] */ private[ml] object SharedReadWrite { import org.json4s.JsonDSL._ @@ -279,11 +280,9 @@ object Pipeline extends MLReadable[Pipeline] { } /** - * :: Experimental :: * Represents a fitted pipeline. */ @Since("1.2.0") -@Experimental class PipelineModel private[ml] ( @Since("1.4.0") override val uid: String, @Since("1.4.0") val stages: Array[Transformer]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 81140d1f7b21f..e29d7f48a1d6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -165,7 +165,7 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, } /** - * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing + * Transforms dataset by reading from [[featuresCol]], calling `predict`, and storing * the predictions as a new column [[predictionCol]]. * * @param dataset input dataset diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala index 7429f9d652ac5..6bbe7e1cb2134 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala @@ -26,38 +26,39 @@ import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} private[ann] object BreezeUtil { // TODO: switch to MLlib BLAS interface - private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N" + private def transposeString(A: BDM[Double]): String = if (A.isTranspose) "T" else "N" /** * DGEMM: C := alpha * A * B + beta * C * @param alpha alpha - * @param a A - * @param b B + * @param A A + * @param B B * @param beta beta - * @param c C + * @param C C */ - def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = { + def dgemm(alpha: Double, A: BDM[Double], B: BDM[Double], beta: Double, C: BDM[Double]): Unit = { // TODO: add code if matrices isTranspose!!! - require(a.cols == b.rows, "A & B Dimension mismatch!") - require(a.rows == c.rows, "A & C Dimension mismatch!") - require(b.cols == c.cols, "A & C Dimension mismatch!") - NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols, - alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride, - beta, c.data, c.offset, c.rows) + require(A.cols == B.rows, "A & B Dimension mismatch!") + require(A.rows == C.rows, "A & C Dimension mismatch!") + require(B.cols == C.cols, "A & C Dimension mismatch!") + NativeBLAS.dgemm(transposeString(A), transposeString(B), C.rows, C.cols, A.cols, + alpha, A.data, A.offset, A.majorStride, B.data, B.offset, B.majorStride, + beta, C.data, C.offset, C.rows) } /** * DGEMV: y := alpha * A * x + beta * y * @param alpha alpha - * @param a A + * @param A A * @param x x * @param beta beta * @param y y */ - def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = { - require(a.cols == x.length, "A & b Dimension mismatch!") - NativeBLAS.dgemv(transposeString(a), a.rows, a.cols, - alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride, + def dgemv(alpha: Double, A: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = { + require(A.cols == x.length, "A & x Dimension mismatch!") + require(A.rows == y.length, "A & y Dimension mismatch!") + NativeBLAS.dgemv(transposeString(A), A.rows, A.cols, + alpha, A.data, A.offset, A.majorStride, x.data, x.offset, x.stride, beta, y.data, y.offset, y.stride) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 3588ac1e95bee..88909a9fb953f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -21,9 +21,12 @@ import java.util.Random import breeze.linalg.{*, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.optimization._ import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom /** @@ -64,8 +67,9 @@ private[ann] trait Layer extends Serializable { * @return the layer model */ def createModel(initialWeights: BDV[Double]): LayerModel + /** - * Returns the instance of the layer with random generated weights + * Returns the instance of the layer with random generated weights. * * @param weights vector for weights initialization, must be equal to weightSize * @param random random number generator @@ -83,11 +87,11 @@ private[ann] trait LayerModel extends Serializable { val weights: BDV[Double] /** - * Evaluates the data (process the data through the layer) + * Evaluates the data (process the data through the layer). * Output is allocated based on the size provided by the - * LayerModel implementation and the stack (batch) size + * LayerModel implementation and the stack (batch) size. * Developer is responsible for checking the size of output - * when writing to it + * when writing to it. * * @param data data * @param output output (modified in place) @@ -95,11 +99,11 @@ private[ann] trait LayerModel extends Serializable { def eval(data: BDM[Double], output: BDM[Double]): Unit /** - * Computes the delta for back propagation + * Computes the delta for back propagation. * Delta is allocated based on the size provided by the - * LayerModel implementation and the stack (batch) size + * LayerModel implementation and the stack (batch) size. * Developer is responsible for checking the size of - * prevDelta when writing to it + * prevDelta when writing to it. * * @param delta delta of this layer * @param output output of this layer @@ -108,10 +112,10 @@ private[ann] trait LayerModel extends Serializable { def computePrevDelta(delta: BDM[Double], output: BDM[Double], prevDelta: BDM[Double]): Unit /** - * Computes the gradient - * cumGrad is a wrapper on the part of the weight vector - * size of cumGrad is based on weightSize provided by - * implementation of LayerModel + * Computes the gradient. + * cumGrad is a wrapper on the part of the weight vector. + * Size of cumGrad is based on weightSize provided by + * implementation of LayerModel. * * @param delta delta for this layer * @param input input data @@ -197,11 +201,11 @@ private[ann] object AffineLayerModel { } /** - * Initialize weights randomly in the interval - * Uses [Bottou-88] heuristic [-a/sqrt(in); a/sqrt(in)] - * where a is chosen in a such way that the weight variance corresponds + * Initialize weights randomly in the interval. + * Uses [Bottou-88] heuristic [-a/sqrt(in); a/sqrt(in)], + * where `a` is chosen in such a way that the weight variance corresponds * to the points to the maximal curvature of the activation function - * (which is approximately 2.38 for a standard sigmoid) + * (which is approximately 2.38 for a standard sigmoid). * * @param numIn number of inputs * @param numOut number of outputs @@ -306,7 +310,7 @@ private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) /** * Functional layer model. Holds no weights. * - * @param layer functiona layer + * @param layer functional layer */ private[ann] class FunctionalLayerModel private[ann] (val layer: FunctionalLayer) extends LayerModel { @@ -352,6 +356,7 @@ private[ann] trait TopologyModel extends Serializable { * Array of layer models */ val layerModels: Array[LayerModel] + /** * Forward propagation * @@ -410,7 +415,7 @@ private[ml] object FeedForwardTopology { * Creates a multi-layer perceptron * * @param layerSizes sizes of layers including input and output size - * @param softmaxOnTop wether to use SoftMax or Sigmoid function for an output layer. + * @param softmaxOnTop whether to use SoftMax or Sigmoid function for an output layer. * Softmax is default * @return multilayer perceptron topology */ @@ -578,10 +583,10 @@ private[ann] object FeedForwardModel { */ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient { override def compute( - data: Vector, + data: OldVector, label: Double, - weights: Vector, - cumGradient: Vector): Double = { + weights: OldVector, + cumGradient: OldVector): Double = { val (input, target, realBatchSize) = dataStacker.unstack(data) val model = topology.model(weights) model.computeGradient(input, target, cumGradient, realBatchSize) @@ -612,8 +617,8 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) data.map { v => (0.0, Vectors.fromBreeze(BDV.vertcat( - v._1.toBreeze.toDenseVector, - v._2.toBreeze.toDenseVector)) + v._1.asBreeze.toDenseVector, + v._2.asBreeze.toDenseVector)) ) } } else { data.mapPartitions { it => @@ -655,15 +660,15 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) private[ann] class ANNUpdater extends Updater { override def compute( - weightsOld: Vector, - gradient: Vector, + weightsOld: OldVector, + gradient: OldVector, stepSize: Double, iter: Int, - regParam: Double): (Vector, Double) = { + regParam: Double): (OldVector, Double) = { val thisIterStepSize = stepSize - val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector - Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) - (Vectors.fromBreeze(brzWeights), 0) + val brzWeights: BV[Double] = weightsOld.asBreeze.toDenseVector + Baxpy(-thisIterStepSize, gradient.asBreeze, brzWeights) + (OldVectors.fromBreeze(brzWeights), 0) } } @@ -806,7 +811,13 @@ private[ml] class FeedForwardTrainer( getWeights } // TODO: deprecate standard optimizer because it needs Vector - val newWeights = optimizer.optimize(dataStacker.stack(data), w) + val trainData = dataStacker.stack(data).map { v => + (v._1, OldVectors.fromML(v._2)) + } + val handlePersistence = trainData.getStorageLevel == StorageLevel.NONE + if (handlePersistence) trainData.persist(StorageLevel.MEMORY_AND_DISK) + val newWeights = optimizer.optimize(trainData, w) + if (handlePersistence) trainData.unpersist() topology.model(newWeights) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index 2c29eeb01a921..12b9732a4c3d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala index 5c7089b491677..078fecf088282 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala @@ -27,6 +27,9 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi sealed abstract class AttributeType(val name: String) +/** + * :: DeveloperApi :: + */ @DeveloperApi object AttributeType { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index bc5fe35ad4a5c..6decea72719fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.{MetadataUtils, SchemaUtils} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -50,7 +50,7 @@ private[spark] trait ClassifierParams * Single-label binary or multiclass classification. * Classes are indexed {0, 1, ..., numClasses - 1}. * - * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam FeaturesType Type of input features. E.g., `Vector` * @tparam E Concrete Estimator type * @tparam M Concrete Model type */ @@ -134,7 +134,7 @@ abstract class Classifier[ * Model produced by a [[Classifier]]. * Classes are indexed {0, 1, ..., numClasses - 1}. * - * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam FeaturesType Type of input features. E.g., `Vector` * @tparam M Concrete Model type */ @DeveloperApi @@ -151,7 +151,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by * parameters: * - predicted labels as [[predictionCol]] of type [[Double]] - * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]. + * - raw predictions (confidences) as [[rawPredictionCol]] of type `Vector`. * * @param dataset input dataset * @return transformed dataset diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 31a69d49a078a..bb192ab5f25ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -21,14 +21,14 @@ import org.apache.hadoop.fs.Path import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -36,14 +36,12 @@ import org.apache.spark.sql.Dataset /** - * :: Experimental :: - * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning) * for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ @Since("1.4.0") -@Experimental class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] @@ -86,6 +84,13 @@ class DecisionTreeClassifier @Since("1.4.0") ( val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = getOldStrategy(categoricalFeatures, numClasses) @@ -127,7 +132,6 @@ class DecisionTreeClassifier @Since("1.4.0") ( } @Since("1.4.0") -@Experimental object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] { /** Accessor for supported impurities: entropy, gini */ @Since("1.4.0") @@ -138,13 +142,11 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi } /** - * :: Experimental :: - * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. + * Decision tree model (http://en.wikipedia.org/wiki/Decision_tree_learning) for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ @Since("1.4.0") -@Experimental class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, @Since("1.4.0")override val rootNode: Node, @@ -243,7 +245,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val (nodeData, _) = NodeData.build(instance.rootNode, 0) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(nodeData).write.parquet(dataPath) + sparkSession.createDataFrame(nodeData).write.parquet(dataPath) } } @@ -258,7 +260,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] - val root = loadTreeNodes(path, metadata, sqlContext) + val root = loadTreeNodes(path, metadata, sparkSession) val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) DefaultParamsReader.getAndSetParams(model, metadata) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index acc04582b8698..ba70293273f94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -21,17 +21,17 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD @@ -40,8 +40,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType /** - * :: Experimental :: - * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. @@ -57,7 +56,6 @@ import org.apache.spark.sql.types.DoubleType * [https://issues.apache.org/jira/browse/SPARK-4240] */ @Since("1.4.0") -@Experimental class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] @@ -149,7 +147,6 @@ class GBTClassifier @Since("1.4.0") ( } @Since("1.4.0") -@Experimental object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { /** Accessor for supported loss settings: logistic */ @@ -161,8 +158,7 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { } /** - * :: Experimental :: - * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * model for classification. * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. @@ -171,7 +167,6 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { * @param _treeWeights Weights for the decision trees in the ensemble. */ @Since("1.6.0") -@Experimental class GBTClassificationModel private[ml]( @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], @@ -238,8 +233,8 @@ class GBTClassificationModel private[ml]( * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) * and follows the implementation from scikit-learn. - * - * @see [[DecisionTreeClassificationModel.featureImportances]] + + * See `DecisionTreeClassificationModel.featureImportances` */ @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) @@ -270,7 +265,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { val extraMetadata: JObject = Map( "numFeatures" -> instance.numFeatures, "numTrees" -> instance.getNumTrees) - EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -283,7 +278,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override def load(path: String): GBTClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d2d4e249b4208..c50ee5dba28b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -27,12 +27,13 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.linalg.BLAS._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -71,10 +72,9 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas /** * Get threshold for binary classification. * - * If [[threshold]] is set, returns that value. - * Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification), + * If [[thresholds]] is set with length 2 (i.e., binary classification), * this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}. - * Otherwise, returns [[threshold]] default value. + * Otherwise, returns [[threshold]] if set, or its default value if unset. * * @group getParam * @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2. @@ -151,13 +151,11 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } /** - * :: Experimental :: * Logistic regression. * Currently, this class only supports binary classification. It will support multiclass * in the future. */ @Since("1.2.0") -@Experimental class LogisticRegression @Since("1.2.0") ( @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] @@ -294,6 +292,12 @@ class LogisticRegression @Since("1.2.0") ( val numClasses = histogram.length val numFeatures = summarizer.mean.size + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) @@ -332,6 +336,13 @@ class LogisticRegression @Since("1.2.0") ( val featuresMean = summarizer.mean.toArray val featuresStd = summarizer.variance.toArray.map(math.sqrt) + if (!$(fitIntercept) && (0 until numFeatures).exists { i => + featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { + logWarning("Fitting LogisticRegressionModel without intercept on dataset with " + + "constant nonzero column, Spark MLlib outputs zero coefficients for constant " + + "nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.") + } + val regParamL1 = $(elasticNetParam) * $(regParam) val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam) @@ -366,9 +377,10 @@ class LogisticRegression @Since("1.2.0") ( Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) { - val vec = optInitialModel.get.coefficients + val vecSize = optInitialModel.get.coefficients.size logWarning( - s"Initial coefficients provided $vec did not match the expected size $numFeatures") + s"Initial coefficients will be ignored!! As its size $vecSize did not match the " + + s"expected size $numFeatures") } if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) { @@ -398,7 +410,7 @@ class LogisticRegression @Since("1.2.0") ( } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficientsWithIntercept.toBreeze.toDenseVector) + initialCoefficientsWithIntercept.asBreeze.toDenseVector) /* Note that in Logistic Regression, the objective history (loss + regularization) @@ -418,6 +430,11 @@ class LogisticRegression @Since("1.2.0") ( throw new SparkException(msg) } + if (!state.actuallyConverged) { + logWarning("LogisticRegression training finished but the result " + + s"is not converged because: ${state.convergedReason.get.reason}") + } + /* The coefficients are trained in the scaled space; we're converting them back to the original space. @@ -467,14 +484,12 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { } /** - * :: Experimental :: * Model produced by [[LogisticRegression]]. */ @Since("1.4.0") -@Experimental class LogisticRegressionModel private[spark] ( @Since("1.4.0") override val uid: String, - @Since("1.6.0") val coefficients: Vector, + @Since("2.0.0") val coefficients: Vector, @Since("1.3.0") val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams with MLWritable { @@ -612,7 +627,7 @@ class LogisticRegressionModel private[spark] ( } /** - * Returns a [[MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. * * For [[LogisticRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -651,7 +666,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -665,13 +680,13 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.format("parquet").load(dataPath) - .select("numClasses", "numFeatures", "intercept", "coefficients").head() + val data = sparkSession.read.format("parquet").load(dataPath) + // We will need numClasses, numFeatures in the future for multinomial logreg support. - // val numClasses = data.getInt(0) - // val numFeatures = data.getInt(1) - val intercept = data.getDouble(2) - val coefficients = data.getAs[Vector](3) + val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("numClasses", "numFeatures", "intercept", "coefficients") + .head() val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept) DefaultParamsReader.getAndSetParams(model, metadata) @@ -684,7 +699,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { /** * MultiClassSummarizer computes the number of distinct labels and corresponding counts, * and validates the data to see if the labels used for k class multi-label classification - * are in the range of {0, 1, ..., k - 1} in a online fashion. + * are in the range of {0, 1, ..., k - 1} in an online fashion. * * Two MultilabelSummarizer can be merged together to have a statistical summary of the * corresponding joint dataset. @@ -744,7 +759,7 @@ private[classification] class MultiClassSummarizer extends Serializable { def countInvalid: Long = totalInvalidCnt /** @return The number of distinct labels in the input dataset. */ - def numClasses: Int = distinctMap.keySet.max + 1 + def numClasses: Int = if (distinctMap.isEmpty) 0 else distinctMap.keySet.max + 1 /** @return The weightSum of each label in the input dataset. */ def histogram: Array[Double] = { @@ -845,19 +860,19 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(probabilityCol, labelCol).rdd.map { + predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) /** * Returns the receiver operating characteristic (ROC) curve, - * which is an Dataframe having two fields (FPR, TPR) + * which is a Dataframe having two fields (FPR, TPR) * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic * - * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. * This will change in later Spark versions. - * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic */ @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") @@ -865,17 +880,17 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Computes the area under the receiver operating characteristic (ROC) curve. * - * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. * This will change in later Spark versions. */ @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() /** - * Returns the precision-recall curve, which is an Dataframe containing + * Returns the precision-recall curve, which is a Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. * - * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. * This will change in later Spark versions. */ @Since("1.5.0") @@ -884,7 +899,7 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. * - * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. * This will change in later Spark versions. */ @Since("1.5.0") @@ -897,7 +912,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. * - * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. * This will change in later Spark versions. */ @Since("1.5.0") @@ -910,7 +925,7 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. * - * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. + * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. * This will change in later Spark versions. */ @Since("1.5.0") @@ -921,57 +936,54 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * LogisticAggregator computes the gradient and loss for binary logistic loss function, as used - * in binary classification for instances in sparse or dense vector in a online fashion. + * in binary classification for instances in sparse or dense vector in an online fashion. * * Note that multinomial logistic loss is not supported yet! * * Two LogisticAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * - * @param coefficients The coefficients corresponding to the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. - * @param featuresStd The standard deviation values of the features. - * @param featuresMean The mean values of the features. */ private class LogisticAggregator( - coefficients: Vector, + private val numFeatures: Int, numClasses: Int, - fitIntercept: Boolean, - featuresStd: Array[Double], - featuresMean: Array[Double]) extends Serializable { + fitIntercept: Boolean) extends Serializable { private var weightSum = 0.0 private var lossSum = 0.0 - private val coefficientsArray = coefficients match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"coefficients only supports dense vector but got type ${coefficients.getClass}.") - } - - private val dim = if (fitIntercept) coefficientsArray.length - 1 else coefficientsArray.length - - private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length) + private val gradientSumArray = + Array.ofDim[Double](if (fitIntercept) numFeatures + 1 else numFeatures) /** * Add a new training instance to this LogisticAggregator, and update the loss and gradient * of the objective function. * * @param instance The instance of data point to be added. + * @param coefficients The coefficients corresponding to the features. + * @param featuresStd The standard deviation values of the features. * @return This LogisticAggregator object. */ - def add(instance: Instance): this.type = { + def add( + instance: Instance, + coefficients: Vector, + featuresStd: Array[Double]): this.type = { instance match { case Instance(label, weight, features) => - require(dim == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $dim but got ${features.size}.") + require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + + s" Expecting $numFeatures but got ${features.size}.") require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this - val localCoefficientsArray = coefficientsArray + val coefficientsArray = coefficients match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"coefficients only supports dense vector but got type ${coefficients.getClass}.") + } val localGradientSumArray = gradientSumArray numClasses match { @@ -981,11 +993,11 @@ private class LogisticAggregator( var sum = 0.0 features.foreachActive { (index, value) => if (featuresStd(index) != 0.0 && value != 0.0) { - sum += localCoefficientsArray(index) * (value / featuresStd(index)) + sum += coefficientsArray(index) * (value / featuresStd(index)) } } sum + { - if (fitIntercept) localCoefficientsArray(dim) else 0.0 + if (fitIntercept) coefficientsArray(numFeatures) else 0.0 } } @@ -998,7 +1010,7 @@ private class LogisticAggregator( } if (fitIntercept) { - localGradientSumArray(dim) += multiplier + localGradientSumArray(numFeatures) += multiplier } if (label > 0) { @@ -1025,8 +1037,8 @@ private class LogisticAggregator( * @return This LogisticAggregator object. */ def merge(other: LogisticAggregator): this.type = { - require(dim == other.dim, s"Dimensions mismatch when merging with another " + - s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + require(numFeatures == other.numFeatures, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $numFeatures but got ${other.numFeatures}.") if (other.weightSum != 0.0) { weightSum += other.weightSum @@ -1077,13 +1089,17 @@ private class LogisticCostFun( override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { val numFeatures = featuresStd.length val coeffs = Vectors.fromBreeze(coefficients) + val n = coeffs.size + val localFeaturesStd = featuresStd + val logisticAggregator = { - val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) + val seqOp = (c: LogisticAggregator, instance: Instance) => + c.add(instance, coeffs, localFeaturesStd) val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) instances.treeAggregate( - new LogisticAggregator(coeffs, numClasses, fitIntercept, featuresStd, featuresMean) + new LogisticAggregator(numFeatures, numClasses, fitIntercept) )(seqOp, combOp) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 72cf55f6bb997..76ef32aa3dc1d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -24,30 +24,30 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.Dataset /** Params for Multilayer Perceptron. */ -private[ml] trait MultilayerPerceptronParams extends PredictorParams +private[classification] trait MultilayerPerceptronParams extends PredictorParams with HasSeed with HasMaxIter with HasTol with HasStepSize { /** * Layer sizes including input size and output size. - * Default: Array(1, 1) * * @group param */ + @Since("1.5.0") final val layers: IntArrayParam = new IntArrayParam(this, "layers", - "Sizes of layers from input layer to output layer" + - " E.g., Array(780, 100, 10) means 780 inputs, " + + "Sizes of layers from input layer to output layer. " + + "E.g., Array(780, 100, 10) means 780 inputs, " + "one hidden layer with 100 neurons and output layer of 10 neurons.", - (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1 - ) + (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1) /** @group getParam */ + @Since("1.5.0") final def getLayers: Array[Int] = $(layers) /** @@ -59,42 +59,49 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams * * @group expertParam */ + @Since("1.5.0") final val blockSize: IntParam = new IntParam(this, "blockSize", "Block size for stacking input data in matrices. Data is stacked within partitions." + " If block size is more than remaining data in a partition then " + "it is adjusted to the size of this data. Recommended size is between 10 and 1000", ParamValidators.gt(0)) - /** @group getParam */ + /** @group expertGetParam */ + @Since("1.5.0") final def getBlockSize: Int = $(blockSize) /** - * Allows setting the solver: minibatch gradient descent (gd) or l-bfgs. - * l-bfgs is the default one. + * The solver algorithm for optimization. + * Supported options: "gd" (minibatch gradient descent) or "l-bfgs". + * Default: "l-bfgs" * * @group expertParam */ + @Since("2.0.0") final val solver: Param[String] = new Param[String](this, "solver", - " Allows setting the solver: minibatch gradient descent (gd) or l-bfgs. " + - " l-bfgs is the default one.", - ParamValidators.inArray[String](Array("gd", "l-bfgs"))) + "The solver algorithm for optimization. Supported options: " + + s"${MultilayerPerceptronClassifier.supportedSolvers.mkString(", ")}. (Default l-bfgs)", + ParamValidators.inArray[String](MultilayerPerceptronClassifier.supportedSolvers)) - /** @group getParam */ - final def getOptimizer: String = $(solver) + /** @group expertGetParam */ + @Since("2.0.0") + final def getSolver: String = $(solver) /** - * Model weights. Can be returned either after training or after explicit setting + * The initial weights of the model. * * @group expertParam */ - final val weights: Param[Vector] = new Param[Vector](this, "weights", - " Sets the weights of the model ") - - /** @group getParam */ - final def getWeights: Vector = $(weights) + @Since("2.0.0") + final val initialWeights: Param[Vector] = new Param[Vector](this, "initialWeights", + "The initial weights of the model") + /** @group expertGetParam */ + @Since("2.0.0") + final def getInitialWeights: Vector = $(initialWeights) - setDefault(maxIter -> 100, tol -> 1e-4, blockSize -> 128, solver -> "l-bfgs", stepSize -> 0.03) + setDefault(maxIter -> 100, tol -> 1e-4, blockSize -> 128, + solver -> MultilayerPerceptronClassifier.LBFGS, stepSize -> 0.03) } /** Label to vector converter. */ @@ -145,14 +152,32 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( @Since("1.5.0") def this() = this(Identifiable.randomUID("mlpc")) - /** @group setParam */ + /** + * Sets the value of param [[layers]]. + * + * @group setParam + */ @Since("1.5.0") def setLayers(value: Array[Int]): this.type = set(layers, value) - /** @group setParam */ + /** + * Sets the value of param [[blockSize]]. + * Default is 128. + * + * @group expertSetParam + */ @Since("1.5.0") def setBlockSize(value: Int): this.type = set(blockSize, value) + /** + * Sets the value of param [[solver]]. + * Default is "l-bfgs". + * + * @group expertSetParam + */ + @Since("2.0.0") + def setSolver(value: String): this.type = set(solver, value) + /** * Set the maximum number of iterations. * Default is 100. @@ -181,12 +206,21 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( def setSeed(value: Long): this.type = set(seed, value) /** - * Sets the model weights. + * Sets the value of param [[initialWeights]]. * - * @group expertParam + * @group expertSetParam + */ + @Since("2.0.0") + def setInitialWeights(value: Vector): this.type = set(initialWeights, value) + + /** + * Sets the value of param [[stepSize]] (applicable only for solver "gd"). + * Default is 0.03. + * + * @group setParam */ @Since("2.0.0") - def setWeights(value: Vector): this.type = set(weights, value) + def setStepSize(value: Double): this.type = set(stepSize, value) @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) @@ -204,16 +238,26 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( val labels = myLayers.last val lpData = extractLabeledPoints(dataset) val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) - val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) + val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, softmaxOnTop = true) val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) - if (isDefined(weights)) { - trainer.setWeights($(weights)) + if (isDefined(initialWeights)) { + trainer.setWeights($(initialWeights)) } else { trainer.setSeed($(seed)) } - trainer.LBFGSOptimizer - .setConvergenceTol($(tol)) - .setNumIterations($(maxIter)) + if ($(solver) == MultilayerPerceptronClassifier.LBFGS) { + trainer.LBFGSOptimizer + .setConvergenceTol($(tol)) + .setNumIterations($(maxIter)) + } else if ($(solver) == MultilayerPerceptronClassifier.GD) { + trainer.SGDOptimizer + .setNumIterations($(maxIter)) + .setConvergenceTol($(tol)) + .setStepSize($(stepSize)) + } else { + throw new IllegalArgumentException( + s"The solver $solver is not supported by MultilayerPerceptronClassifier.") + } trainer.setStackSize($(blockSize)) val mlpModel = trainer.train(data) new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) @@ -224,6 +268,15 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( object MultilayerPerceptronClassifier extends DefaultParamsReadable[MultilayerPerceptronClassifier] { + /** String name for "l-bfgs" solver. */ + private[classification] val LBFGS = "l-bfgs" + + /** String name for "gd" (minibatch gradient descent) solver. */ + private[classification] val GD = "gd" + + /** Set of solvers that MultilayerPerceptronClassifier supports. */ + private[classification] val supportedSolvers = Array(LBFGS, GD) + @Since("2.0.0") override def load(path: String): MultilayerPerceptronClassifier = super.load(path) } @@ -243,14 +296,16 @@ object MultilayerPerceptronClassifier class MultilayerPerceptronClassificationModel private[ml] ( @Since("1.5.0") override val uid: String, @Since("1.5.0") val layers: Array[Int], - @Since("1.5.0") val weights: Vector) + @Since("2.0.0") val weights: Vector) extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] with Serializable with MLWritable { @Since("1.6.0") override val numFeatures: Int = layers.head - private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).model(weights) + private val mlpModel = FeedForwardTopology + .multiLayerPerceptron(layers, softmaxOnTop = true) + .model(weights) /** * Returns layers in a Java List. @@ -301,7 +356,7 @@ object MultilayerPerceptronClassificationModel // Save model data: layers, weights val data = Data(instance.layers, instance.weights) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -315,7 +370,7 @@ object MultilayerPerceptronClassificationModel val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head() + val data = sparkSession.read.parquet(dataPath).select("layers", "weights").head() val layers = data.getAs[Seq[Int]](0).toArray val weights = data.getAs[Vector](1) val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 267d63b51eb6c..f939a1c6808e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -20,16 +20,17 @@ package org.apache.spark.ml.classification import org.apache.hadoop.fs.Path import org.apache.spark.SparkException -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} -import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{Dataset, Row} /** * Params for Naive Bayes Classifiers. @@ -62,7 +63,6 @@ private[ml] trait NaiveBayesParams extends PredictorParams { } /** - * :: Experimental :: * Naive Bayes Classifiers. * It supports both Multinomial NB * ([[http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html]]) @@ -73,7 +73,6 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * The input feature values must be nonnegative. */ @Since("1.5.0") -@Experimental class NaiveBayes @Since("1.5.0") ( @Since("1.5.0") override val uid: String) extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] @@ -102,7 +101,16 @@ class NaiveBayes @Since("1.5.0") ( setDefault(modelType -> OldNaiveBayes.Multinomial) override protected def train(dataset: Dataset[_]): NaiveBayesModel = { - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val numClasses = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + + val oldDataset: RDD[OldLabeledPoint] = + extractLabeledPoints(dataset).map(OldLabeledPoint.fromML) val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) } @@ -119,18 +127,16 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { } /** - * :: Experimental :: * Model produced by [[NaiveBayes]] * @param pi log of class priors, whose dimension is C (number of classes) * @param theta log of class conditional probabilities, whose dimension is C (number of classes) * by D (number of features) */ @Since("1.5.0") -@Experimental class NaiveBayesModel private[ml] ( @Since("1.5.0") override val uid: String, - @Since("1.5.0") val pi: Vector, - @Since("1.5.0") val theta: Matrix) + @Since("2.0.0") val pi: Vector, + @Since("2.0.0") val theta: Matrix) extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams with MLWritable { @@ -261,7 +267,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { // Save model data: pi, theta val data = Data(instance.pi, instance.theta) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -274,9 +280,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head() - val pi = data.getAs[Vector](0) - val theta = data.getAs[Matrix](1) + val data = sparkSession.read.parquet(dataPath) + val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi") + val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta") + .select("pi", "theta") + .head() val model = new NaiveBayesModel(metadata.uid, pi, theta) DefaultParamsReader.getAndSetParams(model, metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index f10c60a78df1a..f4ab0a074c420 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -29,12 +29,12 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -117,7 +117,6 @@ private[ml] object OneVsRestParams extends ClassifierTypeTrait { } /** - * :: Experimental :: * Model produced by [[OneVsRest]]. * This stores the models resulting from training k binary classifiers: one for each class. * Each example is scored against all k models, and the model with the highest score @@ -130,7 +129,6 @@ private[ml] object OneVsRestParams extends ClassifierTypeTrait { * (taking label 0). */ @Since("1.4.0") -@Experimental final class OneVsRestModel private[ml] ( @Since("1.4.0") override val uid: String, private[ml] val labelMetadata: Metadata, @@ -260,8 +258,6 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { } /** - * :: Experimental :: - * * Reduction of Multiclass Classification to Binary Classification. * Performs reduction using one against all strategy. * For a multiclass classification with k classes, train k models (one per class). @@ -269,7 +265,6 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] { * is picked to label the example. */ @Since("1.4.0") -@Experimental final class OneVsRest @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index d00fee12b08c0..88642abf63221 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,9 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -45,7 +45,7 @@ private[classification] trait ProbabilisticClassifierParams * * Single-label binary or multiclass classifier which can output class conditional probabilities. * - * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam FeaturesType Type of input features. E.g., `Vector` * @tparam E Concrete Estimator type * @tparam M Concrete Model type */ @@ -70,7 +70,7 @@ abstract class ProbabilisticClassifier[ * Model produced by a [[ProbabilisticClassifier]]. * Classes are indexed {0, 1, ..., numClasses - 1}. * - * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam FeaturesType Type of input features. E.g., `Vector` * @tparam M Concrete Model type */ @DeveloperApi @@ -89,8 +89,8 @@ abstract class ProbabilisticClassificationModel[ * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by * parameters: * - predicted labels as [[predictionCol]] of type [[Double]] - * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]] - * - probability of each class as [[probabilityCol]] of type [[Vector]]. + * - raw predictions (confidences) as [[rawPredictionCol]] of type `Vector` + * - probability of each class as [[probabilityCol]] of type `Vector`. * * @param dataset input dataset * @return transformed dataset @@ -210,7 +210,7 @@ private[ml] object ProbabilisticClassificationModel { /** * Normalize a vector of raw predictions to be a multinomial probability vector, in place. * - * The input raw predictions should be >= 0. + * The input raw predictions should be nonnegative. * The output vector sums to 1, unless the input vector is all-0 (in which case the output is * all-0 too). * diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 1d33ae83c211f..52345b0626c47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.classification import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -36,14 +36,12 @@ import org.apache.spark.sql.functions._ /** - * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for * classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ @Since("1.4.0") -@Experimental class RandomForestClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] @@ -102,6 +100,13 @@ class RandomForestClassifier @Since("1.4.0") ( val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) @@ -124,7 +129,6 @@ class RandomForestClassifier @Since("1.4.0") ( } @Since("1.4.0") -@Experimental object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] { /** Accessor for supported impurity settings: entropy, gini */ @Since("1.4.0") @@ -140,7 +144,6 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi } /** - * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. @@ -149,7 +152,6 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi * Warning: These have null parents. */ @Since("1.4.0") -@Experimental class RandomForestClassificationModel private[ml] ( @Since("1.5.0") override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], @@ -282,7 +284,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica "numFeatures" -> instance.numFeatures, "numClasses" -> instance.numClasses, "numTrees" -> instance.getNumTrees) - EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -296,7 +298,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = - EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 6cc9117da3fea..a97bd0fb16fd7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -21,12 +21,14 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.clustering. - {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -39,23 +41,27 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol { /** - * Set the number of clusters to create (k). Must be > 1. Default: 2. + * The desired number of leaf clusters. Must be > 1. Default: 4. + * The actual number could be smaller if there are no divisible leaf clusters. * @group param */ @Since("2.0.0") - final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + final val k = new IntParam(this, "k", "The desired number of leaf clusters. " + + "Must be > 1.", ParamValidators.gt(1)) /** @group getParam */ @Since("2.0.0") def getK: Int = $(k) - /** @group expertParam */ + /** + * The minimum number of points (if >= 1.0) or the minimum proportion + * of points (if < 1.0) of a divisible cluster (default: 1.0). + * @group expertParam + */ @Since("2.0.0") - final val minDivisibleClusterSize = new DoubleParam( - this, - "minDivisibleClusterSize", - "the minimum number of points (if >= 1.0) or the minimum proportion", - (value: Double) => value > 0) + final val minDivisibleClusterSize = new DoubleParam(this, "minDivisibleClusterSize", + "The minimum number of points (if >= 1.0) or the minimum proportion " + + "of points (if < 1.0) of a divisible cluster.", ParamValidators.gt(0.0)) /** @group expertGetParam */ @Since("2.0.0") @@ -76,7 +82,7 @@ private[clustering] trait BisectingKMeansParams extends Params * :: Experimental :: * Model fitted by BisectingKMeans. * - * @param parentModel a model trained by spark.mllib.clustering.BisectingKMeans. + * @param parentModel a model trained by [[org.apache.spark.mllib.clustering.BisectingKMeans]]. */ @Since("2.0.0") @Experimental @@ -93,6 +99,7 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -105,7 +112,7 @@ class BisectingKMeansModel private[ml] ( private[clustering] def predict(features: Vector): Int = parentModel.predict(features) @Since("2.0.0") - def clusterCenters: Array[Vector] = parentModel.clusterCenters + def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) /** * Computes the sum of squared distances between the input points and their corresponding cluster @@ -115,7 +122,7 @@ class BisectingKMeansModel private[ml] ( def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } - parentModel.computeCost(data) + parentModel.computeCost(data.map(OldVectors.fromML)) } @Since("2.0.0") @@ -216,7 +223,10 @@ class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): BisectingKMeansModel = { - val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + transformSchema(dataset.schema, logging = true) + val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } val bkm = new MLlibBisectingKMeans() .setK($(k)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index ac86e4ce25e82..69f060ad7711e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -20,19 +20,19 @@ package org.apache.spark.ml.clustering import breeze.linalg.{DenseVector => BDV} import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.impl.Utils.EPSILON import org.apache.spark.ml.linalg._ -import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.stat.distribution.MultivariateGaussian import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM} import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => OldMatrix, - Vector => OldVector, Vectors => OldVectors, VectorUDT => OldVectorUDT} -import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} + Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -44,11 +44,12 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol { /** - * Set the number of clusters to create (k). Must be > 1. Default: 2. + * Number of independent Gaussians in the mixture model. Must be > 1. Default: 2. * @group param */ @Since("2.0.0") - final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + final val k = new IntParam(this, "k", "Number of independent Gaussians in the mixture model. " + + "Must be > 1.", ParamValidators.gt(1)) /** @group getParam */ @Since("2.0.0") @@ -60,9 +61,9 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(featuresCol), new OldVectorUDT) + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - SchemaUtils.appendColumn(schema, $(probabilityCol), new OldVectorUDT) + SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) } } @@ -94,8 +95,9 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - val predUDF = udf((vector: OldVector) => predict(vector.asML)) - val probUDF = udf((vector: OldVector) => OldVectors.fromML(predictProbability(vector.asML))) + transformSchema(dataset.schema, logging = true) + val predUDF = udf((vector: Vector) => predict(vector)) + val probUDF = udf((vector: Vector) => predictProbability(vector)) dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) } @@ -112,7 +114,7 @@ class GaussianMixtureModel private[ml] ( private[clustering] def predictProbability(features: Vector): Vector = { val probs: Array[Double] = - GaussianMixtureModel.computeProbabilities(features.toBreeze.toDenseVector, gaussians, weights) + GaussianMixtureModel.computeProbabilities(features.asBreeze.toDenseVector, gaussians, weights) Vectors.dense(probs) } @@ -132,11 +134,16 @@ class GaussianMixtureModel private[ml] ( val modelGaussians = gaussians.map { gaussian => (OldVectors.fromML(gaussian.mean), OldMatrices.fromML(gaussian.cov)) } - val sc = SparkContext.getOrCreate() - val sqlContext = SQLContext.getOrCreate(sc) - sqlContext.createDataFrame(modelGaussians).toDF("mean", "cov") + SparkSession.builder().getOrCreate().createDataFrame(modelGaussians).toDF("mean", "cov") } + /** + * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * + * For [[GaussianMixtureModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + */ @Since("2.0.0") override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this) @@ -189,7 +196,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { val sigmas = gaussians.map(c => OldMatrices.fromML(c.cov)) val data = Data(weights, mus, sigmas) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -202,7 +209,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head() + val row = sparkSession.read.parquet(dataPath).select("weights", "mus", "sigmas").head() val weights = row.getSeq[Double](0).toArray val mus = row.getSeq[OldVector](1).toArray val sigmas = row.getSeq[OldMatrix](2).toArray @@ -248,6 +255,21 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { /** * :: Experimental :: * Gaussian Mixture clustering. + * + * This class performs expectation maximization for multivariate Gaussian + * Mixture Models (GMMs). A GMM represents a composite distribution of + * independent Gaussian distributions with associated "mixing" weights + * specifying each's contribution to the composite. + * + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * less than convergenceTol, or until it has reached the max number of iterations. + * While this process is generally guaranteed to converge, it is not guaranteed + * to find a global optimum. + * + * Note: For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. */ @Since("2.0.0") @Experimental @@ -296,7 +318,10 @@ class GaussianMixture @Since("2.0.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): GaussianMixtureModel = { - val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: OldVector) => point } + transformSchema(dataset.schema, logging = true) + val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } val algo = new MLlibGM() .setK($(k)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 7c9ac02521ff6..6c46be719674b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -22,11 +22,14 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -38,11 +41,12 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe with HasSeed with HasPredictionCol with HasTol { /** - * Set the number of clusters to create (k). Must be > 1. Default: 2. + * The number of clusters to create (k). Must be > 1. Default: 2. * @group param */ @Since("1.5.0") - final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + final val k = new IntParam(this, "k", "The number of clusters to create. " + + "Must be > 1.", ParamValidators.gt(1)) /** @group getParam */ @Since("1.5.0") @@ -55,7 +59,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @group expertParam */ @Since("1.5.0") - final val initMode = new Param[String](this, "initMode", "initialization algorithm", + final val initMode = new Param[String](this, "initMode", "The initialization algorithm. " + + "Supported options: 'random' and 'k-means||'.", (value: String) => MLlibKMeans.validateInitMode(value)) /** @group expertGetParam */ @@ -68,8 +73,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @group expertParam */ @Since("1.5.0") - final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||", - (value: Int) => value > 0) + final val initSteps = new IntParam(this, "initSteps", "The number of steps for k-means|| " + + "initialization mode. Must be > 0.", ParamValidators.gt(0)) /** @group expertGetParam */ @Since("1.5.0") @@ -105,8 +110,17 @@ class KMeansModel private[ml] ( copyValues(copied, extra) } + /** @group setParam */ + @Since("2.0.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -118,8 +132,8 @@ class KMeansModel private[ml] ( private[clustering] def predict(features: Vector): Int = parentModel.predict(features) - @Since("1.5.0") - def clusterCenters: Array[Vector] = parentModel.clusterCenters + @Since("2.0.0") + def clusterCenters: Array[Vector] = parentModel.clusterCenters.map(_.asML) /** * Return the K-means cost (sum of squared distances of points to their nearest center) for this @@ -129,10 +143,19 @@ class KMeansModel private[ml] ( @Since("2.0.0") def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) - val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + val data: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } parentModel.computeCost(data) } + /** + * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * + * For [[KMeansModel]], this does NOT currently save the training [[summary]]. + * An option to save [[summary]] may be added in the future. + * + */ @Since("1.6.0") override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) @@ -172,6 +195,12 @@ object KMeansModel extends MLReadable[KMeansModel] { /** Helper class for storing model data */ private case class Data(clusterIdx: Int, clusterCenter: Vector) + /** + * We store all cluster centers in a single row and use this class to store model data by + * Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility. + */ + private case class OldData(clusterCenters: Array[OldVector]) + /** [[MLWriter]] instance for [[KMeansModel]] */ private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter { @@ -183,7 +212,7 @@ object KMeansModel extends MLReadable[KMeansModel] { Data(idx, center) } val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(data).repartition(1).write.parquet(dataPath) } } @@ -194,16 +223,23 @@ object KMeansModel extends MLReadable[KMeansModel] { override def load(path: String): KMeansModel = { // Import implicits for Dataset Encoder - val sqlContext = super.sqlContext - import sqlContext.implicits._ + val sparkSession = super.sparkSession + import sparkSession.implicits._ val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val dataPath = new Path(path, "data").toString - val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data] - val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter) - val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) + val versionRegex = "([0-9]+)\\.(.+)".r + val versionRegex(major, _) = metadata.sparkVersion + + val clusterCenters = if (major.toInt >= 2) { + val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] + data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) + } else { + // Loads KMeansModel stored with the old format used by Spark 1.6 and earlier. + sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters + } + val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters)) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -269,7 +305,10 @@ class KMeans @Since("1.5.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): KMeansModel = { - val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + transformSchema(dataset.schema, logging = true) + val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map { + case Row(point: Vector) => OldVectors.fromML(point) + } val instr = Instrumentation.create(this, rdd) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 38ecc5a102c12..8e233255b4e24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,24 +17,33 @@ package org.apache.spark.ml.clustering -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats +import org.json4s.JsonAST.JObject +import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} import org.apache.spark.mllib.impl.PeriodicCheckpointer -import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.linalg.MatrixImplicits._ +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} +import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.VersionUtils private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter @@ -46,8 +55,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @group param */ @Since("1.6.0") - final val k = new IntParam(this, "k", "number of topics (clusters) to infer", - ParamValidators.gt(1)) + final val k = new IntParam(this, "k", "The number of topics (clusters) to infer. " + + "Must be > 1.", ParamValidators.gt(1)) /** @group getParam */ @Since("1.6.0") @@ -76,6 +85,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * - Values should be >= 0 * - default = uniformly (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * * @group param */ @Since("1.6.0") @@ -117,6 +127,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * - Value should be >= 0 * - default = (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. + * * @group param */ @Since("1.6.0") @@ -161,7 +172,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM */ @Since("1.6.0") final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" + - " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), + " algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "), (o: String) => ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase)) /** @group getParam */ @@ -350,6 +361,39 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM } } +private object LDAParams { + + /** + * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]] + * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+. + * + * @param model [[LDA]] or [[LDAModel]] instance. This instance will be modified with + * [[Param]] values extracted from metadata. + * @param metadata Loaded model metadata + */ + def getAndSetParams(model: LDAParams, metadata: Metadata): Unit = { + VersionUtils.majorMinorVersion(metadata.sparkVersion) match { + case (1, 6) => + implicit val format = DefaultFormats + metadata.params match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val origParam = + if (paramName == "topicDistribution") "topicDistributionCol" else paramName + val param = model.getParam(origParam) + val value = param.jsonDecode(compact(render(jsonValue))) + model.set(param, value) + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${metadata.metadataJson}.") + } + case _ => // 2.0+ + DefaultParamsReader.getAndSetParams(model, metadata) + } + } +} + /** * :: Experimental :: @@ -405,12 +449,16 @@ sealed abstract class LDAModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { if ($(topicDistributionCol).nonEmpty) { - val t = udf(oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext)) - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF + + // TODO: Make the transformer natively in ml framework to avoid extra conversion. + val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext) + + val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") - dataset.toDF + dataset.toDF() } } @@ -424,7 +472,7 @@ sealed abstract class LDAModel private[ml] ( * If Online LDA was used and [[optimizeDocConcentration]] was set to false, * then this returns the fixed (given) value for the [[docConcentration]] parameter. */ - @Since("1.6.0") + @Since("2.0.0") def estimatedDocConcentration: Vector = getModel.docConcentration /** @@ -436,8 +484,8 @@ sealed abstract class LDAModel private[ml] ( * the Expectation-Maximization ("em") [[optimizer]], then this method could involve * collecting a large amount of data to the driver (on the order of vocabSize x k). */ - @Since("1.6.0") - def topicsMatrix: Matrix = oldLocalModel.topicsMatrix + @Since("2.0.0") + def topicsMatrix: Matrix = oldLocalModel.topicsMatrix.asML /** Indicates whether this instance is of type [[DistributedLDAModel]] */ @Since("1.6.0") @@ -566,18 +614,16 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString val data = sparkSession.read.parquet(dataPath) - .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", - "gammaShape") - .head() - val vocabSize = data.getAs[Int](0) - val topicsMatrix = data.getAs[Matrix](1) - val docConcentration = data.getAs[Vector](2) - val topicConcentration = data.getAs[Double](3) - val gammaShape = data.getAs[Double](4) + val vectorConverted = MLUtils.convertVectorColumnsToML(data, "docConcentration") + val matrixConverted = MLUtils.convertMatrixColumnsToML(vectorConverted, "topicsMatrix") + val Row(vocabSize: Int, topicsMatrix: Matrix, docConcentration: Vector, + topicConcentration: Double, gammaShape: Double) = + matrixConverted.select("vocabSize", "topicsMatrix", "docConcentration", + "topicConcentration", "gammaShape").head() val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, gammaShape) val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sparkSession) - DefaultParamsReader.getAndSetParams(model, metadata) + LDAParams.getAndSetParams(model, metadata) model } } @@ -667,6 +713,8 @@ class DistributedLDAModel private[ml] ( private var _checkpointFiles: Array[String] = oldDistributedModel.checkpointFiles /** + * :: DeveloperApi :: + * * If using checkpointing and [[LDA.keepLastCheckpoint]] is set to true, then there may be * saved checkpoint files. This method is provided so that users can manage those files. * @@ -681,6 +729,8 @@ class DistributedLDAModel private[ml] ( def getCheckpointFiles: Array[String] = _checkpointFiles /** + * :: DeveloperApi :: + * * Remove any remaining checkpoint files from training. * * @see [[getCheckpointFiles]] @@ -688,8 +738,8 @@ class DistributedLDAModel private[ml] ( @DeveloperApi @Since("2.0.0") def deleteCheckpointFiles(): Unit = { - val fs = FileSystem.get(sparkSession.sparkContext.hadoopConfiguration) - _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) + val hadoopConf = sparkSession.sparkContext.hadoopConfiguration + _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, hadoopConf)) _checkpointFiles = Array.empty[String] } @@ -719,9 +769,9 @@ object DistributedLDAModel extends MLReadable[DistributedLDAModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val modelPath = new Path(path, "oldModel").toString val oldModel = OldDistributedLDAModel.load(sc, modelPath) - val model = new DistributedLDAModel( - metadata.uid, oldModel.vocabSize, oldModel, sparkSession, None) - DefaultParamsReader.getAndSetParams(model, metadata) + val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize, + oldModel, sparkSession, None) + LDAParams.getAndSetParams(model, metadata) model } } @@ -868,20 +918,36 @@ class LDA @Since("1.6.0") ( } } - -private[clustering] object LDA extends DefaultParamsReadable[LDA] { +@Since("2.0.0") +object LDA extends MLReadable[LDA] { /** Get dataset for spark.mllib LDA */ - def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, Vector)] = { + private[clustering] def getOldDataset( + dataset: Dataset[_], + featuresCol: String): RDD[(Long, OldVector)] = { dataset - .withColumn("docId", monotonicallyIncreasingId()) + .withColumn("docId", monotonically_increasing_id()) .select("docId", featuresCol) .rdd .map { case Row(docId: Long, features: Vector) => - (docId, features) + (docId, OldVectors.fromML(features)) } } - @Since("1.6.0") + private class LDAReader extends MLReader[LDA] { + + private val className = classOf[LDA].getName + + override def load(path: String): LDA = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val model = new LDA(metadata.uid) + LDAParams.getAndSetParams(model, metadata) + model + } + } + + override def read: MLReader[LDA] = new LDAReader + + @Since("2.0.0") override def load(path: String): LDA = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 0cbc391d96f8a..bff72b20e1c3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 5f765c071b9cd..e7b949ddce344 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -30,7 +30,8 @@ import org.apache.spark.sql.Dataset abstract class Evaluator extends Params { /** - * Evaluates model output and returns a scalar metric (larger is better). + * Evaluates model output and returns a scalar metric. + * The value of [[isLargerBetter]] specifies whether larger values are better. * * @param dataset a dataset that contains labels/observations and predictions. * @param paramMap parameter map that specifies the input columns and output metrics @@ -42,7 +43,9 @@ abstract class Evaluator extends Params { } /** - * Evaluates the output. + * Evaluates model output and returns a scalar metric. + * The value of [[isLargerBetter]] specifies whether larger values are better. + * * @param dataset a dataset that contains labels/observations and predictions. * @return metric */ @@ -50,7 +53,7 @@ abstract class Evaluator extends Params { def evaluate(dataset: Dataset[_]): Double /** - * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default) + * Indicates whether the metric returned by `evaluate` should be maximized (true, default) * or minimized (false). * A given evaluator may support multiple metrics which may be maximized or minimized. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 3d89843a0b711..794b1e7d9d881 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.DoubleType /** * :: Experimental :: - * Evaluator for multiclass classification, which expects two input columns: score and label. + * Evaluator for multiclass classification, which expects two input columns: prediction and label. */ @Since("1.5.0") @Experimental @@ -39,16 +39,16 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid def this() = this(Identifiable.randomUID("mcEval")) /** - * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, - * `"weightedPrecision"`, `"weightedRecall"`) + * param for metric name in evaluation (supports `"f1"` (default), `"weightedPrecision"`, + * `"weightedRecall"`, `"accuracy"`) * @group param */ @Since("1.5.0") val metricName: Param[String] = { - val allowedParams = ParamValidators.inArray(Array("f1", "precision", - "recall", "weightedPrecision", "weightedRecall")) + val allowedParams = ParamValidators.inArray(Array("f1", "weightedPrecision", + "weightedRecall", "accuracy")) new Param(this, "metricName", "metric name in evaluation " + - "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams) + "(f1|weightedPrecision|weightedRecall|accuracy)", allowedParams) } /** @group getParam */ @@ -82,22 +82,15 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid val metrics = new MulticlassMetrics(predictionAndLabels) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure - case "precision" => metrics.precision - case "recall" => metrics.recall case "weightedPrecision" => metrics.weightedPrecision case "weightedRecall" => metrics.weightedRecall + case "accuracy" => metrics.accuracy } metric } @Since("1.5.0") - override def isLargerBetter: Boolean = $(metricName) match { - case "f1" => true - case "precision" => true - case "recall" => true - case "weightedPrecision" => true - case "weightedRecall" => true - } + override def isLargerBetter: Boolean = true @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 898ac2cc8941b..2b0862c60fdf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -19,25 +19,25 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute +import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: Experimental :: * Binarize a column of continuous features given a threshold. */ -@Experimental -final class Binarizer(override val uid: String) +@Since("1.4.0") +final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("binarizer")) /** @@ -47,21 +47,26 @@ final class Binarizer(override val uid: String) * Default: 0.0 * @group param */ + @Since("1.4.0") val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold used to binarize continuous features") /** @group getParam */ + @Since("1.4.0") def getThreshold: Double = $(threshold) /** @group setParam */ + @Since("1.4.0") def setThreshold(value: Double): this.type = set(threshold, value) setDefault(threshold -> 0.0) /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") @@ -96,6 +101,7 @@ final class Binarizer(override val uid: String) } } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType val outputColName = $(outputCol) @@ -104,9 +110,9 @@ final class Binarizer(override val uid: String) case DoubleType => BinaryAttribute.defaultAttr.withName(outputColName).toStructField() case _: VectorUDT => - new StructField(outputColName, new VectorUDT, true) - case other => - throw new IllegalArgumentException(s"Data type $other is not supported.") + StructField(outputColName, new VectorUDT) + case _ => + throw new IllegalArgumentException(s"Data type $inputType is not supported.") } if (schema.fieldNames.contains(outputColName)) { @@ -115,6 +121,7 @@ final class Binarizer(override val uid: String) StructType(schema.fields :+ outCol) } + @Since("1.4.1") override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 10e622ace6d5e..100d9e7f6cbcc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -31,41 +31,46 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * :: Experimental :: * `Bucketizer` maps a column of continuous features to a column of feature buckets. */ -@Experimental -final class Bucketizer(override val uid: String) +@Since("1.4.0") +final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Model[Bucketizer] with HasInputCol with HasOutputCol with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("bucketizer")) /** * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. * A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which - * also includes y. Splits should be strictly increasing. + * also includes y. Splits should be of length >= 3 and strictly increasing. * Values at -inf, inf must be explicitly provided to cover all Double values; * otherwise, values outside the splits specified will be treated as errors. * @group param */ + @Since("1.4.0") val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits", "Split points for mapping continuous features into buckets. With n+1 splits, there are n " + "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " + - "bucket, which also includes y. The splits should be strictly increasing. " + - "Values at -inf, inf must be explicitly provided to cover all Double values; " + + "bucket, which also includes y. The splits should be of length >= 3 and strictly " + + "increasing. Values at -inf, inf must be explicitly provided to cover all Double values; " + "otherwise, values outside the splits specified will be treated as errors.", Bucketizer.checkSplits) /** @group getParam */ + @Since("1.4.0") def getSplits: Array[Double] = $(splits) /** @group setParam */ + @Since("1.4.0") def setSplits(value: Array[Double]): this.type = set(splits, value) /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") @@ -86,16 +91,19 @@ final class Bucketizer(override val uid: String) attr.toStructField() } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } + @Since("1.4.1") override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } } +@Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** We require splits to be of length >= 3 and to be in strictly increasing order. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index cfecae7e0b152..1482eb3d1f7a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -19,15 +19,17 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml._ import org.apache.spark.ml.attribute.{AttributeGroup, _} +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -40,8 +42,8 @@ private[feature] trait ChiSqSelectorParams extends Params /** * Number of features that selector will select (ordered by statistic value descending). If the - * number of features is < numTopFeatures, then this will select all features. The default value - * of numTopFeatures is 50. + * number of features is less than numTopFeatures, then this will select all features. + * The default value of numTopFeatures is 50. * @group param */ final val numTopFeatures = new IntParam(this, "numTopFeatures", @@ -55,45 +57,52 @@ private[feature] trait ChiSqSelectorParams extends Params } /** - * :: Experimental :: * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. */ -@Experimental -final class ChiSqSelector(override val uid: String) +@Since("1.6.0") +final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable { + @Since("1.6.0") def this() = this(Identifiable.randomUID("chiSqSelector")) /** @group setParam */ + @Since("1.6.0") def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) /** @group setParam */ + @Since("1.6.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.6.0") def setLabelCol(value: String): this.type = set(labelCol, value) @Since("2.0.0") override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { - case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - } + val input: RDD[OldLabeledPoint] = + dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => + OldLabeledPoint(label, OldVectors.fromML(features)) + } val chiSqSelector = new feature.ChiSqSelector($(numTopFeatures)).fit(input) copyValues(new ChiSqSelectorModel(uid, chiSqSelector).setParent(this)) } + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } + @Since("1.6.0") override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra) } @@ -105,37 +114,48 @@ object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] { } /** - * :: Experimental :: * Model fitted by [[ChiSqSelector]]. */ -@Experimental +@Since("1.6.0") final class ChiSqSelectorModel private[ml] ( - override val uid: String, + @Since("1.6.0") override val uid: String, private val chiSqSelector: feature.ChiSqSelectorModel) extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable { import ChiSqSelectorModel._ /** list of indices to select (filter). Must be ordered asc */ + @Since("1.6.0") val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures /** @group setParam */ + @Since("1.6.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** @group setParam */ + /** + * @group setParam + */ + @Since("1.6.0") + @deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0") def setLabelCol(value: String): this.type = set(labelCol, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val newField = transformedSchema.last - val selector = udf { chiSqSelector.transform _ } + + // TODO: Make the transformer natively in ml framework to avoid extra conversion. + val transformer: Vector => Vector = v => chiSqSelector.transform(OldVectors.fromML(v)).asML + + val selector = udf(transformer) dataset.withColumn($(outputCol), selector(col($(featuresCol))), newField.metadata) } + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) val newField = prepOutputField(schema) @@ -158,6 +178,7 @@ final class ChiSqSelectorModel private[ml] ( newAttributeGroup.toStructField() } + @Since("1.6.0") override def copy(extra: ParamMap): ChiSqSelectorModel = { val copied = new ChiSqSelectorModel(uid, chiSqSelector) copyValues(copied, extra).setParent(parent) @@ -179,7 +200,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.selectedFeatures.toSeq) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -190,7 +211,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] { override def load(path: String): ChiSqSelectorModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head() + val data = sparkSession.read.parquet(dataPath).select("selectedFeatures").head() val selectedFeatures = data.getAs[Seq[Int]](0).toArray val oldModel = new feature.ChiSqSelectorModel(selectedFeatures) val model = new ChiSqSelectorModel(metadata.uid, oldModel) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 3fbfce9d48dd2..6299f74a6bf96 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -18,13 +18,13 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -56,7 +56,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit * If this is an integer >= 1, this specifies the number of documents the term must appear in; * if this is a double in [0,1), then this specifies the fraction of documents. * - * Default: 1 + * Default: 1.0 * @group param */ val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" + @@ -86,7 +86,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not * affect fitting. * - * Default: 1 + * Default: 1.0 * @group param */ val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" + @@ -96,8 +96,6 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit " of the document's token count). Note that the parameter is only used in transform of" + " CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0)) - setDefault(minTF -> 1) - /** @group getParam */ def getMinTF: Double = $(minTF) @@ -114,39 +112,43 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getBinary: Boolean = $(binary) - setDefault(binary -> false) + setDefault(vocabSize -> (1 << 18), minDF -> 1.0, minTF -> 1.0, binary -> false) } /** - * :: Experimental :: * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]]. */ -@Experimental -class CountVectorizer(override val uid: String) +@Since("1.5.0") +class CountVectorizer @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("cntVec")) /** @group setParam */ + @Since("1.5.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.5.0") def setVocabSize(value: Int): this.type = set(vocabSize, value) /** @group setParam */ + @Since("1.5.0") def setMinDF(value: Double): this.type = set(minDF, value) /** @group setParam */ + @Since("1.5.0") def setMinTF(value: Double): this.type = set(minTF, value) /** @group setParam */ + @Since("2.0.0") def setBinary(value: Boolean): this.type = set(binary, value) - setDefault(vocabSize -> (1 << 18), minDF -> 1) - @Since("2.0.0") override def fit(dataset: Dataset[_]): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) @@ -180,10 +182,12 @@ class CountVectorizer(override val uid: String) copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra) } @@ -195,31 +199,37 @@ object CountVectorizer extends DefaultParamsReadable[CountVectorizer] { } /** - * :: Experimental :: * Converts a text document to a sparse vector of token counts. * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted. */ -@Experimental -class CountVectorizerModel(override val uid: String, val vocabulary: Array[String]) +@Since("1.5.0") +class CountVectorizerModel( + @Since("1.5.0") override val uid: String, + @Since("1.5.0") val vocabulary: Array[String]) extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable { import CountVectorizerModel._ + @Since("1.5.0") def this(vocabulary: Array[String]) = { this(Identifiable.randomUID("cntVecModel"), vocabulary) set(vocabSize, vocabulary.length) } /** @group setParam */ + @Since("1.5.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.5.0") def setMinTF(value: Double): this.type = set(minTF, value) /** @group setParam */ + @Since("2.0.0") def setBinary(value: Boolean): this.type = set(binary, value) /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */ @@ -256,10 +266,12 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin dataset.withColumn($(outputCol), vectorizer(col($(inputCol)))) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): CountVectorizerModel = { val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent) copyValues(copied, extra) @@ -281,7 +293,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.vocabulary) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -292,7 +304,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { override def load(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("vocabulary") .head() val vocabulary = data.getAs[Seq[String]](0).toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index a6f878151de73..6ff36b35ca4c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -19,15 +19,14 @@ package org.apache.spark.ml.feature import edu.emory.mathcs.jtransforms.dct._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.BooleanParam import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.types.DataType /** - * :: Experimental :: * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero * padding is performed on the input vector. * It returns a real vector of the same length representing the DCT. The return vector is scaled @@ -35,10 +34,11 @@ import org.apache.spark.sql.types.DataType * * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]]. */ -@Experimental -class DCT(override val uid: String) +@Since("1.5.0") +class DCT @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends UnaryTransformer[Vector, Vector, DCT] with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("dct")) /** @@ -46,13 +46,16 @@ class DCT(override val uid: String) * Default: false * @group param */ + @Since("1.5.0") def inverse: BooleanParam = new BooleanParam( this, "inverse", "Set transformer to perform inverse DCT") /** @group setParam */ + @Since("1.5.0") def setInverse(value: Boolean): this.type = set(inverse, value) /** @group getParam */ + @Since("1.5.0") def getInverse: Boolean = $(inverse) setDefault(inverse -> false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 1b0a9a12e83bc..f860b3a787b4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -17,42 +17,46 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.sql.types.DataType /** - * :: Experimental :: * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. */ -@Experimental -class ElementwiseProduct(override val uid: String) +@Since("1.4.0") +class ElementwiseProduct @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends UnaryTransformer[Vector, Vector, ElementwiseProduct] with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("elemProd")) /** * the vector to multiply with input vectors * @group param */ + @Since("2.0.0") val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") /** @group setParam */ + @Since("2.0.0") def setScalingVec(value: Vector): this.type = set(scalingVec, value) /** @group getParam */ + @Since("2.0.0") def getScalingVec: Vector = getOrDefault(scalingVec) override protected def createTransformFunc: Vector => Vector = { require(params.contains(scalingVec), s"transformation requires a weight vector") val elemScaler = new feature.ElementwiseProduct($(scalingVec)) - elemScaler.transform + v => elemScaler.transform(v) } override protected def outputDataType: DataType = new VectorUDT() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 66ae91cfc0970..a8792a35ff4ae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param._ @@ -29,7 +29,6 @@ import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StructType} /** - * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. * Currently we use Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32) * to calculate the hash code value for the term object. @@ -37,16 +36,19 @@ import org.apache.spark.sql.types.{ArrayType, StructType} * it is advisable to use a power of two as the numFeatures parameter; * otherwise the features will not be mapped evenly to the columns. */ -@Experimental -class HashingTF(override val uid: String) +@Since("1.2.0") +class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + @Since("1.2.0") def this() = this(Identifiable.randomUID("hashingTF")) /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @@ -54,6 +56,7 @@ class HashingTF(override val uid: String) * (default = 2^18^) * @group param */ + @Since("1.2.0") val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)", ParamValidators.gt(0)) @@ -64,6 +67,7 @@ class HashingTF(override val uid: String) * (default = false) * @group param */ + @Since("2.0.0") val binary = new BooleanParam(this, "binary", "If true, all non zero counts are set to 1. " + "This is useful for discrete probabilistic models that model binary events rather " + "than integer counts") @@ -71,26 +75,32 @@ class HashingTF(override val uid: String) setDefault(numFeatures -> (1 << 18), binary -> false) /** @group getParam */ + @Since("1.2.0") def getNumFeatures: Int = $(numFeatures) /** @group setParam */ + @Since("1.2.0") def setNumFeatures(value: Int): this.type = set(numFeatures, value) /** @group getParam */ + @Since("2.0.0") def getBinary: Boolean = $(binary) /** @group setParam */ + @Since("2.0.0") def setBinary(value: Boolean): this.type = set(binary, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) - val t = udf { terms: Seq[_] => hashingTF.transform(terms) } + // TODO: Make the hashingTF.transform natively in ml framework to avoid extra conversion. + val t = udf { terms: Seq[_] => hashingTF.transform(terms).asML } val metadata = outputSchema($(outputCol)).metadata dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], @@ -99,6 +109,7 @@ class HashingTF(override val uid: String) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } + @Since("1.4.1") override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 5075b78c9856a..6386dd8a10801 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -19,13 +19,16 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml._ +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StructType @@ -36,12 +39,12 @@ import org.apache.spark.sql.types.StructType private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol { /** - * The minimum of documents in which a term should appear. + * The minimum number of documents in which a term should appear. * Default: 0 * @group param */ final val minDocFreq = new IntParam( - this, "minDocFreq", "minimum of documents in which a term should appear for filtering") + this, "minDocFreq", "minimum number of documents in which a term should appear for filtering") setDefault(minDocFreq -> 0) @@ -58,36 +61,43 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol } /** - * :: Experimental :: * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ -@Experimental -final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase - with DefaultParamsWritable { +@Since("1.4.0") +final class IDF @Since("1.4.0") (@Since("1.4.0") override val uid: String) + extends Estimator[IDFModel] with IDFBase with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("idf")) /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.4.0") def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) @Since("2.0.0") override def fit(dataset: Dataset[_]): IDFModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { + case Row(v: Vector) => OldVectors.fromML(v) + } val idf = new feature.IDF($(minDocFreq)).fit(input) copyValues(new IDFModel(uid, idf).setParent(this)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.4.1") override def copy(extra: ParamMap): IDF = defaultCopy(extra) } @@ -99,42 +109,46 @@ object IDF extends DefaultParamsReadable[IDF] { } /** - * :: Experimental :: * Model fitted by [[IDF]]. */ -@Experimental +@Since("1.4.0") class IDFModel private[ml] ( - override val uid: String, + @Since("1.4.0") override val uid: String, idfModel: feature.IDFModel) extends Model[IDFModel] with IDFBase with MLWritable { import IDFModel._ /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val idf = udf { vec: Vector => idfModel.transform(vec) } + // TODO: Make the idfModel.transform natively in ml framework to avoid extra conversion. + val idf = udf { vec: Vector => idfModel.transform(OldVectors.fromML(vec)).asML } dataset.withColumn($(outputCol), idf(col($(inputCol)))) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.4.1") override def copy(extra: ParamMap): IDFModel = { val copied = new IDFModel(uid, idfModel) copyValues(copied, extra).setParent(parent) } /** Returns the IDF vector. */ - @Since("1.6.0") - def idf: Vector = idfModel.idf + @Since("2.0.0") + def idf: Vector = idfModel.idf.asML @Since("1.6.0") override def write: MLWriter = new IDFModelWriter(this) @@ -151,7 +165,7 @@ object IDFModel extends MLReadable[IDFModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.idf) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -162,11 +176,11 @@ object IDFModel extends MLReadable[IDFModel] { override def load(path: String): IDFModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) + val Row(idf: Vector) = MLUtils.convertVectorColumnsToML(data, "idf") .select("idf") .head() - val idf = data.getAs[Vector](0) - val model = new IDFModel(metadata.uid, new feature.IDFModel(idf)) + val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf))) DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala index 12176757aee3d..cce3ca45ccd8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.ml.linalg.Vector /** * Class that represents an instance of weighted data point with label and features. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9ca34e9ae22f4..96d0bdee9e2b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -20,19 +20,18 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: Experimental :: * Implements the feature interaction transform. This transformer takes in Double and Vector type * columns and outputs a flattened vector of their feature interactions. To handle interaction, * we first one-hot encode any nominal features. Then, a vector of the feature cross-products is @@ -43,8 +42,7 @@ import org.apache.spark.sql.types._ * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. */ @Since("1.6.0") -@Experimental -class Interaction @Since("1.6.0") (override val uid: String) extends Transformer +class Interaction @Since("1.6.0") (@Since("1.6.0") override val uid: String) extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { @Since("1.6.0") @@ -70,6 +68,7 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val inputFeatures = $(inputCols).map(c => dataset.schema(c)) val featureEncoders = getFeatureEncoders(inputFeatures) val featureAttrs = getFeatureAttrs(inputFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala new file mode 100644 index 0000000000000..6cefa7086c881 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.linalg.Vector + +/** + * :: Experimental :: + * + * Class that represents the features and labels of a data point. + * + * @param label Label for this data point. + * @param features List of features for this data point. + */ +@Since("2.0.0") +@Experimental +@BeanInfo +case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features: Vector) { + override def toString: String = { + s"($label,$features)" + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index e9df600c8a991..acabf0b892660 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -21,11 +21,13 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -37,9 +39,7 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) require(!schema.fieldNames.contains($(outputCol)), s"Output column ${$(outputCol)} already exists.") val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) @@ -54,22 +54,27 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H * any sparsity. */ @Experimental -class MaxAbsScaler @Since("2.0.0") (override val uid: String) +@Since("2.0.0") +class MaxAbsScaler @Since("2.0.0") (@Since("2.0.0") override val uid: String) extends Estimator[MaxAbsScalerModel] with MaxAbsScalerParams with DefaultParamsWritable { @Since("2.0.0") def this() = this(Identifiable.randomUID("maxAbsScal")) /** @group setParam */ + @Since("2.0.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("2.0.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") override def fit(dataset: Dataset[_]): MaxAbsScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { + case Row(v: Vector) => OldVectors.fromML(v) + } val summary = Statistics.colStats(input) val minVals = summary.min.toArray val maxVals = summary.max.toArray @@ -79,17 +84,19 @@ class MaxAbsScaler @Since("2.0.0") (override val uid: String) copyValues(new MaxAbsScalerModel(uid, Vectors.dense(maxAbs)).setParent(this)) } + @Since("2.0.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("2.0.0") override def copy(extra: ParamMap): MaxAbsScaler = defaultCopy(extra) } -@Since("1.6.0") +@Since("2.0.0") object MaxAbsScaler extends DefaultParamsReadable[MaxAbsScaler] { - @Since("1.6.0") + @Since("2.0.0") override def load(path: String): MaxAbsScaler = super.load(path) } @@ -99,17 +106,20 @@ object MaxAbsScaler extends DefaultParamsReadable[MaxAbsScaler] { * */ @Experimental +@Since("2.0.0") class MaxAbsScalerModel private[ml] ( - override val uid: String, - val maxAbs: Vector) + @Since("2.0.0") override val uid: String, + @Since("2.0.0") val maxAbs: Vector) extends Model[MaxAbsScalerModel] with MaxAbsScalerParams with MLWritable { import MaxAbsScalerModel._ /** @group setParam */ + @Since("2.0.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("2.0.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") @@ -118,16 +128,18 @@ class MaxAbsScalerModel private[ml] ( // TODO: this looks hack, we may have to handle sparse and dense vectors separately. val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) val reScale = udf { (vector: Vector) => - val brz = vector.toBreeze / maxAbsUnzero.toBreeze + val brz = vector.asBreeze / maxAbsUnzero.asBreeze Vectors.fromBreeze(brz) } dataset.withColumn($(outputCol), reScale(col($(inputCol)))) } + @Since("2.0.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("2.0.0") override def copy(extra: ParamMap): MaxAbsScalerModel = { val copied = new MaxAbsScalerModel(uid, maxAbs) copyValues(copied, extra).setParent(parent) @@ -137,7 +149,7 @@ class MaxAbsScalerModel private[ml] ( override def write: MLWriter = new MaxAbsScalerModelWriter(this) } -@Since("1.6.0") +@Since("2.0.0") object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { private[MaxAbsScalerModel] @@ -149,7 +161,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = new Data(instance.maxAbs) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -160,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { override def load(path: String): MaxAbsScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(maxAbs: Vector) = sqlContext.read.parquet(dataPath) + val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath) .select("maxAbs") .head() val model = new MaxAbsScalerModel(metadata.uid, maxAbs) @@ -169,9 +181,9 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] { } } - @Since("1.6.0") + @Since("2.0.0") override def read: MLReader[MaxAbsScalerModel] = new MaxAbsScalerModelReader - @Since("1.6.0") + @Since("2.0.0") override def load(path: String): MaxAbsScalerModel = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 125becbb8a5b5..068f11a2a573a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -19,13 +19,17 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -60,9 +64,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) require(!schema.fieldNames.contains($(outputCol)), s"Output column ${$(outputCol)} already exists.") val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) @@ -72,49 +74,57 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H } /** - * :: Experimental :: * Rescale each feature individually to a common range [min, max] linearly using column summary * statistics, which is also known as min-max normalization or Rescaling. The rescaled value for * feature E is calculated as, * - * Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min + * `Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min` * - * For the case E_{max} == E_{min}, Rescaled(e_i) = 0.5 * (max + min) + * For the case `E_{max} == E_{min}`, `Rescaled(e_i) = 0.5 * (max + min)`. * Note that since zero values will probably be transformed to non-zero values, output of the * transformer will be DenseVector even for sparse input. */ -@Experimental -class MinMaxScaler(override val uid: String) +@Since("1.5.0") +class MinMaxScaler @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("minMaxScal")) setDefault(min -> 0.0, max -> 1.0) /** @group setParam */ + @Since("1.5.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.5.0") def setMin(value: Double): this.type = set(min, value) /** @group setParam */ + @Since("1.5.0") def setMax(value: Double): this.type = set(max, value) @Since("2.0.0") override def fit(dataset: Dataset[_]): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { + case Row(v: Vector) => OldVectors.fromML(v) + } val summary = Statistics.colStats(input) copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) } @@ -126,7 +136,6 @@ object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { } /** - * :: Experimental :: * Model fitted by [[MinMaxScaler]]. * * @param originalMin min value for each original column during fitting @@ -134,30 +143,35 @@ object MinMaxScaler extends DefaultParamsReadable[MinMaxScaler] { * * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). */ -@Experimental +@Since("1.5.0") class MinMaxScalerModel private[ml] ( - override val uid: String, - val originalMin: Vector, - val originalMax: Vector) + @Since("1.5.0") override val uid: String, + @Since("2.0.0") val originalMin: Vector, + @Since("2.0.0") val originalMax: Vector) extends Model[MinMaxScalerModel] with MinMaxScalerParams with MLWritable { import MinMaxScalerModel._ /** @group setParam */ + @Since("1.5.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.5.0") def setMin(value: Double): this.type = set(min, value) /** @group setParam */ + @Since("1.5.0") def setMax(value: Double): this.type = set(max, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray + transformSchema(dataset.schema, logging = true) + val originalRange = (originalMax.asBreeze - originalMin.asBreeze).toArray val minArray = originalMin.toArray val reScale = udf { (vector: Vector) => @@ -178,10 +192,12 @@ class MinMaxScalerModel private[ml] ( dataset.withColumn($(outputCol), reScale(col($(inputCol)))) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): MinMaxScalerModel = { val copied = new MinMaxScalerModel(uid, originalMin, originalMax) copyValues(copied, extra).setParent(parent) @@ -203,7 +219,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = new Data(instance.originalMin, instance.originalMax) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -214,9 +230,11 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { override def load(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath) - .select("originalMin", "originalMax") - .head() + val data = sparkSession.read.parquet(dataPath) + val Row(originalMin: Vector, originalMax: Vector) = + MLUtils.convertVectorColumnsToML(data, "originalMin", "originalMax") + .select("originalMin", "originalMax") + .head() val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax) DefaultParamsReader.getAndSetParams(model, metadata) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index f8bc7e3f0c031..4463aea0097e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** - * :: Experimental :: * A feature transformer that converts the input array of strings into an array of n-grams. Null * values in the input array are ignored. * It returns an array of n-grams where each n-gram is represented by a space-separated string of @@ -34,10 +33,11 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} * When the input array length is less than n (number of elements per n-gram), no n-grams are * returned. */ -@Experimental -class NGram(override val uid: String) +@Since("1.5.0") +class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends UnaryTransformer[Seq[String], Seq[String], NGram] with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("ngram")) /** @@ -45,13 +45,16 @@ class NGram(override val uid: String) * Default: 2, bigram features * @group param */ + @Since("1.5.0") val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)", ParamValidators.gtEq(1)) /** @group setParam */ + @Since("1.5.0") def setN(value: Int): this.type = set(n, value) /** @group getParam */ + @Since("1.5.0") def getN: Int = $(n) setDefault(n -> 2) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index a603b3f833202..eb0690058013f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,22 +17,23 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.sql.types.DataType /** - * :: Experimental :: * Normalize a vector to have unit norm using the given p-norm. */ -@Experimental -class Normalizer(override val uid: String) +@Since("1.4.0") +class Normalizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("normalizer")) /** @@ -40,19 +41,22 @@ class Normalizer(override val uid: String) * (default: p = 2) * @group param */ + @Since("1.4.0") val p = new DoubleParam(this, "p", "the p norm value", ParamValidators.gtEq(1)) setDefault(p -> 2.0) /** @group getParam */ + @Since("1.4.0") def getP: Double = $(p) /** @group setParam */ + @Since("1.4.0") def setP(value: Double): this.type = set(p, value) override protected def createTransformFunc: Vector => Vector = { val normalizer = new feature.Normalizer($(p)) - normalizer.transform + vector => normalizer.transform(OldVectors.fromML(vector)).asML } override protected def outputDataType: DataType = new VectorUDT() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 99357793dbaeb..8b04b5de6fd2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,19 +17,18 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} /** - * :: Experimental :: * A one-hot encoder that maps a column of category indices to a column of binary vectors, with * at most a single one-value per row that indicates the input category index. * For example with 5 categories, an input value of 2.0 would map to an output vector of @@ -42,29 +41,39 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} * * @see [[StringIndexer]] for converting categorical values into category indices */ -@Experimental -class OneHotEncoder(override val uid: String) extends Transformer +@Since("1.4.0") +class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("oneHot")) /** * Whether to drop the last category in the encoded vector (default: true) * @group param */ + @Since("1.4.0") final val dropLast: BooleanParam = new BooleanParam(this, "dropLast", "whether to drop the last category") setDefault(dropLast -> true) + /** @group getParam */ + @Since("2.0.0") + def getDropLast: Boolean = $(dropLast) + /** @group setParam */ + @Since("1.4.0") def setDropLast(value: Boolean): this.type = set(dropLast, value) /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val outputColName = $(outputCol) @@ -168,6 +177,7 @@ class OneHotEncoder(override val uid: String) extends Transformer dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) } + @Since("1.4.1") override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 9cf722e121697..6b913480fdc28 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -19,13 +19,18 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml._ +import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.{DenseMatrix => OldDenseMatrix, DenseVector => OldDenseVector, + Matrices => OldMatrices, Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.linalg.MatrixImplicits._ +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -44,25 +49,39 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC /** @group getParam */ def getK: Int = $(k) + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + } /** - * :: Experimental :: - * PCA trains a model to project vectors to a low-dimensional space using PCA. + * PCA trains a model to project vectors to a lower dimensional space of the top [[PCA!.k]] + * principal components. */ -@Experimental -class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams - with DefaultParamsWritable { +@Since("1.5.0") +class PCA @Since("1.5.0") ( + @Since("1.5.0") override val uid: String) + extends Estimator[PCAModel] with PCAParams with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("pca")) /** @group setParam */ + @Since("1.5.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.5.0") def setK(value: Int): this.type = set(k, value) /** @@ -71,22 +90,20 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams @Since("2.0.0") override def fit(dataset: Dataset[_]): PCAModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v} + val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { + case Row(v: Vector) => OldVectors.fromML(v) + } val pca = new feature.PCA(k = $(k)) val pcaModel = pca.fit(input) copyValues(new PCAModel(uid, pcaModel.pc, pcaModel.explainedVariance).setParent(this)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): PCA = defaultCopy(extra) } @@ -98,26 +115,27 @@ object PCA extends DefaultParamsReadable[PCA] { } /** - * :: Experimental :: - * Model fitted by [[PCA]]. + * Model fitted by [[PCA]]. Transforms vectors to a lower dimensional space. * * @param pc A principal components Matrix. Each column is one principal component. * @param explainedVariance A vector of proportions of variance explained by * each principal component. */ -@Experimental +@Since("1.5.0") class PCAModel private[ml] ( - override val uid: String, - val pc: DenseMatrix, - val explainedVariance: DenseVector) + @Since("1.5.0") override val uid: String, + @Since("2.0.0") val pc: DenseMatrix, + @Since("2.0.0") val explainedVariance: DenseVector) extends Model[PCAModel] with PCAParams with MLWritable { import PCAModel._ /** @group setParam */ + @Since("1.5.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @@ -128,21 +146,23 @@ class PCAModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val pcaModel = new feature.PCAModel($(k), pc, explainedVariance) - val pcaOp = udf { pcaModel.transform _ } + val pcaModel = new feature.PCAModel($(k), + OldMatrices.fromML(pc).asInstanceOf[OldDenseMatrix], + OldVectors.fromML(explainedVariance).asInstanceOf[OldDenseVector]) + + // TODO: Make the transformer natively in ml framework to avoid extra conversion. + val transformer: Vector => Vector = v => pcaModel.transform(OldVectors.fromML(v)).asML + + val pcaOp = udf(transformer) dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } + @Since("1.5.0") override def copy(extra: ParamMap): PCAModel = { val copied = new PCAModel(uid, pc, explainedVariance) copyValues(copied, extra).setParent(parent) @@ -163,7 +183,7 @@ object PCAModel extends MLReadable[PCAModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.pc, instance.explainedVariance) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -183,24 +203,22 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - // explainedVariance field is not present in Spark <= 1.6 - val versionRegex = "([0-9]+)\\.([0-9]+).*".r - val hasExplainedVariance = metadata.sparkVersion match { - case versionRegex(major, minor) => - (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)) - case _ => false - } + val versionRegex = "([0-9]+)\\.(.+)".r + val versionRegex(major, _) = metadata.sparkVersion val dataPath = new Path(path, "data").toString - val model = if (hasExplainedVariance) { + val model = if (major.toInt >= 2) { val Row(pc: DenseMatrix, explainedVariance: DenseVector) = - sqlContext.read.parquet(dataPath) + sparkSession.read.parquet(dataPath) .select("pc", "explainedVariance") .head() new PCAModel(metadata.uid, pc, explainedVariance) } else { - val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath).select("pc").head() - new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) + // pc field is the old matrix format in Spark <= 1.6 + // explainedVariance field is not present in Spark <= 1.6 + val Row(pc: OldDenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head() + new PCAModel(metadata.uid, pc.asML, + Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) } DefaultParamsReader.getAndSetParams(model, metadata) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 0a9b9719c15d3..6e872c1f2cada 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,25 +19,27 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.commons.math3.util.CombinatoricsUtils + +import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType /** - * :: Experimental :: * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an * expansion of a product of sums expresses it as a sum of products by using the fact that * multiplication distributes over addition". Take a 2-variable feature vector as an example: * `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`. */ -@Experimental -class PolynomialExpansion(override val uid: String) +@Since("1.4.0") +class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("poly")) /** @@ -45,15 +47,18 @@ class PolynomialExpansion(override val uid: String) * Default: 2 * @group param */ + @Since("1.4.0") val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)", ParamValidators.gtEq(1)) setDefault(degree -> 2) /** @group getParam */ + @Since("1.4.0") def getDegree: Int = $(degree) /** @group setParam */ + @Since("1.4.0") def setDegree(value: Int): this.type = set(degree, value) override protected def createTransformFunc: Vector => Vector = { v => @@ -62,6 +67,7 @@ class PolynomialExpansion(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() + @Since("1.4.1") override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) } @@ -80,12 +86,12 @@ class PolynomialExpansion(override val uid: String) @Since("1.6.0") object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] { - private def choose(n: Int, k: Int): Int = { - Range(n, n - k, -1).product / Range(k, 1, -1).product + private def getPolySize(numFeatures: Int, degree: Int): Int = { + val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree) + require(n <= Integer.MAX_VALUE) + n.toInt } - private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree) - private def expandDense( values: Array[Double], lastIdx: Int, diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 5a6daa06ef546..e09800877c694 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.{DoubleType, StructType} @@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} * Params for [[QuantileDiscretizer]]. */ private[feature] trait QuantileDiscretizerBase extends Params - with HasInputCol with HasOutputCol with HasSeed { + with HasInputCol with HasOutputCol { /** * Number of buckets (quantiles, or categories) into which data points are grouped. Must @@ -50,13 +50,13 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * Relative error (see documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description) - * Must be a number in [0, 1]. + * Must be in the range [0, 1]. * default: 0.001 * @group param */ val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " + - "for approxQuantile", - ParamValidators.inRange(0.0, 1.0)) + "for the approximate quantile algorithm used to generate buckets. " + + "Must be in the range [0, 1].", ParamValidators.inRange(0.0, 1.0)) setDefault(relativeError -> 0.001) /** @group getParam */ @@ -64,35 +64,40 @@ private[feature] trait QuantileDiscretizerBase extends Params } /** - * :: Experimental :: * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned - * categorical features. The bin ranges are chosen by taking a sample of the data and dividing it - * into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, + * categorical features. The number of bins can be set using the `numBuckets` parameter. + * The bin ranges are chosen using an approximate algorithm (see the documentation for + * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] + * for a detailed description). The precision of the approximation can be controlled with the + * `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`, * covering all real values. */ -@Experimental -final class QuantileDiscretizer(override val uid: String) +@Since("1.6.0") +final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val uid: String) extends Estimator[Bucketizer] with QuantileDiscretizerBase with DefaultParamsWritable { + @Since("1.6.0") def this() = this(Identifiable.randomUID("quantileDiscretizer")) /** @group setParam */ + @Since("2.0.0") def setRelativeError(value: Double): this.type = set(relativeError, value) /** @group setParam */ + @Since("1.6.0") def setNumBuckets(value: Int): this.type = set(numBuckets, value) /** @group setParam */ + @Since("1.6.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** @group setParam */ - def setSeed(value: Long): this.type = set(seed, value) - + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(inputCol)) val inputFields = schema.fields require(inputFields.forall(_.name != $(outputCol)), s"Output column ${$(outputCol)} already exists.") @@ -103,15 +108,22 @@ final class QuantileDiscretizer(override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): Bucketizer = { + transformSchema(dataset.schema, logging = true) val splits = dataset.stat.approxQuantile($(inputCol), (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) splits(0) = Double.NegativeInfinity splits(splits.length - 1) = Double.PositiveInfinity - val bucketizer = new Bucketizer(uid).setSplits(splits) + val distinctSplits = splits.distinct + if (splits.length != distinctSplits.length) { + log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + + s" buckets as a result.") + } + val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) copyValues(bucketizer.setParent(this)) } + @Since("1.6.0") override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5219680be2dc8..2ee899bcca564 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -25,10 +25,10 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types._ @@ -70,15 +70,18 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { * will be created from the specified response variable in the formula. */ @Experimental -class RFormula(override val uid: String) +@Since("1.5.0") +class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("rFormula")) /** * R formula parameter. The formula is provided in string form. * @group param */ + @Since("1.5.0") val formula: Param[String] = new Param(this, "formula", "R model formula") /** @@ -86,15 +89,19 @@ class RFormula(override val uid: String) * @group setParam * @param value an R formula in string form (e.g. "y ~ x + z") */ + @Since("1.5.0") def setFormula(value: String): this.type = set(formula, value) /** @group getParam */ + @Since("1.5.0") def getFormula: String = $(formula) /** @group setParam */ + @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) /** @group setParam */ + @Since("1.5.0") def setLabelCol(value: String): this.type = set(labelCol, value) /** Whether the formula specifies fitting an intercept. */ @@ -105,6 +112,7 @@ class RFormula(override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): RFormulaModel = { + transformSchema(dataset.schema, logging = true) require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) @@ -170,6 +178,7 @@ class RFormula(override val uid: String) copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) } + @Since("1.5.0") // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { if (hasLabelCol(schema)) { @@ -180,9 +189,11 @@ class RFormula(override val uid: String) } } + @Since("1.5.0") override def copy(extra: ParamMap): RFormula = defaultCopy(extra) - override def toString: String = s"RFormula(${get(formula)}) (uid=$uid)" + @Since("2.0.0") + override def toString: String = s"RFormula(${get(formula).getOrElse("")}) (uid=$uid)" } @Since("2.0.0") @@ -194,13 +205,16 @@ object RFormula extends DefaultParamsReadable[RFormula] { /** * :: Experimental :: - * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. + * Model fitted by [[RFormula]]. Fitting is required to determine the factor levels of + * formula terms. + * * @param resolvedFormula the fitted R formula. * @param pipelineModel the fitted feature model, including factor to index mappings. */ @Experimental +@Since("1.5.0") class RFormulaModel private[feature]( - override val uid: String, + @Since("1.5.0") override val uid: String, private[ml] val resolvedFormula: ResolvedRFormula, private[ml] val pipelineModel: PipelineModel) extends Model[RFormulaModel] with RFormulaBase with MLWritable { @@ -211,6 +225,7 @@ class RFormulaModel private[feature]( transformLabel(pipelineModel.transform(dataset)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) @@ -229,9 +244,11 @@ class RFormulaModel private[feature]( } } + @Since("1.5.0") override def copy(extra: ParamMap): RFormulaModel = copyValues( new RFormulaModel(uid, resolvedFormula, pipelineModel)) + @Since("2.0.0") override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" private def transformLabel(dataset: Dataset[_]): DataFrame = { @@ -256,8 +273,8 @@ class RFormulaModel private[feature]( val columnNames = schema.map(_.name) require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( - !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, - "Label column already exists and is not of type DoubleType.") + !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], + "Label column already exists and is not of type NumericType.") } @Since("2.0.0") @@ -281,7 +298,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: resolvedFormula val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(instance.resolvedFormula)) + sparkSession.createDataFrame(Seq(instance.resolvedFormula)) .repartition(1).write.parquet(dataPath) // Save pipeline model val pmPath = new Path(path, "pipelineModel").toString @@ -298,7 +315,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() + val data = sparkSession.read.parquet(dataPath).select("label", "terms", "hasIntercept").head() val label = data.getString(0) val terms = data.getAs[Seq[Seq[String]]](1) val hasIntercept = data.getBoolean(2) @@ -356,7 +373,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { // Save model data: columnsToPrune val data = Data(instance.columnsToPrune.toSeq) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -369,7 +386,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("columnsToPrune").head() + val data = sparkSession.read.parquet(dataPath).select("columnsToPrune").head() val columnsToPrune = data.getAs[Seq[String]](0).toSet val pruner = new ColumnPruner(metadata.uid, columnsToPrune) @@ -447,7 +464,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite // Save model data: vectorCol, prefixesToRewrite val data = Data(instance.vectorCol, instance.prefixesToRewrite) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -460,7 +477,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() + val data = sparkSession.read.parquet(dataPath).select("vectorCol", "prefixesToRewrite").head() val vectorCol = data.getString(0) val prefixesToRewrite = data.getAs[Map[String, String]](1) val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index cf52710ab8cbe..2dd565a782719 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import scala.collection.mutable import scala.util.parsing.combinator.RegexParsers -import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.sql.types._ /** @@ -126,7 +126,19 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { * @param hasIntercept whether the formula specifies fitting with an intercept. */ private[ml] case class ResolvedRFormula( - label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) + label: String, terms: Seq[Seq[String]], hasIntercept: Boolean) { + + override def toString: String = { + val ts = terms.map { + case t if t.length > 1 => + s"${t.mkString("{", ",", "}")}" + case t => + t.mkString + } + val termStr = ts.mkString("[", ",", "]") + s"ResolvedRFormula(label=$label, terms=$termStr, hasIntercept=$hasIntercept)" + } +} /** * R formula terms. See the R formula docs here for more information: diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 400435d7a9370..b25fff973c441 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -17,16 +17,14 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.types.StructType /** - * :: Experimental :: * Implements the transformations which are defined by SQL statement. * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__ ...' * where '__THIS__' represents the underlying table of the input dataset. @@ -38,9 +36,8 @@ import org.apache.spark.sql.types.StructType * - SELECT a, SQRT(b) AS b_sqrt FROM __THIS__ where a > 5 * - SELECT a, b, SUM(c) AS c_sum FROM __THIS__ GROUP BY a, b */ -@Experimental @Since("1.6.0") -class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer +class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String) extends Transformer with DefaultParamsWritable { @Since("1.6.0") @@ -48,6 +45,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor /** * SQL statement parameter. The statement is provided in string form. + * * @group param */ @Since("1.6.0") @@ -65,23 +63,25 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val tableName = Identifiable.randomUID(uid) - dataset.registerTempTable(tableName) + dataset.createOrReplaceTempView(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - dataset.sparkSession.sql(realStatement) + val result = dataset.sparkSession.sql(realStatement) + dataset.sparkSession.catalog.dropTempView(tableName) + result } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - val sc = SparkContext.getOrCreate() - val sqlContext = SQLContext.getOrCreate(sc) - val dummyRDD = sc.parallelize(Seq(Row.empty)) - val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) + val spark = SparkSession.builder().getOrCreate() + val dummyRDD = spark.sparkContext.parallelize(Seq(Row.empty)) + val dummyDF = spark.createDataFrame(dummyRDD, schema) val tableName = Identifiable.randomUID(uid) val realStatement = $(statement).replace(tableIdentifier, tableName) - dummyDF.registerTempTable(tableName) - val outputSchema = sqlContext.sql(realStatement).schema - sqlContext.dropTempTable(tableName) + dummyDF.createOrReplaceTempView(tableName) + val outputSchema = spark.sql(realStatement).schema + spark.catalog.dropTempView(tableName) outputSchema } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 626e97efb47c6..2494cf51a2bd6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -19,13 +19,17 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml._ +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -59,11 +63,19 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with /** @group getParam */ def getWithStd: Boolean = $(withStd) + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + setDefault(withMean -> false, withStd -> true) } /** - * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. * @@ -72,43 +84,47 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with * corrected sample standard deviation]], * which is computed as the square root of the unbiased sample variance. */ -@Experimental -class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] - with StandardScalerParams with DefaultParamsWritable { +@Since("1.2.0") +class StandardScaler @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) + extends Estimator[StandardScalerModel] with StandardScalerParams with DefaultParamsWritable { + @Since("1.2.0") def this() = this(Identifiable.randomUID("stdScal")) /** @group setParam */ + @Since("1.2.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.2.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.4.0") def setWithMean(value: Boolean): this.type = set(withMean, value) /** @group setParam */ + @Since("1.4.0") def setWithStd(value: Boolean): this.type = set(withStd, value) @Since("2.0.0") override def fit(dataset: Dataset[_]): StandardScalerModel = { transformSchema(dataset.schema, logging = true) - val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } + val input: RDD[OldVector] = dataset.select($(inputCol)).rdd.map { + case Row(v: Vector) => OldVectors.fromML(v) + } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } + @Since("1.4.1") override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) } @@ -120,45 +136,46 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] { } /** - * :: Experimental :: * Model fitted by [[StandardScaler]]. * * @param std Standard deviation of the StandardScalerModel * @param mean Mean of the StandardScalerModel */ -@Experimental +@Since("1.2.0") class StandardScalerModel private[ml] ( - override val uid: String, - val std: Vector, - val mean: Vector) + @Since("1.4.0") override val uid: String, + @Since("2.0.0") val std: Vector, + @Since("2.0.0") val mean: Vector) extends Model[StandardScalerModel] with StandardScalerParams with MLWritable { import StandardScalerModel._ /** @group setParam */ + @Since("1.2.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.2.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) - val scale = udf { scaler.transform _ } + + // TODO: Make the transformer natively in ml framework to avoid extra conversion. + val transformer: Vector => Vector = v => scaler.transform(OldVectors.fromML(v)).asML + + val scale = udf(transformer) dataset.withColumn($(outputCol), scale(col($(inputCol)))) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } + @Since("1.4.1") override def copy(extra: ParamMap): StandardScalerModel = { val copied = new StandardScalerModel(uid, std, mean) copyValues(copied, extra).setParent(parent) @@ -180,7 +197,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.std, instance.mean) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -191,7 +208,8 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) + val Row(std: Vector, mean: Vector) = MLUtils.convertVectorColumnsToML(data, "std", "mean") .select("std", "mean") .head() val model = new StandardScalerModel(metadata.uid, std, mean) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala old mode 100644 new mode 100755 index b96bc48566fa7..666070037cdd8 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} @@ -27,124 +27,83 @@ import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StringType, StructType} /** - * stop words list - */ -private[spark] object StopWords { - - /** - * Use the same default stopwords list as scikit-learn. - * The original list can be found from "Glasgow Information Retrieval Group" - * [[http://ir.dcs.gla.ac.uk/resources/linguistic_utils/stop_words]] - */ - val English = Array( "a", "about", "above", "across", "after", "afterwards", "again", - "against", "all", "almost", "alone", "along", "already", "also", "although", "always", - "am", "among", "amongst", "amoungst", "amount", "an", "and", "another", - "any", "anyhow", "anyone", "anything", "anyway", "anywhere", "are", - "around", "as", "at", "back", "be", "became", "because", "become", - "becomes", "becoming", "been", "before", "beforehand", "behind", "being", - "below", "beside", "besides", "between", "beyond", "bill", "both", - "bottom", "but", "by", "call", "can", "cannot", "cant", "co", "con", - "could", "couldnt", "cry", "de", "describe", "detail", "do", "done", - "down", "due", "during", "each", "eg", "eight", "either", "eleven", "else", - "elsewhere", "empty", "enough", "etc", "even", "ever", "every", "everyone", - "everything", "everywhere", "except", "few", "fifteen", "fify", "fill", - "find", "fire", "first", "five", "for", "former", "formerly", "forty", - "found", "four", "from", "front", "full", "further", "get", "give", "go", - "had", "has", "hasnt", "have", "he", "hence", "her", "here", "hereafter", - "hereby", "herein", "hereupon", "hers", "herself", "him", "himself", "his", - "how", "however", "hundred", "i", "ie", "if", "in", "inc", "indeed", - "interest", "into", "is", "it", "its", "itself", "keep", "last", "latter", - "latterly", "least", "less", "ltd", "made", "many", "may", "me", - "meanwhile", "might", "mill", "mine", "more", "moreover", "most", "mostly", - "move", "much", "must", "my", "myself", "name", "namely", "neither", - "never", "nevertheless", "next", "nine", "no", "nobody", "none", "noone", - "nor", "not", "nothing", "now", "nowhere", "of", "off", "often", "on", - "once", "one", "only", "onto", "or", "other", "others", "otherwise", "our", - "ours", "ourselves", "out", "over", "own", "part", "per", "perhaps", - "please", "put", "rather", "re", "same", "see", "seem", "seemed", - "seeming", "seems", "serious", "several", "she", "should", "show", "side", - "since", "sincere", "six", "sixty", "so", "some", "somehow", "someone", - "something", "sometime", "sometimes", "somewhere", "still", "such", - "system", "take", "ten", "than", "that", "the", "their", "them", - "themselves", "then", "thence", "there", "thereafter", "thereby", - "therefore", "therein", "thereupon", "these", "they", "thick", "thin", - "third", "this", "those", "though", "three", "through", "throughout", - "thru", "thus", "to", "together", "too", "top", "toward", "towards", - "twelve", "twenty", "two", "un", "under", "until", "up", "upon", "us", - "very", "via", "was", "we", "well", "were", "what", "whatever", "when", - "whence", "whenever", "where", "whereafter", "whereas", "whereby", - "wherein", "whereupon", "wherever", "whether", "which", "while", "whither", - "who", "whoever", "whole", "whom", "whose", "why", "will", "with", - "within", "without", "would", "yet", "you", "your", "yours", "yourself", "yourselves") -} - -/** - * :: Experimental :: * A feature transformer that filters out stop words from input. * Note: null values from input array are preserved unless adding null to stopWords explicitly. * @see [[http://en.wikipedia.org/wiki/Stop_words]] */ -@Experimental -class StopWordsRemover(override val uid: String) +@Since("1.5.0") +class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("stopWords")) /** @group setParam */ + @Since("1.5.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** - * the stop words set to be filtered out - * Default: [[StopWords.English]] + * The words to be filtered out. + * Default: English stop words + * @see [[StopWordsRemover.loadDefaultStopWords()]] * @group param */ - val stopWords: StringArrayParam = new StringArrayParam(this, "stopWords", "stop words") + @Since("1.5.0") + val stopWords: StringArrayParam = + new StringArrayParam(this, "stopWords", "the words to be filtered out") /** @group setParam */ + @Since("1.5.0") def setStopWords(value: Array[String]): this.type = set(stopWords, value) /** @group getParam */ + @Since("1.5.0") def getStopWords: Array[String] = $(stopWords) /** - * whether to do a case sensitive comparison over the stop words + * Whether to do a case sensitive comparison over the stop words. * Default: false * @group param */ + @Since("1.5.0") val caseSensitive: BooleanParam = new BooleanParam(this, "caseSensitive", - "whether to do case-sensitive comparison during filtering") + "whether to do a case-sensitive comparison over the stop words") /** @group setParam */ + @Since("1.5.0") def setCaseSensitive(value: Boolean): this.type = set(caseSensitive, value) /** @group getParam */ + @Since("1.5.0") def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWords.English, caseSensitive -> false) + setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val t = if ($(caseSensitive)) { - val stopWordsSet = $(stopWords).toSet - udf { terms: Seq[String] => - terms.filter(s => !stopWordsSet.contains(s)) - } - } else { - val toLower = (s: String) => if (s != null) s.toLowerCase else s - val lowerStopWords = $(stopWords).map(toLower(_)).toSet - udf { terms: Seq[String] => - terms.filter(s => !lowerStopWords.contains(toLower(s))) - } + val stopWordsSet = $(stopWords).toSet + udf { terms: Seq[String] => + terms.filter(s => !stopWordsSet.contains(s)) + } + } else { + // TODO: support user locale (SPARK-15064) + val toLower = (s: String) => if (s != null) s.toLowerCase else s + val lowerStopWords = $(stopWords).map(toLower(_)).toSet + udf { terms: Seq[String] => + terms.filter(s => !lowerStopWords.contains(toLower(s))) + } } - val metadata = outputSchema($(outputCol)).metadata dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), @@ -152,12 +111,31 @@ class StopWordsRemover(override val uid: String) SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } + @Since("1.5.0") override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) } @Since("1.6.0") object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { + private[feature] + val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german", + "hungarian", "italian", "norwegian", "portuguese", "russian", "spanish", "swedish", "turkish") + @Since("1.6.0") override def load(path: String): StopWordsRemover = super.load(path) + + /** + * Loads the default stop words for the given language. + * Supported languages: danish, dutch, english, finnish, french, german, hungarian, + * italian, norwegian, portuguese, russian, spanish, swedish, turkish + * @see [[http://anoncvs.postgresql.org/cvsweb.cgi/pgsql/src/backend/snowball/stopwords/]] + */ + @Since("2.0.0") + def loadDefaultStopWords(language: String): Array[String] = { + require(supportedLanguages.contains(language), + s"$language is not in the supported language list: ${supportedLanguages.mkString(", ")}.") + val is = getClass.getResourceAsStream(s"/org/apache/spark/ml/feature/stopwords/$language.txt") + scala.io.Source.fromInputStream(is)(scala.io.Codec.UTF8).getLines().toArray + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index cc0571fd7e39b..80fe46796f807 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path import org.apache.spark.SparkException -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ @@ -55,7 +55,6 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha } /** - * :: Experimental :: * A label indexer that maps a string column of labels to an ML column of label indices. * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. @@ -63,25 +62,30 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * * @see [[IndexToString]] for the inverse transformation */ -@Experimental -class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] +@Since("1.4.0") +class StringIndexer @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) extends Estimator[StringIndexerModel] with StringIndexerBase with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("strIdx")) /** @group setParam */ + @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) setDefault(handleInvalid, "error") /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) - @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { + transformSchema(dataset.schema, logging = true) val counts = dataset.select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) @@ -90,10 +94,12 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod copyValues(new StringIndexerModel(uid, labels).setParent(this)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.4.1") override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) } @@ -105,7 +111,6 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { } /** - * :: Experimental :: * Model fitted by [[StringIndexer]]. * * NOTE: During transformation, if the input column does not exist, @@ -114,14 +119,15 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { * * @param labels Ordered list of labels, corresponding to indices to be assigned. */ -@Experimental +@Since("1.4.0") class StringIndexerModel ( - override val uid: String, - val labels: Array[String]) + @Since("1.4.0") override val uid: String, + @Since("1.5.0") val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase with MLWritable { import StringIndexerModel._ + @Since("1.5.0") def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) private val labelToIndex: OpenHashMap[String, Double] = { @@ -136,13 +142,16 @@ class StringIndexerModel ( } /** @group setParam */ + @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) setDefault(handleInvalid, "error") /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") @@ -152,7 +161,7 @@ class StringIndexerModel ( "Skip StringIndexerModel.") return dataset.toDF } - validateAndTransformSchema(dataset.schema) + transformSchema(dataset.schema, logging = true) val indexer = udf { label: String => if (labelToIndex.contains(label)) { @@ -177,6 +186,7 @@ class StringIndexerModel ( indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { if (schema.fieldNames.contains($(inputCol))) { validateAndTransformSchema(schema) @@ -186,6 +196,7 @@ class StringIndexerModel ( } } + @Since("1.4.1") override def copy(extra: ParamMap): StringIndexerModel = { val copied = new StringIndexerModel(uid, labels) copyValues(copied, extra).setParent(parent) @@ -207,7 +218,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.labels) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -218,7 +229,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("labels") .head() val labels = data.getAs[Seq[String]](0).toArray @@ -236,7 +247,6 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { } /** - * :: Experimental :: * A [[Transformer]] that maps a column of indices back to a new column of corresponding * string values. * The index-string mapping is either from the ML attributes of the input column, @@ -244,20 +254,24 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { * * @see [[StringIndexer]] for converting strings into indices */ -@Experimental -class IndexToString private[ml] (override val uid: String) +@Since("1.5.0") +class IndexToString private[ml] (@Since("1.5.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("idxToStr")) /** @group setParam */ + @Since("1.5.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.5.0") def setLabels(value: Array[String]): this.type = set(labels, value) /** @@ -266,13 +280,16 @@ class IndexToString private[ml] (override val uid: String) * Default: Not specified, in which case [[inputCol]] metadata is used for labels. * @group param */ + @Since("1.5.0") final val labels: StringArrayParam = new StringArrayParam(this, "labels", "Optional array of labels specifying index-string mapping." + " If not provided or if empty, then metadata from inputCol is used instead.") /** @group getParam */ + @Since("1.5.0") final def getLabels: Array[String] = $(labels) + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType @@ -289,6 +306,7 @@ class IndexToString private[ml] (override val uid: String) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata val values = if (!isDefined(labels) || $(labels).isEmpty) { @@ -310,6 +328,7 @@ class IndexToString private[ml] (override val uid: String) indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName)) } + @Since("1.5.0") override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 8456a0e915804..45d8fa94a8f8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,22 +17,22 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** - * :: Experimental :: * A tokenizer that converts the input string to lowercase and then splits it by white spaces. * * @see [[RegexTokenizer]] */ -@Experimental -class Tokenizer(override val uid: String) +@Since("1.2.0") +class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable { + @Since("1.2.0") def this() = this(Identifiable.randomUID("tok")) override protected def createTransformFunc: String => Seq[String] = { @@ -45,6 +45,7 @@ class Tokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) + @Since("1.4.1") override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } @@ -56,16 +57,16 @@ object Tokenizer extends DefaultParamsReadable[Tokenizer] { } /** - * :: Experimental :: * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split * the text (default) or repeatedly matching the regex (if `gaps` is false). * Optional parameters also allow filtering tokens using a minimal length. * It returns an array of strings that can be empty. */ -@Experimental -class RegexTokenizer(override val uid: String) +@Since("1.4.0") +class RegexTokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends UnaryTransformer[String, Seq[String], RegexTokenizer] with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("regexTok")) /** @@ -73,13 +74,16 @@ class RegexTokenizer(override val uid: String) * Default: 1, to avoid returning empty strings * @group param */ + @Since("1.4.0") val minTokenLength: IntParam = new IntParam(this, "minTokenLength", "minimum token length (>= 0)", ParamValidators.gtEq(0)) /** @group setParam */ + @Since("1.4.0") def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) /** @group getParam */ + @Since("1.4.0") def getMinTokenLength: Int = $(minTokenLength) /** @@ -87,12 +91,15 @@ class RegexTokenizer(override val uid: String) * Default: true * @group param */ + @Since("1.4.0") val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") /** @group setParam */ + @Since("1.4.0") def setGaps(value: Boolean): this.type = set(gaps, value) /** @group getParam */ + @Since("1.4.0") def getGaps: Boolean = $(gaps) /** @@ -100,12 +107,15 @@ class RegexTokenizer(override val uid: String) * Default: `"\\s+"` * @group param */ + @Since("1.4.0") val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") /** @group setParam */ + @Since("1.4.0") def setPattern(value: String): this.type = set(pattern, value) /** @group getParam */ + @Since("1.4.0") def getPattern: String = $(pattern) /** @@ -113,13 +123,16 @@ class RegexTokenizer(override val uid: String) * Default: true * @group param */ + @Since("1.6.0") final val toLowercase: BooleanParam = new BooleanParam(this, "toLowercase", "whether to convert all characters to lowercase before tokenizing.") /** @group setParam */ + @Since("1.6.0") def setToLowercase(value: Boolean): this.type = set(toLowercase, value) /** @group getParam */ + @Since("1.6.0") def getToLowercase: Boolean = $(toLowercase) setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true) @@ -138,6 +151,7 @@ class RegexTokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) + @Since("1.4.1") override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 4d3e46e488c67..ca900536bc7b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,35 +20,38 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: Experimental :: * A feature transformer that merges multiple columns into a vector column. */ -@Experimental -class VectorAssembler(override val uid: String) +@Since("1.4.0") +class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("vecAssembler")) /** @group setParam */ + @Since("1.4.0") def setInputCols(value: Array[String]): this.type = set(inputCols, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema lazy val first = dataset.toDF.first() @@ -106,6 +109,7 @@ class VectorAssembler(override val uid: String) dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) val outputColName = $(outputCol) @@ -122,6 +126,7 @@ class VectorAssembler(override val uid: String) StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, true)) } + @Since("1.4.1") override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 68b699d569c7d..d1a5c2e82581e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -24,13 +24,13 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StructField, StructType} @@ -59,7 +59,6 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu } /** - * :: Experimental :: * Class for indexing categorical feature columns in a dataset of [[Vector]]. * * This has 2 usage modes: @@ -93,19 +92,24 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Add warning if a categorical feature has only 1 category. * - Add option for allowing unknown categories. */ -@Experimental -class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] - with VectorIndexerParams with DefaultParamsWritable { +@Since("1.4.0") +class VectorIndexer @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) + extends Estimator[VectorIndexerModel] with VectorIndexerParams with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("vecIdx")) /** @group setParam */ + @Since("1.4.0") def setMaxCategories(value: Int): this.type = set(maxCategories, value) /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") @@ -126,6 +130,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod copyValues(model) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). @@ -136,6 +141,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod SchemaUtils.appendColumn(schema, $(outputCol), dataType) } + @Since("1.4.1") override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } @@ -239,8 +245,8 @@ object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { } /** - * :: Experimental :: - * Transform categorical features to use 0-based indices instead of their original values. + * Model fitted by [[VectorIndexer]]. Transform categorical features to use 0-based indices + * instead of their original values. * - Categorical features are mapped to indices. * - Continuous features (columns) are left unchanged. * This also appends metadata to the output column, marking features as Numeric (continuous), @@ -254,16 +260,17 @@ object VectorIndexer extends DefaultParamsReadable[VectorIndexer] { * Values are maps from original features values to 0-based category indices. * If a feature is not in this map, it is treated as continuous. */ -@Experimental +@Since("1.4.0") class VectorIndexerModel private[ml] ( - override val uid: String, - val numFeatures: Int, - val categoryMaps: Map[Int, Map[Double, Int]]) + @Since("1.4.0") override val uid: String, + @Since("1.4.0") val numFeatures: Int, + @Since("1.4.0") val categoryMaps: Map[Int, Map[Double, Int]]) extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable { import VectorIndexerModel._ /** Java-friendly version of [[categoryMaps]] */ + @Since("1.4.0") def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { categoryMaps.mapValues(_.asJava).asJava.asInstanceOf[JMap[JInt, JMap[JDouble, JInt]]] } @@ -341,9 +348,11 @@ class VectorIndexerModel private[ml] ( } /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") @@ -355,6 +364,7 @@ class VectorIndexerModel private[ml] ( dataset.withColumn($(outputCol), newCol, newField.metadata) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { val dataType = new VectorUDT require(isDefined(inputCol), @@ -414,6 +424,7 @@ class VectorIndexerModel private[ml] ( newAttributeGroup.toStructField() } + @Since("1.4.1") override def copy(extra: ParamMap): VectorIndexerModel = { val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) copyValues(copied, extra).setParent(parent) @@ -435,7 +446,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.numFeatures, instance.categoryMaps) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -446,7 +457,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] { override def load(path: String): VectorIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("numFeatures", "categoryMaps") .head() val numFeatures = data.getAs[Int](0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 7a9468b87b73e..966ccb85d0e0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -17,19 +17,18 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} +import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StructType /** - * :: Experimental :: * This class takes a feature vector and outputs a new feature vector with a subarray of the * original features. * @@ -40,10 +39,11 @@ import org.apache.spark.sql.types.StructType * The output vector will order features with the selected indices first (in the order given), * followed by the selected names (in the order given). */ -@Experimental -final class VectorSlicer(override val uid: String) +@Since("1.5.0") +final class VectorSlicer @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + @Since("1.5.0") def this() = this(Identifiable.randomUID("vectorSlicer")) /** @@ -52,6 +52,7 @@ final class VectorSlicer(override val uid: String) * Default: Empty array * @group param */ + @Since("1.5.0") val indices = new IntArrayParam(this, "indices", "An array of indices to select features from a vector column." + " There can be no overlap with names.", VectorSlicer.validIndices) @@ -59,9 +60,11 @@ final class VectorSlicer(override val uid: String) setDefault(indices -> Array.empty[Int]) /** @group getParam */ + @Since("1.5.0") def getIndices: Array[Int] = $(indices) /** @group setParam */ + @Since("1.5.0") def setIndices(value: Array[Int]): this.type = set(indices, value) /** @@ -71,6 +74,7 @@ final class VectorSlicer(override val uid: String) * Default: Empty Array * @group param */ + @Since("1.5.0") val names = new StringArrayParam(this, "names", "An array of feature names to select features from a vector column." + " There can be no overlap with indices.", VectorSlicer.validNames) @@ -78,15 +82,19 @@ final class VectorSlicer(override val uid: String) setDefault(names -> Array.empty[String]) /** @group getParam */ + @Since("1.5.0") def getNames: Array[String] = $(names) /** @group setParam */ + @Since("1.5.0") def setNames(value: Array[String]): this.type = set(names, value) /** @group setParam */ + @Since("1.5.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.5.0") def setOutputCol(value: String): this.type = set(outputCol, value) @Since("2.0.0") @@ -134,6 +142,7 @@ final class VectorSlicer(override val uid: String) indFeatures ++ nameFeatures } + @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { require($(indices).length > 0 || $(names).length > 0, s"VectorSlicer requires that at least one feature be selected.") @@ -148,6 +157,7 @@ final class VectorSlicer(override val uid: String) StructType(outputFields) } + @Since("1.5.0") override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index c49e263df0a6d..d53f3df514dff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -50,7 +50,8 @@ private[feature] trait Word2VecBase extends Params def getVectorSize: Int = $(vectorSize) /** - * The window size (context words from [-window, window]) default 5. + * The window size (context words from [-window, window]). + * Default: 5 * @group expertParam */ final val windowSize = new IntParam( @@ -85,6 +86,21 @@ private[feature] trait Word2VecBase extends Params /** @group getParam */ def getMinCount: Int = $(minCount) + /** + * Sets the maximum length (in words) of each sentence in the input data. + * Any sentence longer than this threshold will be divided into chunks of + * up to `maxSentenceLength` size. + * Default: 1000 + * @group param + */ + final val maxSentenceLength = new IntParam(this, "maxSentenceLength", "Maximum length " + + "(in words) of each sentence in the input data. Any sentence longer than this threshold will " + + "be divided into chunks up to the size.") + setDefault(maxSentenceLength -> 1000) + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + setDefault(stepSize -> 0.025) setDefault(maxIter -> 1) @@ -92,49 +108,64 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } /** - * :: Experimental :: * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further * natural language processing or machine learning process. */ -@Experimental -final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase - with DefaultParamsWritable { +@Since("1.4.0") +final class Word2Vec @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) + extends Estimator[Word2VecModel] with Word2VecBase with DefaultParamsWritable { + @Since("1.4.0") def this() = this(Identifiable.randomUID("w2v")) /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @group setParam */ + @Since("1.4.0") def setVectorSize(value: Int): this.type = set(vectorSize, value) /** @group expertSetParam */ + @Since("1.6.0") def setWindowSize(value: Int): this.type = set(windowSize, value) /** @group setParam */ + @Since("1.4.0") def setStepSize(value: Double): this.type = set(stepSize, value) /** @group setParam */ + @Since("1.4.0") def setNumPartitions(value: Int): this.type = set(numPartitions, value) /** @group setParam */ + @Since("1.4.0") def setMaxIter(value: Int): this.type = set(maxIter, value) /** @group setParam */ + @Since("1.4.0") def setSeed(value: Long): this.type = set(seed, value) /** @group setParam */ + @Since("1.4.0") def setMinCount(value: Int): this.type = set(minCount, value) + /** @group setParam */ + @Since("2.0.0") + def setMaxSentenceLength(value: Int): this.type = set(maxSentenceLength, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): Word2VecModel = { transformSchema(dataset.schema, logging = true) @@ -147,14 +178,17 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] .setSeed($(seed)) .setVectorSize($(vectorSize)) .setWindowSize($(windowSize)) + .setMaxSentenceLength($(maxSentenceLength)) .fit(input) copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.4.1") override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } @@ -166,12 +200,11 @@ object Word2Vec extends DefaultParamsReadable[Word2Vec] { } /** - * :: Experimental :: * Model fitted by [[Word2Vec]]. */ -@Experimental +@Since("1.4.0") class Word2VecModel private[ml] ( - override val uid: String, + @Since("1.4.0") override val uid: String, @transient private val wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase with MLWritable { @@ -181,39 +214,42 @@ class Word2VecModel private[ml] ( * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and * and the vector the DenseVector that it is mapped to. */ + @Since("1.5.0") @transient lazy val getVectors: DataFrame = { - val sc = SparkContext.getOrCreate() - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().getOrCreate() val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble))) - sc.parallelize(wordVec.toSeq).toDF("word", "vector") + spark.createDataFrame(wordVec.toSeq).toDF("word", "vector") } /** - * Find "num" number of words closest in similarity to the given word. - * Returns a dataframe with the words and the cosine similarities between the - * synonyms and the given word. + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. Returns a dataframe with the words and the + * cosine similarities between the synonyms and the given word. */ + @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { - findSynonyms(wordVectors.transform(word), num) + val spark = SparkSession.builder().getOrCreate() + spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") } /** - * Find "num" number of words closest to similarity to the given vector representation - * of the word. Returns a dataframe with the words and the cosine similarities between the - * synonyms and the given word vector. + * Find "num" number of words whose vector representation most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. Returns a dataframe with the words and the cosine + * similarities between the synonyms and the given word vector. */ - def findSynonyms(word: Vector, num: Int): DataFrame = { - val sc = SparkContext.getOrCreate() - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ - sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + @Since("2.0.0") + def findSynonyms(vec: Vector, num: Int): DataFrame = { + val spark = SparkSession.builder().getOrCreate() + spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity") } /** @group setParam */ + @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ + @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) /** @@ -229,7 +265,7 @@ class Word2VecModel private[ml] ( val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors) val d = $(vectorSize) val word2Vec = udf { sentence: Seq[String] => - if (sentence.size == 0) { + if (sentence.isEmpty) { Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) } else { val sum = Vectors.zeros(d) @@ -245,10 +281,12 @@ class Word2VecModel private[ml] ( dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) } + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + @Since("1.4.1") override def copy(extra: ParamMap): Word2VecModel = { val copied = new Word2VecModel(uid, wordVectors) copyValues(copied, extra).setParent(parent) @@ -270,7 +308,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -281,7 +319,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] { override def load(path: String): Word2VecModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("wordIndex", "wordVectors") .head() val wordIndex = data.getAs[Map[String, Int]](0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala index 4571ab26800c0..b94187ae787cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.DataFrame * import org.apache.spark.ml.Pipeline * * // a DataFrame with three columns: id (integer), text (string), and rating (double). - * val df = sqlContext.createDataFrame(Seq( + * val df = spark.createDataFrame(Seq( * (0, "Hi I heard about Spark", 3.0), * (1, "I wish Java could use case classes", 4.0), * (2, "Logistic regression models are neat", 4.0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala index 521a216c6785d..a1e53662f02a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ * User-defined type for [[Matrix]] in [[mllib-local]] which allows easy interaction with SQL * via [[org.apache.spark.sql.Dataset]]. */ -private[ml] class MatrixUDT extends UserDefinedType[Matrix] { +private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def sqlType: StructType = { // type: 0 = sparse, 1 = dense diff --git a/project/project/SparkPluginBuild.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/SQLDataTypes.scala similarity index 64% rename from project/project/SparkPluginBuild.scala rename to mllib/src/main/scala/org/apache/spark/ml/linalg/SQLDataTypes.scala index cbb88dc7dd1dd..a66ba27a7b9c5 100644 --- a/project/project/SparkPluginBuild.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/SQLDataTypes.scala @@ -15,14 +15,22 @@ * limitations under the License. */ -import sbt._ -import sbt.Keys._ +package org.apache.spark.ml.linalg + +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.sql.types.DataType /** - * This plugin project is there because we use our custom fork of sbt-pom-reader plugin. This is - * a plugin project so that this gets compiled first and is available on the classpath for SBT build. + * :: DeveloperApi :: + * SQL data types for vectors and matrices. */ -object SparkPluginDef extends Build { - lazy val root = Project("plugins", file(".")) dependsOn(sbtPomReader) - lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") +@Since("2.0.0") +@DeveloperApi +object SQLDataTypes { + + /** Data type for [[Vector]]. */ + val VectorType: DataType = new VectorUDT + + /** Data type for [[Matrix]]. */ + val MatrixType: DataType = new MatrixUDT } diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala index c29f7f86e9f27..0b9b2ff5c5e26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ * User-defined type for [[Vector]] in [[mllib-local]] which allows easy interaction with SQL * via [[org.apache.spark.sql.Dataset]]. */ -private[ml] class VectorUDT extends UserDefinedType[Vector] { +private[spark] class VectorUDT extends UserDefinedType[Vector] { override def sqlType: StructType = { // type: 0 = sparse, 1 = dense diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala index a2b52835e177a..d732f53029e8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquares.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance -import org.apache.spark.mllib.linalg._ +import org.apache.spark.ml.linalg._ import org.apache.spark.rdd.RDD /** @@ -38,7 +38,7 @@ private[ml] class IterativelyReweightedLeastSquaresModel( /** * Implements the method of iteratively reweighted least squares (IRLS) which is used to solve * certain optimization problems by an iterative method. In each step of the iterations, it - * involves solving a weighted lease squares (WLS) problem by [[WeightedLeastSquares]]. + * involves solving a weighted least squares (WLS) problem by [[WeightedLeastSquares]]. * It can be used to find maximum likelihood estimates of a generalized linear model (GLM), * find M-estimator in robust regression and other optimization problems. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 7d21302f962bf..8f5f4427e1f4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.optim import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance -import org.apache.spark.mllib.linalg._ +import org.apache.spark.ml.linalg._ +import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java index 87f4223964ada..cb97382207b00 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package-info.java +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -16,10 +16,7 @@ */ /** - * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly - * assemble and configure practical machine learning pipelines. + * DataFrame-based machine learning APIs to let users quickly assemble and configure practical + * machine learning pipelines. */ -@Experimental package org.apache.spark.ml; - -import org.apache.spark.annotation.Experimental; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index c589d06d9f7e4..a445c675e41e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -18,8 +18,8 @@ package org.apache.spark /** - * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly - * assemble and configure practical machine learning pipelines. + * DataFrame-based machine learning APIs to let users quickly assemble and configure practical + * machine learning pipelines. * * @groupname param Parameters * @groupdesc param A list of (hyper-)parameter keys this algorithm can take. Users can set and get diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index c368aadd23669..9245931b27ca6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -28,9 +28,10 @@ import scala.collection.JavaConverters._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.ml.linalg.JsonVectorConverter +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: @@ -92,7 +93,7 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali case x: String => compact(render(JString(x))) case v: Vector => - v.toJson + JsonVectorConverter.toJson(v) case _ => throw new NotImplementedError( "The default jsonEncode only supports string and vector. " + @@ -128,7 +129,7 @@ private[ml] object Param { val keys = v.map(_._1) assert(keys.contains("type") && keys.contains("values"), s"Expect a JSON serialized vector but cannot find fields 'type' and 'values' in $json.") - Vectors.fromJson(json).asInstanceOf[T] + JsonVectorConverter.fromJson(json).asInstanceOf[T] case _ => throw new NotImplementedError( "The default jsonDecode only supports string and vector. " + @@ -509,11 +510,9 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In } /** - * :: Experimental :: * A param and its value. */ @Since("1.2.0") -@Experimental case class ParamPair[T] @Since("1.2.0") ( @Since("1.2.0") param: Param[T], @Since("1.2.0") value: T) { @@ -553,7 +552,7 @@ trait Params extends Identifiable with Serializable { * * This only needs to check for interactions between parameters. * Parameter value checks which do not depend on other parameters are handled by - * [[Param.validate()]]. This method does not handle input/output column parameters; + * `Param.validate()`. This method does not handle input/output column parameters; * those are checked during schema validation. * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema */ @@ -581,8 +580,7 @@ trait Params extends Identifiable with Serializable { } /** - * Explains all params of this instance. - * @see [[explainParam()]] + * Explains all params of this instance. See `explainParam()`. */ def explainParams(): String = { params.map(explainParam).mkString("\n") @@ -679,7 +677,7 @@ trait Params extends Identifiable with Serializable { /** * Sets default values for a list of params. * - * Note: Java developers should use the single-parameter [[setDefault()]]. + * Note: Java developers should use the single-parameter `setDefault`. * Annotating this with varargs can cause compilation failures due to a Scala compiler bug. * See SPARK-9268. * @@ -713,8 +711,7 @@ trait Params extends Identifiable with Serializable { /** * Creates a copy of this instance with the same UID and some extra params. * Subclasses should implement this method and set the return type properly. - * - * @see [[defaultCopy()]] + * See `defaultCopy()`. */ def copy(extra: ParamMap): Params @@ -731,7 +728,8 @@ trait Params extends Identifiable with Serializable { /** * Extracts the embedded default param values and user-supplied values, and then merges them with * extra values from input into a flat param map, where the latter value is used if there exist - * conflicts, i.e., with ordering: default param values < user-supplied values < extra. + * conflicts, i.e., with ordering: + * default param values less than user-supplied values less than extra. */ final def extractParamMap(extra: ParamMap): ParamMap = { defaultParamMap ++ paramMap ++ extra @@ -789,18 +787,16 @@ trait Params extends Identifiable with Serializable { * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. * Java developers who need to extend [[Params]] should use this class instead. - * If you need to extend a abstract class which already extends [[Params]], then that abstract + * If you need to extend an abstract class which already extends [[Params]], then that abstract * class should be Java-friendly as well. */ @DeveloperApi abstract class JavaParams extends Params /** - * :: Experimental :: * A param to value map. */ @Since("1.2.0") -@Experimental final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { @@ -951,7 +947,6 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) } @Since("1.2.0") -@Experimental object ParamMap { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala new file mode 100644 index 0000000000000..1279c901c5c9e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.python + +import java.io.OutputStream +import java.nio.{ByteBuffer, ByteOrder} +import java.util.{ArrayList => JArrayList} + +import scala.collection.JavaConverters._ + +import net.razorvine.pickle._ + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil +import org.apache.spark.ml.linalg._ +import org.apache.spark.mllib.api.python.SerDeBase +import org.apache.spark.rdd.RDD + +/** + * SerDe utility functions for pyspark.ml. + */ +private[spark] object MLSerDe extends SerDeBase with Serializable { + + override val PYSPARK_PACKAGE = "pyspark.ml" + + // Pickler for DenseVector + private[python] class DenseVectorPickler extends BasePickler[DenseVector] { + + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + val vector: DenseVector = obj.asInstanceOf[DenseVector] + val bytes = new Array[Byte](8 * vector.size) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val db = bb.asDoubleBuffer() + db.put(vector.values) + + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(bytes.length)) + out.write(bytes) + out.write(Opcodes.TUPLE1) + } + + def construct(args: Array[Object]): Object = { + require(args.length == 1) + if (args.length != 1) { + throw new PickleException("should be 1") + } + val bytes = getBytes(args(0)) + val bb = ByteBuffer.wrap(bytes, 0, bytes.length) + bb.order(ByteOrder.nativeOrder()) + val db = bb.asDoubleBuffer() + val ans = new Array[Double](bytes.length / 8) + db.get(ans) + Vectors.dense(ans) + } + } + + // Pickler for DenseMatrix + private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] { + + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] + val bytes = new Array[Byte](8 * m.values.length) + val order = ByteOrder.nativeOrder() + val isTransposed = if (m.isTransposed) 1 else 0 + ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + + out.write(Opcodes.MARK) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(m.numRows)) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(m.numCols)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(bytes.length)) + out.write(bytes) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(isTransposed)) + out.write(Opcodes.TUPLE) + } + + def construct(args: Array[Object]): Object = { + if (args.length != 4) { + throw new PickleException("should be 4") + } + val bytes = getBytes(args(2)) + val n = bytes.length / 8 + val values = new Array[Double](n) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) + val isTransposed = args(3).asInstanceOf[Int] == 1 + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed) + } + } + + // Pickler for SparseMatrix + private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] { + + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + val s = obj.asInstanceOf[SparseMatrix] + val order = ByteOrder.nativeOrder() + + val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length) + val indicesBytes = new Array[Byte](4 * s.rowIndices.length) + val valuesBytes = new Array[Byte](8 * s.values.length) + val isTransposed = if (s.isTransposed) 1 else 0 + ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs) + ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices) + ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values) + + out.write(Opcodes.MARK) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(s.numRows)) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(s.numCols)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length)) + out.write(colPtrsBytes) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(indicesBytes.length)) + out.write(indicesBytes) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(valuesBytes.length)) + out.write(valuesBytes) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(isTransposed)) + out.write(Opcodes.TUPLE) + } + + def construct(args: Array[Object]): Object = { + if (args.length != 6) { + throw new PickleException("should be 6") + } + val order = ByteOrder.nativeOrder() + val colPtrsBytes = getBytes(args(2)) + val indicesBytes = getBytes(args(3)) + val valuesBytes = getBytes(args(4)) + val colPtrs = new Array[Int](colPtrsBytes.length / 4) + val rowIndices = new Array[Int](indicesBytes.length / 4) + val values = new Array[Double](valuesBytes.length / 8) + ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs) + ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices) + ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values) + val isTransposed = args(5).asInstanceOf[Int] == 1 + new SparseMatrix( + args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values, + isTransposed) + } + } + + // Pickler for SparseVector + private[python] class SparseVectorPickler extends BasePickler[SparseVector] { + + def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { + val v: SparseVector = obj.asInstanceOf[SparseVector] + val n = v.indices.length + val indiceBytes = new Array[Byte](4 * n) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices) + val valueBytes = new Array[Byte](8 * n) + ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values) + + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(v.size)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(indiceBytes.length)) + out.write(indiceBytes) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(valueBytes.length)) + out.write(valueBytes) + out.write(Opcodes.TUPLE3) + } + + def construct(args: Array[Object]): Object = { + if (args.length != 3) { + throw new PickleException("should be 3") + } + val size = args(0).asInstanceOf[Int] + val indiceBytes = getBytes(args(1)) + val valueBytes = getBytes(args(2)) + val n = indiceBytes.length / 4 + val indices = new Array[Int](n) + val values = new Array[Double](n) + if (n > 0) { + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices) + ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values) + } + new SparseVector(size, indices, values) + } + } + + var initialized = false + // This should be called before trying to serialize any above classes + // In cluster mode, this should be put in the closure + override def initialize(): Unit = { + SerDeUtil.initialize() + synchronized { + if (!initialized) { + new DenseVectorPickler().register() + new DenseMatrixPickler().register() + new SparseMatrixPickler().register() + new SparseVectorPickler().register() + initialized = true + } + } + } + // will not called in Executor automatically + initialize() +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 9618a3423e9a6..5642abc6450f1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -67,8 +67,8 @@ private[r] object GeneralizedLinearRegressionWrapper data: DataFrame, family: String, link: String, - epsilon: Double, - maxit: Int): GeneralizedLinearRegressionWrapper = { + tol: Double, + maxIter: Int): GeneralizedLinearRegressionWrapper = { val rFormula = new RFormula() .setFormula(formula) val rFormulaModel = rFormula.fit(data) @@ -82,8 +82,8 @@ private[r] object GeneralizedLinearRegressionWrapper .setFamily(family) .setLink(link) .setFitIntercept(rFormula.hasIntercept) - .setTol(epsilon) - .setMaxIter(maxit) + .setTol(tol) + .setMaxIter(maxIter) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 28925c79da66e..1dac246b03329 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -56,7 +56,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" val PREDICTED_LABEL_COL = "prediction" - def fit(formula: String, data: DataFrame, laplace: Double): NaiveBayesWrapper = { + def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = { val rFormula = new RFormula() .setFormula(formula) .fit(data) @@ -70,7 +70,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val features = featureAttrs.map(_.name.get) // assemble and fit the pipeline val naiveBayes = new NaiveBayes() - .setSmoothing(laplace) + .setSmoothing(smoothing) .setModelType("bernoulli") .setPredictionCol(PREDICTED_LABEL_INDEX_COL) val idxToStr = new IndexToString() diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 541923048a3fa..02e2384afe530 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -26,12 +26,12 @@ import scala.util.{Sorting, Try} import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.apache.spark.Partitioner -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.{Dependency, Partitioner, ShuffleDependency, SparkContext} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ @@ -42,7 +42,7 @@ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} @@ -53,24 +53,43 @@ import org.apache.spark.util.random.XORShiftRandom */ private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { /** - * Param for the column name for user ids. + * Param for the column name for user ids. Ids must be integers. Other + * numeric types are supported for this column, but will be cast to integers as long as they + * fall within the integer value range. * Default: "user" * @group param */ - val userCol = new Param[String](this, "userCol", "column name for user ids") + val userCol = new Param[String](this, "userCol", "column name for user ids. Ids must be within " + + "the integer value range.") /** @group getParam */ def getUserCol: String = $(userCol) /** - * Param for the column name for item ids. + * Param for the column name for item ids. Ids must be integers. Other + * numeric types are supported for this column, but will be cast to integers as long as they + * fall within the integer value range. * Default: "item" * @group param */ - val itemCol = new Param[String](this, "itemCol", "column name for item ids") + val itemCol = new Param[String](this, "itemCol", "column name for item ids. Ids must be within " + + "the integer value range.") /** @group getParam */ def getItemCol: String = $(itemCol) + + /** + * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is + * out of integer range. + */ + protected val checkedCast = udf { (n: Double) => + if (n > Int.MaxValue || n < Int.MinValue) { + throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + + s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") + } else { + n.toInt + } + } } /** @@ -80,7 +99,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w with HasPredictionCol with HasCheckpointInterval with HasSeed { /** - * Param for rank of the matrix factorization (>= 1). + * Param for rank of the matrix factorization (positive). * Default: 10 * @group param */ @@ -90,7 +109,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w def getRank: Int = $(rank) /** - * Param for number of user blocks (>= 1). + * Param for number of user blocks (positive). * Default: 10 * @group param */ @@ -101,7 +120,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w def getNumUserBlocks: Int = $(numUserBlocks) /** - * Param for number of item blocks (>= 1). + * Param for number of item blocks (positive). * Default: 10 * @group param */ @@ -122,7 +141,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w def getImplicitPrefs: Boolean = $(implicitPrefs) /** - * Param for the alpha parameter in the implicit preference formulation (>= 0). + * Param for the alpha parameter in the implicit preference formulation (nonnegative). * Default: 1.0 * @group param */ @@ -155,13 +174,13 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w /** * Param for StorageLevel for intermediate datasets. Pass in a string representation of - * [[StorageLevel]]. Cannot be "NONE". + * `StorageLevel`. Cannot be "NONE". * Default: "MEMORY_AND_DISK". * * @group expertParam */ val intermediateStorageLevel = new Param[String](this, "intermediateStorageLevel", - "StorageLevel for intermediate datasets. Cannot be 'NONE'. Default: 'MEMORY_AND_DISK'.", + "StorageLevel for intermediate datasets. Cannot be 'NONE'.", (s: String) => Try(StorageLevel.fromString(s)).isSuccess && s != "NONE") /** @group expertGetParam */ @@ -169,13 +188,13 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w /** * Param for StorageLevel for ALS model factors. Pass in a string representation of - * [[StorageLevel]]. + * `StorageLevel`. * Default: "MEMORY_AND_DISK". * * @group expertParam */ val finalStorageLevel = new Param[String](this, "finalStorageLevel", - "StorageLevel for ALS model factors. Default: 'MEMORY_AND_DISK'.", + "StorageLevel for ALS model factors.", (s: String) => Try(StorageLevel.fromString(s)).isSuccess) /** @group expertGetParam */ @@ -193,23 +212,22 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) - SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) - val ratingType = schema($(ratingCol)).dataType - require(ratingType == FloatType || ratingType == DoubleType) + // user and item will be cast to Int + SchemaUtils.checkNumericType(schema, $(userCol)) + SchemaUtils.checkNumericType(schema, $(itemCol)) + // rating will be cast to Float + SchemaUtils.checkNumericType(schema, $(ratingCol)) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } /** - * :: Experimental :: * Model fitted by ALS. * * @param rank rank of the matrix factorization model * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features` * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` */ -@Experimental @Since("1.3.0") class ALSModel private[ml] ( @Since("1.4.0") override val uid: String, @@ -232,6 +250,7 @@ class ALSModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema) // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => @@ -242,16 +261,19 @@ class ALSModel private[ml] ( } } dataset - .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") - .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") + .join(userFactors, + checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") + .join(itemFactors, + checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } @Since("1.3.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) - SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + // user and item will be cast to Int + SchemaUtils.checkNumericType(schema, $(userCol)) + SchemaUtils.checkNumericType(schema, $(itemCol)) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } @@ -296,9 +318,9 @@ object ALSModel extends MLReadable[ALSModel] { implicit val format = DefaultFormats val rank = (metadata.metadata \ "rank").extract[Int] val userPath = new Path(path, "userFactors").toString - val userFactors = sqlContext.read.format("parquet").load(userPath) + val userFactors = sparkSession.read.format("parquet").load(userPath) val itemPath = new Path(path, "itemFactors").toString - val itemFactors = sqlContext.read.format("parquet").load(itemPath) + val itemFactors = sparkSession.read.format("parquet").load(itemPath) val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors) @@ -309,7 +331,6 @@ object ALSModel extends MLReadable[ALSModel] { } /** - * :: Experimental :: * Alternating Least Squares (ALS) matrix factorization. * * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, @@ -330,15 +351,14 @@ object ALSModel extends MLReadable[ALSModel] { * * For implicit preference data, the algorithm used is based on * "Collaborative Filtering for Implicit Feedback Datasets", available at - * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here. + * http://dx.doi.org/10.1109/ICDM.2008.22, adapted for the blocked approach used here. * * Essentially instead of finding the low-rank approximations to the rating matrix `R`, * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if - * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of + * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of * indicated user * preferences rather than explicit ratings given to items. */ -@Experimental @Since("1.3.0") class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] with ALSParams with DefaultParamsWritable { @@ -406,15 +426,11 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] /** @group expertSetParam */ @Since("2.0.0") - def setIntermediateStorageLevel(value: String): this.type = { - set(intermediateStorageLevel, value) - } + def setIntermediateStorageLevel(value: String): this.type = set(intermediateStorageLevel, value) /** @group expertSetParam */ @Since("2.0.0") - def setFinalStorageLevel(value: String): this.type = { - set(finalStorageLevel, value) - } + def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value) /** * Sets both numUserBlocks and numItemBlocks to the specific value. @@ -430,10 +446,13 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.0.0") override def fit(dataset: Dataset[_]): ALSModel = { + transformSchema(dataset.schema) import dataset.sparkSession.implicits._ + val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r) + .select(checkedCast(col($(userCol)).cast(DoubleType)), + checkedCast(col($(itemCol)).cast(DoubleType)), r) .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) @@ -706,13 +725,15 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { previousItemFactors.unpersist() itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel) // TODO: Generalize PeriodicGraphCheckpointer and use it here. + val deps = itemFactors.dependencies if (shouldCheckpoint(iter)) { - itemFactors.checkpoint() // itemFactors gets materialized in computeFactors. + itemFactors.checkpoint() // itemFactors gets materialized in computeFactors } val previousUserFactors = userFactors userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam, itemLocalIndexEncoder, implicitPrefs, alpha, solver) if (shouldCheckpoint(iter)) { + ALS.cleanShuffleDependencies(sc, deps) deletePreviousCheckpointFile() previousCheckpointFile = itemFactors.getCheckpointFile } @@ -723,8 +744,10 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam, userLocalIndexEncoder, solver = solver) if (shouldCheckpoint(iter)) { + val deps = itemFactors.dependencies itemFactors.checkpoint() itemFactors.count() // checkpoint item factors and cut lineage + ALS.cleanShuffleDependencies(sc, deps) deletePreviousCheckpointFile() previousCheckpointFile = itemFactors.getCheckpointFile } @@ -1355,4 +1378,31 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { * satisfies this requirement, we simply use a type alias here. */ private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner + + /** + * Private function to clean up all of the shuffles files from the dependencies and their parents. + */ + private[spark] def cleanShuffleDependencies[T]( + sc: SparkContext, + deps: Seq[Dependency[_]], + blocking: Boolean = false): Unit = { + // If there is no reference tracking we skip clean up. + sc.cleaner.foreach { cleaner => + /** + * Clean the shuffles & all of its parents. + */ + def cleanEagerly(dep: Dependency[_]): Unit = { + if (dep.isInstanceOf[ShuffleDependency[_, _, _]]) { + val shuffleId = dep.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId + cleaner.doCleanupShuffle(shuffleId, blocking) + } + val rdd = dep.rdd + val rddDeps = rdd.dependencies + if (rdd.getStorageLevel == StorageLevel.NONE && rddDeps != null) { + rddDeps.foreach(cleanEagerly) + } + } + deps.foreach(cleanEagerly) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 89ba6ab5d2772..3ebaba6368837 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -27,11 +27,13 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -88,8 +90,8 @@ private[regression] trait AFTSurvivalRegressionParams extends Params def getQuantilesCol: String = $(quantilesCol) /** Checks whether the input has quantiles column name. */ - protected[regression] def hasQuantilesCol: Boolean = { - isDefined(quantilesCol) && $(quantilesCol) != "" + private[regression] def hasQuantilesCol: Boolean = { + isDefined(quantilesCol) && $(quantilesCol).nonEmpty } /** @@ -194,7 +196,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S @Since("2.0.0") override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { - validateAndTransformSchema(dataset.schema, fitting = true) + transformSchema(dataset.schema, logging = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -208,11 +210,18 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + val numFeatures = featuresStd.size + + if (!$(fitIntercept) && (0 until numFeatures).exists { i => + featuresStd(i) == 0.0 && featuresSummarizer.mean(i) != 0.0 }) { + logWarning("Fitting AFTSurvivalRegressionModel without intercept on dataset with " + + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + + "columns. This behavior is different from R survival::survreg.") + } val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - val numFeatures = featuresStd.size /* The parameters vector has three parts: the first element: Double, log(sigma), the log of scale parameter @@ -222,7 +231,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val initialParameters = Vectors.zeros(numFeatures + 2) val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialParameters.toBreeze.toDenseVector) + initialParameters.asBreeze.toDenseVector) val parameters = { val arrayBuilder = mutable.ArrayBuilder.make[Double] @@ -236,6 +245,11 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S throw new SparkException(msg) } + if (!state.actuallyConverged) { + logWarning("AFTSurvivalRegression training finished but the result " + + s"is not converged because: ${state.convergedReason.get.reason}") + } + state.x.toArray.clone() } @@ -278,7 +292,7 @@ object AFTSurvivalRegression extends DefaultParamsReadable[AFTSurvivalRegression @Since("1.6.0") class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") override val uid: String, - @Since("1.6.0") val coefficients: Vector, + @Since("2.0.0") val coefficients: Vector, @Since("1.6.0") val intercept: Double, @Since("1.6.0") val scale: Double) extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with MLWritable { @@ -299,7 +313,7 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def setQuantilesCol(value: String): this.type = set(quantilesCol, value) - @Since("1.6.0") + @Since("2.0.0") def predictQuantiles(features: Vector): Vector = { // scale parameter for the Weibull distribution of lifetime val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) @@ -311,14 +325,14 @@ class AFTSurvivalRegressionModel private[ml] ( Vectors.dense(quantiles) } - @Since("1.6.0") + @Since("2.0.0") def predict(features: Vector): Double = { math.exp(BLAS.dot(coefficients, features) + intercept) } @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema) + transformSchema(dataset.schema, logging = true) val predictUDF = udf { features: Vector => predict(features) } val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} if (hasQuantilesCol) { @@ -367,7 +381,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] // Save model data: coefficients, intercept, scale val data = Data(instance.coefficients, instance.intercept, instance.scale) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -380,11 +394,11 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) - .select("coefficients", "intercept", "scale").head() - val coefficients = data.getAs[Vector](0) - val intercept = data.getDouble(1) - val scale = data.getDouble(2) + val data = sparkSession.read.parquet(dataPath) + val Row(coefficients: Vector, intercept: Double, scale: Double) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("coefficients", "intercept", "scale") + .head() val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale) DefaultParamsReader.getAndSetParams(model, metadata) @@ -395,7 +409,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] /** * AFTAggregator computes the gradient and loss for a AFT loss function, - * as used in AFT survival regression for samples in sparse or dense vector in a online fashion. + * as used in AFT survival regression for samples in sparse or dense vector in an online fashion. * * The loss function and likelihood function under the AFT model based on: * Lawless, J. F., Statistical Models and Methods for Lifetime Data, @@ -530,7 +544,7 @@ private class AFTAggregator( * @return This AFTAggregator object. */ def merge(other: AFTAggregator): this.type = { - if (totalCnt != 0) { + if (other.count != 0) { totalCnt += other.totalCnt lossSum += other.lossSum diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 874d2a81db216..ebc6c12ddcf92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -21,15 +21,15 @@ import org.apache.hadoop.fs.Path import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD @@ -38,13 +38,11 @@ import org.apache.spark.sql.functions._ /** - * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm * for regression. * It supports both continuous and categorical features. */ @Since("1.4.0") -@Experimental class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] with DecisionTreeRegressorParams with DefaultParamsWritable { @@ -125,7 +123,6 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S } @Since("1.4.0") -@Experimental object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor] { /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities @@ -135,13 +132,11 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor } /** - * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. * It supports both continuous and categorical features. * @param rootNode Root of the decision tree */ @Since("1.4.0") -@Experimental class DecisionTreeRegressionModel private[ml] ( override val uid: String, override val rootNode: Node, @@ -249,7 +244,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) val (nodeData, _) = NodeData.build(instance.rootNode, 0) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(nodeData).write.parquet(dataPath) + sparkSession.createDataFrame(nodeData).write.parquet(dataPath) } } @@ -263,7 +258,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val root = loadTreeNodes(path, metadata, sqlContext) + val root = loadTreeNodes(path, metadata, sparkSession) val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index c41fb4b0629b2..ce355938ec1c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -24,13 +24,13 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD @@ -38,7 +38,6 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ /** - * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for regression. * It supports both continuous and categorical features. @@ -56,7 +55,6 @@ import org.apache.spark.sql.functions._ * [https://issues.apache.org/jira/browse/SPARK-4240] */ @Since("1.4.0") -@Experimental class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] with GBTRegressorParams with DefaultParamsWritable with Logging { @@ -135,7 +133,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) } @Since("1.4.0") -@Experimental object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @@ -147,8 +144,6 @@ object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { } /** - * :: Experimental :: - * * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * model for regression. * It supports both continuous and categorical features. @@ -156,7 +151,6 @@ object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { * @param _treeWeights Weights for the decision trees in the ensemble. */ @Since("1.4.0") -@Experimental class GBTRegressionModel private[ml]( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], @@ -252,7 +246,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { val extraMetadata: JObject = Map( "numFeatures" -> instance.numFeatures, "numTrees" -> instance.getNumTrees) - EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -265,7 +259,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] { override def load(path: String): GBTRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index c294ef31f90de..7f88c12d7e26e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -25,11 +25,11 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{BLAS, Vector} import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -43,6 +43,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam with HasFitIntercept with HasMaxIter with HasTol with HasRegParam with HasWeightCol with HasSolver with Logging { + import GeneralizedLinearRegression._ + /** * Param for the name of family which is a description of the error distribution * to be used in the model. @@ -54,8 +56,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the error distribution to be used in the " + - "model. Supported options: gaussian(default), binomial, poisson and gamma.", - ParamValidators.inArray[String](GeneralizedLinearRegression.supportedFamilyNames.toArray)) + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", + ParamValidators.inArray[String](supportedFamilyNames.toArray)) /** @group getParam */ @Since("2.0.0") @@ -71,9 +73,8 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") final val link: Param[String] = new Param(this, "link", "The name of link function " + "which provides the relationship between the linear predictor and the mean of the " + - "distribution function. Supported options: identity, log, inverse, logit, probit, " + - "cloglog and sqrt.", - ParamValidators.inArray[String](GeneralizedLinearRegression.supportedLinkNames.toArray)) + s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", + ParamValidators.inArray[String](supportedLinkNames.toArray)) /** @group getParam */ @Since("2.0.0") @@ -81,19 +82,23 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam /** * Param for link prediction (linear predictor) column name. - * Default is empty, which means we do not output link prediction. + * Default is not set, which means we do not output link prediction. * * @group param */ @Since("2.0.0") final val linkPredictionCol: Param[String] = new Param[String](this, "linkPredictionCol", "link prediction (linear predictor) column name") - setDefault(linkPredictionCol, "") /** @group getParam */ @Since("2.0.0") def getLinkPredictionCol: String = $(linkPredictionCol) + /** Checks whether we should output link prediction. */ + private[regression] def hasLinkPredictionCol: Boolean = { + isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty + } + import GeneralizedLinearRegression._ @Since("2.0.0") @@ -107,7 +112,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam s"with ${$(family)} family does not support ${$(link)} link function.") } val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) - if ($(linkPredictionCol).nonEmpty) { + if (hasLinkPredictionCol) { SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType) } else { newSchema @@ -205,7 +210,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the value of param [[weightCol]]. * If this is not set or empty, we treat all instance weights as 1.0. - * Default is empty, so all instances have weight one. + * Default is not set, so all instances have weight one. * * @group setParam */ @@ -214,7 +219,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the solver algorithm used for optimization. - * Currently only support "irls" which is also the default solver. + * Currently only supports "irls" which is also the default solver. * * @group setParam */ @@ -239,10 +244,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val familyAndLink = new FamilyAndLink(familyObj, linkObj) - val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd - .map { case Row(features: Vector) => - features.size - }.first() + val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size if (numFeatures > WeightedLeastSquares.MAX_NUM_FEATURES) { val msg = "Currently, GeneralizedLinearRegression only supports number of features" + s" <= ${WeightedLeastSquares.MAX_NUM_FEATURES}. Found $numFeatures in the input dataset." @@ -294,7 +296,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def load(path: String): GeneralizedLinearRegression = super.load(path) /** Set of family and link pairs that GeneralizedLinearRegression supports. */ - private[ml] lazy val supportedFamilyAndLinkPairs = Set( + private[regression] lazy val supportedFamilyAndLinkPairs = Set( Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, @@ -302,17 +304,17 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine ) /** Set of family names that GeneralizedLinearRegression supports. */ - private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + private[regression] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) /** Set of link names that GeneralizedLinearRegression supports. */ - private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + private[regression] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) - private[ml] val epsilon: Double = 1E-16 + private[regression] val epsilon: Double = 1E-16 /** * Wrapper of family and link combination used in the model. */ - private[ml] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { + private[regression] class FamilyAndLink(val family: Family, val link: Link) extends Serializable { /** Linear predictor based on given mu. */ def predict(mu: Double): Double = link.link(family.project(mu)) @@ -359,7 +361,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * * @param name the name of the family. */ - private[ml] abstract class Family(val name: String) extends Serializable { + private[regression] abstract class Family(val name: String) extends Serializable { /** The default link instance of this family. */ val defaultLink: Link @@ -374,7 +376,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def deviance(y: Double, mu: Double, weight: Double): Double /** - * Akaike's 'An Information Criterion'(AIC) value of the family for a given dataset. + * Akaike Information Criterion (AIC) value of the family for a given dataset. * * @param predictions an RDD of (y, mu, weight) of instances in evaluation dataset * @param deviance the deviance for the fitted model in evaluation dataset @@ -391,7 +393,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def project(mu: Double): Double = mu } - private[ml] object Family { + private[regression] object Family { /** * Gets the [[Family]] object from its name. @@ -412,7 +414,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Gaussian exponential family distribution. * The default link for the Gaussian family is the identity link. */ - private[ml] object Gaussian extends Family("gaussian") { + private[regression] object Gaussian extends Family("gaussian") { val defaultLink: Link = Identity @@ -448,7 +450,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Binomial exponential family distribution. * The default link for the Binomial family is the logit link. */ - private[ml] object Binomial extends Family("binomial") { + private[regression] object Binomial extends Family("binomial") { val defaultLink: Link = Logit @@ -492,7 +494,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Poisson exponential family distribution. * The default link for the Poisson family is the log link. */ - private[ml] object Poisson extends Family("poisson") { + private[regression] object Poisson extends Family("poisson") { val defaultLink: Link = Log @@ -533,7 +535,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * Gamma exponential family distribution. * The default link for the Gamma family is the inverse link. */ - private[ml] object Gamma extends Family("gamma") { + private[regression] object Gamma extends Family("gamma") { val defaultLink: Link = Inverse @@ -578,7 +580,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * * @param name the name of link function. */ - private[ml] abstract class Link(val name: String) extends Serializable { + private[regression] abstract class Link(val name: String) extends Serializable { /** The link function. */ def link(mu: Double): Double @@ -590,7 +592,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine def unlink(eta: Double): Double } - private[ml] object Link { + private[regression] object Link { /** * Gets the [[Link]] object from its name. @@ -611,7 +613,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine } } - private[ml] object Identity extends Link("identity") { + private[regression] object Identity extends Link("identity") { override def link(mu: Double): Double = mu @@ -620,7 +622,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = eta } - private[ml] object Logit extends Link("logit") { + private[regression] object Logit extends Link("logit") { override def link(mu: Double): Double = math.log(mu / (1.0 - mu)) @@ -629,7 +631,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = 1.0 / (1.0 + math.exp(-1.0 * eta)) } - private[ml] object Log extends Link("log") { + private[regression] object Log extends Link("log") { override def link(mu: Double): Double = math.log(mu) @@ -638,7 +640,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = math.exp(eta) } - private[ml] object Inverse extends Link("inverse") { + private[regression] object Inverse extends Link("inverse") { override def link(mu: Double): Double = 1.0 / mu @@ -647,7 +649,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = 1.0 / eta } - private[ml] object Probit extends Link("probit") { + private[regression] object Probit extends Link("probit") { override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) @@ -658,7 +660,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) } - private[ml] object CLogLog extends Link("cloglog") { + private[regression] object CLogLog extends Link("cloglog") { override def link(mu: Double): Double = math.log(-1.0 * math.log(1 - mu)) @@ -667,7 +669,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine override def unlink(eta: Double): Double = 1.0 - math.exp(-1.0 * math.exp(eta)) } - private[ml] object Sqrt extends Link("sqrt") { + private[regression] object Sqrt extends Link("sqrt") { override def link(mu: Double): Double = math.sqrt(mu) @@ -700,13 +702,13 @@ class GeneralizedLinearRegressionModel private[ml] ( import GeneralizedLinearRegression._ - lazy val familyObj = Family.fromName($(family)) - lazy val linkObj = if (isDefined(link)) { + private lazy val familyObj = Family.fromName($(family)) + private lazy val linkObj = if (isDefined(link)) { Link.fromName($(link)) } else { familyObj.defaultLink } - lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) + private lazy val familyAndLink = new FamilyAndLink(familyObj, linkObj) override protected def predict(features: Vector): Double = { val eta = predictLink(features) @@ -732,7 +734,7 @@ class GeneralizedLinearRegressionModel private[ml] ( if ($(predictionCol).nonEmpty) { output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - if ($(linkPredictionCol).nonEmpty) { + if (hasLinkPredictionCol) { output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) } output.toDF() @@ -776,6 +778,13 @@ class GeneralizedLinearRegressionModel private[ml] ( .setParent(parent) } + /** + * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * + * For [[GeneralizedLinearRegressionModel]], this does NOT currently save the + * training [[summary]]. An option to save [[summary]] may be added in the future. + * + */ @Since("2.0.0") override def write: MLWriter = new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this) @@ -804,7 +813,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -818,7 +827,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("intercept", "coefficients").head() val intercept = data.getDouble(0) val coefficients = data.getAs[Vector](1) @@ -848,12 +857,12 @@ class GeneralizedLinearRegressionSummary private[regression] ( import GeneralizedLinearRegression._ /** - * Field in "predictions" which gives the prediction value of each instance. + * Field in "predictions" which gives the predicted value of each instance. * This is set to a new column name if the original model's `predictionCol` is not set. */ @Since("2.0.0") val predictionCol: String = { - if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") { + if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol.nonEmpty) { origModel.getPredictionCol } else { "prediction_" + java.util.UUID.randomUUID.toString @@ -870,7 +879,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( protected val model: GeneralizedLinearRegressionModel = origModel.copy(ParamMap.empty).setPredictionCol(predictionCol) - /** predictions output by the model's `transform` method */ + /** Predictions output by the model's `transform` method. */ @Since("2.0.0") @transient val predictions: DataFrame = model.transform(dataset) private[regression] lazy val family: Family = Family.fromName(model.getFamily) @@ -880,10 +889,10 @@ class GeneralizedLinearRegressionSummary private[regression] ( family.defaultLink } - /** Number of instances in DataFrame predictions */ + /** Number of instances in DataFrame predictions. */ private[regression] lazy val numInstances: Long = predictions.count() - /** The numeric rank of the fitted linear model */ + /** The numeric rank of the fitted linear model. */ @Since("2.0.0") lazy val rank: Long = if (model.getFitIntercept) { model.coefficients.size + 1 @@ -891,17 +900,17 @@ class GeneralizedLinearRegressionSummary private[regression] ( model.coefficients.size } - /** Degrees of freedom */ + /** Degrees of freedom. */ @Since("2.0.0") lazy val degreesOfFreedom: Long = { numInstances - rank } - /** The residual degrees of freedom */ + /** The residual degrees of freedom. */ @Since("2.0.0") lazy val residualDegreeOfFreedom: Long = degreesOfFreedom - /** The residual degrees of freedom for the null model */ + /** The residual degrees of freedom for the null model. */ @Since("2.0.0") lazy val residualDegreeOfFreedomNull: Long = if (model.getFitIntercept) { numInstances - 1 @@ -944,7 +953,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( } /** - * Get the default residuals(deviance residuals) of the fitted model. + * Get the default residuals (deviance residuals) of the fitted model. */ @Since("2.0.0") def residuals(): DataFrame = devianceResiduals @@ -979,7 +988,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( } else { link.unlink(0.0) } - predictions.select(col(model.getLabelCol), w).rdd.map { + predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map { case Row(y: Double, weight: Double) => family.deviance(y, wtdmu, weight) }.sum() @@ -991,7 +1000,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.0.0") lazy val deviance: Double = { val w = weightCol - predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { + predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { case Row(label: Double, pred: Double, weight: Double) => family.deviance(label, pred, weight) }.sum() @@ -1000,7 +1009,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** * The dispersion of the fitted model. * It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise - * estimated by the residual Pearson's Chi-Squared statistic(which is defined as + * estimated by the residual Pearson's Chi-Squared statistic (which is defined as * sum of the squares of the Pearson residuals) divided by the residual degrees of freedom. */ @Since("2.0.0") @@ -1012,14 +1021,15 @@ class GeneralizedLinearRegressionSummary private[regression] ( rss / degreesOfFreedom } - /** Akaike's "An Information Criterion"(AIC) for the fitted model. */ + /** Akaike Information Criterion (AIC) for the fitted model. */ @Since("2.0.0") lazy val aic: Double = { val w = weightCol val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) - val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { - case Row(label: Double, pred: Double, weight: Double) => - (label, pred, weight) + val t = predictions.select( + col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map { + case Row(label: Double, pred: Double, weight: Double) => + (label, pred, weight) } family.aic(t, deviance, numInstances, weightSum) + 2 * rank } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index 7a78ecbdf16de..cd7b4f2a9c56e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -19,14 +19,14 @@ package org.apache.spark.ml.regression import org.apache.hadoop.fs.Path -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD @@ -69,8 +69,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures setDefault(isotonic -> true, featureIndex -> 0) /** Checks whether the input has weight column. */ - protected[ml] def hasWeightCol: Boolean = { - isDefined(weightCol) && $(weightCol) != "" + private[regression] def hasWeightCol: Boolean = { + isDefined(weightCol) && $(weightCol).nonEmpty } /** @@ -120,7 +120,6 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } /** - * :: Experimental :: * Isotonic regression. * * Currently implemented using parallelized pool adjacent violators algorithm. @@ -129,7 +128,6 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ @Since("1.5.0") -@Experimental class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[IsotonicRegressionModel] with IsotonicRegressionBase with DefaultParamsWritable { @@ -166,7 +164,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("2.0.0") override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { - validateAndTransformSchema(dataset.schema, fitting = true) + transformSchema(dataset.schema, logging = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -192,7 +190,6 @@ object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] { } /** - * :: Experimental :: * Model fitted by IsotonicRegression. * Predicts using a piecewise linear function. * @@ -202,7 +199,6 @@ object IsotonicRegression extends DefaultParamsReadable[IsotonicRegression] { * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. */ @Since("1.5.0") -@Experimental class IsotonicRegressionModel private[ml] ( override val uid: String, private val oldModel: MLlibIsotonicRegressionModel) @@ -221,14 +217,14 @@ class IsotonicRegressionModel private[ml] ( def setFeatureIndex(value: Int): this.type = set(featureIndex, value) /** Boundaries in increasing order for which predictions are known. */ - @Since("1.5.0") + @Since("2.0.0") def boundaries: Vector = Vectors.dense(oldModel.boundaries) /** * Predictions associated with the boundaries at the same index, monotone because of isotonic * regression. */ - @Since("1.5.0") + @Since("2.0.0") def predictions: Vector = Vectors.dense(oldModel.predictions) @Since("1.5.0") @@ -238,6 +234,7 @@ class IsotonicRegressionModel private[ml] ( @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => udf { feature: Double => oldModel.predict(feature) } @@ -284,7 +281,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val data = Data( instance.oldModel.boundaries, instance.oldModel.predictions, instance.oldModel.isotonic) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -297,7 +294,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath) + val data = sparkSession.read.parquet(dataPath) .select("boundaries", "predictions", "isotonic").head() val boundaries = data.getAs[Seq[Double]](0).toArray val predictions = data.getAs[Seq[Double]](1).toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index d13b15fd82f05..600bbcbac5e81 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -28,15 +28,18 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.BLAS._ import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -51,7 +54,6 @@ private[regression] trait LinearRegressionParams extends PredictorParams with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver /** - * :: Experimental :: * Linear regression. * * The learning objective is to minimize the squared error, with regularization. @@ -65,7 +67,6 @@ private[regression] trait LinearRegressionParams extends PredictorParams * - L2 + L1 (elastic net) */ @Since("1.3.0") -@Experimental class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String) extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with DefaultParamsWritable with Logging { @@ -159,22 +160,21 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. - val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { - case Row(features: Vector) => features.size - }.first() + val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = dataset.select( + col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " + "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) - val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), $(standardization), true) @@ -197,12 +197,6 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String return lrModel.setSummary(trainingSummary) } - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } - val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -240,7 +234,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val coefficients = Vectors.sparse(numFeatures, Seq()) val intercept = yMean - val model = new LinearRegressionModel(uid, coefficients, intercept) + val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() @@ -252,7 +246,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), Array(0D)) - return copyValues(model.setSummary(trainingSummary)) + return model.setSummary(trainingSummary) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") @@ -262,11 +256,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } // if y is constant (rawYStd is zero), then y cannot be scaled. In this case - // setting yStd=1.0 ensures that y is not scaled anymore in l-bfgs algorithm. + // setting yStd=abs(yMean) ensures that y is not scaled anymore in l-bfgs algorithm. val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean) val featuresMean = featuresSummarizer.mean.toArray val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + if (!$(fitIntercept) && (0 until numFeatures).exists { i => + featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) { + logWarning("Fitting LinearRegressionModel without intercept on dataset with " + + "constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero " + + "columns. This behavior is the same as R glmnet but different from LIBSVM.") + } + // Since we implicitly do the feature scaling when we compute the cost function // to improve the convergence, the effective regParam will be changed. val effectiveRegParam = $(regParam) / yStd @@ -297,7 +298,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val initialCoefficients = Vectors.zeros(numFeatures) val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficients.toBreeze.toDenseVector) + initialCoefficients.asBreeze.toDenseVector) val (coefficients, objectiveHistory) = { /* @@ -319,6 +320,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String throw new SparkException(msg) } + if (!state.actuallyConverged) { + logWarning("LinearRegression training finished but the result " + + s"is not converged because: ${state.convergedReason.get.reason}") + } + /* The coefficients are trained in the scaled space; we're converting them back to the original space. @@ -374,15 +380,13 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { } /** - * :: Experimental :: * Model produced by [[LinearRegression]]. */ @Since("1.3.0") -@Experimental class LinearRegressionModel private[ml] ( - override val uid: String, - val coefficients: Vector, - val intercept: Double) + @Since("1.4.0") override val uid: String, + @Since("2.0.0") val coefficients: Vector, + @Since("1.3.0") val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] with LinearRegressionParams with MLWritable { @@ -447,7 +451,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -479,7 +483,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } @@ -492,10 +496,11 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val data = sqlContext.read.format("parquet").load(dataPath) - .select("intercept", "coefficients").head() - val intercept = data.getDouble(0) - val coefficients = data.getAs[Vector](1) + val data = sparkSession.read.format("parquet").load(dataPath) + val Row(intercept: Double, coefficients: Vector) = + MLUtils.convertVectorColumnsToML(data, "coefficients") + .select("intercept", "coefficients") + .head() val model = new LinearRegressionModel(metadata.uid, coefficients, intercept) DefaultParamsReader.getAndSetParams(model, metadata) @@ -558,16 +563,18 @@ class LinearRegressionSummary private[regression] ( val predictionCol: String, val labelCol: String, val featuresCol: String, - @deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0") - val model: LinearRegressionModel, + private val privateModel: LinearRegressionModel, private val diagInvAtWA: Array[Double]) extends Serializable { + @deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0") + val model: LinearRegressionModel = privateModel + @transient private val metrics = new RegressionMetrics( predictions .select(col(predictionCol), col(labelCol).cast(DoubleType)) .rdd .map { case Row(pred: Double, label: Double) => (pred, label) }, - !model.getFitIntercept) + !privateModel.getFitIntercept) /** * Returns the explained variance regression score. @@ -631,10 +638,10 @@ class LinearRegressionSummary private[regression] ( lazy val numInstances: Long = predictions.count() /** Degrees of freedom */ - private val degreesOfFreedom: Long = if (model.getFitIntercept) { - numInstances - model.coefficients.size - 1 + private val degreesOfFreedom: Long = if (privateModel.getFitIntercept) { + numInstances - privateModel.coefficients.size - 1 } else { - numInstances - model.coefficients.size + numInstances - privateModel.coefficients.size } /** @@ -642,13 +649,15 @@ class LinearRegressionSummary private[regression] ( * the square root of the instance weights. */ lazy val devianceResiduals: Array[Double] = { - val weighted = if (!model.isDefined(model.weightCol) || model.getWeightCol.isEmpty) { - lit(1.0) - } else { - sqrt(col(model.getWeightCol)) - } - val dr = predictions.select(col(model.getLabelCol).minus(col(model.getPredictionCol)) - .multiply(weighted).as("weightedResiduals")) + val weighted = + if (!privateModel.isDefined(privateModel.weightCol) || privateModel.getWeightCol.isEmpty) { + lit(1.0) + } else { + sqrt(col(privateModel.getWeightCol)) + } + val dr = predictions + .select(col(privateModel.getLabelCol).minus(col(privateModel.getPredictionCol)) + .multiply(weighted).as("weightedResiduals")) .select(min(col("weightedResiduals")).as("min"), max(col("weightedResiduals")).as("max")) .first() Array(dr.getDouble(0), dr.getDouble(1)) @@ -668,14 +677,15 @@ class LinearRegressionSummary private[regression] ( throw new UnsupportedOperationException( "No Std. Error of coefficients available for this LinearRegressionModel") } else { - val rss = if (!model.isDefined(model.weightCol) || model.getWeightCol.isEmpty) { - meanSquaredError * numInstances - } else { - val t = udf { (pred: Double, label: Double, weight: Double) => - math.pow(label - pred, 2.0) * weight } - predictions.select(t(col(model.getPredictionCol), col(model.getLabelCol), - col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) - } + val rss = + if (!privateModel.isDefined(privateModel.weightCol) || privateModel.getWeightCol.isEmpty) { + meanSquaredError * numInstances + } else { + val t = udf { (pred: Double, label: Double, weight: Double) => + math.pow(label - pred, 2.0) * weight } + predictions.select(t(col(privateModel.getPredictionCol), col(privateModel.getLabelCol), + col(privateModel.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) + } val sigma2 = rss / degreesOfFreedom diagInvAtWA.map(_ * sigma2).map(math.sqrt) } @@ -695,10 +705,10 @@ class LinearRegressionSummary private[regression] ( throw new UnsupportedOperationException( "No t-statistic available for this LinearRegressionModel") } else { - val estimate = if (model.getFitIntercept) { - Array.concat(model.coefficients.toArray, Array(model.intercept)) + val estimate = if (privateModel.getFitIntercept) { + Array.concat(privateModel.coefficients.toArray, Array(privateModel.intercept)) } else { - model.coefficients.toArray + privateModel.coefficients.toArray } estimate.zip(coefficientStandardErrors).map { x => x._1 / x._2 } } @@ -726,7 +736,7 @@ class LinearRegressionSummary private[regression] ( /** * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, - * as used in linear regression for samples in sparse or dense vector in a online fashion. + * as used in linear regression for samples in sparse or dense vector in an online fashion. * * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. @@ -779,27 +789,27 @@ class LinearRegressionSummary private[regression] ( * * Now, the first derivative of the objective function in scaled space is * {{{ - * \frac{\partial L}{\partial\w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} + * \frac{\partial L}{\partial w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} * }}} * However, ($x_i - \bar{x_i}$) will densify the computation, so it's not * an ideal formula when the training dataset is sparse format. * - * This can be addressed by adding the dense \bar{x_i} / \har{x_i} terms + * This can be addressed by adding the dense \bar{x_i} / \hat{x_i} terms * in the end by keeping the sum of diff. The first derivative of total * objective function from all the samples is * {{{ - * \frac{\partial L}{\partial\w_i} = + * \frac{\partial L}{\partial w_i} = * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} - * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i}) / \hat{x_i}) + * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i} / \hat{x_i}) * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) * }}}, - * where correction_i = - diffSum \bar{x_i}) / \hat{x_i} + * where correction_i = - diffSum \bar{x_i} / \hat{x_i} * * A simple math can show that diffSum is actually zero, so we don't even * need to add the correction terms in the end. From the definition of diff, * {{{ * diffSum = \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) - * = N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y_j} - \bar{y}) / \hat{y}) + * = N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y} - \bar{y}) / \hat{y}) * = 0 * }}} * @@ -807,7 +817,7 @@ class LinearRegressionSummary private[regression] ( * the training dataset, which can be easily computed in distributed fashion, and is * sparse format friendly. * {{{ - * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + * \frac{\partial L}{\partial w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) * }}}, * * @param coefficients The coefficients corresponding to the features. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 9605de72020f1..0ad00aa6f9280 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -20,15 +20,15 @@ package org.apache.spark.ml.regression import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD @@ -37,12 +37,10 @@ import org.apache.spark.sql.functions._ /** - * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. * It supports both continuous and categorical features. */ @Since("1.4.0") -@Experimental class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] with RandomForestRegressorParams with DefaultParamsWritable { @@ -118,7 +116,6 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S } @Since("1.4.0") -@Experimental object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") @@ -135,7 +132,6 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor } /** - * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. * @@ -143,7 +139,6 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor * @param numFeatures Number of features used by this model */ @Since("1.4.0") -@Experimental class RandomForestRegressionModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], @@ -244,7 +239,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode val extraMetadata: JObject = Map( "numFeatures" -> instance.numFeatures, "numTrees" -> instance.getNumTrees) - EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -257,7 +252,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = - EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala new file mode 100644 index 0000000000000..73d813064decb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.source.libsvm + +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.sql.{DataFrame, DataFrameReader} + +/** + * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. + * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and + * `features` containing feature vectors stored as [[Vector]]s. + * + * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and + * optionally specify options, for example: + * {{{ + * // Scala + * val df = spark.read.format("libsvm") + * .option("numFeatures", "780") + * .load("data/mllib/sample_libsvm_data.txt") + * + * // Java + * Dataset df = spark.read().format("libsvm") + * .option("numFeatures, "780") + * .load("data/mllib/sample_libsvm_data.txt"); + * }}} + * + * LIBSVM data source supports the following options: + * - "numFeatures": number of features. + * If unspecified or nonpositive, the number of features will be determined automatically at the + * cost of one additional pass. + * This is also useful when the dataset is already split into multiple files and you want to load + * them separately, because some features may not present in certain files, which leads to + * inconsistent feature dimensions. + * - "vectorType": feature vector type, "sparse" (default) or "dense". + * + * Note that this class is public for documentation purpose. Please don't use this class directly. + * Rather, use the data source API as illustrated above. + * + * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] + */ +class LibSVMDataSource private() {} diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index ba2e1e2bc269d..ac95b9272f1d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -25,15 +25,16 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.TaskContext +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -51,7 +52,7 @@ private[libsvm] class LibSVMOutputWriter( new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val uniqueWriteJobId = configuration.get(CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") @@ -76,48 +77,21 @@ private[libsvm] class LibSVMOutputWriter( } } -/** - * `libsvm` package implements Spark SQL data source API for loading LIBSVM data as [[DataFrame]]. - * The loaded [[DataFrame]] has two columns: `label` containing labels stored as doubles and - * `features` containing feature vectors stored as [[Vector]]s. - * - * To use LIBSVM data source, you need to set "libsvm" as the format in [[DataFrameReader]] and - * optionally specify options, for example: - * {{{ - * // Scala - * val df = spark.read.format("libsvm") - * .option("numFeatures", "780") - * .load("data/mllib/sample_libsvm_data.txt") - * - * // Java - * DataFrame df = spark.read().format("libsvm") - * .option("numFeatures, "780") - * .load("data/mllib/sample_libsvm_data.txt"); - * }}} - * - * LIBSVM data source supports the following options: - * - "numFeatures": number of features. - * If unspecified or nonpositive, the number of features will be determined automatically at the - * cost of one additional pass. - * This is also useful when the dataset is already split into multiple files and you want to load - * them separately, because some features may not present in certain files, which leads to - * inconsistent feature dimensions. - * - "vectorType": feature vector type, "sparse" (default) or "dense". - * - * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] - */ -@Since("1.6.0") -class DefaultSource extends FileFormat with DataSourceRegister { +/** @see [[LibSVMDataSource]] for public documentation. */ +// If this is moved or renamed, please update DataSource's backwardCompatibilityMap. +private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { - @Since("1.6.0") override def shortName(): String = "libsvm" override def toString: String = "LibSVM" private def verifySchema(dataSchema: StructType): Unit = { - if (dataSchema.size != 2 || - (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) - || !dataSchema(1).dataType.sameType(new VectorUDT()))) { + if ( + dataSchema.size != 2 || + !dataSchema(0).dataType.sameType(DataTypes.DoubleType) || + !dataSchema(1).dataType.sameType(new VectorUDT()) || + !(dataSchema(1).metadata.getLong("numFeatures").toInt > 0) + ) { throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema") } } @@ -126,17 +100,8 @@ class DefaultSource extends FileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - Some( - StructType( - StructField("label", DoubleType, nullable = false) :: - StructField("features", new VectorUDT(), nullable = false) :: Nil)) - } - - override def prepareRead( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Map[String, String] = { - def computeNumFeatures(): Int = { + val numFeatures: Int = options.get("numFeatures").map(_.toInt).filter(_ > 0).getOrElse { + // Infers number of features if the user doesn't specify (a valid) one. val dataFiles = files.filterNot(_.getPath.getName startsWith "_") val path = if (dataFiles.length == 1) { dataFiles.head.getPath.toUri.toString @@ -151,11 +116,14 @@ class DefaultSource extends FileFormat with DataSourceRegister { MLUtils.computeNumFeatures(parsed) } - val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse { - computeNumFeatures() - } + val featuresMetadata = new MetadataBuilder() + .putLong("numFeatures", numFeatures) + .build() - new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString)) + Some( + StructType( + StructField("label", DoubleType, nullable = false) :: + StructField("features", new VectorUDT(), nullable = false, featuresMetadata) :: Nil)) } override def prepareWrite( @@ -184,7 +152,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { verifySchema(dataSchema) - val numFeatures = options("numFeatures").toInt + val numFeatures = dataSchema("features").metadata.getLong("numFeatures").toInt assert(numFeatures > 0) val sparse = options.getOrElse("vectorType", "sparse") == "sparse" @@ -193,8 +161,10 @@ class DefaultSource extends FileFormat with DataSourceRegister { sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val points = - new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + + val points = linesReader .map(_.toString.trim) .filterNot(line => line.isEmpty || line.startsWith("#")) .map { line => @@ -203,25 +173,18 @@ class DefaultSource extends FileFormat with DataSourceRegister { } val converter = RowEncoder(dataSchema) - - val unsafeRowIterator = points.map { pt => - val features = if (sparse) pt.features.toSparse else pt.features.toDense - converter.toRow(Row(pt.label, features)) - } - - def toAttribute(f: StructField): AttributeReference = + val fullOutput = dataSchema.map { f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() - - // Appends partition values - val fullOutput = (dataSchema ++ partitionSchema).map(toAttribute) + } val requiredOutput = fullOutput.filter { a => - requiredSchema.fieldNames.contains(a.name) || partitionSchema.fieldNames.contains(a.name) + requiredSchema.fieldNames.contains(a.name) } - val joinedRow = new JoinedRow() - val appendPartitionColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) - unsafeRowIterator.map { dataRow => - appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) + val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput) + + points.map { pt => + val features = if (sparse) pt.features.toSparse else pt.features.toDense + requiredColumns(converter.toRow(Row(pt.label, features))) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index f71d28cf59535..07e98a142b10e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,17 +17,14 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** - * :: DeveloperApi :: * Decision tree node interface. */ -@DeveloperApi sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree @@ -109,12 +106,10 @@ private[ml] object Node { } /** - * :: DeveloperApi :: * Decision tree leaf node. * @param prediction Prediction this node makes * @param impurity Impurity measure at this node (for training data) */ -@DeveloperApi class LeafNode private[ml] ( override val prediction: Double, override val impurity: Double, @@ -147,17 +142,15 @@ class LeafNode private[ml] ( } /** - * :: DeveloperApi :: * Internal Decision Tree node. * @param prediction Prediction this node would make if it were a leaf node * @param impurity Impurity measure at this node (for training data) - * @param gain Information gain value. - * Values < 0 indicate missing values; this quirk will be removed with future updates. + * @param gain Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. * @param leftChild Left-hand child node * @param rightChild Right-hand child node * @param split Information about the test used to split to the left or right child. */ -@DeveloperApi class InternalNode private[ml] ( override val prediction: Double, override val impurity: Double, @@ -167,6 +160,9 @@ class InternalNode private[ml] ( val split: Split, override private[ml] val impurityStats: ImpurityCalculator) extends Node { + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index a4287483d18ed..dff44e2d49ec8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -19,18 +19,16 @@ package org.apache.spark.ml.tree import java.util.Objects -import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.annotation.Since +import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} import org.apache.spark.mllib.tree.model.{Split => OldSplit} /** - * :: DeveloperApi :: * Interface for a "Split," which specifies a test made at a decision tree node * to choose the left or right path. */ -@DeveloperApi sealed trait Split extends Serializable { /** Index of feature which this split tests */ @@ -67,14 +65,12 @@ private[tree] object Split { } /** - * :: DeveloperApi :: * Split which tests a categorical feature. * @param featureIndex Index of the feature to test * @param _leftCategories If the feature value is in this set of categories, then the split goes * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ -@DeveloperApi class CategoricalSplit private[ml] ( override val featureIndex: Int, _leftCategories: Array[Double], @@ -153,13 +149,11 @@ class CategoricalSplit private[ml] ( } /** - * :: DeveloperApi :: * Split which tests a continuous feature. * @param featureIndex Index of the feature to test - * @param threshold If the feature value is <= this threshold, then the split goes left. - * Otherwise, it goes right. + * @param threshold If the feature value is less than or equal to this threshold, then the + * split goes left. Otherwise, it goes right. */ -@DeveloperApi class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) extends Split { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 5f7c40f6071f6..442f52bf0231d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -21,8 +21,8 @@ import scala.collection.mutable import scala.util.Try import org.apache.spark.internal.Logging +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.tree.RandomForestParams -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.Strategy diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index b6334762c7a7f..a0faff236e9d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -18,10 +18,10 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.internal.Logging +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 800430f96c5b1..a7c5f489dea86 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -21,7 +21,7 @@ import java.io.IOException import scala.collection.mutable -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.ml.tree.{LearningNode, Split} @@ -77,8 +77,8 @@ private[spark] class NodeIdCache( // Indicates whether we can checkpoint private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty - // FileSystem instance for deleting checkpoints as needed - private val fs = FileSystem.get(nodeIdsForInstances.sparkContext.hadoopConfiguration) + // Hadoop Configuration for deleting checkpoints as needed + private val hadoopConf = nodeIdsForInstances.sparkContext.hadoopConfiguration /** * Update the node index values in the cache. @@ -130,7 +130,9 @@ private[spark] class NodeIdCache( val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, we'll manually delete it here. try { - fs.delete(new Path(old.getCheckpointFile.get), true) + val path = new Path(old.getCheckpointFile.get) + val fs = path.getFileSystem(hadoopConf) + fs.delete(path, true) } catch { case e: IOException => logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + @@ -154,7 +156,9 @@ private[spark] class NodeIdCache( val old = checkpointQueue.dequeue() if (old.getCheckpointFile.isDefined) { try { - fs.delete(new Path(old.getCheckpointFile.get), true) + val path = new Path(old.getCheckpointFile.get) + val fs = path.getFileSystem(hadoopConf) + fs.delete(path, true) } catch { case e: IOException => logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 2038a6873db73..71c8c42ce5eba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -24,10 +24,10 @@ import scala.util.Random import org.apache.spark.internal.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ import org.apache.spark.ml.util.Instrumentation -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats @@ -491,7 +491,7 @@ private[spark] object RandomForest extends Logging { timer.start("chooseSplits") // In each partition, iterate all instances and compute aggregate stats for each node, - // yield an (nodeIndex, nodeAggregateStats) pair for each node. + // yield a (nodeIndex, nodeAggregateStats) pair for each node. // After a `reduceByKey` operation, // stats of a node will be shuffled to a particular partition and be combined together, // then best splits for nodes are found there. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala index 3a2bf3c725730..a6ac64a0463cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml.tree.impl +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.tree.{ContinuousSplit, Split} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index f38e1ec7c09a8..d3cbc363799a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -23,15 +23,15 @@ import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter} import org.apache.spark.ml.util.DefaultParamsReader.Metadata -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, SQLContext} +import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.util.collection.OpenHashMap /** @@ -332,8 +332,8 @@ private[ml] object DecisionTreeModelReadWrite { def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, - sqlContext: SQLContext): Node = { - import sqlContext.implicits._ + sparkSession: SparkSession): Node = { + import sparkSession.implicits._ implicit val format = DefaultFormats // Get impurity to construct ImpurityCalculator for each node @@ -343,7 +343,7 @@ private[ml] object DecisionTreeModelReadWrite { } val dataPath = new Path(path, "data").toString - val data = sqlContext.read.parquet(dataPath).as[NodeData] + val data = sparkSession.read.parquet(dataPath).as[NodeData] buildTreeFromNodes(data.collect(), impurityType) } @@ -393,7 +393,7 @@ private[ml] object EnsembleModelReadWrite { def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]]( instance: M, path: String, - sql: SQLContext, + sql: SparkSession, extraMetadata: JObject): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata)) val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map { @@ -415,16 +415,16 @@ private[ml] object EnsembleModelReadWrite { /** * Helper method for loading a tree ensemble from disk. * This reconstructs all trees, returning the root nodes. - * @param path Path given to [[saveImpl()]] + * @param path Path given to `saveImpl` * @param className Class name for ensemble model type * @param treeClassName Class name for tree model type in the ensemble * @return (ensemble metadata, array over trees of (tree metadata, root node)), * where the root node is linked with all descendents - * @see [[saveImpl()]] for how the model was saved + * @see `saveImpl` for how the model was saved */ def loadImpl( path: String, - sql: SQLContext, + sql: SparkSession, className: String, treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d7559f8950c3d..57c7e44e97607 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -151,7 +151,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams * [[org.apache.spark.SparkContext]]. * Must be >= 1. * (default = 10) - * @group expertSetParam + * @group setParam */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index a41d02cde755b..520557849b9e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -25,12 +25,11 @@ import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -39,7 +38,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { +private[ml] trait CrossValidatorParams extends ValidatorParams { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 @@ -56,11 +55,9 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { } /** - * :: Experimental :: * K-fold cross validation. */ @Since("1.2.0") -@Experimental class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) extends Estimator[CrossValidatorModel] with CrossValidatorParams with MLWritable with Logging { @@ -179,17 +176,18 @@ object CrossValidator extends MLReadable[CrossValidator] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val numFolds = (metadata.params \ "numFolds").extract[Int] + val seed = (metadata.params \ "seed").extract[Long] new CrossValidator(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) .setNumFolds(numFolds) + .setSeed(seed) } } } /** - * :: Experimental :: * Model from k-fold cross validation. * * @param bestModel The best model selected from k-fold cross validation. @@ -197,7 +195,6 @@ object CrossValidator extends MLReadable[CrossValidator] { * [[CrossValidator.estimatorParamMaps]], in the corresponding order. */ @Since("1.2.0") -@Experimental class CrossValidatorModel private[ml] ( @Since("1.4.0") override val uid: String, @Since("1.2.0") val bestModel: Model[_], @@ -267,14 +264,16 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val numFolds = (metadata.params \ "numFolds").extract[Int] + val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray - val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) - cv.set(cv.estimator, estimator) - .set(cv.evaluator, evaluator) - .set(cv.estimatorParamMaps, estimatorParamMaps) - .set(cv.numFolds, numFolds) + val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + model.set(model.estimator, estimator) + .set(model.evaluator, evaluator) + .set(model.estimatorParamMaps, estimatorParamMaps) + .set(model.numFolds, numFolds) + .set(model.seed, seed) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index b836d2a2340e6..d369e7a61cdc5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,15 +20,13 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.ml.param._ /** - * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ @Since("1.2.0") -@Experimental class ParamGridBuilder @Since("1.2.0") { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] @@ -74,7 +72,7 @@ class ParamGridBuilder @Since("1.2.0") { } /** - * Adds a int param with multiple values. + * Adds an int param with multiple values. */ @Since("1.2.0") def addGrid(param: IntParam, values: Array[Int]): this.type = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index f2b7badbe5132..0fdba1cb8814a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -25,12 +25,11 @@ import scala.language.existentials import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -38,7 +37,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. */ -private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed { +private[ml] trait TrainValidationSplitParams extends ValidatorParams { /** * Param for ratio between train and validation data. Must be between 0 and 1. * Default: 0.75 @@ -55,14 +54,12 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSee } /** - * :: Experimental :: * Validation for hyper-parameter tuning. * Randomly splits the input dataset into train and validation sets, * and uses evaluation metric on the validation set to select the best model. * Similar to [[CrossValidator]], but only splits the set once. */ @Since("1.5.0") -@Experimental class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable with Logging { @@ -177,17 +174,18 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val trainRatio = (metadata.params \ "trainRatio").extract[Double] + val seed = (metadata.params \ "seed").extract[Long] new TrainValidationSplit(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) .setTrainRatio(trainRatio) + .setSeed(seed) } } } /** - * :: Experimental :: * Model from train validation split. * * @param uid Id. @@ -195,7 +193,6 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { * @param validationMetrics Evaluated validation metrics. */ @Since("1.5.0") -@Experimental class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override val uid: String, @Since("1.5.0") val bestModel: Model[_], @@ -265,14 +262,16 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val trainRatio = (metadata.params \ "trainRatio").extract[Double] + val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray - val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) - tvs.set(tvs.estimator, estimator) - .set(tvs.evaluator, evaluator) - .set(tvs.estimatorParamMaps, estimatorParamMaps) - .set(tvs.trainRatio, trainRatio) + val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) + model.set(model.estimator, estimator) + .set(model.evaluator, evaluator) + .set(model.estimatorParamMaps, estimatorParamMaps) + .set(model.trainRatio, trainRatio) + .set(model.seed, seed) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 7a4e106aeb999..26fd73814d70a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -25,15 +25,15 @@ import org.apache.spark.SparkContext import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} -import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, - MLWritable} +import org.apache.spark.ml.param.shared.HasSeed +import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, MLWritable} import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.sql.types.StructType /** * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]]. */ -private[ml] trait ValidatorParams extends Params { +private[ml] trait ValidatorParams extends HasSeed with Params { /** * param for the estimator to be validated @@ -137,7 +137,8 @@ private[ml] object ValidatorParams { } val jsonParams = validatorSpecificParams ++ List( - "estimatorParamMaps" -> parse(estimatorParamMapsJson)) + "estimatorParamMaps" -> parse(estimatorParamMapsJson), + "seed" -> parse(instance.seed.jsonEncode(instance.getSeed))) DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index 96a38a3bde960..f34a8310ddf1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.util import scala.collection.immutable.HashMap import org.apache.spark.ml.attribute._ -import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.sql.types.StructField diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 94d1b83ec2531..4413fefdea3ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -26,7 +26,7 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -40,34 +40,49 @@ import org.apache.spark.util.Utils * Trait for [[MLWriter]] and [[MLReader]]. */ private[util] sealed trait BaseReadWrite { - private var optionSQLContext: Option[SQLContext] = None + private var optionSparkSession: Option[SparkSession] = None /** - * Sets the SQL context to use for saving/loading. + * Sets the Spark SQLContext to use for saving/loading. */ @Since("1.6.0") + @deprecated("Use session instead", "2.0.0") def context(sqlContext: SQLContext): this.type = { - optionSQLContext = Option(sqlContext) + optionSparkSession = Option(sqlContext.sparkSession) this } /** - * Returns the user-specified SQL context or the default. + * Sets the Spark Session to use for saving/loading. */ - protected final def sqlContext: SQLContext = { - if (optionSQLContext.isEmpty) { - optionSQLContext = Some(SQLContext.getOrCreate(SparkContext.getOrCreate())) + @Since("2.0.0") + def session(sparkSession: SparkSession): this.type = { + optionSparkSession = Option(sparkSession) + this + } + + /** + * Returns the user-specified Spark Session or the default. + */ + protected final def sparkSession: SparkSession = { + if (optionSparkSession.isEmpty) { + optionSparkSession = Some(SparkSession.builder().getOrCreate()) } - optionSQLContext.get + optionSparkSession.get } - protected final def sparkSession: SparkSession = sqlContext.sparkSession + /** + * Returns the user-specified SQL context or the default. + */ + protected final def sqlContext: SQLContext = sparkSession.sqlContext /** Returns the underlying [[SparkContext]]. */ protected final def sc: SparkContext = sparkSession.sparkContext } /** + * :: Experimental :: + * * Abstract class for utility classes that can save ML instances. */ @Experimental @@ -116,12 +131,18 @@ abstract class MLWriter extends BaseReadWrite with Logging { } // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** + * :: Experimental :: + * * Trait for classes that provide [[MLWriter]]. */ +@Experimental @Since("1.6.0") trait MLWritable { @@ -139,12 +160,26 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } -private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => +/** + * :: DeveloperApi :: + * + * Helper trait for making simple [[Params]] types writable. If a [[Params]] class stores + * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide + * a default implementation of writing saved instances of the class. + * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle + * [[org.apache.spark.sql.Dataset]]. + * + * @see [[DefaultParamsReadable]], the counterpart to this trait + */ +@DeveloperApi +trait DefaultParamsWritable extends MLWritable { self: Params => override def write: MLWriter = new DefaultParamsWriter(this) } /** + * :: Experimental :: + * * Abstract class for utility classes that can load ML instances. * * @tparam T ML instance type @@ -160,10 +195,15 @@ abstract class MLReader[T] extends BaseReadWrite { def load(path: String): T // override for Java compatibility - override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } /** + * :: Experimental :: + * * Trait for objects that provide [[MLReader]]. * * @tparam T ML instance type @@ -187,9 +227,23 @@ trait MLReadable[T] { def load(path: String): T = read.load(path) } -private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { - override def read: MLReader[T] = new DefaultParamsReader +/** + * :: DeveloperApi :: + * + * Helper trait for making simple [[Params]] types readable. If a [[Params]] class stores + * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide + * a default implementation of reading saved instances of the class. + * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle + * [[org.apache.spark.sql.Dataset]]. + * + * @tparam T ML instance type + * @see [[DefaultParamsWritable]], the counterpart to this trait + */ +@DeveloperApi +trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader[T] } /** @@ -422,7 +476,7 @@ private[ml] object MetaAlgorithmReadWrite { case rformModel: RFormulaModel => Array(rformModel.pipelineModel) case _: Params => Array() } - val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _) + val subStageMaps = subStages.flatMap(getUidMapImpl) List((instance.uid, instance)) ++ subStageMaps } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala index 8d4174124b5c4..e79b1f31643d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.util import scala.collection.mutable -import org.apache.spark.{Accumulator, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.util.LongAccumulator; /** * Abstract class for stopwatches. @@ -102,12 +103,12 @@ private[spark] class DistributedStopwatch( sc: SparkContext, override val name: String) extends Stopwatch { - private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)") + private val elapsedTime: LongAccumulator = sc.longAccumulator(s"DistributedStopwatch($name)") override def elapsed(): Long = elapsedTime.value override protected def add(duration: Long): Unit = { - elapsedTime += duration + elapsedTime.add(duration) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala b/mllib/src/main/scala/org/apache/spark/mllib/JavaPackage.java similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala rename to mllib/src/main/scala/org/apache/spark/mllib/JavaPackage.java index 5a9852809c0eb..22e34524aa592 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SinkStatus.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/JavaPackage.java @@ -15,20 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.mllib; -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.{Offset, Sink} +import org.apache.spark.annotation.AlphaComponent; /** - * :: Experimental :: - * Status and metrics of a streaming [[Sink]]. - * - * @param description Description of the source corresponding to this status - * @param offset Current offset up to which data has been written by the sink - * @since 2.0.0 + * A dummy class as a workaround to show the package doc of spark.mllib in generated + * Java API docs. + * @see + * JDK-4492654 */ -@Experimental -class SinkStatus private[sql]( - val description: String, - val offset: Offset) +@AlphaComponent +public class JavaPackage { + private JavaPackage() {} +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 8daee7b3aa1e9..a80cca70f4b28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -41,8 +41,7 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.stat.{ - KernelDensity, MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.stat.{KernelDensity, MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult} @@ -54,7 +53,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree RandomForestModel} import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -150,7 +149,7 @@ private[python] class PythonMLLibAPI extends Serializable { intercept: Boolean, validateData: Boolean, convergenceTol: Double): JList[Object] = { - val lrAlg = new LinearRegressionWithSGD() + val lrAlg = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0) lrAlg.setIntercept(intercept) .setValidateData(validateData) lrAlg.optimizer @@ -179,7 +178,7 @@ private[python] class PythonMLLibAPI extends Serializable { intercept: Boolean, validateData: Boolean, convergenceTol: Double): JList[Object] = { - val lassoAlg = new LassoWithSGD() + val lassoAlg = new LassoWithSGD(1.0, 100, 0.01, 1.0) lassoAlg.setIntercept(intercept) .setValidateData(validateData) lassoAlg.optimizer @@ -207,7 +206,7 @@ private[python] class PythonMLLibAPI extends Serializable { intercept: Boolean, validateData: Boolean, convergenceTol: Double): JList[Object] = { - val ridgeAlg = new RidgeRegressionWithSGD() + val ridgeAlg = new RidgeRegressionWithSGD(1.0, 100, 0.01, 1.0) ridgeAlg.setIntercept(intercept) .setValidateData(validateData) ridgeAlg.optimizer @@ -266,7 +265,7 @@ private[python] class PythonMLLibAPI extends Serializable { intercept: Boolean, validateData: Boolean, convergenceTol: Double): JList[Object] = { - val LogRegAlg = new LogisticRegressionWithSGD() + val LogRegAlg = new LogisticRegressionWithSGD(1.0, 100, 0.01, 1.0) LogRegAlg.setIntercept(intercept) .setValidateData(validateData) LogRegAlg.optimizer @@ -1128,7 +1127,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Wrapper around RowMatrix constructor. */ def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = { - new RowMatrix(rows.rdd.retag(classOf[Vector]), numRows, numCols) + new RowMatrix(rows.rdd, numRows, numCols) } /** @@ -1176,8 +1175,9 @@ private[python] class PythonMLLibAPI extends Serializable { def getIndexedRows(indexedRowMatrix: IndexedRowMatrix): DataFrame = { // We use DataFrames for serialization of IndexedRows to Python, // so return a DataFrame. - val sqlContext = SQLContext.getOrCreate(indexedRowMatrix.rows.sparkContext) - sqlContext.createDataFrame(indexedRowMatrix.rows) + val sc = indexedRowMatrix.rows.sparkContext + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + spark.createDataFrame(indexedRowMatrix.rows) } /** @@ -1186,8 +1186,9 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixEntries(coordinateMatrix: CoordinateMatrix): DataFrame = { // We use DataFrames for serialization of MatrixEntry entries to // Python, so return a DataFrame. - val sqlContext = SQLContext.getOrCreate(coordinateMatrix.entries.sparkContext) - sqlContext.createDataFrame(coordinateMatrix.entries) + val sc = coordinateMatrix.entries.sparkContext + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + spark.createDataFrame(coordinateMatrix.entries) } /** @@ -1196,22 +1197,52 @@ private[python] class PythonMLLibAPI extends Serializable { def getMatrixBlocks(blockMatrix: BlockMatrix): DataFrame = { // We use DataFrames for serialization of sub-matrix blocks to // Python, so return a DataFrame. - val sqlContext = SQLContext.getOrCreate(blockMatrix.blocks.sparkContext) - sqlContext.createDataFrame(blockMatrix.blocks) + val sc = blockMatrix.blocks.sparkContext + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + spark.createDataFrame(blockMatrix.blocks) + } + + /** + * Python-friendly version of [[MLUtils.convertVectorColumnsToML()]]. + */ + def convertVectorColumnsToML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { + MLUtils.convertVectorColumnsToML(dataset, cols.asScala: _*) + } + + /** + * Python-friendly version of [[MLUtils.convertVectorColumnsFromML()]] + */ + def convertVectorColumnsFromML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { + MLUtils.convertVectorColumnsFromML(dataset, cols.asScala: _*) + } + + /** + * Python-friendly version of [[MLUtils.convertMatrixColumnsToML()]]. + */ + def convertMatrixColumnsToML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { + MLUtils.convertMatrixColumnsToML(dataset, cols.asScala: _*) + } + + /** + * Python-friendly version of [[MLUtils.convertMatrixColumnsFromML()]] + */ + def convertMatrixColumnsFromML(dataset: DataFrame, cols: JArrayList[String]): DataFrame = { + MLUtils.convertMatrixColumnsFromML(dataset, cols.asScala: _*) } } /** - * SerDe utility functions for PythonMLLibAPI. + * Basic SerDe utility class. */ -private[spark] object SerDe extends Serializable { +private[spark] abstract class SerDeBase { - val PYSPARK_PACKAGE = "pyspark.mllib" + val PYSPARK_PACKAGE: String + def initialize(): Unit /** * Base class used for pickle */ - private[python] abstract class BasePickler[T: ClassTag] + private[spark] abstract class BasePickler[T: ClassTag] extends IObjectPickler with IObjectConstructor { private val cls = implicitly[ClassTag[T]].runtimeClass @@ -1262,6 +1293,68 @@ private[spark] object SerDe extends Serializable { private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler) } + def dumps(obj: AnyRef): Array[Byte] = { + obj match { + // Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834. + case array: Array[_] => new Pickler().dumps(array.toSeq.asJava) + case _ => new Pickler().dumps(obj) + } + } + + def loads(bytes: Array[Byte]): AnyRef = { + new Unpickler().loads(bytes) + } + + /* convert object into Tuple */ + def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = { + rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int])) + } + + /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { + rdd.map(x => Array(x._1, x._2)) + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + new SerDeUtil.AutoBatchedPickler(iter) + } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj match { + case list: JArrayList[_] => list.asScala + case arr: Array[_] => arr + } + } else { + Seq(obj) + } + } + }.toJavaRDD() + } +} + +/** + * SerDe utility functions for PythonMLLibAPI. + */ +private[spark] object SerDe extends SerDeBase with Serializable { + + override val PYSPARK_PACKAGE = "pyspark.mllib" + // Pickler for DenseVector private[python] class DenseVectorPickler extends BasePickler[DenseVector] { @@ -1428,7 +1521,7 @@ private[spark] object SerDe extends Serializable { } } - // Pickler for LabeledPoint + // Pickler for MLlib LabeledPoint private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] { def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = { @@ -1474,7 +1567,7 @@ private[spark] object SerDe extends Serializable { var initialized = false // This should be called before trying to serialize any above classes // In cluster mode, this should be put in the closure - def initialize(): Unit = { + override def initialize(): Unit = { SerDeUtil.initialize() synchronized { if (!initialized) { @@ -1490,58 +1583,4 @@ private[spark] object SerDe extends Serializable { } // will not called in Executor automatically initialize() - - def dumps(obj: AnyRef): Array[Byte] = { - obj match { - // Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834. - case array: Array[_] => new Pickler().dumps(array.toSeq.asJava) - case _ => new Pickler().dumps(obj) - } - } - - def loads(bytes: Array[Byte]): AnyRef = { - new Unpickler().loads(bytes) - } - - /* convert object into Tuple */ - def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = { - rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int])) - } - - /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { - rdd.map(x => Array(x._1, x._2)) - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { - jRDD.rdd.mapPartitions { iter => - initialize() // let it called in executor - new SerDeUtil.AutoBatchedPickler(iter) - } - } - - /** - * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. - */ - def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { - pyRDD.rdd.mapPartitions { iter => - initialize() // let it called in executor - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj match { - case list: JArrayList[_] => list.asScala - case arr: Array[_] => arr - } - } else { - Seq(obj) - } - } - }.toJavaRDD() - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 4b4ed2291d139..5cbfbff3e4a62 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -43,18 +43,34 @@ private[python] class Word2VecModelWrapper(model: Word2VecModel) { rdd.rdd.map(model.transform) } + /** + * Finds synonyms of a word; do not include the word itself in results. + * @param word a word + * @param num number of synonyms to find + * @return a list consisting of a list of words and a vector of cosine similarities + */ def findSynonyms(word: String, num: Int): JList[Object] = { - val vec = transform(word) - findSynonyms(vec, num) + prepareResult(model.findSynonyms(word, num)) } + /** + * Finds words similar to the the vector representation of a word without + * filtering results. + * @param vector a vector + * @param num number of synonyms to find + * @return a list consisting of a list of words and a vector of cosine similarities + */ def findSynonyms(vector: Vector, num: Int): JList[Object] = { - val result = model.findSynonyms(vector, num) + prepareResult(model.findSynonyms(vector, num)) + } + + private def prepareResult(result: Array[(String, Double)]) = { val similarity = Vectors.dense(result.map(_._2)) val words = result.map(_._1) List(words, similarity).map(_.asInstanceOf[Object]).asJava } + def getVectors: JMap[String, JList[Float]] = { model.getVectors.map { case (k, v) => (k, v.toList.asJava) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 1d25a58e0f2e8..e4cbf5acbc11d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.storage.StorageLevel /** @@ -86,7 +86,7 @@ class LogisticRegressionModel @Since("1.3.0") ( /** * Sets the threshold that separates positive predictions from negative predictions * in Binary Logistic Regression. An example with prediction score greater than or equal to - * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. + * this threshold is identified as a positive, and negative otherwise. The default value is 0.5. * It is only used for binary classification. */ @Since("1.0.0") @@ -200,13 +200,12 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { /** * Train a classification model for Binary Logistic Regression * using Stochastic Gradient Descent. By default L2 regularization is used, - * which can be changed via [[LogisticRegressionWithSGD.optimizer]]. + * which can be changed via `LogisticRegressionWithSGD.optimizer`. * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * for k classes multi-label classification problem. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ @Since("0.8.0") -@deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0") class LogisticRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -229,6 +228,7 @@ class LogisticRegressionWithSGD private[mllib] ( * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. */ @Since("0.8.0") + @deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0") def this() = this(1.0, 100, 0.01, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { @@ -422,7 +422,7 @@ class LogisticRegressionWithLBFGS LogisticRegressionModel = { // ml's Logistic regression only supports binary classification currently. if (numOfLinearPredictor == 1) { - def runWithMlLogisitcRegression(elasticNetParam: Double) = { + def runWithMlLogisticRegression(elasticNetParam: Double) = { // Prepare the ml LogisticRegression based on our settings val lr = new org.apache.spark.ml.classification.LogisticRegression() lr.setRegParam(optimizer.getRegParam()) @@ -431,26 +431,25 @@ class LogisticRegressionWithLBFGS if (userSuppliedWeights) { val uid = Identifiable.randomUID("logreg-static") lr.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel( - uid, initialWeights, 1.0)) + uid, initialWeights.asML, 1.0)) } lr.setFitIntercept(addIntercept) lr.setMaxIter(optimizer.getNumIterations()) lr.setTol(optimizer.getConvergenceTol()) // Convert our input into a DataFrame - val sqlContext = new SQLContext(input.context) - import sqlContext.implicits._ - val df = input.toDF() + val spark = SparkSession.builder().sparkContext(input.context).getOrCreate() + val df = spark.createDataFrame(input.map(_.asML)) // Determine if we should cache the DF val handlePersistence = input.getStorageLevel == StorageLevel.NONE // Train our model - val mlLogisticRegresionModel = lr.train(df, handlePersistence) + val mlLogisticRegressionModel = lr.train(df, handlePersistence) // convert the model - val weights = Vectors.dense(mlLogisticRegresionModel.coefficients.toArray) - createModel(weights, mlLogisticRegresionModel.intercept) + val weights = Vectors.dense(mlLogisticRegressionModel.coefficients.toArray) + createModel(weights, mlLogisticRegressionModel.intercept) } optimizer.getUpdater() match { - case x: SquaredL2Updater => runWithMlLogisitcRegression(0.0) - case x: L1Updater => runWithMlLogisitcRegression(1.0) + case x: SquaredL2Updater => runWithMlLogisticRegression(0.0) + case x: L1Updater => runWithMlLogisticRegression(1.0) case _ => super.run(input, initialWeights) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index eb3ee41f7cf4f..593a86f69ad51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -31,7 +31,7 @@ import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVect import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.SparkSession /** * Model for Naive Bayes Classifiers. @@ -193,8 +193,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { modelType: String) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -203,15 +202,14 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) // Create Parquet data. - val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() - dataRDD.write.parquet(dataPath(path)) + spark.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath(path)) } @Since("1.3.0") def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Load Parquet data. - val dataRDD = sqlContext.read.parquet(dataPath(path)) + val dataRDD = spark.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) @@ -240,8 +238,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { theta: Array[Array[Double]]) def save(sc: SparkContext, path: String, data: Data): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -250,14 +247,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) // Create Parquet data. - val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() - dataRDD.write.parquet(dataPath(path)) + spark.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): NaiveBayesModel = { - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Load Parquet data. - val dataRDD = sqlContext.read.parquet(dataPath(path)) + val dataRDD = spark.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta").take(1) @@ -327,7 +323,7 @@ class NaiveBayes private ( @Since("0.9.0") def setLambda(lambda: Double): NaiveBayes = { require(lambda >= 0, - s"Smoothing parameter must be nonnegative but got ${lambda}") + s"Smoothing parameter must be nonnegative but got $lambda") this.lambda = lambda this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index a8d3fd4177a23..7c3ccbb40b812 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -44,7 +44,7 @@ class SVMModel @Since("1.1.0") ( /** * Sets the threshold that separates positive predictions from negative predictions. An example - * with prediction score greater than or equal to this threshold is identified as an positive, + * with prediction score greater than or equal to this threshold is identified as a positive, * and negative otherwise. The default value is 0.0. */ @Since("1.0.0") @@ -72,7 +72,7 @@ class SVMModel @Since("1.1.0") ( dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { - val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept + val margin = weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept threshold match { case Some(t) => if (margin > t) 1.0 else 0.0 case None => margin diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index 4308ae04ee84d..84491181d077f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.Loader -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} /** * Helper class for import/export of GLM classification models. @@ -51,8 +51,7 @@ private[classification] object GLMClassificationModel { weights: Vector, intercept: Double, threshold: Option[Double]): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -62,7 +61,7 @@ private[classification] object GLMClassificationModel { // Create Parquet data. val data = Data(weights, intercept, threshold) - sc.parallelize(Seq(data), 1).toDF().write.parquet(Loader.dataPath(path)) + spark.createDataFrame(Seq(data)).repartition(1).write.parquet(Loader.dataPath(path)) } /** @@ -73,13 +72,13 @@ private[classification] object GLMClassificationModel { * @param modelClass String name for model class (used for error messages) */ def loadData(sc: SparkContext, path: String, modelClass: String): Data = { - val datapath = Loader.dataPath(path) - val sqlContext = SQLContext.getOrCreate(sc) - val dataRDD = sqlContext.read.parquet(datapath) + val dataPath = Loader.dataPath(path) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val dataRDD = spark.read.parquet(dataPath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) - assert(dataArray.length == 1, s"Unable to load $modelClass data from: $datapath") + assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath") val data = dataArray(0) - assert(data.size == 3, s"Unable to load $modelClass data from: $datapath") + assert(data.size == 3, s"Unable to load $modelClass data from: $dataPath") val (weights, intercept) = data match { case Row(weights: Vector, intercept: Double, _) => (weights, intercept) @@ -92,5 +91,4 @@ private[classification] object GLMClassificationModel { Data(weights, intercept, threshold) } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index e4bd0dc25ee54..f1664ce4ab3f8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -22,7 +22,7 @@ import java.util.Random import scala.annotation.tailrec import scala.collection.mutable -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} @@ -52,7 +52,6 @@ import org.apache.spark.storage.StorageLevel * KDD Workshop on Text Mining, 2000.]] */ @Since("1.6.0") -@Experimental class BisectingKMeans private ( private var k: Int, private var maxIterations: Int, @@ -407,7 +406,6 @@ private object BisectingKMeans extends Serializable { * @param children children nodes */ @Since("1.6.0") -@Experimental private[clustering] class ClusteringTreeNode private[clustering] ( val index: Int, val size: Long, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index c3b5b8b7900f5..8438015ccecea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -23,13 +23,13 @@ import org.json4s.jackson.JsonMethods._ import org.json4s.JsonDSL._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} /** * Clustering model produced by [[BisectingKMeans]]. @@ -39,7 +39,6 @@ import org.apache.spark.sql.{Row, SQLContext} * @param root the root node of the clustering tree */ @Since("1.6.0") -@Experimental class BisectingKMeansModel private[clustering] ( private[clustering] val root: ClusteringTreeNode ) extends Serializable with Saveable with Logging { @@ -144,8 +143,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rootId" -> model.root.index))) @@ -154,8 +152,7 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { val data = getNodes(model.root).map(node => Data(node.index, node.size, node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, node.children.map(_.index))) - val dataRDD = sc.parallelize(data).toDF() - dataRDD.write.parquet(Loader.dataPath(path)) + spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) } private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { @@ -167,8 +164,8 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { } def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { - val sqlContext = SQLContext.getOrCreate(sc) - val rows = sqlContext.read.parquet(Loader.dataPath(path)) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val rows = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Data](rows.schema) val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index f04c87259c941..a214b1a26f443 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -166,7 +166,7 @@ class GaussianMixture private ( val sc = data.sparkContext // we will operate on the data as breeze data - val breezeData = data.map(_.toBreeze).cache() + val breezeData = data.map(_.asBreeze).cache() // Get length of the input vectors val d = breezeData.first().length diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index f87613cc72f9a..c30cc3e2398e4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} /** * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points @@ -96,7 +96,7 @@ class GaussianMixtureModel @Since("1.3.0") ( val bcDists = sc.broadcast(gaussians) val bcWeights = sc.broadcast(weights) points.map { x => - computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) + computeSoftAssignments(x.asBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } @@ -105,7 +105,7 @@ class GaussianMixtureModel @Since("1.3.0") ( */ @Since("1.4.0") def predictSoft(point: Vector): Array[Double] = { - computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k) + computeSoftAssignments(point.asBreeze.toDenseVector, gaussians, weights, k) } /** @@ -143,9 +143,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { path: String, weights: Array[Double], gaussians: Array[MultivariateGaussian]): Unit = { - - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render @@ -156,13 +154,13 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val dataArray = Array.tabulate(weights.length) { i => Data(weights(i), gaussians(i).mu, gaussians(i).sigma) } - sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) + spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) - val sqlContext = SQLContext.getOrCreate(sc) - val dataFrame = sqlContext.read.parquet(dataPath) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val dataFrame = spark.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) val dataArray = dataFrame.select("weight", "mu", "sigma").collect() @@ -189,7 +187,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { s"GaussianMixtureModel requires weights of length $k " + s"got weights of length ${model.weights.length}") require(model.gaussians.length == k, - s"GaussianMixtureModel requires gaussians of length $k" + + s"GaussianMixtureModel requires gaussians of length $k " + s"got gaussians of length ${model.gaussians.length}") model case _ => throw new Exception( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 60f13d27d0a6e..871b1c7d211c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -279,7 +279,7 @@ class KMeans private ( } val activeCenters = activeRuns.map(r => centers(r)).toArray - val costAccums = activeRuns.map(_ => sc.accumulator(0.0)) + val costAccums = activeRuns.map(_ => sc.doubleAccumulator) val bcActiveCenters = sc.broadcast(activeCenters) @@ -296,7 +296,7 @@ class KMeans private ( points.foreach { point => (0 until runs).foreach { i => val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point) - costAccums(i) += cost + costAccums(i).add(cost) val sum = sums(i)(bestCenter) axpy(1.0, point.vector, sum) counts(i)(bestCenter) += 1 @@ -441,7 +441,7 @@ class KMeans private ( val rs = (0 until runs).filter { r => rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) } - if (rs.length > 0) Some((p, rs)) else None + if (rs.nonEmpty) Some((p, rs)) else None } }.collect() mergeNewCenters() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 439e4f8672242..aa78149699a27 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. @@ -123,25 +123,24 @@ object KMeansModel extends Loader[KMeansModel] { val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel" def save(sc: SparkContext, model: KMeansModel, path: String): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => Cluster(id, point) - }.toDF() - dataRDD.write.parquet(Loader.dataPath(path)) + } + spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): KMeansModel = { implicit val formats = DefaultFormats - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val centroids = sqlContext.read.parquet(Loader.dataPath(path)) + val centroids = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Cluster](centroids.schema) val localCentroids = centroids.rdd.map(Cluster.apply).collect() assert(k == localCentroids.length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 4913c0287a1d0..9ebba1de0dad4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -25,13 +25,13 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.util.BoundedPriorityQueue /** @@ -205,7 +205,7 @@ class LocalLDAModel private[spark] ( @Since("1.3.0") override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { - val brzTopics = topics.toBreeze.toDenseMatrix + val brzTopics = topics.asBreeze.toDenseMatrix Range(0, k).map { topicIndex => val topic = normalize(brzTopics(::, topicIndex), 1.0) val (termWeights, terms) = @@ -233,7 +233,7 @@ class LocalLDAModel private[spark] ( */ @Since("1.5.0") def logLikelihood(documents: RDD[(Long, Vector)]): Double = logLikelihoodBound(documents, - docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, + docConcentration, topicConcentration, topicsMatrix.asBreeze.toDenseMatrix, gammaShape, k, vocabSize) /** @@ -291,7 +291,7 @@ class LocalLDAModel private[spark] ( gammaShape: Double, k: Int, vocabSize: Long): Double = { - val brzAlpha = alpha.toBreeze.toDenseVector + val brzAlpha = alpha.asBreeze.toDenseVector // transpose because dirichletExpectation normalizes by row and we need to normalize // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t @@ -344,9 +344,9 @@ class LocalLDAModel private[spark] ( def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { // Double transpose because dirichletExpectation normalizes by row and we need to normalize // by topic (columns of lambda) - val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) val expElogbetaBc = documents.sparkContext.broadcast(expElogbeta) - val docConcentrationBrz = this.docConcentration.toBreeze + val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k @@ -367,9 +367,9 @@ class LocalLDAModel private[spark] ( /** Get a method usable as a UDF for [[topicDistributions()]] */ private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { - val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) val expElogbetaBc = sc.broadcast(expElogbeta) - val docConcentrationBrz = this.docConcentration.toBreeze + val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k @@ -399,14 +399,14 @@ class LocalLDAModel private[spark] ( */ @Since("2.0.0") def topicDistribution(document: Vector): Vector = { - val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) if (document.numNonzeros == 0) { Vectors.zeros(this.k) } else { val (gamma, _, _) = OnlineLDAOptimizer.variationalTopicInference( document, expElogbeta, - this.docConcentration.toBreeze, + this.docConcentration.asBreeze, gammaShape, this.k) Vectors.dense(normalize(gamma, 1.0).toArray) @@ -425,7 +425,11 @@ class LocalLDAModel private[spark] ( } -@Experimental +/** + * Local (non-distributed) model fitted by [[LDA]]. + * + * This model stores the inferred topics only; it does not store info about the training dataset. + */ @Since("1.5.0") object LocalLDAModel extends Loader[LocalLDAModel] { @@ -446,9 +450,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { docConcentration: Vector, topicConcentration: Double, gammaShape: Double): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ - + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -458,11 +460,11 @@ object LocalLDAModel extends Loader[LocalLDAModel] { ("gammaShape" -> gammaShape))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) - val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix + val topicsDenseMatrix = topicsMatrix.asBreeze.toDenseMatrix val topics = Range(0, k).map { topicInd => Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd) - }.toSeq - sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path)) + } + spark.createDataFrame(topics).repartition(1).write.parquet(Loader.dataPath(path)) } def load( @@ -472,8 +474,8 @@ object LocalLDAModel extends Loader[LocalLDAModel] { topicConcentration: Double, gammaShape: Double): LocalLDAModel = { val dataPath = Loader.dataPath(path) - val sqlContext = SQLContext.getOrCreate(sc) - val dataFrame = sqlContext.read.parquet(dataPath) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val dataFrame = spark.read.parquet(dataPath) Loader.checkSchema[Data](dataFrame.schema) val topics = dataFrame.collect() @@ -482,7 +484,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val brzTopics = BDM.zeros[Double](vocabSize, k) topics.foreach { case Row(vec: Vector, ind: Int) => - brzTopics(::, ind) := vec.toBreeze + brzTopics(::, ind) := vec.asBreeze } val topicsMat = Matrices.fromBreeze(brzTopics) @@ -816,8 +818,13 @@ class DistributedLDAModel private[clustering] ( } } - -@Experimental +/** + * Distributed model fitted by [[LDA]]. + * This type of model is currently only produced by Expectation-Maximization (EM). + * + * This model stores the inferred topics, the full training dataset, and the topic distribution + * for each training document. + */ @Since("1.5.0") object DistributedLDAModel extends Loader[DistributedLDAModel] { @@ -853,8 +860,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { topicConcentration: Double, iterationTimes: Array[Double], gammaShape: Double): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ @@ -866,18 +872,17 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString - sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF() - .write.parquet(newPath) + spark.createDataFrame(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).write.parquet(newPath) val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString - graph.vertices.map { case (ind, vertex) => + spark.createDataFrame(graph.vertices.map { case (ind, vertex) => VertexData(ind, Vectors.fromBreeze(vertex)) - }.toDF().write.parquet(verticesPath) + }).write.parquet(verticesPath) val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString - graph.edges.map { case Edge(srcId, dstId, prop) => + spark.createDataFrame(graph.edges.map { case Edge(srcId, dstId, prop) => EdgeData(srcId, dstId, prop) - }.toDF().write.parquet(edgesPath) + }).write.parquet(edgesPath) } def load( @@ -891,18 +896,18 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString - val sqlContext = SQLContext.getOrCreate(sc) - val dataFrame = sqlContext.read.parquet(dataPath) - val vertexDataFrame = sqlContext.read.parquet(vertexDataPath) - val edgeDataFrame = sqlContext.read.parquet(edgeDataPath) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val dataFrame = spark.read.parquet(dataPath) + val vertexDataFrame = spark.read.parquet(vertexDataPath) + val edgeDataFrame = spark.read.parquet(edgeDataPath) Loader.checkSchema[Data](dataFrame.schema) Loader.checkSchema[VertexData](vertexDataFrame.schema) Loader.checkSchema[EdgeData](edgeDataFrame.schema) val globalTopicTotals: LDA.TopicCounts = - dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector + dataFrame.first().getAs[Vector](0).asBreeze.toDenseVector val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.rdd.map { - case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector) + case Row(ind: Long, vec: Vector) => (ind, vec.asBreeze.toDenseVector) } val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.rdd.map { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 1b3e2f600d028..2436efba32489 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -137,7 +137,7 @@ final class EMLDAOptimizer extends LDAOptimizer { // For each document, create an edge (Document -> Term) for each unique term in the document. val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => // Add edges for terms with non-zero counts. - termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + termCounts.asBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => Edge(docID, term2index(term), cnt) } } @@ -457,7 +457,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val vocabSize = this.vocabSize val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta) - val alpha = this.alpha.toBreeze + val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => @@ -507,7 +507,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { private def updateAlpha(gammat: BDM[Double]): Unit = { val weight = rho() val N = gammat.rows.toDouble - val alpha = this.alpha.toBreeze.toDenseVector + val alpha = this.alpha.asBreeze.toDenseVector val logphat: BDM[Double] = sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)) / N val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat.toDenseVector) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index adf20dc4b8b16..53587670a5db0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -46,17 +46,15 @@ private[mllib] object LocalKMeans extends Logging { // Initialize centers by sampling using the k-means++ procedure. centers(0) = pickWeighted(rand, points, weights).toDense + val costArray = points.map(KMeans.fastSquaredDistance(_, centers(0))) + for (i <- 1 until k) { - // Pick the next center with a probability proportional to cost under current centers - val curCenters = centers.view.take(i) - val sum = points.view.zip(weights).map { case (p, w) => - w * KMeans.pointCost(curCenters, p) - }.sum + val sum = costArray.zip(weights).map(p => p._1 * p._2).sum val r = rand.nextDouble() * sum var cumulativeScore = 0.0 var j = 0 while (j < points.length && cumulativeScore < r) { - cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j)) + cumulativeScore += weights(j) * costArray(j) j += 1 } if (j == 0) { @@ -66,6 +64,12 @@ private[mllib] object LocalKMeans extends Logging { } else { centers(i) = points(j - 1).toDense } + + // update costArray + for (p <- points.indices) { + costArray(p) = math.min(KMeans.fastSquaredDistance(points(p), centers(i)), costArray(p)) + } + } // Run up to maxIterations iterations of Lloyd's algorithm diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 2e257ff9b7def..c760ddd6ad40b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.util.random.XORShiftRandom /** @@ -70,28 +70,26 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode @Since("1.4.0") def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) - val dataRDD = model.assignments.toDF() - dataRDD.write.parquet(Loader.dataPath(path)) + spark.createDataFrame(model.assignments).write.parquet(Loader.dataPath(path)) } @Since("1.4.0") def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { implicit val formats = DefaultFormats - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val assignments = sqlContext.read.parquet(Loader.dataPath(path)) + val assignments = spark.read.parquet(Loader.dataPath(path)) Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) val assignmentsRDD = assignments.rdd.map { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 24e1cff0dcc6b..52bdccb919a61 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -135,8 +135,8 @@ class StreamingKMeansModel @Since("1.2.0") ( while (j < dim) { val x = largestClusterCenter(j) val p = 1e-14 * math.max(math.abs(x), 1.0) - largestClusterCenter.toBreeze(j) = x + p - smallestClusterCenter.toBreeze(j) = x - p + largestClusterCenter.asBreeze(j) = x + p + smallestClusterCenter.asBreeze(j) = x - p j += 1 } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 5dde2bdb17f3a..9a6a8dbdccbf3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -25,7 +25,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * ::Experimental:: * Evaluator for multiclass classification. * * @param predictionAndLabels an RDD of (prediction, label) pairs. @@ -139,7 +138,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * Returns precision */ @Since("1.1.0") - lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount + @deprecated("Use accuracy.", "2.0.0") + lazy val precision: Double = accuracy /** * Returns recall @@ -148,14 +148,24 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * of all false negatives) */ @Since("1.1.0") - lazy val recall: Double = precision + @deprecated("Use accuracy.", "2.0.0") + lazy val recall: Double = accuracy /** * Returns f-measure * (equals to precision and recall because precision equals recall) */ @Since("1.1.0") - lazy val fMeasure: Double = precision + @deprecated("Use accuracy.", "2.0.0") + lazy val fMeasure: Double = accuracy + + /** + * Returns accuracy + * (equals to the total number of correctly classified instances + * out of the total number of instances.) + */ + @Since("2.0.0") + lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount /** * Returns weighted true positive rate diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index 4ed4a058945c3..e29b51c3a19da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD /** - * ::Experimental:: * Evaluator for ranking algorithms. * * Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 4f0e13feae086..56fb2d33c2ca0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. @@ -134,8 +134,8 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel" def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) @@ -144,18 +144,17 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val dataArray = Array.tabulate(model.selectedFeatures.length) { i => Data(model.selectedFeatures(i)) } - sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) - + spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): ChiSqSelectorModel = { implicit val formats = DefaultFormats - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) - val dataFrame = sqlContext.read.parquet(Loader.dataPath(path)) + val dataFrame = spark.read.parquet(Loader.dataPath(path)) val dataArray = dataFrame.select("feature") // Check schema explicitly since erasure makes it hard to use match-case for checking. @@ -174,8 +173,8 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { * Creates a ChiSquared feature selector. * @param numTopFeatures number of features that selector will select * (ordered by statistic value descending) - * Note that if the number of features is < numTopFeatures, then this will - * select all features. + * Note that if the number of features is less than numTopFeatures, + * then this will select all features. */ @Since("1.3.0") class ChiSqSelector @Since("1.3.0") ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 9457c6e9e35f2..bb4b37ef21a84 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -204,7 +204,7 @@ private object IDFModel { * Transforms a term frequency (TF) vector to a TF-IDF vector with a IDF vector * * @param idf an IDF vector - * @param v a term frequence vector + * @param v a term frequency vector * @return a TF-IDF vector */ def transform(idf: Vector, v: Vector): Vector = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 30c403e547bee..aaecfa8d45dc0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -40,8 +40,9 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { */ @Since("1.4.0") def fit(sources: RDD[Vector]): PCAModel = { - require(k <= sources.first().size, - s"source vector size is ${sources.first().size} must be greater than k=$k") + val numFeatures = sources.first().size + require(k <= numFeatures, + s"source vector size $numFeatures must be no less than k=$k") val mat = new RowMatrix(sources) val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) @@ -58,7 +59,6 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { case m => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}") - } val denseExplainedVariance = explainedVariance match { case dv: DenseVector => @@ -70,7 +70,7 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) { } /** - * Java-friendly version of [[fit()]] + * Java-friendly version of `fit()`. */ @Since("1.4.0") def fit(sources: JavaRDD[Vector]): PCAModel = fit(sources.rdd) @@ -91,7 +91,7 @@ class PCAModel private[spark] ( * Transform a vector by computed Principal Components. * * @param vector vector to be transformed. - * Vector must be the same length as the source vectors given to [[PCA.fit()]]. + * Vector must be the same length as the source vectors given to `PCA.fit()`. * @return transformed vector. Vector will be of length k. */ @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index ee97045f34dc8..3e86c6c59c953 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -27,9 +27,8 @@ import org.apache.spark.rdd.RDD * Standardizes features by removing the mean and scaling to unit std using column summary * statistics on the samples in the training set. * - * The "unit std" is computed using the - * [[https://en.wikipedia.org/wiki/Standard_deviation#Corrected_sample_standard_deviation - * corrected sample standard deviation]], + * The "unit std" is computed using the corrected sample standard deviation + * (https://en.wikipedia.org/wiki/Standard_deviation#Corrected_sample_standard_deviation), * which is computed as the square root of the unbiased sample variance. * * @param withMean False by default. Centers the data with mean before scaling. It will build a @@ -97,6 +96,9 @@ class StandardScalerModel @Since("1.3.0") ( @Since("1.3.0") def this(std: Vector) = this(std, null) + /** + * :: DeveloperApi :: + */ @Since("1.3.0") @DeveloperApi def setWithMean(withMean: Boolean): this.type = { @@ -105,6 +107,9 @@ class StandardScalerModel @Since("1.3.0") ( this } + /** + * :: DeveloperApi :: + */ @Since("1.3.0") @DeveloperApi def setWithStd(withStd: Boolean): this.type = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala index ca7385128d79a..9db725097ae90 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -53,7 +53,7 @@ trait VectorTransformer extends Serializable { } /** - * Applies transformation on an JavaRDD[Vector]. + * Applies transformation on a JavaRDD[Vector]. * * @param data JavaRDD[Vector] to be transformed. * @return transformed JavaRDD[Vector]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 7e6c3679704c5..761996f44739a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -34,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -434,6 +434,9 @@ class Word2Vec extends Serializable with Logging { bcSyn1Global.unpersist(false) } newSentences.unpersist() + expTable.destroy() + bcVocab.destroy() + bcVocabHash.destroy() val wordArray = vocab.map(_.word) new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) @@ -515,7 +518,7 @@ class Word2VecModel private[spark] ( } /** - * Find synonyms of a word + * Find synonyms of a word; do not include the word itself in results. * @param word a word * @param num number of synonyms to find * @return array of (word, cosineSimilarity) @@ -523,17 +526,34 @@ class Word2VecModel private[spark] ( @Since("1.1.0") def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - findSynonyms(vector, num) + findSynonyms(vector, num, Some(word)) } /** - * Find synonyms of the vector representation of a word + * Find synonyms of the vector representation of a word, possibly + * including any words in the model vocabulary whose vector respresentation + * is the supplied vector. * @param vector vector representation of a word * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ @Since("1.1.0") def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { + findSynonyms(vector, num, None) + } + + /** + * Find synonyms of the vector representation of a word, rejecting + * words identical to the value of wordOpt, if one is supplied. + * @param vector vector representation of a word + * @param num number of synonyms to find + * @param wordOpt optionally, a word to reject from the results list + * @return array of (word, cosineSimilarity) + */ + private def findSynonyms( + vector: Vector, + num: Int, + wordOpt: Option[String]): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") // TODO: optimize top-k val fVector = vector.toArray.map(_.toFloat) @@ -560,12 +580,14 @@ class Word2VecModel private[spark] ( ind += 1 } - wordList.zip(cosVec) - .toSeq - .sortBy(-_._2) - .take(num + 1) - .tail - .toArray + val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2) + + val filtered = wordOpt match { + case Some(w) => scored.take(num + 1).filter(tup => w != tup._1) + case None => scored + } + + filtered.take(num).toArray } /** @@ -609,9 +631,8 @@ object Word2VecModel extends Loader[Word2VecModel] { case class Data(word: String, vector: Array[Float]) def load(sc: SparkContext, path: String): Word2VecModel = { - val dataPath = Loader.dataPath(path) - val sqlContext = SQLContext.getOrCreate(sc) - val dataFrame = sqlContext.read.parquet(dataPath) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val dataFrame = spark.read.parquet(Loader.dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[Data](dataFrame.schema) @@ -621,9 +642,7 @@ object Word2VecModel extends Loader[Word2VecModel] { } def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]): Unit = { - - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val vectorSize = model.values.head.length val numWords = model.size @@ -632,16 +651,18 @@ object Word2VecModel extends Loader[Word2VecModel] { ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) - // We want to partition the model in partitions of size 32MB - val partitionSize = (1L << 25) + // We want to partition the model in partitions smaller than + // spark.kryoserializer.buffer.max + val bufferSize = Utils.byteStringAsBytes( + spark.conf.get("spark.kryoserializer.buffer.max", "64m")) // We calculate the approximate size of the model - // We only calculate the array size, not considering - // the string size, the formula is: - // floatSize * numWords * vectorSize - val approxSize = 4L * numWords * vectorSize - val nPartitions = ((approxSize / partitionSize) + 1).toInt + // We only calculate the array size, considering an + // average string size of 15 bytes, the formula is: + // (floatSize * vectorSize + 15) * numWords + val approxSize = (4L * vectorSize + 15) * numWords + val nPartitions = ((approxSize / bufferSize) + 1).toInt val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } - sc.parallelize(dataArray.toSeq, nPartitions).toDF().write.parquet(Loader.dataPath(path)) + spark.createDataFrame(dataArray).repartition(nPartitions).write.parquet(Loader.dataPath(path)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 9a63cc29dacb5..3c26d2670841b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.fpm import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.internal.Logging @@ -28,14 +28,11 @@ import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.rdd.RDD /** - * :: Experimental :: - * * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates * association rules which have a single item as the consequent. * */ @Since("1.5.0") -@Experimental class AssociationRules private[fpm] ( private var minConfidence: Double) extends Logging with Serializable { @@ -95,8 +92,6 @@ class AssociationRules private[fpm] ( object AssociationRules { /** - * :: Experimental :: - * * An association rule between sets of items. * @param antecedent hypotheses of the rule. Java users should call [[Rule#javaAntecedent]] * instead. @@ -106,7 +101,6 @@ object AssociationRules { * */ @Since("1.5.0") - @Experimental class Rule[Item] private[fpm] ( @Since("1.5.0") val antecedent: Array[Item], @Since("1.5.0") val consequent: Array[Item], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 9166faa54de56..0f7fbe9556c5d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.fpm.FPGrowth._ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -99,7 +99,7 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { def save(model: FPGrowthModel[_], path: String): Unit = { val sc = model.freqItemsets.sparkContext - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -116,20 +116,20 @@ object FPGrowthModel extends Loader[FPGrowthModel[_]] { StructField("freq", LongType)) val schema = StructType(fields) val rowDataRDD = model.freqItemsets.map { x => - Row(x.items, x.freq) + Row(x.items.toSeq, x.freq) } - sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + spark.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): FPGrowthModel[_] = { implicit val formats = DefaultFormats - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) - val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path)) + val freqItemsets = spark.read.parquet(Loader.dataPath(path)) val sample = freqItemsets.select("items").head().get(0) loadImpl(freqItemsets, sample) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala index 1d2d777c00793..b0fa287473c3c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala @@ -126,7 +126,7 @@ private[fpm] object FPTree { def isRoot: Boolean = parent == null } - /** Summary of a item in an FP-Tree. */ + /** Summary of an item in an FP-Tree. */ private class Summary[T] extends Serializable { var count: Long = 0L val nodes: ListBuffer[Node[T]] = ListBuffer.empty diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 4344ab1bade9a..7382000791cfb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -30,20 +30,18 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.internal.Logging import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel /** - * :: Experimental :: - * * A parallel PrefixSpan algorithm to mine frequent sequential patterns. * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth ([[http://doi.org/10.1109/ICDE.2001.914830]]). @@ -60,7 +58,6 @@ import org.apache.spark.storage.StorageLevel * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining * (Wikipedia)]] */ -@Experimental @Since("1.5.0") class PrefixSpan private ( private var minSupport: Double, @@ -230,7 +227,6 @@ class PrefixSpan private ( } -@Experimental @Since("1.5.0") object PrefixSpan extends Logging { @@ -616,7 +612,7 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { def save(model: PrefixSpanModel[_], path: String): Unit = { val sc = model.freqSequences.sparkContext - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) @@ -635,18 +631,18 @@ object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { val rowDataRDD = model.freqSequences.map { x => Row(x.sequence, x.freq) } - sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + spark.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { implicit val formats = DefaultFormats - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) - val freqSequences = sqlContext.read.parquet(Loader.dataPath(path)) + val freqSequences = spark.read.parquet(Loader.dataPath(path)) val sample = freqSequences.select("sequence").head().get(0) loadImpl(freqSequences, sample) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 5c12c9305b99c..4dd498cd91b4e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -19,7 +19,8 @@ package org.apache.spark.mllib.impl import scala.collection.mutable -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext import org.apache.spark.internal.Logging @@ -160,21 +161,23 @@ private[mllib] abstract class PeriodicCheckpointer[T]( private def removeCheckpointFile(): Unit = { val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, we manually delete it. - val fs = FileSystem.get(sc.hadoopConfiguration) - getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) + getCheckpointFiles(old).foreach( + PeriodicCheckpointer.removeCheckpointFile(_, sc.hadoopConfiguration)) } } private[spark] object PeriodicCheckpointer extends Logging { /** Delete a checkpoint file, and log a warning if deletion fails. */ - def removeCheckpointFile(path: String, fs: FileSystem): Unit = { + def removeCheckpointFile(checkpointFile: String, conf: Configuration): Unit = { try { - fs.delete(new Path(path), true) + val path = new Path(checkpointFile) + val fs = path.getFileSystem(conf) + fs.delete(path, true) } catch { case e: Exception => logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + - path) + checkpointFile) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 20db6084d0e0d..80074897567eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -87,7 +87,10 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { - data.persist() + data.vertices.persist() + } + if (data.edges.getStorageLevel == StorageLevel.NONE) { + data.edges.persist() } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 6a85608706974..0cd68a633c0b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -637,12 +637,16 @@ private[spark] object BLAS extends Serializable with Logging { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } else if (xIndices(k) < Acols(i)) { + k += 1 + } else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 076cca6016ecb..4c39cf17f4271 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.linalg import java.util.{Arrays, Random} import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, HashSet => MHashSet} +import scala.language.implicitConversions import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -74,7 +75,7 @@ sealed trait Matrix extends Serializable { def rowIter: Iterator[Vector] = this.transpose.colIter /** Converts to a breeze matrix. */ - private[mllib] def toBreeze: BM[Double] + private[mllib] def asBreeze: BM[Double] /** Gets the (i, j)-th element. */ @Since("1.3.0") @@ -117,11 +118,11 @@ sealed trait Matrix extends Serializable { } /** A human readable representation of the matrix */ - override def toString: String = toBreeze.toString() + override def toString: String = asBreeze.toString() /** A human readable representation of the matrix with maximum lines and width */ @Since("1.4.0") - def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) + def toString(maxLines: Int, maxLineWidth: Int): String = asBreeze.toString(maxLines, maxLineWidth) /** * Map the values of this matrix using a function. Generates a new matrix. Performs the @@ -163,7 +164,8 @@ sealed trait Matrix extends Serializable { * Convert this matrix to the new mllib-local representation. * This does NOT copy the data; it copies references. */ - private[spark] def asML: newlinalg.Matrix + @Since("2.0.0") + def asML: newlinalg.Matrix } private[spark] class MatrixUDT extends UserDefinedType[Matrix] { @@ -298,7 +300,7 @@ class DenseMatrix @Since("1.3.0") ( this(numRows, numCols, values, false) override def equals(o: Any): Boolean = o match { - case m: Matrix => toBreeze == m.toBreeze + case m: Matrix => asBreeze == m.asBreeze case _ => false } @@ -306,7 +308,7 @@ class DenseMatrix @Since("1.3.0") ( com.google.common.base.Objects.hashCode(numRows: Integer, numCols: Integer, toArray) } - private[mllib] def toBreeze: BM[Double] = { + private[mllib] def asBreeze: BM[Double] = { if (!isTransposed) { new BDM[Double](numRows, numCols, values) } else { @@ -426,7 +428,8 @@ class DenseMatrix @Since("1.3.0") ( } } - private[spark] override def asML: newlinalg.DenseMatrix = { + @Since("2.0.0") + override def asML: newlinalg.DenseMatrix = { new newlinalg.DenseMatrix(numRows, numCols, values, isTransposed) } } @@ -526,8 +529,11 @@ object DenseMatrix { matrix } - /** Convert new linalg type to spark.mllib type. Light copy; only copies references */ - private[spark] def fromML(m: newlinalg.DenseMatrix): DenseMatrix = { + /** + * Convert new linalg type to spark.mllib type. Light copy; only copies references + */ + @Since("2.0.0") + def fromML(m: newlinalg.DenseMatrix): DenseMatrix = { new DenseMatrix(m.numRows, m.numCols, m.values, m.isTransposed) } } @@ -566,10 +572,13 @@ class SparseMatrix @Since("1.3.0") ( require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - // The Or statement is for the case when the matrix is transposed - require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + - "column indices should be the number of columns + 1. Currently, colPointers.length: " + - s"${colPtrs.length}, numCols: $numCols") + if (isTransposed) { + require(colPtrs.length == numRows + 1, + s"Expecting ${numRows + 1} colPtrs when numRows = $numRows but got ${colPtrs.length}") + } else { + require(colPtrs.length == numCols + 1, + s"Expecting ${numCols + 1} colPtrs when numCols = $numCols but got ${colPtrs.length}") + } require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") @@ -601,13 +610,13 @@ class SparseMatrix @Since("1.3.0") ( values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) override def equals(o: Any): Boolean = o match { - case m: Matrix => toBreeze == m.toBreeze + case m: Matrix => asBreeze == m.asBreeze case _ => false } - override def hashCode(): Int = toBreeze.hashCode + override def hashCode(): Int = asBreeze.hashCode - private[mllib] def toBreeze: BM[Double] = { + private[mllib] def asBreeze: BM[Double] = { if (!isTransposed) { new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) } else { @@ -739,7 +748,8 @@ class SparseMatrix @Since("1.3.0") ( } } - private[spark] override def asML: newlinalg.SparseMatrix = { + @Since("2.0.0") + override def asML: newlinalg.SparseMatrix = { new newlinalg.SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) } } @@ -917,8 +927,11 @@ object SparseMatrix { } } - /** Convert new linalg type to spark.mllib type. Light copy; only copies references */ - private[spark] def fromML(m: newlinalg.SparseMatrix): SparseMatrix = { + /** + * Convert new linalg type to spark.mllib type. Light copy; only copies references + */ + @Since("2.0.0") + def fromML(m: newlinalg.SparseMatrix): SparseMatrix = { new SparseMatrix(m.numRows, m.numCols, m.colPtrs, m.rowIndices, m.values, m.isTransposed) } } @@ -1204,11 +1217,35 @@ object Matrices { } } - /** Convert new linalg type to spark.mllib type. Light copy; only copies references */ - private[spark] def fromML(m: newlinalg.Matrix): Matrix = m match { + /** + * Convert new linalg type to spark.mllib type. Light copy; only copies references + */ + @Since("2.0.0") + def fromML(m: newlinalg.Matrix): Matrix = m match { case dm: newlinalg.DenseMatrix => DenseMatrix.fromML(dm) case sm: newlinalg.SparseMatrix => SparseMatrix.fromML(sm) } } + +/** + * Implicit methods available in Scala for converting [[org.apache.spark.mllib.linalg.Matrix]] to + * [[org.apache.spark.ml.linalg.Matrix]] and vice versa. + */ +private[spark] object MatrixImplicits { + + implicit def mllibMatrixToMLMatrix(m: Matrix): newlinalg.Matrix = m.asML + + implicit def mllibDenseMatrixToMLDenseMatrix(m: DenseMatrix): newlinalg.DenseMatrix = m.asML + + implicit def mllibSparseMatrixToMLSparseMatrix(m: SparseMatrix): newlinalg.SparseMatrix = m.asML + + implicit def mlMatrixToMLlibMatrix(m: newlinalg.Matrix): Matrix = Matrices.fromML(m) + + implicit def mlDenseMatrixToMLlibDenseMatrix(m: newlinalg.DenseMatrix): DenseMatrix = + Matrices.fromML(m).asInstanceOf[DenseMatrix] + + implicit def mlSparseMatrixToMLlibSparseMatrix(m: newlinalg.SparseMatrix): SparseMatrix = + Matrices.fromML(m).asInstanceOf[SparseMatrix] +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index 4591cb88ef152..8024b1c0031fe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.linalg -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since /** * Represents singular value decomposition (SVD) factors. @@ -26,10 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since} case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) /** - * :: Experimental :: * Represents QR factors. */ @Since("1.5.0") -@Experimental case class QRDecomposition[QType, RType](Q: QType, R: RType) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 132e54a8c3de2..91f065831c804 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -22,6 +22,7 @@ import java.util import scala.annotation.varargs import scala.collection.JavaConverters._ +import scala.language.implicitConversions import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.json4s.DefaultFormats @@ -102,14 +103,14 @@ sealed trait Vector extends Serializable { /** * Converts the instance to a breeze vector. */ - private[spark] def toBreeze: BV[Double] + private[spark] def asBreeze: BV[Double] /** * Gets the value of the ith element. * @param i index */ @Since("1.1.0") - def apply(i: Int): Double = toBreeze(i) + def apply(i: Int): Double = asBreeze(i) /** * Makes a deep copy of this vector. @@ -185,7 +186,8 @@ sealed trait Vector extends Serializable { * Convert this vector to the new mllib-local representation. * This does NOT copy the data; it copies references. */ - private[spark] def asML: newlinalg.Vector + @Since("2.0.0") + def asML: newlinalg.Vector } /** @@ -580,8 +582,11 @@ object Vectors { /** Max number of nonzero entries used in computing hash code. */ private[linalg] val MAX_HASH_NNZ = 128 - /** Convert new linalg type to spark.mllib type. Light copy; only copies references */ - private[spark] def fromML(v: newlinalg.Vector): Vector = v match { + /** + * Convert new linalg type to spark.mllib type. Light copy; only copies references + */ + @Since("2.0.0") + def fromML(v: newlinalg.Vector): Vector = v match { case dv: newlinalg.DenseVector => DenseVector.fromML(dv) case sv: newlinalg.SparseVector => @@ -605,7 +610,7 @@ class DenseVector @Since("1.0.0") ( @Since("1.0.0") override def toArray: Array[Double] = values - private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) + private[spark] override def asBreeze: BV[Double] = new BDV[Double](values) @Since("1.0.0") override def apply(i: Int): Double = values(i) @@ -703,7 +708,8 @@ class DenseVector @Since("1.0.0") ( compact(render(jValue)) } - private[spark] override def asML: newlinalg.DenseVector = { + @Since("2.0.0") + override def asML: newlinalg.DenseVector = { new newlinalg.DenseVector(values) } } @@ -715,14 +721,17 @@ object DenseVector { @Since("1.3.0") def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values) - /** Convert new linalg type to spark.mllib type. Light copy; only copies references */ - private[spark] def fromML(v: newlinalg.DenseVector): DenseVector = { + /** + * Convert new linalg type to spark.mllib type. Light copy; only copies references + */ + @Since("2.0.0") + def fromML(v: newlinalg.DenseVector): DenseVector = { new DenseVector(v.values) } } /** - * A sparse vector represented by an index array and an value array. + * A sparse vector represented by an index array and a value array. * * @param size size of the vector. * @param indices index array, assume to be strictly increasing. @@ -761,7 +770,7 @@ class SparseVector @Since("1.0.0") ( new SparseVector(size, indices.clone(), values.clone()) } - private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + private[spark] override def asBreeze: BV[Double] = new BSV[Double](indices, values, size) @Since("1.6.0") override def foreachActive(f: (Int, Double) => Unit): Unit = { @@ -910,7 +919,8 @@ class SparseVector @Since("1.0.0") ( compact(render(jValue)) } - private[spark] override def asML: newlinalg.SparseVector = { + @Since("2.0.0") + override def asML: newlinalg.SparseVector = { new newlinalg.SparseVector(size, indices, values) } } @@ -921,8 +931,32 @@ object SparseVector { def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] = Some((sv.size, sv.indices, sv.values)) - /** Convert new linalg type to spark.mllib type. Light copy; only copies references */ - private[spark] def fromML(v: newlinalg.SparseVector): SparseVector = { + /** + * Convert new linalg type to spark.mllib type. Light copy; only copies references + */ + @Since("2.0.0") + def fromML(v: newlinalg.SparseVector): SparseVector = { new SparseVector(v.size, v.indices, v.values) } } + +/** + * Implicit methods available in Scala for converting [[org.apache.spark.mllib.linalg.Vector]] to + * [[org.apache.spark.ml.linalg.Vector]] and vice versa. + */ +private[spark] object VectorImplicits { + + implicit def mllibVectorToMLVector(v: Vector): newlinalg.Vector = v.asML + + implicit def mllibDenseVectorToMLDenseVector(v: DenseVector): newlinalg.DenseVector = v.asML + + implicit def mllibSparseVectorToMLSparseVector(v: SparseVector): newlinalg.SparseVector = v.asML + + implicit def mlVectorToMLlibVector(v: newlinalg.Vector): Vector = Vectors.fromML(v) + + implicit def mlDenseVectorToMLlibDenseVector(v: newlinalg.DenseVector): DenseVector = + Vectors.fromML(v).asInstanceOf[DenseVector] + + implicit def mlSparseVectorToMLlibSparseVector(v: newlinalg.SparseVector): SparseVector = + Vectors.fromML(v).asInstanceOf[SparseVector] +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 580d7a98fb362..9782350587061 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -275,7 +275,7 @@ class BlockMatrix @Since("1.3.0") ( val rows = blocks.flatMap { case ((blockRowIdx, blockColIdx), mat) => mat.rowIter.zipWithIndex.map { case (vector, rowIdx) => - blockRowIdx * rowsPerBlock + rowIdx -> (blockColIdx, vector.toBreeze) + blockRowIdx * rowsPerBlock + rowIdx -> (blockColIdx, vector.asBreeze) } }.groupByKey().map { case (rowIdx, vectors) => val numberNonZeroPerRow = vectors.map(_._2.activeSize).sum.toDouble / cols.toDouble @@ -288,7 +288,7 @@ class BlockMatrix @Since("1.3.0") ( vectors.foreach { case (blockColIdx: Int, vec: BV[Double]) => val offset = colsPerBlock * blockColIdx - wholeVector(offset until offset + colsPerBlock) := vec + wholeVector(offset until Math.min(cols, offset + colsPerBlock)) := vec } new IndexedRow(rowIdx, Vectors.fromBreeze(wholeVector)) } @@ -367,12 +367,12 @@ class BlockMatrix @Since("1.3.0") ( } if (a.isEmpty) { val zeroBlock = BM.zeros[Double](b.head.numRows, b.head.numCols) - val result = binMap(zeroBlock, b.head.toBreeze) + val result = binMap(zeroBlock, b.head.asBreeze) new MatrixBlock((blockRowIndex, blockColIndex), Matrices.fromBreeze(result)) } else if (b.isEmpty) { new MatrixBlock((blockRowIndex, blockColIndex), a.head) } else { - val result = binMap(a.head.toBreeze, b.head.toBreeze) + val result = binMap(a.head.asBreeze, b.head.asBreeze) new MatrixBlock((blockRowIndex, blockColIndex), Matrices.fromBreeze(result)) } } @@ -426,16 +426,21 @@ class BlockMatrix @Since("1.3.0") ( partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = { val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached val rightMatrix = other.blocks.keys.collect() + + val rightCounterpartsHelper = rightMatrix.groupBy(_._1).mapValues(_.map(_._2)) val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) => - val rightCounterparts = rightMatrix.filter(_._1 == colIndex) - val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2))) + val rightCounterparts = rightCounterpartsHelper.getOrElse(colIndex, Array()) + val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b))) ((rowIndex, colIndex), partitions.toSet) }.toMap + + val leftCounterpartsHelper = leftMatrix.groupBy(_._2).mapValues(_.map(_._1)) val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) => - val leftCounterparts = leftMatrix.filter(_._2 == rowIndex) - val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex))) + val leftCounterparts = leftCounterpartsHelper.getOrElse(rowIndex, Array()) + val partitions = leftCounterparts.map(b => partitioner.getPartition((b, colIndex))) ((rowIndex, colIndex), partitions.toSet) }.toMap + (leftDestinations, rightDestinations) } @@ -479,7 +484,7 @@ class BlockMatrix @Since("1.3.0") ( case _ => throw new SparkException(s"Unrecognized matrix type ${rightBlock.getClass}.") } - ((leftRowIndex, rightColIndex), C.toBreeze) + ((leftRowIndex, rightColIndex), C.asBreeze) } } }.reduceByKey(resultPartitioner, (a, b) => a + b).mapValues(Matrices.fromBreeze) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 97b03b340f20e..008b03d1cc334 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} import org.apache.spark.rdd.RDD /** - * Represents an entry in an distributed matrix. + * Represents an entry in a distributed matrix. * @param i row index * @param j column index * @param value value of the entry diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 4b8ed301eb3cb..ec32e37afb792 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -92,7 +92,7 @@ class RowMatrix @Since("1.0.0") ( val vbr = rows.context.broadcast(v) rows.treeAggregate(BDV.zeros[Double](n))( seqOp = (U, r) => { - val rBrz = r.toBreeze + val rBrz = r.asBreeze val a = rBrz.dot(vbr.value) rBrz match { // use specialized axpy for better performance @@ -250,12 +250,12 @@ class RowMatrix @Since("1.0.0") ( val (sigmaSquares: BDV[Double], u: BDM[Double]) = computeMode match { case SVDMode.LocalARPACK => require(k < n, s"k must be smaller than n in local-eigs mode but got k=$k and n=$n.") - val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] + val G = computeGramianMatrix().asBreeze.asInstanceOf[BDM[Double]] EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter) case SVDMode.LocalLAPACK => // breeze (v0.10) svd latent constraint, 7 * n * n + 4 * n < Int.MaxValue require(n < 17515, s"$n exceeds the breeze svd capability") - val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] + val G = computeGramianMatrix().asBreeze.asInstanceOf[BDM[Double]] val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) case SVDMode.DistARPACK => @@ -338,7 +338,7 @@ class RowMatrix @Since("1.0.0") ( // large but Cov(X, Y) is small, but it is good for sparse computation. // TODO: find a fast and stable way for sparse data. - val G = computeGramianMatrix().toBreeze + val G = computeGramianMatrix().asBreeze var i = 0 var j = 0 @@ -381,7 +381,7 @@ class RowMatrix @Since("1.0.0") ( val n = numCols().toInt require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") - val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]] + val Cov = computeCovariance().asBreeze.asInstanceOf[BDM[Double]] val brzSvd.SVD(u: BDM[Double], s: BDV[Double], _) = brzSvd(Cov) @@ -436,14 +436,14 @@ class RowMatrix @Since("1.0.0") ( require(B.isInstanceOf[DenseMatrix], s"Only support dense matrix at this time but found ${B.getClass.getName}.") - val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray) + val Bb = rows.context.broadcast(B.asBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray) val AB = rows.mapPartitions { iter => val Bi = Bb.value iter.map { row => val v = BDV.zeros[Double](k) var i = 0 while (i < k) { - v(i) = row.toBreeze.dot(new BDV(Bi, i * n, 1, n)) + v(i) = row.asBreeze.dot(new BDV(Bi, i * n, 1, n)) i += 1 } Vectors.fromBreeze(v) @@ -537,21 +537,22 @@ class RowMatrix @Since("1.0.0") ( def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { val col = numCols().toInt // split rows horizontally into smaller matrices, and compute QR for each of them - val blockQRs = rows.glom().map { partRows => + val blockQRs = rows.retag(classOf[Vector]).glom().filter(_.length != 0).map { partRows => val bdm = BDM.zeros[Double](partRows.length, col) var i = 0 partRows.foreach { row => - bdm(i, ::) := row.toBreeze.t + bdm(i, ::) := row.asBreeze.t i += 1 } breeze.linalg.qr.reduced(bdm).r } // combine the R part from previous results vertically into a tall matrix - val combinedR = blockQRs.treeReduce{ (r1, r2) => + val combinedR = blockQRs.treeReduce { (r1, r2) => val stackedR = BDM.vertcat(r1, r2) breeze.linalg.qr.reduced(stackedR).r } + val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix) val finalQ = if (computeQ) { try { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index a67ea836e5681..f372355005656 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{norm, DenseVector => BDV} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -53,11 +53,9 @@ class GradientDescent private[spark] (private var gradient: Gradient, private va } /** - * :: Experimental :: * Set fraction of data to be used for each SGD iteration. * Default 1.0 (corresponding to deterministic/classical gradient descent) */ - @Experimental def setMiniBatchFraction(fraction: Double): this.type = { require(fraction > 0 && fraction <= 1.0, s"Fraction for mini-batch SGD must be in range (0, 1] but got ${fraction}") @@ -197,6 +195,11 @@ object GradientDescent extends Logging { "< 1.0 can be unstable because of the stochasticity in sampling.") } + if (numIterations * miniBatchFraction < 1.0) { + logWarning("Not all examples will be used if numIterations * miniBatchFraction < 1.0: " + + s"numIterations=$numIterations and miniBatchFraction=$miniBatchFraction") + } + val stochasticLossHistory = new ArrayBuffer[Double](numIterations) // Record previous weight and current one to calculate solution vector difference @@ -296,8 +299,8 @@ object GradientDescent extends Logging { currentWeights: Vector, convergenceTol: Double): Boolean = { // To compare with convergence tolerance. - val previousBDV = previousWeights.toBreeze.toDenseVector - val currentBDV = currentWeights.toBreeze.toDenseVector + val previousBDV = previousWeights.asBreeze.toDenseVector + val currentBDV = currentWeights.asBreeze.toDenseVector // This represents the difference of updated weights in the iteration. val solutionVecDiff: Double = norm(previousBDV - currentBDV) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 74e2cad76c8f5..fd09f35277a09 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -200,7 +200,7 @@ object LBFGS extends Logging { val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol) val states = - lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) + lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.asBreeze.toDenseVector) /** * NOTE: lossSum and loss is computed using the weights from the previous iteration @@ -212,6 +212,12 @@ object LBFGS extends Logging { state = states.next() } lossHistory += state.value + + if (!state.actuallyConverged) { + logWarning("LBFGS training finished but the result " + + s"is not converged because: ${state.convergedReason.get.reason}") + } + val weights = Vectors.fromBreeze(state.x) val lossHistoryArray = lossHistory.result() @@ -281,7 +287,7 @@ object LBFGS extends Logging { // gradientTotal = gradientSum / numExamples + gradientTotal axpy(1.0 / numExamples, gradientSum, gradientTotal) - (loss, gradientTotal.toBreeze.asInstanceOf[BDV[Double]]) + (loss, gradientTotal.asBreeze.asInstanceOf[BDV[Double]]) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 03c01e0553d78..67d484575db52 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -75,8 +75,8 @@ class SimpleUpdater extends Updater { iter: Int, regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) - val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector - brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + val brzWeights: BV[Double] = weightsOld.asBreeze.toDenseVector + brzAxpy(-thisIterStepSize, gradient.asBreeze, brzWeights) (Vectors.fromBreeze(brzWeights), 0) } @@ -87,7 +87,7 @@ class SimpleUpdater extends Updater { * Updater for L1 regularized problems. * R(w) = ||w||_1 * Uses a step-size decreasing with the square root of the number of iterations. - + * * Instead of subgradient of the regularizer, the proximal operator for the * L1 regularization is applied after the gradient step. This is known to * result in better sparsity of the intermediate solution. @@ -111,8 +111,8 @@ class L1Updater extends Updater { regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) // Take gradient step - val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector - brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + val brzWeights: BV[Double] = weightsOld.asBreeze.toDenseVector + brzAxpy(-thisIterStepSize, gradient.asBreeze, brzWeights) // Apply proximal operator (soft thresholding) val shrinkageVal = regParam * thisIterStepSize var i = 0 @@ -146,9 +146,9 @@ class SquaredL2Updater extends Updater { // w' = w - thisIterStepSize * (gradient + regParam * w) // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient val thisIterStepSize = stepSize / math.sqrt(iter) - val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + val brzWeights: BV[Double] = weightsOld.asBreeze.toDenseVector brzWeights :*= (1.0 - thisIterStepSize * regParam) - brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + brzAxpy(-thisIterStepSize, gradient.asBreeze, brzWeights) val norm = brzNorm(brzWeights, 2.0) (Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/package-info.java b/mllib/src/main/scala/org/apache/spark/mllib/package-info.java index 5962efa96baf3..72b71b7cd9b14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/package-info.java +++ b/mllib/src/main/scala/org/apache/spark/mllib/package-info.java @@ -16,6 +16,26 @@ */ /** - * Spark's machine learning library. + * RDD-based machine learning APIs (in maintenance mode). + * + * The spark.mllib package is in maintenance mode as of the Spark 2.0.0 release to + * encourage migration to the DataFrame-based APIs under the spark.ml package. + * While in maintenance mode, + *
      + *
    • + * no new features in the RDD-based spark.mllib package will be accepted, unless + * they block implementing new features in the DataFrame-based spark.ml package; + *
    • + *
    • + * bug fixes in the RDD-based APIs will still be accepted. + *
    • + *
    + * + * The developers will continue adding more features to the DataFrame-based APIs in the 2.x series + * to reach feature parity with the RDD-based APIs. + * And once we reach feature parity, this package will be deprecated. + * + * @see SPARK-4591 to + * track the progress of feature parity */ package org.apache.spark.mllib; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/package.scala b/mllib/src/main/scala/org/apache/spark/mllib/package.scala index 5c2b2160c030e..9810b6f668064 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/package.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/package.scala @@ -18,6 +18,21 @@ package org.apache.spark /** - * Spark's machine learning library. + * RDD-based machine learning APIs (in maintenance mode). + * + * The `spark.mllib` package is in maintenance mode as of the Spark 2.0.0 release to encourage + * migration to the DataFrame-based APIs under the [[org.apache.spark.ml]] package. + * While in maintenance mode, + * + * - no new features in the RDD-based `spark.mllib` package will be accepted, unless they block + * implementing new features in the DataFrame-based `spark.ml` package; + * - bug fixes in the RDD-based APIs will still be accepted. + * + * The developers will continue adding more features to the DataFrame-based APIs in the 2.x series + * to reach feature parity with the RDD-based APIs. + * And once we reach feature parity, this package will be deprecated. + * + * @see [[https://issues.apache.org/jira/browse/SPARK-4591 SPARK-4591]] to track the progress of + * feature parity */ package object mllib diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 274ac7c99553b..5d61796f1de60 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -23,7 +23,7 @@ import javax.xml.transform.stream.StreamResult import org.jpmml.model.JAXBUtil import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory /** @@ -45,20 +45,16 @@ trait PMMLExportable { } /** - * :: Experimental :: * Export the model to a local file in PMML format */ - @Experimental @Since("1.4.0") def toPMML(localPath: String): Unit = { toPMML(new StreamResult(new File(localPath))) } /** - * :: Experimental :: * Export the model to a directory on a distributed file system in PMML format */ - @Experimental @Since("1.4.0") def toPMML(sc: SparkContext, path: String): Unit = { val pmml = toPMML() @@ -66,20 +62,16 @@ trait PMMLExportable { } /** - * :: Experimental :: * Export the model to the OutputStream in PMML format */ - @Experimental @Since("1.4.0") def toPMML(outputStream: OutputStream): Unit = { toPMML(new StreamResult(outputStream)) } /** - * :: Experimental :: * Export the model to a String in PMML format */ - @Experimental @Since("1.4.0") def toPMML(): String = { val writer = new StringWriter diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index b0a716936ae6f..c2bc1f17ccd58 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -418,6 +418,7 @@ object RandomRDDs { } /** + * :: DeveloperApi :: * [[RandomRDDs#randomJavaRDD]] with the default seed. */ @DeveloperApi @@ -431,6 +432,7 @@ object RandomRDDs { } /** + * :: DeveloperApi :: * [[RandomRDDs#randomJavaRDD]] with the default seed & numPartitions */ @DeveloperApi @@ -854,6 +856,7 @@ object RandomRDDs { } /** + * :: DeveloperApi :: * Java-friendly version of [[RandomRDDs#randomVectorRDD]]. */ @DeveloperApi @@ -869,6 +872,7 @@ object RandomRDDs { } /** + * :: DeveloperApi :: * [[RandomRDDs#randomJavaVectorRDD]] with the default seed. */ @DeveloperApi @@ -883,6 +887,7 @@ object RandomRDDs { } /** + * :: DeveloperApi :: * [[RandomRDDs#randomJavaVectorRDD]] with the default number of partitions and the default seed. */ @DeveloperApi diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala index 1b93e2d764c69..e28e1af5b0a26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctions.scala @@ -25,6 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.BoundedPriorityQueue /** + * :: DeveloperApi :: * Machine learning specific Pair RDD functions. */ @DeveloperApi @@ -46,10 +47,13 @@ class MLPairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Se combOp = (queue1, queue2) => { queue1 ++= queue2 } - ).mapValues(_.toArray.sorted(ord.reverse)) // This is an min-heap, so we reverse the order. + ).mapValues(_.toArray.sorted(ord.reverse)) // This is a min-heap, so we reverse the order. } } +/** + * :: DeveloperApi :: + */ @DeveloperApi object MLPairRDDFunctions { /** Implicit conversion from a pair RDD to MLPairRDDFunctions. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index e8a937ffcb96f..0f7857b8d8627 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD /** + * :: DeveloperApi :: * Machine learning specific RDD functions. */ @DeveloperApi @@ -53,6 +54,9 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { } +/** + * :: DeveloperApi :: + */ @DeveloperApi object RDDFunctions { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 467cb83cd1662..cc9ee15738ad6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -216,6 +216,7 @@ class ALS private ( } /** + * :: DeveloperApi :: * Set period (in iterations) between checkpoints (default = 10). Checkpointing helps with * recovery (when nodes fail) and StackOverflow exceptions caused by long lineage. It also helps * with eliminating temporary shuffle files on disk, which can be important when there are many diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 6f780b0da71f5..c642573ccba6d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -37,7 +37,7 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.storage.StorageLevel /** @@ -354,8 +354,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { */ def save(model: MatrixFactorizationModel, path: String): Unit = { val sc = model.userFeatures.sparkContext - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + import spark.implicits._ val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) @@ -365,16 +365,16 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { def load(sc: SparkContext, path: String): MatrixFactorizationModel = { implicit val formats = DefaultFormats - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val (className, formatVersion, metadata) = loadMetadata(sc, path) assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val rank = (metadata \ "rank").extract[Int] - val userFeatures = sqlContext.read.parquet(userPath(path)).rdd.map { + val userFeatures = spark.read.parquet(userPath(path)).rdd.map { case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) } - val productFeatures = sqlContext.read.parquet(productPath(path)).rdd.map { + val productFeatures = spark.read.parquet(productPath(path)).rdd.map { case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index abdd7981970fa..1cd6f2a8969a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -34,7 +34,7 @@ import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession /** * Regression model for isotonic regression. @@ -185,21 +185,21 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { boundaries: Array[Double], predictions: Array[Double], isotonic: Boolean): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("isotonic" -> isotonic))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) - sqlContext.createDataFrame( + spark.createDataFrame( boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) } ).write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { - val sqlContext = SQLContext.getOrCreate(sc) - val dataRDD = sqlContext.read.parquet(dataPath(path)) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val dataRDD = spark.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("boundary", "prediction").collect() @@ -221,7 +221,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { val (boundaries, predictions) = SaveLoadV1_0.load(sc, path) new IsotonicRegressionModel(boundaries, predictions, isotonic) case _ => throw new Exception( - s"IsotonicRegressionModel.load did not recognize model with (className, format version):" + + s"IsotonicRegressionModel.load did not recognize model with (className, format version): " + s"($loadedClassName, $version). Supported:\n" + s" ($classNameV1_0, 1.0)" ) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 45540f0c5c4ce..f082b16b95e81 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.regression import scala.beans.BeanInfo import org.apache.spark.annotation.Since +import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException @@ -38,6 +39,10 @@ case class LabeledPoint @Since("1.0.0") ( override def toString: String = { s"($label,$features)" } + + private[spark] def asML: NewLabeledPoint = { + NewLabeledPoint(label, features.asML) + } } /** @@ -67,4 +72,8 @@ object LabeledPoint { LabeledPoint(label, features) } } + + private[spark] def fromML(point: NewLabeledPoint): LabeledPoint = { + LabeledPoint(point.label, Vectors.fromML(point.features)) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index ef8c80f0cb807..cef1b4f51b843 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -44,7 +44,7 @@ class LassoModel @Since("1.1.0") ( dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { - weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept + weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept } @Since("1.3.0") @@ -85,9 +85,7 @@ object LassoModel extends Loader[LassoModel] { * See also the documentation for the precise formulation. */ @Since("0.8.0") -@deprecated("Use ml.regression.LinearRegression with elasticNetParam = 1.0. Note the default " + - "regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", "2.0.0") -class LassoWithSGD private ( +class LassoWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, private var regParam: Double, @@ -108,6 +106,8 @@ class LassoWithSGD private ( * regParam: 0.01, miniBatchFraction: 1.0}. */ @Since("0.8.0") + @deprecated("Use ml.regression.LinearRegression with elasticNetParam = 1.0. Note the default " + + "regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", "2.0.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 9e9d98bc5e41b..60262fdc497a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -44,7 +44,7 @@ class LinearRegressionModel @Since("1.1.0") ( dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { - weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept + weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept } @Since("1.3.0") @@ -86,7 +86,6 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] { * See also the documentation for the precise formulation. */ @Since("0.8.0") -@deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0") class LinearRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, @@ -108,6 +107,7 @@ class LinearRegressionWithSGD private[mllib] ( * numIterations: 100, miniBatchFraction: 1.0}. */ @Since("0.8.0") + @deprecated("Use ml.regression.LinearRegression or LBFGS", "2.0.0") def this() = this(1.0, 100, 0.0, 1.0) override protected[mllib] def createModel(weights: Vector, intercept: Double) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 512fb9a712b7a..52977ac4f062a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -45,7 +45,7 @@ class RidgeRegressionModel @Since("1.1.0") ( dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double = { - weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept + weightMatrix.asBreeze.dot(dataMatrix.asBreeze) + intercept } @Since("1.3.0") @@ -86,9 +86,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] { * See also the documentation for the precise formulation. */ @Since("0.8.0") -@deprecated("Use ml.regression.LinearRegression with elasticNetParam = 0.0. Note the default " + - "regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for LinearRegression.", "2.0.0") -class RidgeRegressionWithSGD private ( +class RidgeRegressionWithSGD private[mllib] ( private var stepSize: Double, private var numIterations: Int, private var regParam: Double, @@ -109,6 +107,8 @@ class RidgeRegressionWithSGD private ( * regParam: 0.01, miniBatchFraction: 1.0}. */ @Since("0.8.0") + @deprecated("Use ml.regression.LinearRegression with elasticNetParam = 0.0. Note the default " + + "regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for LinearRegression.", "2.0.0") def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index 7696fdf2dcbed..cd90e97cc5388 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -23,7 +23,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.Loader -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} /** * Helper methods for import/export of GLM regression models. @@ -47,7 +47,7 @@ private[regression] object GLMRegressionModel { modelClass: String, weights: Vector, intercept: Double): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // Create JSON metadata. val metadata = compact(render( @@ -57,7 +57,7 @@ private[regression] object GLMRegressionModel { // Create Parquet data. val data = Data(weights, intercept) - sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(Loader.dataPath(path)) + spark.createDataFrame(Seq(data)).repartition(1).write.parquet(Loader.dataPath(path)) } /** @@ -67,17 +67,17 @@ private[regression] object GLMRegressionModel { * The length of the weights vector should equal numFeatures. */ def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { - val datapath = Loader.dataPath(path) - val sqlContext = SQLContext.getOrCreate(sc) - val dataRDD = sqlContext.read.parquet(datapath) + val dataPath = Loader.dataPath(path) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val dataRDD = spark.read.parquet(dataPath) val dataArray = dataRDD.select("weights", "intercept").take(1) - assert(dataArray.length == 1, s"Unable to load $modelClass data from: $datapath") + assert(dataArray.length == 1, s"Unable to load $modelClass data from: $dataPath") val data = dataArray(0) - assert(data.size == 2, s"Unable to load $modelClass data from: $datapath") + assert(data.size == 2, s"Unable to load $modelClass data from: $dataPath") data match { case Row(weights: Vector, intercept: Double) => assert(weights.size == numFeatures, s"Expected $numFeatures features, but" + - s" found ${weights.size} features when loading $modelClass weights from $datapath") + s" found ${weights.size} features when loading $modelClass weights from $dataPath") Data(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 98404be2603c7..d4de0fd7d5f4d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} * :: DeveloperApi :: * MultivariateOnlineSummarizer implements [[MultivariateStatisticalSummary]] to compute the mean, * variance, minimum, maximum, counts, and nonzero counts for instances in sparse or dense vector - * format in a online fashion. + * format in an online fashion. * * Two MultivariateOnlineSummarizer can be merged together to have a statistical summary of * the corresponding joint dataset. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala index 515be0b817835..e478c31bc9a05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala @@ -55,7 +55,7 @@ private[stat] object PearsonCorrelation extends Correlation with Logging { * 0 variance results in a correlation value of Double.NaN. */ def computeCorrelationMatrixFromCovariance(covarianceMatrix: Matrix): Matrix = { - val cov = covarianceMatrix.toBreeze.asInstanceOf[BDM[Double]] + val cov = covarianceMatrix.asBreeze.asInstanceOf[BDM[Double]] val n = cov.cols // Compute the standard deviation on the diagonals first diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 6c6e9fb7c6b3d..39c3644450d6d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -42,7 +42,7 @@ class MultivariateGaussian @Since("1.3.0") ( require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - private val breezeMu = mu.toBreeze.toDenseVector + private val breezeMu = mu.asBreeze.toDenseVector /** * private[mllib] constructor @@ -64,17 +64,17 @@ class MultivariateGaussian @Since("1.3.0") ( /** * Returns density of this multivariate Gaussian at given point, x */ - @Since("1.3.0") + @Since("1.3.0") def pdf(x: Vector): Double = { - pdf(x.toBreeze) + pdf(x.asBreeze) } /** * Returns the log-density of this multivariate Gaussian at given point, x */ - @Since("1.3.0") + @Since("1.3.0") def logpdf(x: Vector): Double = { - logpdf(x.toBreeze) + logpdf(x.asBreeze) } /** Returns density of this multivariate Gaussian at given point, x */ @@ -118,7 +118,7 @@ class MultivariateGaussian @Since("1.3.0") ( * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (DBM[Double], Double) = { - val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t + val eigSym.EigSym(d, u) = eigSym(sigma.asBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 76ca6a8abd032..da5df9bf45e5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -110,7 +110,7 @@ private[stat] object ChiSqTest extends Logging { } i += 1 distinctLabels += label - val brzFeatures = features.toBreeze + val brzFeatures = features.asBreeze (startCol until endCol).map { col => val feature = brzFeatures(col) allDistinctFeatures(col) += feature diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala index 9748fbf2c97b6..c3de5d75f4f7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala @@ -45,7 +45,7 @@ import org.apache.spark.rdd.RDD * many elements are in each partition. Once these three values have been returned for every * partition, we can collect and operate locally. Locally, we can now adjust each distance by the * appropriate constant (the cumulative sum of number of elements in the prior partitions divided by - * thedata set size). Finally, we take the maximum absolute value, and this is the statistic. + * the data set size). Finally, we take the maximum absolute value, and this is the statistic. */ private[stat] object KolmogorovSmirnovTest extends Logging { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala index 4c382d7c2b791..97c032de7a813 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/StreamingTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.stat.test import scala.beans.BeanInfo -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.streaming.api.java.JavaDStream import org.apache.spark.streaming.dstream.DStream @@ -42,7 +42,6 @@ case class BinarySample @Since("1.6.0") ( } /** - * :: Experimental :: * Performs online 2-sample significance testing for a stream of (Boolean, Double) pairs. The * Boolean identifies which sample each observation comes from, and the Double is the numeric value * of the observation. @@ -67,7 +66,6 @@ case class BinarySample @Since("1.6.0") ( * .registerStream(DStream) * }}} */ -@Experimental @Since("1.6.0") class StreamingTest @Since("1.6.0") () extends Logging with Serializable { private var peacePeriod: Int = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala index 8a29fd39a9106..5cfc05a3dd2d2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.stat.test -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since /** * Trait for hypothesis test results. @@ -94,10 +94,8 @@ class ChiSqTestResult private[stat] (override val pValue: Double, } /** - * :: Experimental :: * Object containing the test results for the Kolmogorov-Smirnov test. */ -@Experimental @Since("1.5.0") class KolmogorovSmirnovTestResult private[stat] ( @Since("1.5.0") override val pValue: Double, @@ -113,10 +111,8 @@ class KolmogorovSmirnovTestResult private[stat] ( } /** - * :: Experimental :: * Object containing the test results for streaming testing. */ -@Experimental @Since("1.6.0") private[stat] class StreamingTestResult @Since("1.6.0") ( @Since("1.6.0") override val pValue: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 7fe60e2d99e4f..ece1e41d986d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint} import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.BoostingStrategy @@ -66,7 +67,9 @@ class GradientBoostedTrees private[spark] ( @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo - val (trees, treeWeights) = NewGBT.run(input, boostingStrategy, seed.toLong) + val (trees, treeWeights) = NewGBT.run(input.map { point => + NewLabeledPoint(point.label, point.features.asML) + }, boostingStrategy, seed.toLong) new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } @@ -94,8 +97,11 @@ class GradientBoostedTrees private[spark] ( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo - val (trees, treeWeights) = NewGBT.runWithValidation(input, validationInput, boostingStrategy, - seed.toLong) + val (trees, treeWeights) = NewGBT.runWithValidation(input.map { point => + NewLabeledPoint(point.label, point.features.asML) + }, validationInput.map { point => + NewLabeledPoint(point.label, point.features.asML) + }, boostingStrategy, seed.toLong) new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index c31ed9c1ce0b8..14f11ce51b878 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -90,8 +90,8 @@ private class RandomForest ( * @return RandomForestModel that can be used for prediction. */ def run(input: RDD[LabeledPoint]): RandomForestModel = { - val trees: Array[NewDTModel] = - NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong, None) + val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees, + featureSubsetStrategy, seed.toLong, None) new RandomForestModel(strategy.algo, trees.map(_.toOld)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 853c7319ec44d..2436ce40866e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -17,14 +17,12 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since /** - * :: Experimental :: * Enum to select the algorithm for the decision tree */ @Since("1.0.0") -@Experimental object Algo extends Enumeration { @Since("1.0.0") type Algo = Value diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index ff7700d2d1b7f..d4448da9eef51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -17,15 +17,12 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} /** - * :: Experimental :: - * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during - * binary classification. + * Class for calculating entropy during multiclass classification. */ @Since("1.0.0") -@Experimental object Entropy extends Impurity { private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 58dc79b7398e2..c5e34ffa4f2e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -17,16 +17,14 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} /** - * :: Experimental :: - * Class for calculating the - * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] - * during binary classification. + * Class for calculating the Gini impurity + * (http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity) + * during multiclass classification. */ @Since("1.0.0") -@Experimental object Gini extends Impurity { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 65f0163ec6059..a5bdc2c6d2c94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -17,17 +17,15 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} /** - * :: Experimental :: * Trait for calculating information gain. * This trait is used for * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]] * (b) calculating impurity values from sufficient statistics. */ @Since("1.0.0") -@Experimental trait Impurity extends Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 2423516123b82..c9bf0db4de3c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -17,14 +17,12 @@ package org.apache.spark.mllib.tree.impurity -import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Since} /** - * :: Experimental :: * Class for calculating variance during regression */ @Since("1.0.0") -@Experimental object Variance extends Impurity { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index a87f8a6cde318..a1562384b0a7e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.util.Utils /** @@ -75,8 +75,8 @@ class DecisionTreeModel @Since("1.0.0") ( * @return JavaRDD of predictions for each of the given data points */ @Since("1.2.0") - def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { - predict(features.rdd) + def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = { + predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] } /** @@ -202,9 +202,6 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { } def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ - // SPARK-6120: We do a hacky check here so users understand why save() is failing // when they run the ML guide example. // TODO: Fix this issue for real. @@ -235,17 +232,16 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { // Create Parquet data. val nodes = model.topNode.subtreeIterator.toSeq - val dataRDD: DataFrame = sc.parallelize(nodes) - .map(NodeData.apply(0, _)) - .toDF() - dataRDD.write.parquet(Loader.dataPath(path)) + val dataRDD = sc.parallelize(nodes).map(NodeData.apply(0, _)) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { - val datapath = Loader.dataPath(path) - val sqlContext = SQLContext.getOrCreate(sc) // Load Parquet data. - val dataRDD = sqlContext.read.parquet(datapath) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + val dataPath = Loader.dataPath(path) + val dataRDD = spark.read.parquet(dataPath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[NodeData](dataRDD.schema) val nodes = dataRDD.rdd.map(NodeData.apply) @@ -254,7 +250,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { assert(trees.length == 1, "Decision tree should contain exactly one tree but got ${trees.size} trees.") val model = new DecisionTreeModel(trees(0), Algo.fromString(algo)) - assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." + + assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $dataPath." + s" Expected $numNodes nodes but found ${model.numNodes}") model } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 06ceff19d8633..1dbdd2d860efd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.{DeveloperApi, Since} /** + * :: DeveloperApi :: * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index cbf49b6d5821a..f7d9b22b6f424 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -36,7 +36,7 @@ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ import org.apache.spark.mllib.tree.loss.Loss import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.util.Utils /** @@ -413,8 +413,7 @@ private[tree] object TreeEnsembleModel extends Logging { case class EnsembleNodeData(treeId: Int, node: NodeData) def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() // SPARK-6120: We do a hacky check here so users understand why save() is failing // when they run the ML guide example. @@ -450,8 +449,8 @@ private[tree] object TreeEnsembleModel extends Logging { // Create Parquet data. val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) - }.toDF() - dataRDD.write.parquet(Loader.dataPath(path)) + } + spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } /** @@ -472,10 +471,10 @@ private[tree] object TreeEnsembleModel extends Logging { sc: SparkContext, path: String, treeAlgo: String): Array[DecisionTreeModel] = { - val datapath = Loader.dataPath(path) - val sqlContext = SQLContext.getOrCreate(sc) - val nodes = sqlContext.read.parquet(datapath).rdd.map(NodeData.apply) - val trees = constructTrees(nodes) + val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + import spark.implicits._ + val nodes = spark.read.parquet(Loader.dataPath(path)).map(NodeData.apply) + val trees = constructTrees(nodes.rdd) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 774170ff401e9..e96c2bc6edfc3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -17,14 +17,19 @@ package org.apache.spark.mllib.util +import scala.annotation.varargs import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.{MatrixUDT => MLMatrixUDT, VectorUDT => MLVectorUDT} +import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.BernoulliCellSampler @@ -32,7 +37,7 @@ import org.apache.spark.util.random.BernoulliCellSampler * Helper methods to load, save and pre-process data used in ML Lib. */ @Since("0.8.0") -object MLUtils { +object MLUtils extends Logging { private[mllib] lazy val EPSILON = { var eps = 1.0 @@ -50,7 +55,6 @@ object MLUtils { * where the indices are one-based and in ascending order. * This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]], * where the feature indices are converted to zero-based. - * * @param sc Spark context * @param path file or directory path in any Hadoop-supported file system URI * @param numFeatures number of features, which will be determined from the input data if a @@ -104,7 +108,7 @@ object MLUtils { val (indices, values) = items.tail.filter(_.nonEmpty).map { item => val indexAndValue = item.split(':') val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. - val value = indexAndValue(1).toDouble + val value = indexAndValue(1).toDouble (index, value) }.unzip @@ -119,7 +123,6 @@ object MLUtils { previous = current i += 1 } - (label, indices.toArray, values.toArray) } @@ -146,7 +149,6 @@ object MLUtils { * Save labeled data in LIBSVM format. * @param data an RDD of LabeledPoint to be saved * @param dir directory to save the data - * * @see [[org.apache.spark.mllib.util.MLUtils#loadLibSVMFile]] */ @Since("1.0.0") @@ -254,6 +256,211 @@ object MLUtils { } } + /** + * Converts vector columns in an input Dataset from the [[org.apache.spark.mllib.linalg.Vector]] + * type to the new [[org.apache.spark.ml.linalg.Vector]] type under the `spark.ml` package. + * @param dataset input dataset + * @param cols a list of vector columns to be converted. New vector columns will be ignored. If + * unspecified, all old vector columns will be converted except nested ones. + * @return the input [[DataFrame]] with old vector columns converted to the new vector type + */ + @Since("2.0.0") + @varargs + def convertVectorColumnsToML(dataset: Dataset[_], cols: String*): DataFrame = { + val schema = dataset.schema + val colSet = if (cols.nonEmpty) { + cols.flatMap { c => + val dataType = schema(c).dataType + if (dataType.getClass == classOf[VectorUDT]) { + Some(c) + } else { + // ignore new vector columns and raise an exception on other column types + require(dataType.getClass == classOf[MLVectorUDT], + s"Column $c must be old Vector type to be converted to new type but got $dataType.") + None + } + }.toSet + } else { + schema.fields + .filter(_.dataType.getClass == classOf[VectorUDT]) + .map(_.name) + .toSet + } + + if (colSet.isEmpty) { + return dataset.toDF() + } + + logWarning("Vector column conversion has serialization overhead. " + + "Please migrate your datasets and workflows to use the spark.ml package.") + + // TODO: This implementation has performance issues due to unnecessary serialization. + // TODO: It is better (but trickier) if we can cast the old vector type to new type directly. + val convertToML = udf { v: Vector => v.asML } + val exprs = schema.fields.map { field => + val c = field.name + if (colSet.contains(c)) { + convertToML(col(c)).as(c, field.metadata) + } else { + col(c) + } + } + dataset.select(exprs: _*) + } + + /** + * Converts vector columns in an input Dataset to the [[org.apache.spark.mllib.linalg.Vector]] + * type from the new [[org.apache.spark.ml.linalg.Vector]] type under the `spark.ml` package. + * @param dataset input dataset + * @param cols a list of vector columns to be converted. Old vector columns will be ignored. If + * unspecified, all new vector columns will be converted except nested ones. + * @return the input [[DataFrame]] with new vector columns converted to the old vector type + */ + @Since("2.0.0") + @varargs + def convertVectorColumnsFromML(dataset: Dataset[_], cols: String*): DataFrame = { + val schema = dataset.schema + val colSet = if (cols.nonEmpty) { + cols.flatMap { c => + val dataType = schema(c).dataType + if (dataType.getClass == classOf[MLVectorUDT]) { + Some(c) + } else { + // ignore old vector columns and raise an exception on other column types + require(dataType.getClass == classOf[VectorUDT], + s"Column $c must be new Vector type to be converted to old type but got $dataType.") + None + } + }.toSet + } else { + schema.fields + .filter(_.dataType.getClass == classOf[MLVectorUDT]) + .map(_.name) + .toSet + } + + if (colSet.isEmpty) { + return dataset.toDF() + } + + logWarning("Vector column conversion has serialization overhead. " + + "Please migrate your datasets and workflows to use the spark.ml package.") + + // TODO: This implementation has performance issues due to unnecessary serialization. + // TODO: It is better (but trickier) if we can cast the new vector type to old type directly. + val convertFromML = udf { Vectors.fromML _ } + val exprs = schema.fields.map { field => + val c = field.name + if (colSet.contains(c)) { + convertFromML(col(c)).as(c, field.metadata) + } else { + col(c) + } + } + dataset.select(exprs: _*) + } + + /** + * Converts Matrix columns in an input Dataset from the [[org.apache.spark.mllib.linalg.Matrix]] + * type to the new [[org.apache.spark.ml.linalg.Matrix]] type under the `spark.ml` package. + * @param dataset input dataset + * @param cols a list of matrix columns to be converted. New matrix columns will be ignored. If + * unspecified, all old matrix columns will be converted except nested ones. + * @return the input [[DataFrame]] with old matrix columns converted to the new matrix type + */ + @Since("2.0.0") + @varargs + def convertMatrixColumnsToML(dataset: Dataset[_], cols: String*): DataFrame = { + val schema = dataset.schema + val colSet = if (cols.nonEmpty) { + cols.flatMap { c => + val dataType = schema(c).dataType + if (dataType.getClass == classOf[MatrixUDT]) { + Some(c) + } else { + // ignore new matrix columns and raise an exception on other column types + require(dataType.getClass == classOf[MLMatrixUDT], + s"Column $c must be old Matrix type to be converted to new type but got $dataType.") + None + } + }.toSet + } else { + schema.fields + .filter(_.dataType.getClass == classOf[MatrixUDT]) + .map(_.name) + .toSet + } + + if (colSet.isEmpty) { + return dataset.toDF() + } + + logWarning("Matrix column conversion has serialization overhead. " + + "Please migrate your datasets and workflows to use the spark.ml package.") + + val convertToML = udf { v: Matrix => v.asML } + val exprs = schema.fields.map { field => + val c = field.name + if (colSet.contains(c)) { + convertToML(col(c)).as(c, field.metadata) + } else { + col(c) + } + } + dataset.select(exprs: _*) + } + + /** + * Converts matrix columns in an input Dataset to the [[org.apache.spark.mllib.linalg.Matrix]] + * type from the new [[org.apache.spark.ml.linalg.Matrix]] type under the `spark.ml` package. + * @param dataset input dataset + * @param cols a list of matrix columns to be converted. Old matrix columns will be ignored. If + * unspecified, all new matrix columns will be converted except nested ones. + * @return the input [[DataFrame]] with new matrix columns converted to the old matrix type + */ + @Since("2.0.0") + @varargs + def convertMatrixColumnsFromML(dataset: Dataset[_], cols: String*): DataFrame = { + val schema = dataset.schema + val colSet = if (cols.nonEmpty) { + cols.flatMap { c => + val dataType = schema(c).dataType + if (dataType.getClass == classOf[MLMatrixUDT]) { + Some(c) + } else { + // ignore old matrix columns and raise an exception on other column types + require(dataType.getClass == classOf[MatrixUDT], + s"Column $c must be new Matrix type to be converted to old type but got $dataType.") + None + } + }.toSet + } else { + schema.fields + .filter(_.dataType.getClass == classOf[MLMatrixUDT]) + .map(_.name) + .toSet + } + + if (colSet.isEmpty) { + return dataset.toDF() + } + + logWarning("Matrix column conversion has serialization overhead. " + + "Please migrate your datasets and workflows to use the spark.ml package.") + + val convertFromML = udf { Matrices.fromML _ } + val exprs = schema.fields.map { field => + val c = field.name + if (colSet.contains(c)) { + convertFromML(col(c)).as(c, field.metadata) + } else { + col(c) + } + } + dataset.select(exprs: _*) + } + + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: @@ -262,7 +469,6 @@ object MLUtils { *
    * When both vector norms are given, this is faster than computing the squared distance directly, * especially when one of the vectors is a sparse vector. - * * @param v1 the first vector * @param norm1 the norm of the first vector, non-negative * @param v2 the second vector @@ -315,7 +521,6 @@ object MLUtils { * When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic * overflow. This will happen when `x > 709.78` which is not a very large number. * It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`. - * * @param x a floating-point value as input. * @return the result of `math.log(1 + math.exp(x))`. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index cde5979396178..c9468606544db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -55,7 +55,7 @@ object SVMDataGenerator { val sc = new SparkContext(sparkMaster, "SVMGenerator") val globalRnd = new Random(94720) - val trueWeights = Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()) + val trueWeights = Array.fill[Double](nfeatures)(globalRnd.nextGaussian()) val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx => val rnd = new Random(42 + idx) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala index 4d71d534a0774..c881c8ea50c09 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -45,7 +45,7 @@ trait Saveable { * - human-readable (JSON) model metadata to path/metadata/ * - Parquet formatted data to path/data/ * - * The model may be loaded using [[Loader.load]]. + * The model may be loaded using `Loader.load`. * * @param sc Spark context used to save model data. * @param path Path specifying the directory in which to save this model. diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java new file mode 100644 index 0000000000000..43779878890db --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.IOException; +import java.io.Serializable; + +import org.junit.After; +import org.junit.Before; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; + +public abstract class SharedSparkSession implements Serializable { + + protected transient SparkSession spark; + protected transient JavaSparkContext jsc; + + @Before + public void setUp() throws IOException { + spark = SparkSession.builder() + .master("local[2]") + .appName(getClass().getSimpleName()) + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 60a4a1d2ea2af..9b209006bc369 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -17,42 +17,32 @@ package org.apache.spark.ml; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.junit.After; -import org.junit.Before; +import java.io.IOException; + import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; +import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.SQLContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; /** * Test Pipeline construction and fitting in Java. */ -public class JavaPipelineSuite { +public class JavaPipelineSuite extends SharedSparkSession { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset dataset; - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaPipelineSuite"); - jsql = new SQLContext(jsc); + @Override + public void setUp() throws IOException { + super.setUp(); JavaRDD points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); - dataset = jsql.createDataFrame(points, LabeledPoint.class); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; + dataset = spark.createDataFrame(points, LabeledPoint.class); } @Test @@ -63,10 +53,10 @@ public void pipeline() { LogisticRegression lr = new LogisticRegression() .setFeaturesCol("scaledFeatures"); Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {scaler, lr}); + .setStages(new PipelineStage[]{scaler, lr}); PipelineModel model = pipeline.fit(dataset); - model.transform(dataset).registerTempTable("prediction"); - Dataset predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + model.transform(dataset).createOrReplaceTempView("prediction"); + Dataset predictions = spark.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java index b74bbed231434..15cde0d3c045f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java @@ -17,8 +17,8 @@ package org.apache.spark.ml.attribute; -import org.junit.Test; import org.junit.Assert; +import org.junit.Test; public class JavaAttributeSuite { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 1f23682621594..5aba4e8f7de07 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -17,37 +17,19 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.tree.impl.TreeTests; -import org.apache.spark.mllib.classification.LogisticRegressionSuite; -import org.apache.spark.mllib.regression.LabeledPoint; - - -public class JavaDecisionTreeClassifierSuite implements Serializable { - - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite"); - } +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaDecisionTreeClassifierSuite extends SharedSparkSession { @Test public void runDT() { @@ -55,7 +37,7 @@ public void runDT() { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -70,7 +52,7 @@ public void runDT() { .setCacheNodeIds(false) .setCheckpointInterval(10) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: DecisionTreeClassifier.supportedImpurities()) { + for (String impurity : DecisionTreeClassifier.supportedImpurities()) { dt.setImpurity(impurity); } DecisionTreeClassificationModel model = dt.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java index 74841058a21b0..74bb46bd217a9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -17,37 +17,19 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.tree.impl.TreeTests; -import org.apache.spark.mllib.classification.LogisticRegressionSuite; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; - -public class JavaGBTClassifierSuite implements Serializable { - - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaGBTClassifierSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaGBTClassifierSuite extends SharedSparkSession { @Test public void runDT() { @@ -55,7 +37,7 @@ public void runDT() { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -74,7 +56,7 @@ public void runDT() { .setMaxIter(3) .setStepSize(0.1) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String lossType: GBTClassifier.supportedLossTypes()) { + for (String lossType : GBTClassifier.supportedLossTypes()) { rf.setLossType(lossType); } GBTClassificationModel model = rf.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index e160a5a47e304..004102103d52c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -17,47 +17,34 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; +import java.io.IOException; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.regression.LabeledPoint; +import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.ml.feature.LabeledPoint; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +public class JavaLogisticRegressionSuite extends SharedSparkSession { -public class JavaLogisticRegressionSuite implements Serializable { - - private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset dataset; private transient JavaRDD datasetRDD; private double eps = 1e-5; - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - jsql = new SQLContext(jsc); + @Override + public void setUp() throws IOException { + super.setUp(); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); - dataset.registerTempTable("dataset"); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; + dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.createOrReplaceTempView("dataset"); } @Test @@ -65,8 +52,8 @@ public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); - model.transform(dataset).registerTempTable("prediction"); - Dataset predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + model.transform(dataset).createOrReplaceTempView("prediction"); + Dataset predictions = spark.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults Assert.assertEquals(0.5, model.getThreshold(), eps); @@ -94,24 +81,24 @@ public void logisticRegressionWithSetters() { // Modify model params, and check that the params worked. model.setThreshold(1.0); - model.transform(dataset).registerTempTable("predAllZero"); - Dataset predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); - for (Row r: predAllZero.collectAsList()) { + model.transform(dataset).createOrReplaceTempView("predAllZero"); + Dataset predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero"); + for (Row r : predAllZero.collectAsList()) { Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) - .registerTempTable("predNotAllZero"); - Dataset predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + .createOrReplaceTempView("predNotAllZero"); + Dataset predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; - for (Row r: predNotAllZero.collectAsList()) { + for (Row r : predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; } Assert.assertTrue(foundNonZero); // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); Assert.assertEquals(5, parent2.getMaxIter()); Assert.assertEquals(0.1, parent2.getRegParam(), eps); @@ -127,11 +114,11 @@ public void logisticRegressionPredictorClassifierMethods() { LogisticRegressionModel model = lr.fit(dataset); Assert.assertEquals(2, model.numClasses()); - model.transform(dataset).registerTempTable("transformed"); - Dataset trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); - for (Row row: trans1.collectAsList()) { - Vector raw = (Vector)row.get(0); - Vector prob = (Vector)row.get(1); + model.transform(dataset).createOrReplaceTempView("transformed"); + Dataset trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row : trans1.collectAsList()) { + Vector raw = (Vector) row.get(0); + Vector prob = (Vector) row.get(1); Assert.assertEquals(raw.size(), 2); Assert.assertEquals(prob.size(), 2); double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); @@ -139,11 +126,11 @@ public void logisticRegressionPredictorClassifierMethods() { Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } - Dataset trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); - for (Row row: trans2.collectAsList()) { + Dataset trans2 = spark.sql("SELECT prediction, probability FROM transformed"); + for (Row row : trans2.collectAsList()) { double pred = row.getDouble(0); - Vector prob = (Vector)row.get(1); - double probOfPred = prob.apply((int)pred); + Vector prob = (Vector) row.get(1); + double probOfPred = prob.apply((int) pred); for (int i = 0; i < prob.size(); ++i) { Assert.assertTrue(probOfPred >= prob.apply(i)); } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index bc955f3cf6b0f..6d0604d8f9a5a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -17,58 +17,39 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.feature.LabeledPoint; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -public class JavaMultilayerPerceptronClassifierSuite implements Serializable { - - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - sqlContext = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - sqlContext = null; - } +public class JavaMultilayerPerceptronClassifierSuite extends SharedSparkSession { @Test public void testMLPC() { - Dataset dataFrame = sqlContext.createDataFrame( - jsc.parallelize(Arrays.asList( - new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), - new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), - new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))), - LabeledPoint.class); + List data = Arrays.asList( + new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)) + ); + Dataset dataFrame = spark.createDataFrame(data, LabeledPoint.class); + MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() - .setLayers(new int[] {2, 5, 2}) + .setLayers(new int[]{2, 5, 2}) .setBlockSize(1) .setSeed(123L) .setMaxIter(100); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); Dataset result = model.transform(dataFrame); List predictionAndLabels = result.select("prediction", "label").collectAsList(); - for (Row r: predictionAndLabels) { + for (Row r : predictionAndLabels) { Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 45101f286c6d2..c2a9e7b58b470 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -17,43 +17,24 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaNaiveBayesSuite implements Serializable { - - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaNaiveBayesSuite extends SharedSparkSession { public void validatePrediction(Dataset predictionAndLabels) { for (Row r : predictionAndLabels.collectAsList()) { @@ -88,7 +69,7 @@ public void testNaiveBayes() { new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - Dataset dataset = jsql.createDataFrame(data, schema); + Dataset dataset = spark.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index 00f4476841af1..6194167bda354 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -17,69 +17,57 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; +import java.io.IOException; import java.util.List; -import org.apache.spark.sql.Row; import scala.collection.JavaConverters; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; -import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SQLContext; - -public class JavaOneVsRestSuite implements Serializable { +import org.apache.spark.sql.Row; +import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - private transient Dataset dataset; - private transient JavaRDD datasetRDD; +public class JavaOneVsRestSuite extends SharedSparkSession { - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); - jsql = new SQLContext(jsc); - int nPoints = 3; + private transient Dataset dataset; + private transient JavaRDD datasetRDD; - // The following coefficients and xMean/xVariance are computed from iris dataset with - // lambda=0.2. - // As a result, we are drawing samples from probability distribution of an actual model. - double[] coefficients = { - -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, - -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; + @Override + public void setUp() throws IOException { + super.setUp(); + int nPoints = 3; - double[] xMean = {5.843, 3.057, 3.758, 1.199}; - double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List points = JavaConverters.seqAsJavaListConverter( - generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) - ).asJava(); - datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); - } + // The following coefficients and xMean/xVariance are computed from iris dataset with + // lambda=0.2. + // As a result, we are drawing samples from probability distribution of an actual model. + double[] coefficients = { + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682}; - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } + double[] xMean = {5.843, 3.057, 3.758, 1.199}; + double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; + List points = JavaConverters.seqAsJavaListConverter( + generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) + ).asJava(); + datasetRDD = jsc.parallelize(points, 2); + dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); + } - @Test - public void oneVsRestDefaultParams() { - OneVsRest ova = new OneVsRest(); - ova.setClassifier(new LogisticRegression()); - Assert.assertEquals(ova.getLabelCol() , "label"); - Assert.assertEquals(ova.getPredictionCol() , "prediction"); - OneVsRestModel ovaModel = ova.fit(dataset); - Dataset predictions = ovaModel.transform(dataset).select("label", "prediction"); - predictions.collectAsList(); - Assert.assertEquals(ovaModel.getLabelCol(), "label"); - Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); - } + @Test + public void oneVsRestDefaultParams() { + OneVsRest ova = new OneVsRest(); + ova.setClassifier(new LogisticRegression()); + Assert.assertEquals(ova.getLabelCol(), "label"); + Assert.assertEquals(ova.getPredictionCol(), "prediction"); + OneVsRestModel ovaModel = ova.fit(dataset); + Dataset predictions = ovaModel.transform(dataset).select("label", "prediction"); + predictions.collectAsList(); + Assert.assertEquals(ovaModel.getLabelCol(), "label"); + Assert.assertEquals(ovaModel.getPredictionCol(), "prediction"); + } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 4f40fd65b9f11..dd98513f37ecf 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -17,39 +17,21 @@ package org.apache.spark.ml.classification; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.feature.LabeledPoint; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.tree.impl.TreeTests; -import org.apache.spark.mllib.classification.LogisticRegressionSuite; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; - -public class JavaRandomForestClassifierSuite implements Serializable { - - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaRandomForestClassifierSuite extends SharedSparkSession { @Test public void runDT() { @@ -57,7 +39,7 @@ public void runDT() { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -75,22 +57,22 @@ public void runDT() { .setSeed(1234) .setNumTrees(3) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: RandomForestClassifier.supportedImpurities()) { + for (String impurity : RandomForestClassifier.supportedImpurities()) { rf.setImpurity(impurity); } - for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { + for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; - for (String strategy: realStrategies) { + for (String strategy : realStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] integerStrategies = {"1", "10", "100", "1000", "10000"}; - for (String strategy: integerStrategies) { + for (String strategy : integerStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; - for (String strategy: invalidStrategies) { + for (String strategy : invalidStrategies) { try { rf.setFeatureSubsetStrategy(strategy); Assert.fail("Expected exception to be thrown for invalid strategies"); diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java index a3fcdb54ee7ad..1be6f96f4c942 100644 --- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -17,41 +17,28 @@ package org.apache.spark.ml.clustering; -import java.io.Serializable; +import java.io.IOException; import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -public class JavaKMeansSuite implements Serializable { +public class JavaKMeansSuite extends SharedSparkSession { private transient int k = 5; - private transient JavaSparkContext sc; private transient Dataset dataset; - private transient SQLContext sql; - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaKMeansSuite"); - sql = new SQLContext(sc); - - dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; + @Override + public void setUp() throws IOException { + super.setUp(); + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k); } @Test @@ -65,7 +52,7 @@ public void fitAndTransform() { Dataset transformed = model.transform(dataset); List columns = Arrays.asList(transformed.columns()); List expectedColumns = Arrays.asList("features", "prediction"); - for (String column: expectedColumns) { + for (String column : expectedColumns) { assertTrue(columns.contains(column)); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index 77e3a489a93aa..87639380bdcf4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -20,45 +20,28 @@ import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaBucketizerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaBucketizerSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaBucketizerSuite extends SharedSparkSession { @Test public void bucketizerTest() { double[] splits = {-0.5, 0.0, 0.5}; - StructType schema = new StructType(new StructField[] { + StructType schema = new StructType(new StructField[]{ new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - Dataset dataset = jsql.createDataFrame( + Dataset dataset = spark.createDataFrame( Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index ed1ad4c3a316a..b7956b6fd3e9a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -21,43 +21,27 @@ import java.util.List; import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; -import org.junit.After; + import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaDCTSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaDCTSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaDCTSuite extends SharedSparkSession { @Test public void javaCompatibilityTest() { - double[] input = new double[] {1D, 2D, 3D, 4D}; - Dataset dataset = jsql.createDataFrame( + double[] input = new double[]{1D, 2D, 3D, 4D}; + Dataset dataset = spark.createDataFrame( Arrays.asList(RowFactory.create(Vectors.dense(input))), new StructType(new StructField[]{ new StructField("vec", (new VectorUDT()), false, Metadata.empty()) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 6e2cc7e8877c6..57696d0150a8b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -20,38 +20,21 @@ import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaHashingTFSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaHashingTFSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaHashingTFSuite extends SharedSparkSession { @Test public void hashingTF() { @@ -65,7 +48,7 @@ public void hashingTF() { new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - Dataset sentenceData = jsql.createDataFrame(data, schema); + Dataset sentenceData = spark.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer() .setInputCol("sentence") .setOutputCol("words"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index 5bbd9634b2c27..6f877b566875c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -19,32 +19,15 @@ import java.util.Arrays; -import org.junit.After; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -public class JavaNormalizerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaNormalizerSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaNormalizerSuite extends SharedSparkSession { @Test public void normalizer() { @@ -54,7 +37,7 @@ public void normalizer() { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) )); - Dataset dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); + Dataset dataFrame = spark.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index 1389d17e7e07a..ac479c08418ce 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -23,37 +23,20 @@ import scala.Tuple2; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.function.Function; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - -public class JavaPCASuite implements Serializable { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaPCASuite"); - sqlContext = new SQLContext(jsc); - } - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaPCASuite extends SharedSparkSession { public static class VectorPair implements Serializable { private Vector features = Vectors.dense(0.0); @@ -85,9 +68,25 @@ public void testPCA() { ); JavaRDD dataRDD = jsc.parallelize(points, 2); - RowMatrix mat = new RowMatrix(dataRDD.rdd()); + RowMatrix mat = new RowMatrix(dataRDD.map( + new Function() { + public org.apache.spark.mllib.linalg.Vector call(Vector vector) { + return new org.apache.spark.mllib.linalg.DenseVector(vector.toArray()); + } + } + ).rdd()); + Matrix pc = mat.computePrincipalComponents(3); - JavaRDD expected = mat.multiply(pc).rows().toJavaRDD(); + + mat.multiply(pc).rows().toJavaRDD(); + + JavaRDD expected = mat.multiply(pc).rows().toJavaRDD().map( + new Function() { + public Vector call(org.apache.spark.mllib.linalg.Vector vector) { + return vector.asML(); + } + } + ); JavaRDD featuresExpected = dataRDD.zip(expected).map( new Function, VectorPair>() { @@ -100,7 +99,7 @@ public VectorPair call(Tuple2 pair) { } ); - Dataset df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); + Dataset df = spark.createDataFrame(featuresExpected, VectorPair.class); PCAModel pca = new PCA() .setInputCol("features") .setOutputCol("pca_features") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index 6a8bb6480174a..df5d34fbe94e6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -20,38 +20,21 @@ import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaPolynomialExpansionSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaPolynomialExpansionSuite extends SharedSparkSession { @Test public void polynomialExpansionTest() { @@ -72,20 +55,20 @@ public void polynomialExpansionTest() { ) ); - StructType schema = new StructType(new StructField[] { + StructType schema = new StructType(new StructField[]{ new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("expected", new VectorUDT(), false, Metadata.empty()) }); - Dataset dataset = jsql.createDataFrame(data, schema); + Dataset dataset = spark.createDataFrame(data, schema); List pairs = polyExpansion.transform(dataset) .select("polyFeatures", "expected") .collectAsList(); for (Row r : pairs) { - double[] polyFeatures = ((Vector)r.get(0)).toArray(); - double[] expected = ((Vector)r.get(1)).toArray(); + double[] polyFeatures = ((Vector) r.get(0)).toArray(); + double[] expected = ((Vector) r.get(1)).toArray(); Assert.assertArrayEquals(polyFeatures, expected, 1e-1); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index 3f6fc333e4e13..dbc0b1db5c002 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -20,31 +20,14 @@ import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -public class JavaStandardScalerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaStandardScalerSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaStandardScalerSuite extends SharedSparkSession { @Test public void standardScaler() { @@ -54,7 +37,7 @@ public void standardScaler() { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) ); - Dataset dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + Dataset dataFrame = spark.createDataFrame(jsc.parallelize(points, 2), VectorIndexerSuite.FeatureData.class); StandardScaler scaler = new StandardScaler() .setInputCol("features") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java index bdcbde5e26223..6480b57e1f796 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -20,37 +20,19 @@ import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaStopWordsRemoverSuite { - - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaStopWordsRemoverSuite extends SharedSparkSession { @Test public void javaCompatibilityTest() { @@ -62,11 +44,11 @@ public void javaCompatibilityTest() { RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) ); - StructType schema = new StructType(new StructField[] { + StructType schema = new StructType(new StructField[]{ new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, - Metadata.empty()) + Metadata.empty()) }); - Dataset dataset = jsql.createDataFrame(data, schema); + Dataset dataset = spark.createDataFrame(data, schema); remover.transform(dataset).collect(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 431779cd2e72e..c1928a26b609e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -20,45 +20,29 @@ import java.util.Arrays; import java.util.List; -import org.junit.After; +import static org.apache.spark.sql.types.DataTypes.*; + import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import static org.apache.spark.sql.types.DataTypes.*; - -public class JavaStringIndexerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); - sqlContext = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - sqlContext = null; - } +public class JavaStringIndexerSuite extends SharedSparkSession { @Test public void testStringIndexer() { - StructType schema = createStructType(new StructField[] { + StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); List data = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); - Dataset dataset = sqlContext.createDataFrame(data, schema); + Dataset dataset = spark.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") @@ -70,7 +54,9 @@ public void testStringIndexer() { output.orderBy("id").select("id", "labelIndex").collectAsList()); } - /** An alias for RowFactory.create. */ + /** + * An alias for RowFactory.create. + */ private Row cr(Object... values) { return RowFactory.create(values); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 83d16cbd0e7a1..27550a3d5c373 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -20,32 +20,15 @@ import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -public class JavaTokenizerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaTokenizerSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaTokenizerSuite extends SharedSparkSession { @Test public void regexTokenizer() { @@ -59,10 +42,10 @@ public void regexTokenizer() { JavaRDD rdd = jsc.parallelize(Arrays.asList( - new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), - new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) + new TokenizerTestData("Test of tok.", new String[]{"Test", "tok."}), + new TokenizerTestData("Te,st. punct", new String[]{"Te,st.", "punct"}) )); - Dataset dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + Dataset dataset = spark.createDataFrame(rdd, TokenizerTestData.class); List pairs = myRegExTokenizer.transform(dataset) .select("tokens", "wantedTokens") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index e45e19804345a..583652badb8fb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -19,41 +19,26 @@ import java.util.Arrays; -import org.junit.After; +import static org.apache.spark.sql.types.DataTypes.*; + import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.VectorUDT; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.*; -import static org.apache.spark.sql.types.DataTypes.*; - -public class JavaVectorAssemblerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); - sqlContext = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaVectorAssemblerSuite extends SharedSparkSession { @Test public void testVectorAssembler() { - StructType schema = createStructType(new StructField[] { + StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("x", DoubleType, false), createStructField("y", new VectorUDT(), false), @@ -63,14 +48,14 @@ public void testVectorAssembler() { }); Row row = RowFactory.create( 0, 0.0, Vectors.dense(1.0, 2.0), "a", - Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); - Dataset dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); + Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L); + Dataset dataset = spark.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"x", "y", "z", "n"}) + .setInputCols(new String[]{"x", "y", "z", "n"}) .setOutputCol("features"); Dataset output = assembler.transform(dataset); Assert.assertEquals( - Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), + Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}), output.select("features").first().getAs(0)); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index fec6cac8bec30..ca8fae3a48b9d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -17,37 +17,21 @@ package org.apache.spark.ml.feature; -import java.io.Serializable; import java.util.Arrays; import java.util.List; import java.util.Map; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -public class JavaVectorIndexerSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaVectorIndexerSuite extends SharedSparkSession { @Test public void vectorIndexerAPI() { @@ -57,8 +41,7 @@ public void vectorIndexerAPI() { new FeatureData(Vectors.dense(1.0, 3.0)), new FeatureData(Vectors.dense(1.0, 4.0)) ); - SQLContext sqlContext = new SQLContext(sc); - Dataset data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + Dataset data = spark.createDataFrame(jsc.parallelize(points, 2), FeatureData.class); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index e2da11183b93f..3dc2e1f896143 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -20,39 +20,22 @@ import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.attribute.Attribute; import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.StructType; -public class JavaVectorSlicerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite"); - jsql = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaVectorSlicerSuite extends SharedSparkSession { @Test public void vectorSlice() { @@ -69,7 +52,7 @@ public void vectorSlice() { ); Dataset dataset = - jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); + spark.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index 7517b70cc9bee..d0a849fd11c7e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -19,41 +19,24 @@ import java.util.Arrays; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.types.*; -public class JavaWord2VecSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaWord2VecSuite"); - sqlContext = new SQLContext(jsc); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } +public class JavaWord2VecSuite extends SharedSparkSession { @Test public void testJavaWord2Vec() { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - Dataset documentDF = sqlContext.createDataFrame( + Dataset documentDF = spark.createDataFrame( Arrays.asList( RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), @@ -68,8 +51,8 @@ public void testJavaWord2Vec() { Word2VecModel model = word2Vec.fit(documentDF); Dataset result = model.transform(documentDF); - for (Row r: result.select("result").collectAsList()) { - double[] polyFeatures = ((Vector)r.get(0)).toArray(); + for (Row r : result.select("result").collectAsList()) { + double[] polyFeatures = ((Vector) r.get(0)).toArray(); Assert.assertEquals(polyFeatures.length, 3); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/linalg/JavaSQLDataTypesSuite.java b/mllib/src/test/java/org/apache/spark/ml/linalg/JavaSQLDataTypesSuite.java new file mode 100644 index 0000000000000..bd64a7186eac0 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/linalg/JavaSQLDataTypesSuite.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.linalg; + +import org.junit.Assert; +import org.junit.Test; + +import static org.apache.spark.ml.linalg.SQLDataTypes.*; + +public class JavaSQLDataTypesSuite { + @Test + public void testSQLDataTypes() { + Assert.assertEquals(new VectorUDT(), VectorType()); + Assert.assertEquals(new MatrixUDT(), MatrixType()); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index fa777f3d42a9a..1077e103a3b89 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -19,31 +19,14 @@ import java.util.Arrays; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; - /** * Test Param and related classes in Java */ public class JavaParamsSuite { - private transient JavaSparkContext jsc; - - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaParamsSuite"); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } - @Test public void testParams() { JavaTestParams testParams = new JavaTestParams(); @@ -51,7 +34,7 @@ public void testParams() { testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); Assert.assertEquals(testParams.getMyStringParam(), "a"); - Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0); + Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0}, 0.0); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 06f7fbb86e88e..1ad5f7a442daa 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -45,9 +45,14 @@ public String uid() { } private IntParam myIntParam_; - public IntParam myIntParam() { return myIntParam_; } - public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } + public IntParam myIntParam() { + return myIntParam_; + } + + public int getMyIntParam() { + return (Integer) getOrDefault(myIntParam_); + } public JavaTestParams setMyIntParam(int value) { set(myIntParam_, value); @@ -55,9 +60,14 @@ public JavaTestParams setMyIntParam(int value) { } private DoubleParam myDoubleParam_; - public DoubleParam myDoubleParam() { return myDoubleParam_; } - public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } + public DoubleParam myDoubleParam() { + return myDoubleParam_; + } + + public double getMyDoubleParam() { + return (Double) getOrDefault(myDoubleParam_); + } public JavaTestParams setMyDoubleParam(double value) { set(myDoubleParam_, value); @@ -65,9 +75,14 @@ public JavaTestParams setMyDoubleParam(double value) { } private Param myStringParam_; - public Param myStringParam() { return myStringParam_; } - public String getMyStringParam() { return getOrDefault(myStringParam_); } + public Param myStringParam() { + return myStringParam_; + } + + public String getMyStringParam() { + return getOrDefault(myStringParam_); + } public JavaTestParams setMyStringParam(String value) { set(myStringParam_, value); @@ -75,9 +90,14 @@ public JavaTestParams setMyStringParam(String value) { } private DoubleArrayParam myDoubleArrayParam_; - public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; } - public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); } + public DoubleArrayParam myDoubleArrayParam() { + return myDoubleArrayParam_; + } + + public double[] getMyDoubleArrayParam() { + return getOrDefault(myDoubleArrayParam_); + } public JavaTestParams setMyDoubleArrayParam(double[] value) { set(myDoubleArrayParam_, value); @@ -96,7 +116,7 @@ private void init() { setDefault(myIntParam(), 1); setDefault(myDoubleParam(), 0.5); - setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); + setDefault(myDoubleArrayParam(), new double[]{1.0, 2.0}); } @Override diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index fa3b28ed4f302..1da85ed9dab4e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -17,37 +17,21 @@ package org.apache.spark.ml.regression; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegressionSuite; +import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.tree.impl.TreeTests; -import org.apache.spark.mllib.classification.LogisticRegressionSuite; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -public class JavaDecisionTreeRegressorSuite implements Serializable { - - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaDecisionTreeRegressorSuite extends SharedSparkSession { @Test public void runDT() { @@ -55,7 +39,7 @@ public void runDT() { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); @@ -70,7 +54,7 @@ public void runDT() { .setCacheNodeIds(false) .setCheckpointInterval(10) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: DecisionTreeRegressor.supportedImpurities()) { + for (String impurity : DecisionTreeRegressor.supportedImpurities()) { dt.setImpurity(impurity); } DecisionTreeRegressionModel model = dt.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java index 8413ea0e0a940..7fd9b1feb7f83 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -17,37 +17,21 @@ package org.apache.spark.ml.regression; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegressionSuite; +import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.tree.impl.TreeTests; -import org.apache.spark.mllib.classification.LogisticRegressionSuite; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -public class JavaGBTRegressorSuite implements Serializable { - - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaGBTRegressorSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaGBTRegressorSuite extends SharedSparkSession { @Test public void runDT() { @@ -55,7 +39,7 @@ public void runDT() { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); @@ -73,7 +57,7 @@ public void runDT() { .setMaxIter(3) .setStepSize(0.1) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String lossType: GBTRegressor.supportedLossTypes()) { + for (String lossType : GBTRegressor.supportedLossTypes()) { rf.setLossType(lossType); } GBTRegressionModel model = rf.fit(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 9f817515eb86d..6cdcdda1a6480 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -17,45 +17,30 @@ package org.apache.spark.ml.regression; -import java.io.Serializable; +import java.io.IOException; import java.util.List; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertEquals; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.regression.LabeledPoint; +import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; - -public class JavaLinearRegressionSuite implements Serializable { - - private transient JavaSparkContext jsc; - private transient SQLContext jsql; +public class JavaLinearRegressionSuite extends SharedSparkSession { private transient Dataset dataset; private transient JavaRDD datasetRDD; - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); - jsql = new SQLContext(jsc); + @Override + public void setUp() throws IOException { + super.setUp(); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); - dataset.registerTempTable("dataset"); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; + dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); + dataset.createOrReplaceTempView("dataset"); } @Test @@ -64,8 +49,8 @@ public void linearRegressionDefaultParams() { assertEquals("label", lr.getLabelCol()); assertEquals("auto", lr.getSolver()); LinearRegressionModel model = lr.fit(dataset); - model.transform(dataset).registerTempTable("prediction"); - Dataset predictions = jsql.sql("SELECT label, prediction FROM prediction"); + model.transform(dataset).createOrReplaceTempView("prediction"); + Dataset predictions = spark.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults assertEquals("features", model.getFeaturesCol()); @@ -76,8 +61,8 @@ public void linearRegressionDefaultParams() { public void linearRegressionWithSetters() { // Set params, train, and check as many params as we can. LinearRegression lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(1.0).setSolver("l-bfgs"); + .setMaxIter(10) + .setRegParam(1.0).setSolver("l-bfgs"); LinearRegressionModel model = lr.fit(dataset); LinearRegression parent = (LinearRegression) model.parent(); assertEquals(10, parent.getMaxIter()); @@ -85,7 +70,7 @@ public void linearRegressionWithSetters() { // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = - lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); + lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); LinearRegression parent2 = (LinearRegression) model2.parent(); assertEquals(5, parent2.getMaxIter()); assertEquals(0.1, parent2.getRegParam(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index 38b895f1fdd7d..4ba13e2e06c8d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -17,39 +17,23 @@ package org.apache.spark.ml.regression; -import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.ml.classification.LogisticRegressionSuite; +import org.apache.spark.ml.feature.LabeledPoint; +import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.tree.impl.TreeTests; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -public class JavaRandomForestRegressorSuite implements Serializable { - - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaRandomForestRegressorSuite extends SharedSparkSession { @Test public void runDT() { @@ -57,7 +41,7 @@ public void runDT() { double A = 2.0; double B = -1.5; - JavaRDD data = sc.parallelize( + JavaRDD data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap<>(); Dataset dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); @@ -75,22 +59,22 @@ public void runDT() { .setSeed(1234) .setNumTrees(3) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: RandomForestRegressor.supportedImpurities()) { + for (String impurity : RandomForestRegressor.supportedImpurities()) { rf.setImpurity(impurity); } - for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { + for (String featureSubsetStrategy : RandomForestRegressor.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; - for (String strategy: realStrategies) { + for (String strategy : realStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] integerStrategies = {"1", "10", "100", "1000", "10000"}; - for (String strategy: integerStrategies) { + for (String strategy : integerStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; - for (String strategy: invalidStrategies) { + for (String strategy : invalidStrategies) { try { rf.setFeatureSubsetStrategy(strategy); Assert.fail("Expected exception to be thrown for invalid strategies"); diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 1c18b2b266fef..fa39f4560c8aa 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -23,35 +23,28 @@ import com.google.common.io.Files; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.DenseVector; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.SharedSparkSession; +import org.apache.spark.ml.linalg.DenseVector; +import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; import org.apache.spark.util.Utils; /** * Test LibSVMRelation in Java. */ -public class JavaLibSVMRelationSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; +public class JavaLibSVMRelationSuite extends SharedSparkSession { private File tempDir; private String path; - @Before + @Override public void setUp() throws IOException { - jsc = new JavaSparkContext("local", "JavaLibSVMRelationSuite"); - sqlContext = new SQLContext(jsc); - + super.setUp(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); File file = new File(tempDir, "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; @@ -59,16 +52,15 @@ public void setUp() throws IOException { path = tempDir.toURI().toString(); } - @After + @Override public void tearDown() { - jsc.stop(); - jsc = null; + super.tearDown(); Utils.deleteRecursively(tempDir); } @Test public void verifyLibSVMDF() { - Dataset dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + Dataset dataset = spark.read().format("libsvm").option("vectorType", "dense") .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 24b0097454fe0..692d5ad591e84 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -17,50 +17,39 @@ package org.apache.spark.ml.tuning; -import java.io.Serializable; +import java.io.IOException; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; +import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList; -public class JavaCrossValidatorSuite implements Serializable { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; +public class JavaCrossValidatorSuite extends SharedSparkSession { + private transient Dataset dataset; - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); - jsql = new SQLContext(jsc); + @Override + public void setUp() throws IOException { + super.setUp(); List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); - dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); - } - - @After - public void tearDown() { - jsc.stop(); - jsc = null; + dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class); } @Test public void crossValidationWithLogisticRegression() { LogisticRegression lr = new LogisticRegression(); ParamMap[] lrParamMaps = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.001, 1000.0}) - .addGrid(lr.maxIter(), new int[] {0, 10}) + .addGrid(lr.regParam(), new double[]{0.001, 1000.0}) + .addGrid(lr.maxIter(), new int[]{0, 10}) .build(); BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator(); CrossValidator cv = new CrossValidator() diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala index 928301523fba9..878bc66ee37c3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -37,4 +37,5 @@ object IdentifiableSuite { class Test(override val uid: String) extends Identifiable { def this() = this(Identifiable.randomUID("test")) } + } diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index 01ff1ea658610..e4f678fef1d13 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -20,39 +20,25 @@ import java.io.File; import java.io.IOException; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.SharedSparkSession; import org.apache.spark.util.Utils; -public class JavaDefaultReadWriteSuite { - - JavaSparkContext jsc = null; - SQLContext sqlContext = null; +public class JavaDefaultReadWriteSuite extends SharedSparkSession { File tempDir = null; - @Before - public void setUp() { - jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); - SQLContext.clearActive(); - sqlContext = new SQLContext(jsc); - SQLContext.setActive(sqlContext); + @Override + public void setUp() throws IOException { + super.setUp(); tempDir = Utils.createTempDir( System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); } - @After + @Override public void tearDown() { - sqlContext = null; - SQLContext.clearActive(); - if (jsc != null) { - jsc.stop(); - jsc = null; - } + super.tearDown(); Utils.deleteRecursively(tempDir); } @@ -70,7 +56,7 @@ public void testDefaultReadWrite() throws IOException { } catch (IOException e) { // expected } - instance.write().context(sqlContext).overwrite().save(outputPath); + instance.write().session(spark).overwrite().save(outputPath); MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("Params should be preserved.", diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java index 862221d48798a..c04e2e69541ba 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java @@ -17,36 +17,20 @@ package org.apache.spark.mllib.classification; -import java.io.Serializable; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; - import org.apache.spark.mllib.regression.LabeledPoint; -public class JavaLogisticRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaLogisticRegressionSuite extends SharedSparkSession { int validatePrediction(List validationData, LogisticRegressionModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); if (prediction == point.label()) { numAccurate++; @@ -61,16 +45,16 @@ public void runLRUsingConstructor() { double A = 2.0; double B = -1.5; - JavaRDD testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + JavaRDD testRDD = jsc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); List validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD(); lrImpl.setIntercept(true); lrImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); + .setRegParam(1.0) + .setNumIterations(100); LogisticRegressionModel model = lrImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -83,13 +67,13 @@ public void runLRUsingStaticMethods() { double A = 0.0; double B = -2.5; - JavaRDD testRDD = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + JavaRDD testRDD = jsc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); List validationData = - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17); LogisticRegressionModel model = LogisticRegressionWithSGD.train( - testRDD.rdd(), 100, 1.0, 1.0); + testRDD.rdd(), 100, 1.0, 1.0); int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 3771c0ea7ad83..6ded42e928250 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -17,36 +17,21 @@ package org.apache.spark.mllib.classification; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -public class JavaNaiveBayesSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaNaiveBayesSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaNaiveBayesSuite extends SharedSparkSession { private static final List POINTS = Arrays.asList( new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)), @@ -59,7 +44,7 @@ public void tearDown() { private int validatePrediction(List points, NaiveBayesModel model) { int correct = 0; - for (LabeledPoint p: points) { + for (LabeledPoint p : points) { if (model.predict(p.features()) == p.label()) { correct += 1; } @@ -69,7 +54,7 @@ private int validatePrediction(List points, NaiveBayesModel model) @Test public void runUsingConstructor() { - JavaRDD testRDD = sc.parallelize(POINTS, 2).cache(); + JavaRDD testRDD = jsc.parallelize(POINTS, 2).cache(); NaiveBayes nb = new NaiveBayes().setLambda(1.0); NaiveBayesModel model = nb.run(testRDD.rdd()); @@ -80,7 +65,7 @@ public void runUsingConstructor() { @Test public void runUsingStaticMethods() { - JavaRDD testRDD = sc.parallelize(POINTS, 2).cache(); + JavaRDD testRDD = jsc.parallelize(POINTS, 2).cache(); NaiveBayesModel model1 = NaiveBayes.train(testRDD.rdd()); int numAccurate1 = validatePrediction(POINTS, model1); @@ -93,13 +78,14 @@ public void runUsingStaticMethods() { @Test public void testPredictJavaRDD() { - JavaRDD examples = sc.parallelize(POINTS, 2).cache(); + JavaRDD examples = jsc.parallelize(POINTS, 2).cache(); NaiveBayesModel model = NaiveBayes.train(examples.rdd()); JavaRDD vectors = examples.map(new Function() { @Override public Vector call(LabeledPoint v) throws Exception { return v.features(); - }}); + } + }); JavaRDD predictions = model.predict(vectors); // Should be able to get the first prediction. predictions.first(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java index 31b9f3e8d438e..0f54e684e447d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -17,35 +17,20 @@ package org.apache.spark.mllib.classification; -import java.io.Serializable; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -public class JavaSVMSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaSVMSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaSVMSuite extends SharedSparkSession { int validatePrediction(List validationData, SVMModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); if (prediction == point.label()) { numAccurate++; @@ -60,16 +45,16 @@ public void runSVMUsingConstructor() { double A = 2.0; double[] weights = {-1.5, 1.0}; - JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); + JavaRDD testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); List validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); SVMWithSGD svmSGDImpl = new SVMWithSGD(); svmSGDImpl.setIntercept(true); svmSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(1.0) - .setNumIterations(100); + .setRegParam(1.0) + .setNumIterations(100); SVMModel model = svmSGDImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -82,10 +67,10 @@ public void runSVMUsingStaticMethods() { double A = 0.0; double[] weights = {-1.5, 1.0}; - JavaRDD testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A, - weights, nPoints, 42), 2).cache(); + JavaRDD testRDD = jsc.parallelize(SVMSuite.generateSVMInputAsList(A, + weights, nPoints, 42), 2).cache(); List validationData = - SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); + SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17); SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index 62c6d9b7e390a..8c6bced52dd74 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.classification; -import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -37,7 +36,7 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.apache.spark.streaming.JavaTestUtils.*; -public class JavaStreamingLogisticRegressionSuite implements Serializable { +public class JavaStreamingLogisticRegressionSuite { protected transient JavaStreamingContext ssc; diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java index a714620ff7e4b..3d62b273d2210 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -17,39 +17,24 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; - import com.google.common.collect.Lists; -import org.junit.After; + import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -public class JavaBisectingKMeansSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", this.getClass().getSimpleName()); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaBisectingKMeansSuite extends SharedSparkSession { @Test public void twoDimensionalData() { - JavaRDD points = sc.parallelize(Lists.newArrayList( + JavaRDD points = jsc.parallelize(Lists.newArrayList( Vectors.dense(4, -1), Vectors.dense(4, 1), - Vectors.sparse(2, new int[] {0}, new double[] {1.0}) + Vectors.sparse(2, new int[]{0}, new double[]{1.0}) ), 2); BisectingKMeans bkm = new BisectingKMeans() @@ -58,15 +43,15 @@ public void twoDimensionalData() { .setSeed(1L); BisectingKMeansModel model = bkm.run(points); Assert.assertEquals(3, model.k()); - Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12); - for (ClusteringTreeNode child: model.root().children()) { + Assert.assertArrayEquals(new double[]{3.0, 0.0}, model.root().center().toArray(), 1e-12); + for (ClusteringTreeNode child : model.root().children()) { double[] center = child.center().toArray(); if (center[0] > 2) { Assert.assertEquals(2, child.size()); - Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12); + Assert.assertArrayEquals(new double[]{4.0, 0.0}, center, 1e-12); } else { Assert.assertEquals(1, child.size()); - Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12); + Assert.assertArrayEquals(new double[]{1.0, 0.0}, center, 1e-12); } } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java index 123f78da54e34..bf76719937772 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -17,34 +17,19 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - import static org.junit.Assert.assertEquals; +import org.junit.Test; + +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -public class JavaGaussianMixtureSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaGaussianMixture"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaGaussianMixtureSuite extends SharedSparkSession { @Test public void runGaussianMixture() { @@ -54,7 +39,7 @@ public void runGaussianMixture() { Vectors.dense(1.0, 4.0, 6.0) ); - JavaRDD data = sc.parallelize(points, 2); + JavaRDD data = jsc.parallelize(points, 2); GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) .run(data); assertEquals(model.gaussians().length, 2); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java index ad06676c72ac6..270e636f82117 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java @@ -17,33 +17,19 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; +import static org.junit.Assert.assertEquals; + import org.junit.Test; -import static org.junit.Assert.*; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -public class JavaKMeansSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaKMeans"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaKMeansSuite extends SharedSparkSession { @Test public void runKMeansUsingStaticMethods() { @@ -55,7 +41,7 @@ public void runKMeansUsingStaticMethods() { Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0); - JavaRDD data = sc.parallelize(points, 2); + JavaRDD data = jsc.parallelize(points, 2); KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL()); assertEquals(1, model.clusterCenters().length); assertEquals(expectedCenter, model.clusterCenters()[0]); @@ -74,7 +60,7 @@ public void runKMeansUsingConstructor() { Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0); - JavaRDD data = sc.parallelize(points, 2); + JavaRDD data = jsc.parallelize(points, 2); KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); assertEquals(1, model.clusterCenters().length); assertEquals(expectedCenter, model.clusterCenters()[0]); @@ -94,7 +80,7 @@ public void testPredictJavaRDD() { Vectors.dense(1.0, 3.0, 0.0), Vectors.dense(1.0, 4.0, 6.0) ); - JavaRDD data = sc.parallelize(points, 2); + JavaRDD data = jsc.parallelize(points, 2); KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd()); JavaRDD predictions = model.predict(data); // Should be able to get the first prediction. diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index db19b309f65ae..08d6713ab2bc3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -17,49 +17,37 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import scala.Tuple2; import scala.Tuple3; -import org.junit.After; -import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; -import org.apache.spark.api.java.function.Function; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -public class JavaLDASuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLDA"); +public class JavaLDASuite extends SharedSparkSession { + @Override + public void setUp() throws IOException { + super.setUp(); ArrayList> tinyCorpus = new ArrayList<>(); for (int i = 0; i < LDASuite.tinyCorpus().length; i++) { - tinyCorpus.add(new Tuple2<>((Long)LDASuite.tinyCorpus()[i]._1(), - LDASuite.tinyCorpus()[i]._2())); + tinyCorpus.add(new Tuple2<>((Long) LDASuite.tinyCorpus()[i]._1(), + LDASuite.tinyCorpus()[i]._2())); } - JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); + JavaRDD> tmpCorpus = jsc.parallelize(tinyCorpus, 2); corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); } - @After - public void tearDown() { - sc.stop(); - sc = null; - } - @Test public void localLDAModel() { Matrix topics = LDASuite.tinyTopics(); @@ -95,7 +83,7 @@ public void distributedLDAModel() { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel) lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); @@ -124,7 +112,7 @@ public void distributedLDAModel() { public Boolean call(Tuple2 tuple2) { return Vectors.norm(tuple2._2(), 1.0) != 0.0; } - }); + }); assertEquals(topicDistributions.count(), nonEmptyCorpus.count()); // Check: javaTopTopicsPerDocuments @@ -179,7 +167,7 @@ public void onlineOptimizerCompatibility() { @Test public void localLdaMethods() { - JavaRDD> docs = sc.parallelize(toyData, 2); + JavaRDD> docs = jsc.parallelize(toyData, 2); JavaPairRDD pairedDocs = JavaPairRDD.fromJavaRDD(docs); // check: topicDistributions @@ -191,7 +179,7 @@ public void localLdaMethods() { // check: logLikelihood. ArrayList> docsSingleWord = new ArrayList<>(); docsSingleWord.add(new Tuple2<>(0L, Vectors.dense(1.0, 0.0, 0.0))); - JavaPairRDD single = JavaPairRDD.fromJavaRDD(sc.parallelize(docsSingleWord)); + JavaPairRDD single = JavaPairRDD.fromJavaRDD(jsc.parallelize(docsSingleWord)); double logLikelihood = toyModel.logLikelihood(single); } @@ -199,7 +187,7 @@ public void localLdaMethods() { private static int tinyVocabSize = LDASuite.tinyVocabSize(); private static Matrix tinyTopics = LDASuite.tinyTopics(); private static Tuple2[] tinyTopicDescription = - LDASuite.tinyTopicDescription(); + LDASuite.tinyTopicDescription(); private JavaPairRDD corpus; private LocalLDAModel toyModel = LDASuite.toyModel(); private ArrayList> toyData = LDASuite.javaToyData(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index 62edbd3a298c0..d41fc0e4dca96 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.clustering; -import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -27,8 +26,6 @@ import org.junit.Before; import org.junit.Test; -import static org.apache.spark.streaming.JavaTestUtils.*; - import org.apache.spark.SparkConf; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; @@ -36,8 +33,9 @@ import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; -public class JavaStreamingKMeansSuite implements Serializable { +public class JavaStreamingKMeansSuite { protected transient JavaStreamingContext ssc; diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java index fa4d334801ce4..e9d7e4fdbe8ce 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java @@ -17,41 +17,32 @@ package org.apache.spark.mllib.evaluation; -import java.io.Serializable; +import java.io.IOException; import java.util.Arrays; import java.util.List; import scala.Tuple2; import scala.Tuple2$; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -public class JavaRankingMetricsSuite implements Serializable { - private transient JavaSparkContext sc; +public class JavaRankingMetricsSuite extends SharedSparkSession { private transient JavaRDD, List>> predictionAndLabels; - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaRankingMetricsSuite"); - predictionAndLabels = sc.parallelize(Arrays.asList( + @Override + public void setUp() throws IOException { + super.setUp(); + predictionAndLabels = jsc.parallelize(Arrays.asList( Tuple2$.MODULE$.apply( Arrays.asList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Arrays.asList(1, 2, 3, 4, 5)), Tuple2$.MODULE$.apply( - Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), + Arrays.asList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Arrays.asList(1, 2, 3)), Tuple2$.MODULE$.apply( - Arrays.asList(1, 2, 3, 4, 5), Arrays.asList())), 2); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; + Arrays.asList(1, 2, 3, 4, 5), Arrays.asList())), 2); } @Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index 8a320afa4b13d..05128ea343420 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -17,39 +17,24 @@ package org.apache.spark.mllib.feature; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -public class JavaTfIdfSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaTfIdfSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaTfIdfSuite extends SharedSparkSession { @Test public void tfIdf() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Arrays.asList( + JavaRDD> documents = jsc.parallelize(Arrays.asList( Arrays.asList("this is a sentence".split(" ")), Arrays.asList("this is another sentence".split(" ")), Arrays.asList("this is still a sentence".split(" "))), 2); @@ -59,7 +44,7 @@ public void tfIdf() { JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); List localTfIdfs = tfIdfs.collect(); int indexOfThis = tf.indexOf("this"); - for (Vector v: localTfIdfs) { + for (Vector v : localTfIdfs) { Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); } } @@ -69,7 +54,7 @@ public void tfIdfMinimumDocumentFrequency() { // The tests are to check Java compatibility. HashingTF tf = new HashingTF(); @SuppressWarnings("unchecked") - JavaRDD> documents = sc.parallelize(Arrays.asList( + JavaRDD> documents = jsc.parallelize(Arrays.asList( Arrays.asList("this is a sentence".split(" ")), Arrays.asList("this is another sentence".split(" ")), Arrays.asList("this is still a sentence".split(" "))), 2); @@ -79,7 +64,7 @@ public void tfIdfMinimumDocumentFrequency() { JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); List localTfIdfs = tfIdfs.collect(); int indexOfThis = tf.indexOf("this"); - for (Vector v: localTfIdfs) { + for (Vector v : localTfIdfs) { Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java index e13ed07e283dd..3e3abddbee638 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -17,34 +17,20 @@ package org.apache.spark.mllib.feature; -import java.io.Serializable; import java.util.Arrays; import java.util.List; +import com.google.common.base.Strings; + import scala.Tuple2; -import com.google.common.base.Strings; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; - -public class JavaWord2VecSuite implements Serializable { - private transient JavaSparkContext sc; - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaWord2VecSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaWord2VecSuite extends SharedSparkSession { @Test @SuppressWarnings("unchecked") @@ -53,7 +39,7 @@ public void word2Vec() { String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); List words = Arrays.asList(sentence.split(" ")); List> localDoc = Arrays.asList(words, words); - JavaRDD> doc = sc.parallelize(localDoc); + JavaRDD> doc = jsc.parallelize(localDoc); Word2Vec word2vec = new Word2Vec() .setVectorSize(10) .setSeed(42L); diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java index 2bef7a8609757..3451e0773759b 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -16,42 +16,26 @@ */ package org.apache.spark.mllib.fpm; -import java.io.Serializable; import java.util.Arrays; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; -public class JavaAssociationRulesSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaFPGrowth"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaAssociationRulesSuite extends SharedSparkSession { @Test public void runAssociationRules() { @SuppressWarnings("unchecked") - JavaRDD> freqItemsets = sc.parallelize(Arrays.asList( - new FreqItemset(new String[] {"a"}, 15L), - new FreqItemset(new String[] {"b"}, 35L), - new FreqItemset(new String[] {"a", "b"}, 12L) + JavaRDD> freqItemsets = jsc.parallelize(Arrays.asList( + new FreqItemset(new String[]{"a"}, 15L), + new FreqItemset(new String[]{"b"}, 35L), + new FreqItemset(new String[]{"a", "b"}, 12L) )); JavaRDD> results = (new AssociationRules()).run(freqItemsets); } } - diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 916fff14a7214..46e9dd8b59828 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -18,38 +18,24 @@ package org.apache.spark.mllib.fpm; import java.io.File; -import java.io.Serializable; import java.util.Arrays; import java.util.List; -import org.junit.After; -import org.junit.Before; +import static org.junit.Assert.assertEquals; + import org.junit.Test; -import static org.junit.Assert.*; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.util.Utils; -public class JavaFPGrowthSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaFPGrowth"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaFPGrowthSuite extends SharedSparkSession { @Test public void runFPGrowth() { @SuppressWarnings("unchecked") - JavaRDD> rdd = sc.parallelize(Arrays.asList( + JavaRDD> rdd = jsc.parallelize(Arrays.asList( Arrays.asList("r z h k p".split(" ")), Arrays.asList("z y x w v u t s".split(" ")), Arrays.asList("s x o n r".split(" ")), @@ -65,7 +51,7 @@ public void runFPGrowth() { List> freqItemsets = model.freqItemsets().toJavaRDD().collect(); assertEquals(18, freqItemsets.size()); - for (FPGrowth.FreqItemset itemset: freqItemsets) { + for (FPGrowth.FreqItemset itemset : freqItemsets) { // Test return types. List items = itemset.javaItems(); long freq = itemset.freq(); @@ -76,7 +62,7 @@ public void runFPGrowth() { public void runFPGrowthSaveLoad() { @SuppressWarnings("unchecked") - JavaRDD> rdd = sc.parallelize(Arrays.asList( + JavaRDD> rdd = jsc.parallelize(Arrays.asList( Arrays.asList("r z h k p".split(" ")), Arrays.asList("z y x w v u t s".split(" ")), Arrays.asList("s x o n r".split(" ")), @@ -94,15 +80,15 @@ public void runFPGrowthSaveLoad() { String outputPath = tempDir.getPath(); try { - model.save(sc.sc(), outputPath); + model.save(spark.sparkContext(), outputPath); @SuppressWarnings("unchecked") FPGrowthModel newModel = - (FPGrowthModel) FPGrowthModel.load(sc.sc(), outputPath); + (FPGrowthModel) FPGrowthModel.load(spark.sparkContext(), outputPath); List> freqItemsets = newModel.freqItemsets().toJavaRDD() .collect(); assertEquals(18, freqItemsets.size()); - for (FPGrowth.FreqItemset itemset: freqItemsets) { + for (FPGrowth.FreqItemset itemset : freqItemsets) { // Test return types. List items = itemset.javaItems(); long freq = itemset.freq(); diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java index 8a67793abc142..32d3141149a74 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java @@ -21,33 +21,19 @@ import java.util.Arrays; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; import org.apache.spark.util.Utils; -public class JavaPrefixSpanSuite { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaPrefixSpan"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaPrefixSpanSuite extends SharedSparkSession { @Test public void runPrefixSpan() { - JavaRDD>> sequences = sc.parallelize(Arrays.asList( + JavaRDD>> sequences = jsc.parallelize(Arrays.asList( Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), @@ -61,7 +47,7 @@ public void runPrefixSpan() { List> localFreqSeqs = freqSeqs.collect(); Assert.assertEquals(5, localFreqSeqs.size()); // Check that each frequent sequence could be materialized. - for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) { + for (PrefixSpan.FreqSequence freqSeq : localFreqSeqs) { List> seq = freqSeq.javaSequence(); long freq = freqSeq.freq(); } @@ -69,7 +55,7 @@ public void runPrefixSpan() { @Test public void runPrefixSpanSaveLoad() { - JavaRDD>> sequences = sc.parallelize(Arrays.asList( + JavaRDD>> sequences = jsc.parallelize(Arrays.asList( Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), @@ -85,13 +71,15 @@ public void runPrefixSpanSaveLoad() { String outputPath = tempDir.getPath(); try { - model.save(sc.sc(), outputPath); - PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath); + model.save(spark.sparkContext(), outputPath); + @SuppressWarnings("unchecked") + PrefixSpanModel newModel = + (PrefixSpanModel) PrefixSpanModel.load(spark.sparkContext(), outputPath); JavaRDD> freqSeqs = newModel.freqSequences().toJavaRDD(); List> localFreqSeqs = freqSeqs.collect(); Assert.assertEquals(5, localFreqSeqs.size()); // Check that each frequent sequence could be materialized. - for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) { + for (PrefixSpan.FreqSequence freqSeq : localFreqSeqs) { List> seq = freqSeq.javaSequence(); long freq = freqSeq.freq(); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java index 8beea102efd01..f427846b9ad10 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaMatricesSuite.java @@ -17,147 +17,148 @@ package org.apache.spark.mllib.linalg; -import static org.junit.Assert.*; -import org.junit.Test; - -import java.io.Serializable; import java.util.Random; -public class JavaMatricesSuite implements Serializable { - - @Test - public void randMatrixConstruction() { - Random rng = new Random(24); - Matrix r = Matrices.rand(3, 4, rng); - rng.setSeed(24); - DenseMatrix dr = DenseMatrix.rand(3, 4, rng); - assertArrayEquals(r.toArray(), dr.toArray(), 0.0); - - rng.setSeed(24); - Matrix rn = Matrices.randn(3, 4, rng); - rng.setSeed(24); - DenseMatrix drn = DenseMatrix.randn(3, 4, rng); - assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); - - rng.setSeed(24); - Matrix s = Matrices.sprand(3, 4, 0.5, rng); - rng.setSeed(24); - SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng); - assertArrayEquals(s.toArray(), sr.toArray(), 0.0); - - rng.setSeed(24); - Matrix sn = Matrices.sprandn(3, 4, 0.5, rng); - rng.setSeed(24); - SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng); - assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); - } - - @Test - public void identityMatrixConstruction() { - Matrix r = Matrices.eye(2); - DenseMatrix dr = DenseMatrix.eye(2); - SparseMatrix sr = SparseMatrix.speye(2); - assertArrayEquals(r.toArray(), dr.toArray(), 0.0); - assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); - assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); - } - - @Test - public void diagonalMatrixConstruction() { - Vector v = Vectors.dense(1.0, 0.0, 2.0); - Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); - - Matrix m = Matrices.diag(v); - Matrix sm = Matrices.diag(sv); - DenseMatrix d = DenseMatrix.diag(v); - DenseMatrix sd = DenseMatrix.diag(sv); - SparseMatrix s = SparseMatrix.spdiag(v); - SparseMatrix ss = SparseMatrix.spdiag(sv); - - assertArrayEquals(m.toArray(), sm.toArray(), 0.0); - assertArrayEquals(d.toArray(), sm.toArray(), 0.0); - assertArrayEquals(d.toArray(), sd.toArray(), 0.0); - assertArrayEquals(sd.toArray(), s.toArray(), 0.0); - assertArrayEquals(s.toArray(), ss.toArray(), 0.0); - assertArrayEquals(s.values(), ss.values(), 0.0); - assertEquals(2, s.values().length); - assertEquals(2, ss.values().length); - assertEquals(4, s.colPtrs().length); - assertEquals(4, ss.colPtrs().length); - } - - @Test - public void zerosMatrixConstruction() { - Matrix z = Matrices.zeros(2, 2); - Matrix one = Matrices.ones(2, 2); - DenseMatrix dz = DenseMatrix.zeros(2, 2); - DenseMatrix done = DenseMatrix.ones(2, 2); - - assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); - assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); - assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); - assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); - } - - @Test - public void sparseDenseConversion() { - int m = 3; - int n = 2; - double[] values = new double[]{1.0, 2.0, 4.0, 5.0}; - double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0}; - int[] colPtrs = new int[]{0, 2, 4}; - int[] rowIndices = new int[]{0, 1, 1, 2}; - - SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values); - DenseMatrix deMat1 = new DenseMatrix(m, n, allValues); - - SparseMatrix spMat2 = deMat1.toSparse(); - DenseMatrix deMat2 = spMat1.toDense(); - - assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0); - assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0); - } - - @Test - public void concatenateMatrices() { - int m = 3; - int n = 2; - - Random rng = new Random(42); - SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng); - rng.setSeed(42); - DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng); - Matrix deMat2 = Matrices.eye(3); - Matrix spMat2 = Matrices.speye(3); - Matrix deMat3 = Matrices.eye(2); - Matrix spMat3 = Matrices.speye(2); - - Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); - Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); - Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); - Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); - - assertEquals(3, deHorz1.numRows()); - assertEquals(3, deHorz2.numRows()); - assertEquals(3, deHorz3.numRows()); - assertEquals(3, spHorz.numRows()); - assertEquals(5, deHorz1.numCols()); - assertEquals(5, deHorz2.numCols()); - assertEquals(5, deHorz3.numCols()); - assertEquals(5, spHorz.numCols()); - - Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); - Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); - Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); - Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); - - assertEquals(5, deVert1.numRows()); - assertEquals(5, deVert2.numRows()); - assertEquals(5, deVert3.numRows()); - assertEquals(5, spVert.numRows()); - assertEquals(2, deVert1.numCols()); - assertEquals(2, deVert2.numCols()); - assertEquals(2, deVert3.numCols()); - assertEquals(2, spVert.numCols()); - } +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +public class JavaMatricesSuite { + + @Test + public void randMatrixConstruction() { + Random rng = new Random(24); + Matrix r = Matrices.rand(3, 4, rng); + rng.setSeed(24); + DenseMatrix dr = DenseMatrix.rand(3, 4, rng); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + + rng.setSeed(24); + Matrix rn = Matrices.randn(3, 4, rng); + rng.setSeed(24); + DenseMatrix drn = DenseMatrix.randn(3, 4, rng); + assertArrayEquals(rn.toArray(), drn.toArray(), 0.0); + + rng.setSeed(24); + Matrix s = Matrices.sprand(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix sr = SparseMatrix.sprand(3, 4, 0.5, rng); + assertArrayEquals(s.toArray(), sr.toArray(), 0.0); + + rng.setSeed(24); + Matrix sn = Matrices.sprandn(3, 4, 0.5, rng); + rng.setSeed(24); + SparseMatrix srn = SparseMatrix.sprandn(3, 4, 0.5, rng); + assertArrayEquals(sn.toArray(), srn.toArray(), 0.0); + } + + @Test + public void identityMatrixConstruction() { + Matrix r = Matrices.eye(2); + DenseMatrix dr = DenseMatrix.eye(2); + SparseMatrix sr = SparseMatrix.speye(2); + assertArrayEquals(r.toArray(), dr.toArray(), 0.0); + assertArrayEquals(sr.toArray(), dr.toArray(), 0.0); + assertArrayEquals(r.toArray(), new double[]{1.0, 0.0, 0.0, 1.0}, 0.0); + } + + @Test + public void diagonalMatrixConstruction() { + Vector v = Vectors.dense(1.0, 0.0, 2.0); + Vector sv = Vectors.sparse(3, new int[]{0, 2}, new double[]{1.0, 2.0}); + + Matrix m = Matrices.diag(v); + Matrix sm = Matrices.diag(sv); + DenseMatrix d = DenseMatrix.diag(v); + DenseMatrix sd = DenseMatrix.diag(sv); + SparseMatrix s = SparseMatrix.spdiag(v); + SparseMatrix ss = SparseMatrix.spdiag(sv); + + assertArrayEquals(m.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sm.toArray(), 0.0); + assertArrayEquals(d.toArray(), sd.toArray(), 0.0); + assertArrayEquals(sd.toArray(), s.toArray(), 0.0); + assertArrayEquals(s.toArray(), ss.toArray(), 0.0); + assertArrayEquals(s.values(), ss.values(), 0.0); + assertEquals(2, s.values().length); + assertEquals(2, ss.values().length); + assertEquals(4, s.colPtrs().length); + assertEquals(4, ss.colPtrs().length); + } + + @Test + public void zerosMatrixConstruction() { + Matrix z = Matrices.zeros(2, 2); + Matrix one = Matrices.ones(2, 2); + DenseMatrix dz = DenseMatrix.zeros(2, 2); + DenseMatrix done = DenseMatrix.ones(2, 2); + + assertArrayEquals(z.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(dz.toArray(), new double[]{0.0, 0.0, 0.0, 0.0}, 0.0); + assertArrayEquals(one.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + assertArrayEquals(done.toArray(), new double[]{1.0, 1.0, 1.0, 1.0}, 0.0); + } + + @Test + public void sparseDenseConversion() { + int m = 3; + int n = 2; + double[] values = new double[]{1.0, 2.0, 4.0, 5.0}; + double[] allValues = new double[]{1.0, 2.0, 0.0, 0.0, 4.0, 5.0}; + int[] colPtrs = new int[]{0, 2, 4}; + int[] rowIndices = new int[]{0, 1, 1, 2}; + + SparseMatrix spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values); + DenseMatrix deMat1 = new DenseMatrix(m, n, allValues); + + SparseMatrix spMat2 = deMat1.toSparse(); + DenseMatrix deMat2 = spMat1.toDense(); + + assertArrayEquals(spMat1.toArray(), spMat2.toArray(), 0.0); + assertArrayEquals(deMat1.toArray(), deMat2.toArray(), 0.0); + } + + @Test + public void concatenateMatrices() { + int m = 3; + int n = 2; + + Random rng = new Random(42); + SparseMatrix spMat1 = SparseMatrix.sprand(m, n, 0.5, rng); + rng.setSeed(42); + DenseMatrix deMat1 = DenseMatrix.rand(m, n, rng); + Matrix deMat2 = Matrices.eye(3); + Matrix spMat2 = Matrices.speye(3); + Matrix deMat3 = Matrices.eye(2); + Matrix spMat3 = Matrices.speye(2); + + Matrix spHorz = Matrices.horzcat(new Matrix[]{spMat1, spMat2}); + Matrix deHorz1 = Matrices.horzcat(new Matrix[]{deMat1, deMat2}); + Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2}); + Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2}); + + assertEquals(3, deHorz1.numRows()); + assertEquals(3, deHorz2.numRows()); + assertEquals(3, deHorz3.numRows()); + assertEquals(3, spHorz.numRows()); + assertEquals(5, deHorz1.numCols()); + assertEquals(5, deHorz2.numCols()); + assertEquals(5, deHorz3.numCols()); + assertEquals(5, spHorz.numCols()); + + Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3}); + Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3}); + Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3}); + Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3}); + + assertEquals(5, deVert1.numRows()); + assertEquals(5, deVert2.numRows()); + assertEquals(5, deVert3.numRows()); + assertEquals(5, spVert.numRows()); + assertEquals(2, deVert1.numCols()); + assertEquals(2, deVert2.numCols()); + assertEquals(2, deVert3.numCols()); + assertEquals(2, spVert.numCols()); + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 4ba8e543a9a6b..f67f555e418a7 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -17,15 +17,15 @@ package org.apache.spark.mllib.linalg; -import java.io.Serializable; import java.util.Arrays; +import static org.junit.Assert.assertArrayEquals; + import scala.Tuple2; import org.junit.Test; -import static org.junit.Assert.*; -public class JavaVectorsSuite implements Serializable { +public class JavaVectorsSuite { @Test public void denseArrayConstruction() { @@ -37,8 +37,8 @@ public void denseArrayConstruction() { public void sparseArrayConstruction() { @SuppressWarnings("unchecked") Vector v = Vectors.sparse(3, Arrays.asList( - new Tuple2<>(0, 2.0), - new Tuple2<>(2, 3.0))); + new Tuple2<>(0, 2.0), + new Tuple2<>(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaRowMatrixSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaRowMatrixSuite.java new file mode 100644 index 0000000000000..c01af405491b7 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaRowMatrixSuite.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed; + +import java.util.Arrays; + +import org.junit.Test; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.QRDecomposition; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaRowMatrixSuite extends SharedSparkSession { + + @Test + public void rowMatrixQRDecomposition() { + Vector v1 = Vectors.dense(1.0, 10.0, 100.0); + Vector v2 = Vectors.dense(2.0, 20.0, 200.0); + Vector v3 = Vectors.dense(3.0, 30.0, 300.0); + + JavaRDD rows = jsc.parallelize(Arrays.asList(v1, v2, v3), 1); + RowMatrix mat = new RowMatrix(rows.rdd()); + + QRDecomposition result = mat.tallSkinnyQR(true); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java index be58691f4d87e..6d114024c31be 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -20,40 +20,26 @@ import java.io.Serializable; import java.util.Arrays; -import org.apache.spark.api.java.JavaRDD; import org.junit.Assert; -import org.junit.After; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaDoubleRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.mllib.linalg.Vector; import static org.apache.spark.mllib.random.RandomRDDs.*; -public class JavaRandomRDDsSuite { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaRandomRDDsSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaRandomRDDsSuite extends SharedSparkSession { @Test public void testUniformRDD() { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); - JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); - JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m); + JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p); + JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -63,10 +49,10 @@ public void testNormalRDD() { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); - JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); - JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m); + JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p); + JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -78,10 +64,10 @@ public void testLNormalRDD() { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = logNormalJavaRDD(sc, mean, std, m); - JavaDoubleRDD rdd2 = logNormalJavaRDD(sc, mean, std, m, p); - JavaDoubleRDD rdd3 = logNormalJavaRDD(sc, mean, std, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m); + JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p); + JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -92,10 +78,10 @@ public void testPoissonRDD() { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); - JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); - JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m); + JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p); + JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -106,10 +92,10 @@ public void testExponentialRDD() { long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = exponentialJavaRDD(sc, mean, m); - JavaDoubleRDD rdd2 = exponentialJavaRDD(sc, mean, m, p); - JavaDoubleRDD rdd3 = exponentialJavaRDD(sc, mean, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m); + JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p); + JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -117,14 +103,14 @@ public void testExponentialRDD() { @Test public void testGammaRDD() { double shape = 1.0; - double scale = 2.0; + double jscale = 2.0; long m = 1000L; int p = 2; long seed = 1L; - JavaDoubleRDD rdd1 = gammaJavaRDD(sc, shape, scale, m); - JavaDoubleRDD rdd2 = gammaJavaRDD(sc, shape, scale, m, p); - JavaDoubleRDD rdd3 = gammaJavaRDD(sc, shape, scale, m, p, seed); - for (JavaDoubleRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m); + JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p); + JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed); + for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); } } @@ -137,10 +123,10 @@ public void testUniformVectorRDD() { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = uniformJavaVectorRDD(sc, m, n); - JavaRDD rdd2 = uniformJavaVectorRDD(sc, m, n, p); - JavaRDD rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = uniformJavaVectorRDD(jsc, m, n); + JavaRDD rdd2 = uniformJavaVectorRDD(jsc, m, n, p); + JavaRDD rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -153,10 +139,10 @@ public void testNormalVectorRDD() { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = normalJavaVectorRDD(sc, m, n); - JavaRDD rdd2 = normalJavaVectorRDD(sc, m, n, p); - JavaRDD rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = normalJavaVectorRDD(jsc, m, n); + JavaRDD rdd2 = normalJavaVectorRDD(jsc, m, n, p); + JavaRDD rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -171,10 +157,10 @@ public void testLogNormalVectorRDD() { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = logNormalJavaVectorRDD(sc, mean, std, m, n); - JavaRDD rdd2 = logNormalJavaVectorRDD(sc, mean, std, m, n, p); - JavaRDD rdd3 = logNormalJavaVectorRDD(sc, mean, std, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n); + JavaRDD rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p); + JavaRDD rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -188,10 +174,10 @@ public void testPoissonVectorRDD() { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = poissonJavaVectorRDD(sc, mean, m, n); - JavaRDD rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); - JavaRDD rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = poissonJavaVectorRDD(jsc, mean, m, n); + JavaRDD rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p); + JavaRDD rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -205,10 +191,10 @@ public void testExponentialVectorRDD() { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = exponentialJavaVectorRDD(sc, mean, m, n); - JavaRDD rdd2 = exponentialJavaVectorRDD(sc, mean, m, n, p); - JavaRDD rdd3 = exponentialJavaVectorRDD(sc, mean, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n); + JavaRDD rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p); + JavaRDD rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -218,15 +204,15 @@ public void testExponentialVectorRDD() { @SuppressWarnings("unchecked") public void testGammaVectorRDD() { double shape = 1.0; - double scale = 2.0; + double jscale = 2.0; long m = 100L; int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = gammaJavaVectorRDD(sc, shape, scale, m, n); - JavaRDD rdd2 = gammaJavaVectorRDD(sc, shape, scale, m, n, p); - JavaRDD rdd3 = gammaJavaVectorRDD(sc, shape, scale, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n); + JavaRDD rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p); + JavaRDD rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -238,10 +224,10 @@ public void testArbitrary() { long seed = 1L; int numPartitions = 0; StringGenerator gen = new StringGenerator(); - JavaRDD rdd1 = randomJavaRDD(sc, gen, size); - JavaRDD rdd2 = randomJavaRDD(sc, gen, size, numPartitions); - JavaRDD rdd3 = randomJavaRDD(sc, gen, size, numPartitions, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = randomJavaRDD(jsc, gen, size); + JavaRDD rdd2 = randomJavaRDD(jsc, gen, size, numPartitions); + JavaRDD rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(size, rdd.count()); Assert.assertEquals(2, rdd.first().length()); } @@ -255,10 +241,10 @@ public void testRandomVectorRDD() { int n = 10; int p = 2; long seed = 1L; - JavaRDD rdd1 = randomJavaVectorRDD(sc, generator, m, n); - JavaRDD rdd2 = randomJavaVectorRDD(sc, generator, m, n, p); - JavaRDD rdd3 = randomJavaVectorRDD(sc, generator, m, n, p, seed); - for (JavaRDD rdd: Arrays.asList(rdd1, rdd2, rdd3)) { + JavaRDD rdd1 = randomJavaVectorRDD(jsc, generator, m, n); + JavaRDD rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p); + JavaRDD rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed); + for (JavaRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) { Assert.assertEquals(m, rdd.count()); Assert.assertEquals(n, rdd.first().size()); } @@ -271,10 +257,12 @@ class StringGenerator implements RandomDataGenerator, Serializable { public String nextValue() { return "42"; } + @Override public StringGenerator copy() { return new StringGenerator(); } + @Override public void setSeed(long seed) { } diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index d0bf7f556dcc0..363ab42546d11 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -17,55 +17,40 @@ package org.apache.spark.mllib.recommendation; -import java.io.Serializable; import java.util.ArrayList; import java.util.List; import scala.Tuple2; import scala.Tuple3; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -public class JavaALSSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaALS"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaALSSuite extends SharedSparkSession { private void validatePrediction( - MatrixFactorizationModel model, - int users, - int products, - double[] trueRatings, - double matchThreshold, - boolean implicitPrefs, - double[] truePrefs) { + MatrixFactorizationModel model, + int users, + int products, + double[] trueRatings, + double matchThreshold, + boolean implicitPrefs, + double[] truePrefs) { List> localUsersProducts = new ArrayList<>(users * products); - for (int u=0; u < users; ++u) { - for (int p=0; p < products; ++p) { + for (int u = 0; u < users; ++u) { + for (int p = 0; p < products; ++p) { localUsersProducts.add(new Tuple2<>(u, p)); } } - JavaPairRDD usersProducts = sc.parallelizePairs(localUsersProducts); + JavaPairRDD usersProducts = jsc.parallelizePairs(localUsersProducts); List predictedRatings = model.predict(usersProducts).collect(); Assert.assertEquals(users * products, predictedRatings.size()); if (!implicitPrefs) { - for (Rating r: predictedRatings) { + for (Rating r : predictedRatings) { double prediction = r.rating(); double correct = trueRatings[r.product() * users + r.user()]; Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", @@ -76,7 +61,7 @@ private void validatePrediction( // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; - for (Rating r: predictedRatings) { + for (Rating r : predictedRatings) { double prediction = r.rating(); double truePref = truePrefs[r.product() * users + r.user()]; double confidence = 1.0 + @@ -98,9 +83,9 @@ public void runALSUsingStaticMethods() { int users = 50; int products = 100; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @@ -112,9 +97,9 @@ public void runALSUsingConstructor() { int users = 100; int products = 200; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, false, false); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) @@ -129,9 +114,9 @@ public void runImplicitALSUsingStaticMethods() { int users = 80; int products = 160; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @@ -143,9 +128,9 @@ public void runImplicitALSUsingConstructor() { int users = 100; int products = 200; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, false); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) @@ -161,9 +146,9 @@ public void runImplicitALSWithNegativeWeight() { int users = 80; int products = 160; Tuple3, double[], double[]> testData = - ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true); + ALSSuite.generateRatingsAsJava(users, products, features, 0.7, true, true); - JavaRDD data = sc.parallelize(testData._1()); + JavaRDD data = jsc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) .setImplicitPrefs(true) @@ -179,8 +164,8 @@ public void runRecommend() { int users = 200; int products = 50; List testData = ALSSuite.generateRatingsAsJava( - users, products, features, 0.7, true, false)._1(); - JavaRDD data = sc.parallelize(testData); + users, products, features, 0.7, true, false)._1(); + JavaRDD data = jsc.parallelize(testData); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) .setImplicitPrefs(true) @@ -193,7 +178,7 @@ public void runRecommend() { private static void validateRecommendations(Rating[] recommendations, int howMany) { Assert.assertEquals(howMany, recommendations.length); for (int i = 1; i < recommendations.length; i++) { - Assert.assertTrue(recommendations[i-1].rating() >= recommendations[i].rating()); + Assert.assertTrue(recommendations[i - 1].rating() >= recommendations[i].rating()); } Assert.assertTrue(recommendations[0].rating() > 0.7); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index 3db9b39e740e7..dbd4cbfd2b746 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -17,30 +17,26 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import scala.Tuple3; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -public class JavaIsotonicRegressionSuite implements Serializable { - private transient JavaSparkContext sc; +public class JavaIsotonicRegressionSuite extends SharedSparkSession { private static List> generateIsotonicInput(double[] labels) { List> input = new ArrayList<>(labels.length); for (int i = 1; i <= labels.length; i++) { - input.add(new Tuple3<>(labels[i-1], (double) i, 1.0)); + input.add(new Tuple3<>(labels[i - 1], (double) i, 1.0)); } return input; @@ -48,29 +44,18 @@ private static List> generateIsotonicInput(double private IsotonicRegressionModel runIsotonicRegression(double[] labels) { JavaRDD> trainRDD = - sc.parallelize(generateIsotonicInput(labels), 2).cache(); + jsc.parallelize(generateIsotonicInput(labels), 2).cache(); return new IsotonicRegression().run(trainRDD); } - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } - @Test public void testIsotonicRegressionJavaRDD() { IsotonicRegressionModel model = runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); Assert.assertArrayEquals( - new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14); + new double[]{1, 2, 7.0 / 3, 7.0 / 3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14); } @Test @@ -78,7 +63,7 @@ public void testIsotonicRegressionPredictionsJavaRDD() { IsotonicRegressionModel model = runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); - JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); + JavaDoubleRDD testRDD = jsc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); List predictions = model.predict(testRDD).collect(); Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java index 8950b48888b74..1458cc72bc17f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java @@ -17,35 +17,20 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.util.LinearDataGenerator; -public class JavaLassoSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLassoSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaLassoSuite extends SharedSparkSession { int validatePrediction(List validationData, LassoModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); // A prediction is off if the prediction is more than 0.5 away from expected value. if (Math.abs(prediction - point.label()) <= 0.5) { @@ -61,15 +46,15 @@ public void runLassoUsingConstructor() { double A = 0.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42, 0.1), 2).cache(); + JavaRDD testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, + weights, nPoints, 42, 0.1), 2).cache(); List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LassoWithSGD lassoSGDImpl = new LassoWithSGD(); lassoSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.01) - .setNumIterations(20); + .setRegParam(0.01) + .setNumIterations(20); LassoModel model = lassoSGDImpl.run(testRDD.rdd()); int numAccurate = validatePrediction(validationData, model); @@ -82,10 +67,10 @@ public void runLassoUsingStaticMethods() { double A = 0.0; double[] weights = {-1.5, 1.0e-2}; - JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, - weights, nPoints, 42, 0.1), 2).cache(); + JavaRDD testRDD = jsc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, + weights, nPoints, 42, 0.1), 2).cache(); List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java index 24c4c20d9af18..a46b1321b3ca2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLinearRegressionSuite.java @@ -17,42 +17,27 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.function.Function; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.util.LinearDataGenerator; -public class JavaLinearRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaLinearRegressionSuite extends SharedSparkSession { int validatePrediction(List validationData, LinearRegressionModel model) { int numAccurate = 0; - for (LabeledPoint point: validationData) { - Double prediction = model.predict(point.features()); - // A prediction is off if the prediction is more than 0.5 away from expected value. - if (Math.abs(prediction - point.label()) <= 0.5) { - numAccurate++; - } + for (LabeledPoint point : validationData) { + Double prediction = model.predict(point.features()); + // A prediction is off if the prediction is more than 0.5 away from expected value. + if (Math.abs(prediction - point.label()) <= 0.5) { + numAccurate++; + } } return numAccurate; } @@ -63,10 +48,10 @@ public void runLinearRegressionUsingConstructor() { double A = 3.0; double[] weights = {10, 10}; - JavaRDD testRDD = sc.parallelize( - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + JavaRDD testRDD = jsc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); linSGDImpl.setIntercept(true); @@ -82,10 +67,10 @@ public void runLinearRegressionUsingStaticMethods() { double A = 0.0; double[] weights = {10, 10}; - JavaRDD testRDD = sc.parallelize( - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); + JavaRDD testRDD = jsc.parallelize( + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); List validationData = - LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); + LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1); LinearRegressionModel model = LinearRegressionWithSGD.train(testRDD.rdd(), 100); @@ -98,7 +83,7 @@ public void testPredictJavaRDD() { int nPoints = 100; double A = 0.0; double[] weights = {10, 10}; - JavaRDD testRDD = sc.parallelize( + JavaRDD testRDD = jsc.parallelize( LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache(); LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(); LinearRegressionModel model = linSGDImpl.run(testRDD.rdd()); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java index c56db703ea0b4..cb00977412345 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -17,37 +17,22 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.List; import java.util.Random; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.util.LinearDataGenerator; -public class JavaRidgeRegressionSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaRidgeRegressionSuite extends SharedSparkSession { private static double predictionError(List validationData, RidgeRegressionModel model) { double errorSum = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); errorSum += (prediction - point.label()) * (prediction - point.label()); } @@ -68,9 +53,9 @@ private static List generateRidgeData(int numPoints, int numFeatur public void runRidgeRegressionUsingConstructor() { int numExamples = 50; int numFeatures = 20; - List data = generateRidgeData(2*numExamples, numFeatures, 10.0); + List data = generateRidgeData(2 * numExamples, numFeatures, 10.0); - JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples)); + JavaRDD testRDD = jsc.parallelize(data.subList(0, numExamples)); List validationData = data.subList(numExamples, 2 * numExamples); RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); @@ -94,7 +79,7 @@ public void runRidgeRegressionUsingStaticMethods() { int numFeatures = 20; List data = generateRidgeData(2 * numExamples, numFeatures, 10.0); - JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples)); + JavaRDD testRDD = jsc.parallelize(data.subList(0, numExamples)); List validationData = data.subList(numExamples, 2 * numExamples); RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index ea0ccd7448986..ab554475d59a1 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.regression; -import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -36,7 +35,7 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext; import static org.apache.spark.streaming.JavaTestUtils.*; -public class JavaStreamingLinearRegressionSuite implements Serializable { +public class JavaStreamingLinearRegressionSuite { protected transient JavaStreamingContext ssc; diff --git a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java index 5f1d5987e8098..1abaa39eadc22 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/stat/JavaStatisticsSuite.java @@ -17,20 +17,17 @@ package org.apache.spark.mllib.stat; -import java.io.Serializable; import java.util.Arrays; import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Test; - -import static org.apache.spark.streaming.JavaTestUtils.*; import static org.junit.Assert.assertEquals; import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; @@ -38,36 +35,42 @@ import org.apache.spark.mllib.stat.test.ChiSqTestResult; import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult; import org.apache.spark.mllib.stat.test.StreamingTest; +import org.apache.spark.sql.SparkSession; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; +import static org.apache.spark.streaming.JavaTestUtils.*; -public class JavaStatisticsSuite implements Serializable { - private transient JavaSparkContext sc; +public class JavaStatisticsSuite { + private transient SparkSession spark; + private transient JavaSparkContext jsc; private transient JavaStreamingContext ssc; @Before public void setUp() { SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("JavaStatistics") .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); - sc = new JavaSparkContext(conf); - ssc = new JavaStreamingContext(sc, new Duration(1000)); + spark = SparkSession.builder() + .master("local[2]") + .appName("JavaStatistics") + .config(conf) + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + ssc = new JavaStreamingContext(jsc, new Duration(1000)); ssc.checkpoint("checkpoint"); } @After public void tearDown() { + spark.stop(); ssc.stop(); - ssc = null; - sc = null; + spark = null; } @Test public void testCorr() { - JavaRDD x = sc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0)); - JavaRDD y = sc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3)); + JavaRDD x = jsc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + JavaRDD y = jsc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3)); Double corr1 = Statistics.corr(x, y); Double corr2 = Statistics.corr(x, y, "pearson"); @@ -77,7 +80,7 @@ public void testCorr() { @Test public void kolmogorovSmirnovTest() { - JavaDoubleRDD data = sc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0)); + JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0)); KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm"); KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest( data, "norm", 0.0, 1.0); @@ -85,7 +88,7 @@ public void kolmogorovSmirnovTest() { @Test public void chiSqTest() { - JavaRDD data = sc.parallelize(Arrays.asList( + JavaRDD data = jsc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)), new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)), new LabeledPoint(0.0, Vectors.dense(2.4, 8.1)))); diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 8dd29061daaad..1dcbbcaa0223c 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -17,41 +17,27 @@ package org.apache.spark.mllib.tree; -import java.io.Serializable; import java.util.HashMap; import java.util.List; -import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.Test; +import org.apache.spark.SharedSparkSession; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.configuration.Algo; import org.apache.spark.mllib.tree.configuration.Strategy; import org.apache.spark.mllib.tree.impurity.Gini; import org.apache.spark.mllib.tree.model.DecisionTreeModel; - -public class JavaDecisionTreeSuite implements Serializable { - private transient JavaSparkContext sc; - - @Before - public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeSuite"); - } - - @After - public void tearDown() { - sc.stop(); - sc = null; - } +public class JavaDecisionTreeSuite extends SharedSparkSession { int validatePrediction(List validationData, DecisionTreeModel model) { int numCorrect = 0; - for (LabeledPoint point: validationData) { + for (LabeledPoint point : validationData) { Double prediction = model.predict(point.features()); if (prediction == point.label()) { numCorrect++; @@ -63,7 +49,7 @@ int validatePrediction(List validationData, DecisionTreeModel mode @Test public void runDTUsingConstructor() { List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); - JavaRDD rdd = sc.parallelize(arr); + JavaRDD rdd = jsc.parallelize(arr); HashMap categoricalFeaturesInfo = new HashMap<>(); categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories @@ -71,7 +57,7 @@ public void runDTUsingConstructor() { int numClasses = 2; int maxBins = 100; Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, - maxBins, categoricalFeaturesInfo); + maxBins, categoricalFeaturesInfo); DecisionTree learner = new DecisionTree(strategy); DecisionTreeModel model = learner.run(rdd.rdd()); @@ -83,7 +69,7 @@ public void runDTUsingConstructor() { @Test public void runDTUsingStaticMethods() { List arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList(); - JavaRDD rdd = sc.parallelize(arr); + JavaRDD rdd = jsc.parallelize(arr); HashMap categoricalFeaturesInfo = new HashMap<>(); categoricalFeaturesInfo.put(1, 2); // feature 1 has 2 categories @@ -91,10 +77,18 @@ public void runDTUsingStaticMethods() { int numClasses = 2; int maxBins = 100; Strategy strategy = new Strategy(Algo.Classification(), Gini.instance(), maxDepth, numClasses, - maxBins, categoricalFeaturesInfo); + maxBins, categoricalFeaturesInfo); DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); + // java compatibility test + JavaRDD predictions = model.predict(rdd.map(new Function() { + @Override + public Vector call(LabeledPoint v1) { + return v1.features(); + } + })); + int numCorrect = validatePrediction(arr, model); Assert.assertTrue(numCorrect == rdd.count()); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java new file mode 100644 index 0000000000000..e271a0a77c782 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util; + +import java.util.Arrays; +import java.util.Collections; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.mllib.linalg.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaMLUtilsSuite extends SharedSparkSession { + + @Test + public void testConvertVectorColumnsToAndFromML() { + Vector x = Vectors.dense(2.0); + Dataset dataset = spark.createDataFrame( + Collections.singletonList(new LabeledPoint(1.0, x)), LabeledPoint.class + ).select("label", "features"); + Dataset newDataset1 = MLUtils.convertVectorColumnsToML(dataset); + Row new1 = newDataset1.first(); + Assert.assertEquals(RowFactory.create(1.0, x.asML()), new1); + Row new2 = MLUtils.convertVectorColumnsToML(dataset, "features").first(); + Assert.assertEquals(new1, new2); + Row old1 = MLUtils.convertVectorColumnsFromML(newDataset1).first(); + Assert.assertEquals(RowFactory.create(1.0, x), old1); + } + + @Test + public void testConvertMatrixColumnsToAndFromML() { + Matrix x = Matrices.dense(2, 1, new double[]{1.0, 2.0}); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new MatrixUDT(), false, Metadata.empty()) + }); + Dataset dataset = spark.createDataFrame( + Arrays.asList( + RowFactory.create(1.0, x)), + schema); + + Dataset newDataset1 = MLUtils.convertMatrixColumnsToML(dataset); + Row new1 = newDataset1.first(); + Assert.assertEquals(RowFactory.create(1.0, x.asML()), new1); + Row new2 = MLUtils.convertMatrixColumnsToML(dataset, "features").first(); + Assert.assertEquals(new1, new2); + Row old1 = MLUtils.convertMatrixColumnsFromML(newDataset1).first(); + Assert.assertEquals(RowFactory.create(1.0, x), old1); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 1de638f245635..3b490cdf56018 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -27,9 +27,9 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair} +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -183,7 +183,7 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("pipeline validateParams") { - val df = sqlContext.createDataFrame( + val df = spark.createDataFrame( Seq( (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala index dc91fc5f9e458..35586320cb82b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.ml.ann import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ - class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala index 04cc426c40b5e..f0c0183323c92 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.ann import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext class GradientSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala index 89afb94b0f963..4db5f03fb00b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.classification import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.ClassifierSuite.MockClassifier +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.Identifiable -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} @@ -32,7 +32,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("extractLabeledPoints") { def getTestData(labels: Seq[Double]): DataFrame = { val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - sqlContext.createDataFrame(data) + spark.createDataFrame(data) } val c = new MockClassifier @@ -72,7 +72,7 @@ class ClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("getNumClasses") { def getTestData(labels: Seq[Double]): DataFrame = { val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - sqlContext.createDataFrame(data) + spark.createDataFrame(data) } val c = new MockClassifier diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 29845b5554bf8..089d30abb5ef9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode} import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -44,17 +45,18 @@ class DecisionTreeClassifierSuite override def beforeAll() { super.beforeAll() categoricalDataPointsRDD = - sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()).map(_.asML) orderedLabeledPointsWithLabel0RDD = - sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()) + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()).map(_.asML) orderedLabeledPointsWithLabel1RDD = - sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()) + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()).map(_.asML) categoricalDataPointsForMulticlassRDD = - sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass()) + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass()).map(_.asML) continuousDataPointsForMulticlassRDD = - sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass()) + sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass()).map(_.asML) categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize( OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) + .map(_.asML) } test("params") { @@ -337,13 +339,13 @@ class DecisionTreeClassifierSuite test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier]( - dt, isClassification = true, sqlContext) { (expected, actual) => + dt, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } test("Fitting without numClasses in metadata") { - val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) val dt = new DecisionTreeClassifier().setMaxDepth(1) dt.fit(df) } @@ -395,7 +397,7 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { numClasses: Int): Unit = { val numFeatures = data.first().features.size val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses) - val oldTree = OldDecisionTree.train(data, oldStrategy) + val oldTree = OldDecisionTree.train(data.map(OldLabeledPoint.fromML), oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newTree = dt.fit(newData) // Use parent from newTree since this is not checked anyways. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 087e201234b14..8d588ccfd3545 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.classification import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -51,10 +52,13 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext override def beforeAll() { super.beforeAll() data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + .map(_.asML) trainData = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + .map(_.asML) validationData = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + .map(_.asML) } test("params") { @@ -106,7 +110,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext test("should support all NumericType labels and not support other types") { val gbt = new GBTClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier]( - gbt, isClassification = true, sqlContext) { (expected, actual) => + gbt, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } @@ -130,7 +134,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext */ test("Fitting without numClasses in metadata") { - val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) gbt.fit(df) } @@ -138,7 +142,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext test("extractLabeledPoints with bad data") { def getTestData(labels: Seq[Double]): DataFrame = { val data = labels.map { label: Double => LabeledPoint(label, Vectors.dense(0.0)) } - sqlContext.createDataFrame(data) + spark.createDataFrame(data) } val gbt = new GBTClassifier().setMaxDepth(1).setMaxIter(1) @@ -229,7 +233,7 @@ private object GBTClassifierSuite extends SparkFunSuite { val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) - val oldModel = oldGBT.run(data) + val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML)) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) // Use parent from newTree since this is not checked anyways. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 73e961dbbca27..27c872a82b371 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,20 +17,22 @@ package org.apache.spark.ml.classification +import scala.collection.JavaConverters._ import scala.language.existentials import scala.util.Random +import scala.util.control.Breaks._ import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.classification.LogisticRegressionSuite._ +import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Dataset, Row} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.LongType class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -42,7 +44,7 @@ class LogisticRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) + dataset = spark.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) binaryDataset = { val nPoints = 10000 @@ -54,7 +56,7 @@ class LogisticRegressionSuite generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, 42) - sqlContext.createDataFrame(sc.parallelize(testData, 4)) + spark.createDataFrame(sc.parallelize(testData, 4)) } } @@ -202,7 +204,7 @@ class LogisticRegressionSuite } test("logistic regression: Predictor, Classifier methods") { - val sqlContext = this.sqlContext + val spark = this.spark val lr = new LogisticRegression val model = lr.fit(dataset) @@ -256,6 +258,10 @@ class LogisticRegressionSuite assert(summarizer4.countInvalid === 2) assert(summarizer4.numClasses === 4) + val summarizer5 = new MultiClassSummarizer + assert(summarizer5.histogram.isEmpty) + assert(summarizer5.numClasses === 0) + // small map merges large one val summarizerA = summarizer1.merge(summarizer2) assert(summarizerA.hashCode() === summarizer2.hashCode()) @@ -807,6 +813,21 @@ class LogisticRegressionSuite summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect()) } + test("evaluate with labels that are not doubles") { + // Evaluate a test set with Label that is a numeric type other than Double + val lr = new LogisticRegression() + .setMaxIter(1) + .setRegParam(1.0) + val model = lr.fit(dataset) + val summary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary] + + val longLabelData = dataset.select(col(model.getLabelCol).cast(LongType), + col(model.getFeaturesCol)) + val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary] + + assert(summary.areaUnderROC ~== longSummary.areaUnderROC absTol 1E-10) + } + test("statistics on training data") { // Test that loss is monotonically decreasing. val lr = new LogisticRegression() @@ -864,8 +885,8 @@ class LogisticRegressionSuite } } - (sqlContext.createDataFrame(sc.parallelize(data1, 4)), - sqlContext.createDataFrame(sc.parallelize(data2, 4))) + (spark.createDataFrame(sc.parallelize(data1, 4)), + spark.createDataFrame(sc.parallelize(data2, 4))) } val trainer1a = (new LogisticRegression).setFitIntercept(true) @@ -938,7 +959,7 @@ class LogisticRegressionSuite test("should support all NumericType labels and not support other types") { val lr = new LogisticRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( - lr, isClassification = true, sqlContext) { (expected, actual) => + lr, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients.toArray === actual.coefficients.toArray) } @@ -963,4 +984,122 @@ object LogisticRegressionSuite { "standardization" -> false, "threshold" -> 0.6 ) + + def generateLogisticInputAsList( + offset: Double, + scale: Double, + nPoints: Int, + seed: Int): java.util.List[LabeledPoint] = { + generateLogisticInput(offset, scale, nPoints, seed).asJava + } + + // Generate input of the form Y = logistic(offset + scale*X) + def generateLogisticInput( + offset: Double, + scale: Double, + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) + + val y = (0 until nPoints).map { i => + val p = 1.0 / (1.0 + math.exp(-(offset + scale * x1(i)))) + if (rnd.nextDouble() < p) 1.0 else 0.0 + } + + val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i))))) + testData + } + + /** + * Generates `k` classes multinomial synthetic logistic input in `n` dimensional space given the + * model weights and mean/variance of the features. The synthetic data will be drawn from + * the probability distribution constructed by weights using the following formula. + * + * P(y = 0 | x) = 1 / norm + * P(y = 1 | x) = exp(x * w_1) / norm + * P(y = 2 | x) = exp(x * w_2) / norm + * ... + * P(y = k-1 | x) = exp(x * w_{k-1}) / norm + * where norm = 1 + exp(x * w_1) + exp(x * w_2) + ... + exp(x * w_{k-1}) + * + * @param weights matrix is flatten into a vector; as a result, the dimension of weights vector + * will be (k - 1) * (n + 1) if `addIntercept == true`, and + * if `addIntercept != true`, the dimension will be (k - 1) * n. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param addIntercept whether to add intercept. + * @param nPoints the number of instance of generated data. + * @param seed the seed for random generator. For consistent testing result, it will be fixed. + */ + def generateMultinomialLogisticInput( + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + addIntercept: Boolean, + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + + val xDim = xMean.length + val xWithInterceptsDim = if (addIntercept) xDim + 1 else xDim + val nClasses = weights.length / xWithInterceptsDim + 1 + + val x = Array.fill[Vector](nPoints)(Vectors.dense(Array.fill[Double](xDim)(rnd.nextGaussian()))) + + x.foreach { vector => + // This doesn't work if `vector` is a sparse vector. + val vectorArray = vector.toArray + var i = 0 + val len = vectorArray.length + while (i < len) { + vectorArray(i) = vectorArray(i) * math.sqrt(xVariance(i)) + xMean(i) + i += 1 + } + } + + val y = (0 until nPoints).map { idx => + val xArray = x(idx).toArray + val margins = Array.ofDim[Double](nClasses) + val probs = Array.ofDim[Double](nClasses) + + for (i <- 0 until nClasses - 1) { + for (j <- 0 until xDim) margins(i + 1) += weights(i * xWithInterceptsDim + j) * xArray(j) + if (addIntercept) margins(i + 1) += weights((i + 1) * xWithInterceptsDim - 1) + } + // Preventing the overflow when we compute the probability + val maxMargin = margins.max + if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin + + // Computing the probabilities for each class from the margins. + val norm = { + var temp = 0.0 + for (i <- 0 until nClasses) { + probs(i) = math.exp(margins(i)) + temp += probs(i) + } + temp + } + for (i <- 0 until nClasses) probs(i) /= norm + + // Compute the cumulative probability so we can generate a random number and assign a label. + for (i <- 1 until nClasses) probs(i) += probs(i - 1) + val p = rnd.nextDouble() + var y = 0 + breakable { + for (i <- 0 until nClasses) { + if (p < probs(i)) { + y = i + break + } + } + } + y + } + + val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) + testData + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index f41db31f1e7e2..e809dd4092afa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -18,14 +18,16 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.LogisticRegressionSuite._ +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Dataset, Row} class MultilayerPerceptronClassifierSuite @@ -36,7 +38,7 @@ class MultilayerPerceptronClassifierSuite override def beforeAll(): Unit = { super.beforeAll() - dataset = sqlContext.createDataFrame(Seq( + dataset = spark.createDataFrame(Seq( (Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0), @@ -68,6 +70,7 @@ class MultilayerPerceptronClassifierSuite .setBlockSize(1) .setSeed(123L) .setMaxIter(100) + .setSolver("l-bfgs") val model = trainer.fit(dataset) val result = model.transform(dataset) val predictionAndLabels = result.select("prediction", "label").collect() @@ -77,7 +80,7 @@ class MultilayerPerceptronClassifierSuite } test("Test setWeights by training restart") { - val dataFrame = sqlContext.createDataFrame(Seq( + val dataFrame = spark.createDataFrame(Seq( (Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0), @@ -91,9 +94,9 @@ class MultilayerPerceptronClassifierSuite .setMaxIter(1) .setTol(1e-6) val initialWeights = trainer.fit(dataFrame).weights - trainer.setWeights(initialWeights.copy) + trainer.setInitialWeights(initialWeights.copy) val weights1 = trainer.fit(dataFrame).weights - trainer.setWeights(initialWeights.copy) + trainer.setInitialWeights(initialWeights.copy) val weights2 = trainer.fit(dataFrame).weights assert(weights1 ~== weights2 absTol 10e-5, "Training should produce the same weights given equal initial weights and number of steps") @@ -113,7 +116,7 @@ class MultilayerPerceptronClassifierSuite // the input seed is somewhat magic, to make this test pass val rdd = sc.parallelize(generateMultinomialLogisticInput( coefficients, xMean, xVariance, true, nPoints, 1), 2) - val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") + val dataFrame = spark.createDataFrame(rdd).toDF("label", "features") val numClasses = 3 val numIterations = 100 val layers = Array[Int](4, 5, 4, numClasses) @@ -134,12 +137,13 @@ class MultilayerPerceptronClassifierSuite .setNumClasses(numClasses) lr.optimizer.setRegParam(0.0) .setNumIterations(numIterations) - val lrModel = lr.run(rdd) - val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + val lrModel = lr.run(rdd.map(OldLabeledPoint.fromML)) + val lrPredictionAndLabels = + lrModel.predict(rdd.map(p => OldVectors.fromML(p.features))).zip(rdd.map(_.label)) // MLP's predictions should not differ a lot from LR's. val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) - assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) + assert(mlpMetrics.confusionMatrix.asML ~== lrMetrics.confusionMatrix.asML absTol 100) } test("read/write: MultilayerPerceptronClassifier") { @@ -169,7 +173,7 @@ class MultilayerPerceptronClassifierSuite val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1) MLTestingUtils.checkNumericTypes[ MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier]( - mpc, isClassification = true, sqlContext) { (expected, actual) => + mpc, spark) { (expected, actual) => assert(expected.layers === actual.layers) assert(expected.weights === actual.weights) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 80a46fc70c75d..04c010bd13e1e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -17,16 +17,20 @@ package org.apache.spark.ml.classification -import breeze.linalg.{Vector => BV} +import scala.util.Random + +import breeze.linalg.{DenseVector => BDV, Vector => BV} +import breeze.stats.distributions.{Multinomial => BrzMultinomial} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.NaiveBayesSuite._ +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} -import org.apache.spark.mllib.classification.NaiveBayesSuite._ -import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Dataset, Row} class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -43,7 +47,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Array(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) - dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) + dataset = spark.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42)) } def validatePrediction(predictionAndLabels: DataFrame): Unit = { @@ -65,7 +69,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { - val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val logClassProbs: BV[Double] = model.pi.asBreeze + model.theta.multiply(feature).asBreeze val classProbs = logClassProbs.toArray.map(math.exp) val classProbsSum = classProbs.sum Vectors.dense(classProbs.map(_ / classProbsSum)) @@ -74,8 +78,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v))) val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v)) - val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze - val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze + val piTheta: BV[Double] = model.pi.asBreeze + model.theta.multiply(feature).asBreeze + val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).asBreeze val classProbs = logClassProbs.toArray.map(math.exp) val classProbsSum = classProbs.sum Vectors.dense(classProbs.map(_ / classProbsSum)) @@ -127,7 +131,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val pi = Vectors.dense(piArray) val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) - val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + val testDataset = spark.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 42, "multinomial")) val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val model = nb.fit(testDataset) @@ -135,7 +139,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateModelFit(pi, theta, model) assert(model.hasParent) - val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + val validationDataset = spark.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 17, "multinomial")) val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") @@ -157,7 +161,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val pi = Vectors.dense(piArray) val theta = new DenseMatrix(3, 12, thetaArray.flatten, true) - val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + val testDataset = spark.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 45, "bernoulli")) val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") val model = nb.fit(testDataset) @@ -165,7 +169,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateModelFit(pi, theta, model) assert(model.hasParent) - val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( + val validationDataset = spark.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 20, "bernoulli")) val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") @@ -188,7 +192,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa test("should support all NumericType labels and not support other types") { val nb = new NaiveBayes() MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( - nb, isClassification = true, sqlContext) { (expected, actual) => + nb, spark) { (expected, actual) => assert(expected.pi === actual.pi) assert(expected.theta === actual.theta) } @@ -206,4 +210,48 @@ object NaiveBayesSuite { "predictionCol" -> "myPrediction", "smoothing" -> 0.1 ) + + private def calcLabel(p: Double, pi: Array[Double]): Int = { + var sum = 0.0 + for (j <- 0 until pi.length) { + sum += pi(j) + if (p < sum) return j + } + -1 + } + + // Generate input of the form Y = (theta * x).argmax() + def generateNaiveBayesInput( + pi: Array[Double], // 1XC + theta: Array[Array[Double]], // CXD + nPoints: Int, + seed: Int, + modelType: String = Multinomial, + sample: Int = 10): Seq[LabeledPoint] = { + val D = theta(0).length + val rnd = new Random(seed) + val _pi = pi.map(math.pow(math.E, _)) + val _theta = theta.map(row => row.map(math.pow(math.E, _))) + + for (i <- 0 until nPoints) yield { + val y = calcLabel(rnd.nextDouble(), _pi) + val xi = modelType match { + case Bernoulli => Array.tabulate[Double] (D) { j => + if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 + } + case Multinomial => + val mult = BrzMultinomial(BDV(_theta(y))) + val emptyMap = (0 until D).map(x => (x, 0.0)).toMap + val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { + case (index, reps) => (index, reps.size.toDouble) + } + counts.toArray.sortBy(_._1).map(_._2) + case _ => + // This should never happen. + throw new UnknownError(s"Invalid modelType: $modelType.") + } + + LabeledPoint(y, Vectors.dense(xi)) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 51871a9babe45..361dd74cb082e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -19,14 +19,16 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.classification.LogisticRegressionSuite._ +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} -import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD @@ -53,7 +55,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) rdd = sc.parallelize(generateMultinomialLogisticInput( coefficients, xMean, xVariance, true, nPoints, 42), 2) - dataset = sqlContext.createDataFrame(rdd) + dataset = spark.createDataFrame(rdd) } test("params") { @@ -88,8 +90,8 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) lr.optimizer.setRegParam(0.1).setNumIterations(100) - val model = lr.run(rdd) - val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + val model = lr.run(rdd.map(OldLabeledPoint.fromML)) + val results = model.predict(rdd.map(p => OldVectors.fromML(p.features))).zip(rdd.map(_.label)) // determine the #confusion matrix in each class. // bound how much error we allow compared to multinomial logistic regression. val expectedMetrics = new MulticlassMetrics(results) @@ -228,7 +230,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("should support all NumericType labels and not support other types") { val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1)) MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest]( - ovr, isClassification = true, sqlContext) { (expected, actual) => + ovr, spark) { (expected, actual) => val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel]) val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel]) assert(expectedModels.length === actualModels.length) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index cfa75ecf387cd..b3bd2b3e57b36 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors} final class TestProbabilisticClassificationModel( override val uid: String, diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 90744353d9d3b..2e99ee157ae95 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -46,8 +47,10 @@ class RandomForestClassifierSuite super.beforeAll() orderedLabeledPoints50_1000 = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + .map(_.asML) orderedLabeledPoints5_20 = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20)) + .map(_.asML) } ///////////////////////////////////////////////////////////////////////////// @@ -155,7 +158,7 @@ class RandomForestClassifierSuite } test("Fitting without numClasses in metadata") { - val df: DataFrame = sqlContext.createDataFrame(TreeTests.featureImportanceData(sc)) + val df: DataFrame = spark.createDataFrame(TreeTests.featureImportanceData(sc)) val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1) rf.fit(df) } @@ -189,7 +192,7 @@ class RandomForestClassifierSuite test("should support all NumericType labels and not support other types") { val rf = new RandomForestClassifier().setMaxDepth(1) MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier]( - rf, isClassification = true, sqlContext) { (expected, actual) => + rf, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } @@ -233,7 +236,8 @@ private object RandomForestClassifierSuite extends SparkFunSuite { val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity) val oldModel = OldRandomForest.trainClassifier( - data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, + rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newModel = rf.fit(newData) // Use parent from newTree since this is not checked anyways. diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 212ea7a0a960f..4f7d4418a8d09 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -30,7 +30,7 @@ class BisectingKMeansSuite override def beforeAll(): Unit = { super.beforeAll() - dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) } test("default parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 9d868174c1723..04366f5250287 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -32,7 +32,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext override def beforeAll(): Unit = { super.beforeAll() - dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) } test("default parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2ca386e4229ca..88f31a1cd26fb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) @@ -34,7 +34,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR override def beforeAll(): Unit = { super.beforeAll() - dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) } test("default parameters") { @@ -117,6 +117,21 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusterSizes.forall(_ >= 0)) } + test("KMeansModel transform with non-default feature and prediction cols") { + val featuresColName = "kmeans_model_features" + val predictionColName = "kmeans_model_prediction" + + val model = new KMeans().setK(k).setSeed(1).fit(dataset) + model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) + + val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName)) + Seq(featuresColName, predictionColName).foreach { column => + assert(transformed.columns.contains(column)) + } + assert(model.getFeaturesCol == featuresColName) + assert(model.getPredictionCol == predictionColName) + } + test("read/write") { def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = { assert(model.clusterCenters === model2.clusterCenters) @@ -127,11 +142,11 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } object KMeansSuite { - def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = { - val sc = sql.sparkContext + def generateKMeansData(spark: SparkSession, rows: Int, dim: Int, k: Int): DataFrame = { + val sc = spark.sparkContext val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble))) .map(v => new TestRow(v)) - sql.createDataFrame(rdd) + spark.createDataFrame(rdd) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 6cb07aecb952c..ddfa87555427b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,30 +17,30 @@ package org.apache.spark.ml.clustering -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} +import org.apache.spark.sql._ object LDASuite { def generateLDAData( - sql: SQLContext, + spark: SparkSession, rows: Int, k: Int, vocabSize: Int): DataFrame = { val avgWC = 1 // average instances of each word in a doc - val sc = sql.sparkContext + val sc = spark.sparkContext val rng = new java.util.Random() rng.setSeed(1) val rdd = sc.parallelize(1 to rows).map { i => Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble)) }.map(v => new TestRow(v)) - sql.createDataFrame(rdd) + spark.createDataFrame(rdd) } /** @@ -68,7 +68,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead override def beforeAll(): Unit = { super.beforeAll() - dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize) + dataset = LDASuite.generateLDAData(spark, 50, k, vocabSize) } test("default parameters") { @@ -140,7 +140,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead new LDA().setTopicConcentration(-1.1) } - val dummyDF = sqlContext.createDataFrame(Seq( + val dummyDF = spark.createDataFrame(Seq( (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features") // validate parameters lda.transformSchema(dummyDF.schema) @@ -274,7 +274,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead // There should be 1 checkpoint remaining. assert(model.getCheckpointFiles.length === 1) val checkpointFile = new Path(model.getCheckpointFiles.head) - val fs = checkpointFile.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val fs = checkpointFile.getFileSystem(spark.sparkContext.hadoopConfiguration) assert(fs.exists(checkpointFile)) model.deleteCheckpointFiles() assert(model.getCheckpointFiles.isEmpty) diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala index ff34522178084..9ee3df5eb5e33 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext class BinaryClassificationEvaluatorSuite @@ -42,21 +42,21 @@ class BinaryClassificationEvaluatorSuite val evaluator = new BinaryClassificationEvaluator() .setMetricName("areaUnderPR") - val vectorDF = sqlContext.createDataFrame(Seq( + val vectorDF = spark.createDataFrame(Seq( (0d, Vectors.dense(12, 2.5)), (1d, Vectors.dense(1, 3)), (0d, Vectors.dense(10, 2)) )).toDF("label", "rawPrediction") assert(evaluator.evaluate(vectorDF) === 1.0) - val doubleDF = sqlContext.createDataFrame(Seq( + val doubleDF = spark.createDataFrame(Seq( (0d, 0d), (1d, 1d), (0d, 0d) )).toDF("label", "rawPrediction") assert(evaluator.evaluate(doubleDF) === 1.0) - val stringDF = sqlContext.createDataFrame(Seq( + val stringDF = spark.createDataFrame(Seq( (0d, "0d"), (1d, "1d"), (0d, "0d") @@ -71,6 +71,6 @@ class BinaryClassificationEvaluatorSuite test("should support all NumericType labels and not support other types") { val evaluator = new BinaryClassificationEvaluator().setRawPredictionCol("prediction") - MLTestingUtils.checkNumericTypes(evaluator, sqlContext) + MLTestingUtils.checkNumericTypes(evaluator, spark) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 87e511a368a29..1a3a8a13a2d09 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -33,11 +33,11 @@ class MulticlassClassificationEvaluatorSuite val evaluator = new MulticlassClassificationEvaluator() .setPredictionCol("myPrediction") .setLabelCol("myLabel") - .setMetricName("recall") + .setMetricName("accuracy") testDefaultReadWrite(evaluator) } test("should support all NumericType labels and not support other types") { - MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, sqlContext) + MLTestingUtils.checkNumericTypes(new MulticlassClassificationEvaluator, spark) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index c7b94830696c0..42ff8adf6bd65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -42,9 +42,9 @@ class RegressionEvaluatorSuite * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)) * .saveAsTextFile("path") */ - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML)) /** * Using the following R code to load the data, train the model and evaluate metrics. @@ -85,6 +85,6 @@ class RegressionEvaluatorSuite } test("should support all NumericType labels and not support other types") { - MLTestingUtils.checkNumericTypes(new RegressionEvaluator, sqlContext) + MLTestingUtils.checkNumericTypes(new RegressionEvaluator, spark) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 714b9db3aa19f..9cb84a6ee9b87 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -39,7 +39,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) - val dataFrame: DataFrame = sqlContext.createDataFrame( + val dataFrame: DataFrame = spark.createDataFrame( data.zip(defaultBinarized)).toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() @@ -55,7 +55,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize continuous features with setter") { val threshold: Double = 0.2 val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) - val dataFrame: DataFrame = sqlContext.createDataFrame( + val dataFrame: DataFrame = spark.createDataFrame( data.zip(thresholdBinarized)).toDF("feature", "expected") val binarizer: Binarizer = new Binarizer() @@ -71,7 +71,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize vector of continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) - val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + val dataFrame: DataFrame = spark.createDataFrame(Seq( (Vectors.dense(data), Vectors.dense(defaultBinarized)) )).toDF("feature", "expected") @@ -88,7 +88,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau test("Binarize vector of continuous features with setter") { val threshold: Double = 0.2 val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) - val dataFrame: DataFrame = sqlContext.createDataFrame(Seq( + val dataFrame: DataFrame = spark.createDataFrame(Seq( (Vectors.dense(data), Vectors.dense(defaultBinarized)) )).toDF("feature", "expected") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 9ea7d431763a1..cd10c78311e1c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.feature import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -39,7 +39,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validData = Array(-0.5, -0.3, 0.0, 0.2) val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0) val dataFrame: DataFrame = - sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") @@ -55,13 +55,13 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa // Check for exceptions when using a set of invalid feature values. val invalidData1: Array[Double] = Array(-0.9) ++ validData val invalidData2 = Array(0.51) ++ validData - val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") + val badDF1 = spark.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx") withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer.transform(badDF1).collect() } } - val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") + val badDF2 = spark.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx") withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { intercept[SparkException] { bucketizer.transform(badDF2).collect() @@ -74,7 +74,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) val dataFrame: DataFrame = - sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 7827db2794cf3..3558290b23ae0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -18,21 +18,19 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.Row class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test Chi-Square selector") { - val sqlContext = SQLContext.getOrCreate(sc) - import sqlContext.implicits._ - + val spark = this.spark + import spark.implicits._ val data = Seq( LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), @@ -78,4 +76,12 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext val newInstance = testDefaultReadWrite(instance) assert(newInstance.selectedFeatures === instance.selectedFeatures) } + + test("should support all NumericType labels and not support other types") { + val css = new ChiSqSelector() + MLTestingUtils.checkNumericTypes[ChiSqSelectorModel, ChiSqSelector]( + css, spark) { (expected, actual) => + assert(expected.selectedFeatures === actual.selectedFeatures) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 7641e3b8cf668..a59203c33d814 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext @@ -35,7 +35,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext private def split(s: String): Seq[String] = s.split("\\s+") test("CountVectorizerModel common cases") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a b c d"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), (1, split("a b b c d a"), @@ -55,7 +55,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizer common cases") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), @@ -76,7 +76,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizer vocabSize and minDF") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))), @@ -118,7 +118,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext test("CountVectorizer throws exception when vocab is empty") { intercept[IllegalArgumentException] { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a a b b c c")), (1, split("aa bb cc"))) ).toDF("id", "words") @@ -132,7 +132,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel with minTF count") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (2, split("a"), Vectors.sparse(4, Seq())), @@ -151,7 +151,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel with minTF freq") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))), (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))), (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))), @@ -170,7 +170,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } test("CountVectorizerModel and CountVectorizer with binary") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, split("a a a a b b b b c d"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 36cafa290f083..c02e9610418bf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -22,8 +22,8 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row @@ -63,7 +63,7 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } val expectedResult = Vectors.dense(expectedResultBuffer) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( DCTTestData(data, expectedResult) )) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala index fc1c05de233ea..a4cca27be7815 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext class ElementwiseProductSuite diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 44bad4aba4dd3..99b800776bb64 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF} -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -34,7 +34,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("hashingTF") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, "a a b b c d".split(" ").toSeq) )).toDF("id", "words") val n = 100 @@ -54,7 +54,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("applying binary term freqs") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, "a a b c c c".split(" ").toSeq) )).toDF("id", "words") val n = 100 diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index bc958c15857ba..09dc8b9b932fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -60,7 +61,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead }) val expected = scaleDataWithIDF(data, idf) - val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") val idfModel = new IDF() .setInputCol("features") @@ -86,7 +87,7 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead }) val expected = scaleDataWithIDF(data, idf) - val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") val idfModel = new IDF() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 0d4e00668ddb8..3429172a8c903 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -21,9 +21,9 @@ import scala.collection.mutable.ArrayBuilder import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.functions.col @@ -59,7 +59,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("numeric interaction") { - val data = sqlContext.createDataFrame( + val data = spark.createDataFrame( Seq( (2, Vectors.dense(3.0, 4.0)), (1, Vectors.dense(1.0, 5.0))) @@ -74,7 +74,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("b").as("b", groupAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) @@ -90,7 +90,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("nominal interaction") { - val data = sqlContext.createDataFrame( + val data = spark.createDataFrame( Seq( (2, Vectors.dense(3.0, 4.0)), (1, Vectors.dense(1.0, 5.0))) @@ -106,7 +106,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("b").as("b", groupAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) @@ -126,7 +126,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def } test("default attr names") { - val data = sqlContext.createDataFrame( + val data = spark.createDataFrame( Seq( (2, Vectors.dense(0.0, 4.0), 1.0), (1, Vectors.dense(1.0, 5.0), 10.0)) @@ -142,7 +142,7 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with Def col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") val res = trans.transform(df) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala index e083d4713680e..d6400ee02f951 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MaxAbsScalerSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row @@ -36,7 +36,7 @@ class MaxAbsScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.sparse(3, Array(0, 2), Array(-1, -1)), Vectors.sparse(3, Array(0), Array(-0.75))) - val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") val scaler = new MaxAbsScaler() .setInputCol("features") .setOutputCol("scaled") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala index 87206c777e352..5da84711758c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row @@ -38,7 +38,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De Vectors.sparse(3, Array(0, 2), Array(5, 5)), Vectors.sparse(3, Array(0), Array(-2.5))) - val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(expected)).toDF("features", "expected") val scaler = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaled") @@ -57,7 +57,7 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De test("MinMaxScaler arguments max must be larger than min") { withClue("arguments max must be larger than min") { - val dummyDF = sqlContext.createDataFrame(Seq( + val dummyDF = spark.createDataFrame(Seq( (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature") intercept[IllegalArgumentException] { val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index a9421e682519d..e5288d9259d3c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -34,7 +34,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe val nGram = new NGram() .setInputCol("inputTokens") .setOutputCol("nGrams") - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( NGramTestData( Array("Test", "for", "ngram", "."), Array("Test for", "for ngram", "ngram .") @@ -47,7 +47,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(4) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( NGramTestData( Array("a", "b", "c", "d", "e"), Array("a b c d", "b c d e") @@ -60,7 +60,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(4) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( NGramTestData( Array(), Array() @@ -73,7 +73,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe .setInputCol("inputTokens") .setOutputCol("nGrams") .setN(6) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( NGramTestData( Array("a", "b", "c", "d", "e"), Array() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 468833901995a..b692831714466 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -18,10 +18,10 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -61,7 +61,7 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vectors.sparse(3, Seq()) ) - dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) + dataFrame = spark.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData)) normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normalized_features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 49803aef71587..d41eeec1329c5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col @@ -32,7 +32,7 @@ class OneHotEncoderSuite def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -49,7 +49,9 @@ class OneHotEncoderSuite val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") - .setDropLast(false) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").rdd.map { r => @@ -81,7 +83,7 @@ class OneHotEncoderSuite test("input column with ML attribute") { val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") - val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") .select(col("size").as("size", attr.toMetadata())) val encoder = new OneHotEncoder() .setInputCol("size") @@ -94,7 +96,7 @@ class OneHotEncoderSuite } test("input column without ML attribute") { - val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val df = spark.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") val encoder = new OneHotEncoder() .setInputCol("index") .setOutputCol("encoded") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index f372ec58269e4..ddb51fb1706a7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg._ +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.linalg.{Vectors => OldVectors} import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -45,11 +46,11 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val dataRDD = sc.parallelize(data, 2) - val mat = new RowMatrix(dataRDD) + val mat = new RowMatrix(dataRDD.map(OldVectors.fromML)) val pc = mat.computePrincipalComponents(3) - val expected = mat.multiply(pc).rows + val expected = mat.multiply(pc).rows.map(_.asML) - val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") + val df = spark.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") val pca = new PCA() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 86dbee1cf4a5a..9ecd321b128f6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.feature import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row class PolynomialExpansionSuite @@ -59,7 +59,7 @@ class PolynomialExpansionSuite Vectors.sparse(19, Array.empty, Array.empty)) test("Polynomial expansion with default parameter") { - val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -76,7 +76,7 @@ class PolynomialExpansionSuite } test("Polynomial expansion with setter") { - val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -94,7 +94,7 @@ class PolynomialExpansionSuite } test("Polynomial expansion with degree 1 is identity on vectors") { - val df = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + val df = spark.createDataFrame(data.zip(data)).toDF("features", "expected") val polynomialExpansion = new PolynomialExpansion() .setInputCol("features") @@ -116,5 +116,29 @@ class PolynomialExpansionSuite .setDegree(3) testDefaultReadWrite(t) } + + test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") { + val data: Array[(Vector, Int, Int)] = Array( + (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367), + (Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367), + (Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375) + ) + + val df = spark.createDataFrame(data) + .toDF("features", "expectedPoly10size", "expectedPoly11size") + + val t = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + + for (i <- Seq(10, 11)) { + val transformed = t.setDegree(i) + .transform(df) + .select(s"expectedPoly${i}size", "polyFeatures") + .rdd.map { case Row(expected: Int, v: Vector) => expected == v.size } + + assert(transformed.collect.forall(identity)) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 8895d630a0879..18f1e89ee8148 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -20,15 +20,15 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("Test observed number of buckets and their sizes match expected values") { - val sqlCtx = SQLContext.getOrCreate(sc) - import sqlCtx.implicits._ + val spark = this.spark + import spark.implicits._ val datasetSize = 100000 val numBuckets = 5 @@ -52,9 +52,28 @@ class QuantileDiscretizerSuite "Bucket sizes are not within expected relative error tolerance.") } + test("Test Bucketizer on duplicated splits") { + val spark = this.spark + import spark.implicits._ + + val datasetSize = 12 + val numBuckets = 5 + val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0)) + .map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) + + val observedNumBuckets = result.select("result").distinct.count + assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets, + "Observed number of buckets are not within expected range.") + } + test("Test transform method on unseen data") { - val sqlCtx = SQLContext.getOrCreate(sc) - import sqlCtx.implicits._ + val spark = this.spark + import spark.implicits._ val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input") @@ -82,8 +101,8 @@ class QuantileDiscretizerSuite } test("Verify resulting model has parent") { - val sqlCtx = SQLContext.getOrCreate(sc) - import sqlCtx.implicits._ + val spark = this.spark + import spark.implicits._ val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index f8476953d8d2f..c12ab8fe9efe7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.types.DoubleType class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { @@ -32,12 +32,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("transform numeric data") { val formula = new RFormula().setFormula("id ~ v1 + v2") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)) @@ -50,7 +50,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("features column already exists") { val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") - val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") intercept[IllegalArgumentException] { formula.fit(original) } @@ -61,16 +61,16 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("label column already exists") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") + val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") val model = formula.fit(original) val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) assert(resultSchema.toString == model.transform(original).schema.toString) } - test("label column already exists but is not double type") { + test("label column already exists but is not numeric type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") - val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + val original = spark.createDataFrame(Seq((0, true), (2, false))).toDF("x", "y") val model = formula.fit(original) intercept[IllegalArgumentException] { model.transformSchema(original.schema) @@ -82,7 +82,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("allow missing label column for test datasets") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") - val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") + val original = spark.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") val model = formula.fit(original) val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) @@ -91,14 +91,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } test("allow empty label") { - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, 2.0, 3.0), (4, 5.0, 6.0), (7, 8.0, 9.0)) ).toDF("id", "a", "b") val formula = new RFormula().setFormula("~ a + b") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, 2.0, 3.0, Vectors.dense(2.0, 3.0)), (4, 5.0, 6.0, Vectors.dense(5.0, 6.0)), @@ -110,13 +110,13 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("encodes string terms") { val formula = new RFormula().setFormula("id ~ a + b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) val resultSchema = model.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), @@ -129,13 +129,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("index string label") { val formula = new RFormula().setFormula("id ~ a + b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), ("male", "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val resultSchema = model.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0), @@ -148,7 +147,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("attribute generation") { val formula = new RFormula().setFormula("id ~ a + b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) @@ -165,7 +164,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("vector attribute generation") { val formula = new RFormula().setFormula("id ~ vec") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) ).toDF("id", "vec") val model = formula.fit(original) @@ -181,14 +180,14 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("vector attribute generation with unnamed input attrs") { val formula = new RFormula().setFormula("id ~ vec2") - val base = sqlContext.createDataFrame( + val base = spark.createDataFrame( Seq((1, Vectors.dense(0.0, 1.0)), (2, Vectors.dense(1.0, 2.0))) ).toDF("id", "vec") val metadata = new AttributeGroup( "vec2", Array[Attribute]( NumericAttribute.defaultAttr, - NumericAttribute.defaultAttr)).toMetadata + NumericAttribute.defaultAttr)).toMetadata() val original = base.select(base.col("id"), base.col("vec").as("vec2", metadata)) val model = formula.fit(original) val result = model.transform(original) @@ -203,12 +202,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, 2, 4, 2), (2, 3, 4, 1)) ).toDF("a", "b", "c", "d") val model = formula.fit(original) val result = model.transform(original) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, 2, 4, 2, Vectors.dense(16.0), 1.0), (2, 3, 4, 1, Vectors.dense(12.0), 2.0)) @@ -223,12 +222,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("factor numeric interaction") { val formula = new RFormula().setFormula("id ~ a:b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), @@ -250,12 +249,12 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("factor factor interaction") { val formula = new RFormula().setFormula("id ~ a:b") - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) ).toDF("id", "a", "b") val model = formula.fit(original) val result = model.transform(original) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq( (1, "foo", "zq", Vectors.dense(0.0, 0.0, 1.0, 0.0), 1.0), (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), @@ -299,7 +298,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) ).toDF("id", "a", "b") @@ -309,4 +308,23 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newModel = testDefaultReadWrite(model) checkModelData(model, newModel) } + + test("should support all NumericType labels") { + val formula = new RFormula().setFormula("label ~ features") + .setLabelCol("x") + .setFeaturesCol("y") + val dfs = MLTestingUtils.genRegressionDFWithNumericLabelCol(spark) + val expected = formula.fit(dfs(DoubleType)) + val actuals = dfs.keys.filter(_ != DoubleType).map(t => formula.fit(dfs(t))) + actuals.foreach { actual => + assert(expected.pipelineModel.stages.length === actual.pipelineModel.stages.length) + expected.pipelineModel.stages.zip(actual.pipelineModel.stages).foreach { + case (exTransformer, acTransformer) => + assert(exTransformer.params === acTransformer.params) + } + assert(expected.resolvedFormula.label === actual.resolvedFormula.label) + assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms) + assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index e213e17d0d9de..9d3c00707039c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -31,18 +31,19 @@ class SQLTransformerSuite } test("transform numeric data") { - val original = sqlContext.createDataFrame( + val original = spark.createDataFrame( Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") val result = sqlTrans.transform(original) val resultSchema = sqlTrans.transformSchema(original.schema) - val expected = sqlContext.createDataFrame( + val expected = spark.createDataFrame( Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) .toDF("id", "v1", "v2", "v3", "v4") assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) assert(result.collect().toSeq == expected.collect().toSeq) + assert(original.sparkSession.catalog.listTables().count() == 0) } test("read/write") { @@ -52,7 +53,7 @@ class SQLTransformerSuite } test("transformSchema") { - val df = sqlContext.range(10) + val df = spark.range(10) val outputSchema = new SQLTransformer() .setStatement("SELECT id + 1 AS id1 FROM __THIS__") .transformSchema(df.schema) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 8c5e47a22c969..2243a0f972d32 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext @@ -73,7 +73,7 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Standardization with default parameter") { - val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") + val df0 = spark.createDataFrame(data.zip(resWithStd)).toDF("features", "expected") val standardScaler0 = new StandardScaler() .setInputCol("features") @@ -84,9 +84,9 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Standardization with setter") { - val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") - val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") - val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected") + val df1 = spark.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected") + val df2 = spark.createDataFrame(data.zip(resWithMean)).toDF("features", "expected") + val df3 = spark.createDataFrame(data.zip(data)).toDF("features", "expected") val standardScaler1 = new StandardScaler() .setInputCol("features") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala old mode 100644 new mode 100755 index 3505befdf8e37..125ad02ebcc02 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Dataset, Row} object StopWordsRemoverSuite extends SparkFunSuite { def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { @@ -42,8 +42,26 @@ class StopWordsRemoverSuite val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("test", "test"), Seq("test", "test")), + (Seq("a", "b", "c", "d"), Seq("b", "c")), + (Seq("a", "the", "an"), Seq()), + (Seq("A", "The", "AN"), Seq()), + (Seq(null), Seq(null)), + (Seq(), Seq()) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with particular stop words list") { + val stopWords = Array("test", "a", "an", "the") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords) + val dataSet = spark.createDataFrame(Seq( + (Seq("test", "test"), Seq()), (Seq("a", "b", "c", "d"), Seq("b", "c", "d")), (Seq("a", "the", "an"), Seq()), (Seq("A", "The", "AN"), Seq()), @@ -59,7 +77,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setCaseSensitive(true) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("A"), Seq("A")), (Seq("The", "the"), Seq("The")) )).toDF("raw", "expected") @@ -67,13 +85,48 @@ class StopWordsRemoverSuite testStopWordsRemover(remover, dataSet) } - test("StopWordsRemover with additional words") { - val stopWords = StopWords.English ++ Array("python", "scala") + test("default stop words of supported languages are not empty") { + StopWordsRemover.supportedLanguages.foreach { lang => + assert(StopWordsRemover.loadDefaultStopWords(lang).nonEmpty, + s"The default stop words of $lang cannot be empty.") + } + } + + test("StopWordsRemover with language selection") { + val stopWords = StopWordsRemover.loadDefaultStopWords("turkish") val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( + (Seq("acaba", "ama", "biri"), Seq()), + (Seq("hep", "her", "scala"), Seq("scala")) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with ignored words") { + val stopWords = StopWordsRemover.loadDefaultStopWords("english").toSet -- Set("a") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords.toArray) + val dataSet = spark.createDataFrame(Seq( + (Seq("python", "scala", "a"), Seq("python", "scala", "a")), + (Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift")) + )).toDF("raw", "expected") + + testStopWordsRemover(remover, dataSet) + } + + test("StopWordsRemover with additional words") { + val stopWords = StopWordsRemover.loadDefaultStopWords("english").toSet ++ Set("python", "scala") + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol("filtered") + .setStopWords(stopWords.toArray) + val dataSet = spark.createDataFrame(Seq( (Seq("python", "scala", "a"), Seq()), (Seq("Python", "Scala", "swift"), Seq("swift")) )).toDF("raw", "expected") @@ -95,7 +148,7 @@ class StopWordsRemoverSuite val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol(outputCol) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("The", "the", "swift"), Seq("swift")) )).toDF("raw", outputCol) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d0f3cdc841d11..b478fea5e74ec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -39,7 +39,7 @@ class StringIndexerSuite test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -63,8 +63,8 @@ class StringIndexerSuite test("StringIndexerUnseen") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") - val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") + val df2 = spark.createDataFrame(data2).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -93,7 +93,7 @@ class StringIndexerSuite test("StringIndexer with a numeric input column") { val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -114,18 +114,26 @@ class StringIndexerSuite val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") .setOutputCol("labelIndex") - val df = sqlContext.range(0L, 10L).toDF() + val df = spark.range(0L, 10L).toDF() assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) } test("StringIndexerModel can't overwrite output column") { - val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + intercept[IllegalArgumentException] { + new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + } + val indexer = new StringIndexer() .setInputCol("input") - .setOutputCol("output") + .setOutputCol("indexedInput") .fit(df) + intercept[IllegalArgumentException] { - indexer.transform(df) + indexer.setOutputCol("output").transform(df) } } @@ -153,7 +161,7 @@ class StringIndexerSuite test("IndexToString.transform") { val labels = Array("a", "b", "c") - val df0 = sqlContext.createDataFrame(Seq( + val df0 = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), (0, "a") )).toDF("index", "expected") @@ -180,7 +188,7 @@ class StringIndexerSuite test("StringIndexer, IndexToString are inverses") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -213,7 +221,7 @@ class StringIndexerSuite test("StringIndexer metadata") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index 123ddfe42c958..f30bdc3ddc0d7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -57,13 +57,13 @@ class RegexTokenizerSuite .setPattern("\\w+|\\p{Punct}") .setInputCol("rawText") .setOutputCol("tokens") - val dataset0 = sqlContext.createDataFrame(Seq( + val dataset0 = spark.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer0, dataset0) - val dataset1 = sqlContext.createDataFrame(Seq( + val dataset1 = spark.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) @@ -73,7 +73,7 @@ class RegexTokenizerSuite val tokenizer2 = new RegexTokenizer() .setInputCol("rawText") .setOutputCol("tokens") - val dataset2 = sqlContext.createDataFrame(Seq( + val dataset2 = spark.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) )) @@ -85,7 +85,7 @@ class RegexTokenizerSuite .setInputCol("rawText") .setOutputCol("tokens") .setToLowercase(false) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), TokenizerTestData("java scala", Array("java", "scala")) )) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index dce994fdbd056..561493fbafd6c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col @@ -57,7 +57,7 @@ class VectorAssemblerSuite } test("VectorAssembler") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) )).toDF("id", "x", "y", "name", "z", "n") val assembler = new VectorAssembler() @@ -70,14 +70,14 @@ class VectorAssemblerSuite } test("transform should throw an exception in case of unsupported type") { - val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") val assembler = new VectorAssembler() .setInputCols(Array("a", "b", "c")) .setOutputCol("features") - val thrown = intercept[SparkException] { + val thrown = intercept[IllegalArgumentException] { assembler.transform(df) } - assert(thrown.getMessage contains "VectorAssembler does not support the StringType type") + assert(thrown.getMessage contains "Data type StringType is not supported") } test("ML attributes") { @@ -87,7 +87,7 @@ class VectorAssemblerSuite NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), NumericAttribute.defaultAttr.withName("salary"))) val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) - val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") .select( col("browser").as("browser", browser.toMetadata()), col("hour").as("hour", hour.toMetadata()), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 1ffc62b38e856..707142332349c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -22,9 +22,9 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -85,11 +85,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext checkPair(densePoints1Seq, sparsePoints1Seq) checkPair(densePoints2Seq, sparsePoints2Seq) - densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) - sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) - densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) - sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) - badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) } private def getIndexer: VectorIndexer = @@ -102,7 +102,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Cannot fit an empty DataFrame") { - val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) val vectorIndexer = getIndexer intercept[IllegalArgumentException] { vectorIndexer.fit(rdd) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 6bb4678dc5f97..1746ce53107c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.{StructField, StructType} @@ -79,7 +79,7 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]]) val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) } - val df = sqlContext.createDataFrame(rdd, + val df = spark.createDataFrame(rdd, StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField()))) val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 80c177b8d3188..c8f1311538956 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} -import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -36,8 +36,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("Word2Vec") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a b " * 100 + "a c " * 10 val numOfWords = sentence.split(" ").size @@ -78,8 +78,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("getVectors") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -119,8 +119,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -146,8 +146,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -191,6 +191,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .setStepSize(0.01) .setVectorSize(100) + .setMaxSentenceLength(500) testDefaultReadWrite(t) } @@ -206,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newInstance = testDefaultReadWrite(instance) assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } + + test("Word2Vec works with input that is non-nullable (NGram)") { + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " + val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") + + val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams") + val ngramDF = ngram.transform(docDF) + + val model = new Word2Vec() + .setVectorSize(2) + .setInputCol("ngrams") + .setOutputCol("result") + .fit(ngramDF) + + // Just test that this transformation succeeds + model.transform(ngramDF).collect() + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/SQLDataTypesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/SQLDataTypesSuite.scala new file mode 100644 index 0000000000000..0bd0c32f19d04 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/SQLDataTypesSuite.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.linalg + +import org.apache.spark.SparkFunSuite + +class SQLDataTypesSuite extends SparkFunSuite { + test("sqlDataTypes") { + assert(SQLDataTypes.VectorType === new VectorUDT) + assert(SQLDataTypes.MatrixType === new MatrixUDT) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala index 6d01d8f2828e1..6ddb12cb76aac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.ml.linalg import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.sql.catalyst.JavaTypeInference import org.apache.spark.sql.types._ class VectorUDTSuite extends SparkFunSuite { @@ -36,4 +38,10 @@ class VectorUDTSuite extends SparkFunSuite { assert(udt.simpleString == "vector") } } + + test("JavaTypeInference with VectorUDT") { + val (dataType, _) = JavaTypeInference.inferDataType(classOf[LabeledPoint]) + assert(dataType.asInstanceOf[StructType].fields.map(_.dataType) + === Seq(new VectorUDT, DoubleType)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala index 604021220a139..b30d995794d4c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/IterativelyReweightedLeastSquaresSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.optim import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index 0b58a9821f57b..c8de796b2de87 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.optim import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index a3366c0e5934c..aa9c53ca30eee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.param -import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream} +import java.io.{ByteArrayOutputStream, ObjectOutputStream} import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.util.MyParams -import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala new file mode 100644 index 0000000000000..5eaef9aabda50 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.python + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vectors} + +class MLSerDeSuite extends SparkFunSuite { + + MLSerDe.initialize() + + test("pickle vector") { + val vectors = Seq( + Vectors.dense(Array.empty[Double]), + Vectors.dense(0.0), + Vectors.dense(0.0, -2.0), + Vectors.sparse(0, Array.empty[Int], Array.empty[Double]), + Vectors.sparse(1, Array.empty[Int], Array.empty[Double]), + Vectors.sparse(2, Array(1), Array(-2.0))) + vectors.foreach { v => + val u = MLSerDe.loads(MLSerDe.dumps(v)) + assert(u.getClass === v.getClass) + assert(u === v) + } + } + + test("pickle double") { + for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { + val deser = MLSerDe.loads(MLSerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double] + // We use `equals` here for comparison because we cannot use `==` for NaN + assert(x.equals(deser)) + } + } + + test("pickle matrix") { + val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) + val matrix = Matrices.dense(2, 3, values) + val nm = MLSerDe.loads(MLSerDe.dumps(matrix)).asInstanceOf[DenseMatrix] + assert(matrix === nm) + + // Test conversion for empty matrix + val empty = Array[Double]() + val emptyMatrix = Matrices.dense(0, 0, empty) + val ne = MLSerDe.loads(MLSerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix] + assert(emptyMatrix == ne) + + val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4)) + val nsm = MLSerDe.loads(MLSerDe.dumps(sm)).asInstanceOf[SparseMatrix] + assert(sm.toArray === nsm.toArray) + + val smt = new SparseMatrix( + 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9), + isTransposed = true) + val nsmt = MLSerDe.loads(MLSerDe.dumps(smt)).asInstanceOf[SparseMatrix] + assert(smt.toArray === nsmt.toArray) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 4c4eb72cd16e7..d0aa2cdfe0fd1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -17,25 +17,31 @@ package org.apache.spark.ml.recommendation +import java.io.File import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.TrueFileFilter -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.types.{FloatType, IntegerType} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { @@ -200,7 +206,6 @@ class ALSSuite /** * Generates an explicit feedback dataset for testing ALS. - * * @param numUsers number of users * @param numItems number of items * @param rank rank @@ -241,7 +246,6 @@ class ALSSuite /** * Generates an implicit feedback dataset for testing ALS. - * * @param numUsers number of users * @param numItems number of items * @param rank rank @@ -255,42 +259,11 @@ class ALSSuite rank: Int, noiseStd: Double = 0.0, seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = { - // The assumption of the implicit feedback model is that unobserved ratings are more likely to - // be negatives. - val positiveFraction = 0.8 - val negativeFraction = 1.0 - positiveFraction - val trainingFraction = 0.6 - val testFraction = 0.3 - val totalFraction = trainingFraction + testFraction - val random = new Random(seed) - val userFactors = genFactors(numUsers, rank, random) - val itemFactors = genFactors(numItems, rank, random) - val training = ArrayBuffer.empty[Rating[Int]] - val test = ArrayBuffer.empty[Rating[Int]] - for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { - val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) - val threshold = if (rating > 0) positiveFraction else negativeFraction - val observed = random.nextDouble() < threshold - if (observed) { - val x = random.nextDouble() - if (x < totalFraction) { - if (x < trainingFraction) { - val noise = noiseStd * random.nextGaussian() - training += Rating(userId, itemId, rating + noise.toFloat) - } else { - test += Rating(userId, itemId, rating) - } - } - } - } - logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " + - s"and ${test.size} for test.") - (sc.parallelize(training, 2), sc.parallelize(test, 2)) + ALSSuite.genImplicitTestData(sc, numUsers, numItems, rank, noiseStd, seed) } /** * Generates random user/item factors, with i.i.d. values drawn from U(a, b). - * * @param size number of users/items * @param rank number of features * @param random random number generator @@ -304,19 +277,11 @@ class ALSSuite random: Random, a: Float = -1.0f, b: Float = 1.0f): Seq[(Int, Array[Float])] = { - require(size > 0 && size < Int.MaxValue / 3) - require(b > a) - val ids = mutable.Set.empty[Int] - while (ids.size < size) { - ids += random.nextInt() - } - val width = b - a - ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width))) + ALSSuite.genFactors(size, rank, random, a, b) } /** * Test ALS using the given training/test splits and parameters. - * * @param training training dataset * @param test test dataset * @param rank rank of the matrix factorization @@ -337,8 +302,8 @@ class ALSSuite numUserBlocks: Int = 2, numItemBlocks: Int = 3, targetRMSE: Double = 0.05): Unit = { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val als = new ALS() .setRank(rank) .setRegParam(regParam) @@ -492,8 +457,8 @@ class ALSSuite allEstimatorParamSettings.foreach { case (p, v) => als.set(als.getParam(p), v) } - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val model = als.fit(ratings.toDF()) // Test Estimator save/load @@ -518,6 +483,139 @@ class ALSSuite assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) } + + test("input type validation") { + val spark = this.spark + import spark.implicits._ + + // check that ALS can handle all numeric types for rating column + // and user/item columns (when the user/item ids are within Int range) + val als = new ALS().setMaxIter(1).setRank(1) + Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { + case (colName, sqlType) => + MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) { + (ex, act) => + ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) + } { (ex, act, _) => + ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~== + act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6 + } + } + // check user/item ids falling outside of Int range + val big = Int.MaxValue.toLong + 1 + val small = Int.MinValue.toDouble - 1 + val df = Seq( + (0, 0L, 0d, 1, 1L, 1d, 3.0), + (0, big, small, 0, big, small, 2.0), + (1, 1L, 1d, 0, 0L, 0d, 5.0) + ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") + withClue("fit should fail when ids exceed integer range. ") { + assert(intercept[SparkException] { + als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) + }.getCause.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) + }.getCause.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) + }.getCause.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) + }.getCause.getMessage.contains("was out of Integer range")) + } + withClue("transform should fail when ids exceed integer range. ") { + val model = als.fit(df) + assert(intercept[SparkException] { + model.transform(df.select(df("user_big").as("user"), df("item"))).first + }.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + model.transform(df.select(df("user_small").as("user"), df("item"))).first + }.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + model.transform(df.select(df("item_big").as("item"), df("user"))).first + }.getMessage.contains("was out of Integer range")) + assert(intercept[SparkException] { + model.transform(df.select(df("item_small").as("item"), df("user"))).first + }.getMessage.contains("was out of Integer range")) + } + } +} + +class ALSCleanerSuite extends SparkFunSuite { + test("ALS shuffle cleanup standalone") { + val conf = new SparkConf() + val localDir = Utils.createTempDir() + val checkpointDir = Utils.createTempDir() + def getAllFiles: Set[File] = + FileUtils.listFiles(localDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + try { + conf.set("spark.local.dir", localDir.getAbsolutePath) + val sc = new SparkContext("local[2]", "test", conf) + try { + sc.setCheckpointDir(checkpointDir.getAbsolutePath) + // Test checkpoint and clean parents + val input = sc.parallelize(1 to 1000) + val keyed = input.map(x => (x % 20, 1)) + val shuffled = keyed.reduceByKey(_ + _) + val keysOnly = shuffled.keys + val deps = keysOnly.dependencies + keysOnly.count() + ALS.cleanShuffleDependencies(sc, deps, true) + val resultingFiles = getAllFiles + assert(resultingFiles === Set()) + // Ensure running count again works fine even if we kill the shuffle files. + keysOnly.count() + } finally { + sc.stop() + } + } finally { + Utils.deleteRecursively(localDir) + Utils.deleteRecursively(checkpointDir) + } + } + + test("ALS shuffle cleanup in algorithm") { + val conf = new SparkConf() + val localDir = Utils.createTempDir() + val checkpointDir = Utils.createTempDir() + def getAllFiles: Set[File] = + FileUtils.listFiles(localDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet + try { + conf.set("spark.local.dir", localDir.getAbsolutePath) + val sc = new SparkContext("local[2]", "test", conf) + try { + sc.setCheckpointDir(checkpointDir.getAbsolutePath) + // Generate test data + val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0) + // Implicitly test the cleaning of parents during ALS training + val spark = SparkSession.builder + .master("local[2]") + .appName("ALSCleanerSuite") + .sparkContext(sc) + .getOrCreate() + import spark.implicits._ + val als = new ALS() + .setRank(1) + .setRegParam(1e-5) + .setSeed(0) + .setCheckpointInterval(1) + .setMaxIter(7) + val model = als.fit(training.toDF()) + val resultingFiles = getAllFiles + // We expect the last shuffles files, block ratings, user factors, and item factors to be + // around but no more. + val pattern = "shuffle_(\\d+)_.+\\.data".r + val rddIds = resultingFiles.flatMap { f => + pattern.findAllIn(f.getName()).matchData.map { _.group(1) } } + assert(rddIds.size === 4) + } finally { + sc.stop() + } + } finally { + Utils.deleteRecursively(localDir) + Utils.deleteRecursively(checkpointDir) + } + } } class ALSStorageSuite @@ -536,8 +634,8 @@ class ALSStorageSuite } test("default and non-default storage params set correct RDD StorageLevels") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val data = Seq( (0, 0, 1.0), (0, 1, 2.0), @@ -591,7 +689,7 @@ private class IntermediateRDDStorageListener extends SparkListener { } -object ALSSuite { +object ALSSuite extends Logging { /** * Mapping from all Params to valid settings which differ from the defaults. @@ -620,4 +718,82 @@ object ALSSuite { "intermediateStorageLevel" -> "MEMORY_ONLY", "finalStorageLevel" -> "MEMORY_AND_DISK_SER" ) + + // Helper functions to generate test data we share between ALS test suites + + /** + * Generates random user/item factors, with i.i.d. values drawn from U(a, b). + * @param size number of users/items + * @param rank number of features + * @param random random number generator + * @param a min value of the support (default: -1) + * @param b max value of the support (default: 1) + * @return a sequence of (ID, factors) pairs + */ + private def genFactors( + size: Int, + rank: Int, + random: Random, + a: Float = -1.0f, + b: Float = 1.0f): Seq[(Int, Array[Float])] = { + require(size > 0 && size < Int.MaxValue / 3) + require(b > a) + val ids = mutable.Set.empty[Int] + while (ids.size < size) { + ids += random.nextInt() + } + val width = b - a + ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width))) + } + + /** + * Generates an implicit feedback dataset for testing ALS. + * + * @param sc SparkContext + * @param numUsers number of users + * @param numItems number of items + * @param rank rank + * @param noiseStd the standard deviation of additive Gaussian noise on training data + * @param seed random seed + * @return (training, test) + */ + def genImplicitTestData( + sc: SparkContext, + numUsers: Int, + numItems: Int, + rank: Int, + noiseStd: Double = 0.0, + seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = { + // The assumption of the implicit feedback model is that unobserved ratings are more likely to + // be negatives. + val positiveFraction = 0.8 + val negativeFraction = 1.0 - positiveFraction + val trainingFraction = 0.6 + val testFraction = 0.3 + val totalFraction = trainingFraction + testFraction + val random = new Random(seed) + val userFactors = genFactors(numUsers, rank, random) + val itemFactors = genFactors(numItems, rank, random) + val training = ArrayBuffer.empty[Rating[Int]] + val test = ArrayBuffer.empty[Rating[Int]] + for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { + val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) + val threshold = if (rating > 0) positiveFraction else negativeFraction + val observed = random.nextDouble() < threshold + if (observed) { + val x = random.nextDouble() + if (x < totalFraction) { + if (x < trainingFraction) { + val noise = noiseStd * random.nextGaussian() + training += Rating(userId, itemId, rating + noise.toFloat) + } else { + test += Rating(userId, itemId, rating) + } + } + } + } + logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " + + s"and ${test.size} for test.") + (sc.parallelize(training, 2), sc.parallelize(test, 2)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 76891ad562811..1c70b702de063 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.regression import scala.util.Random import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} class AFTSurvivalRegressionSuite @@ -37,13 +37,13 @@ class AFTSurvivalRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - datasetUnivariate = sqlContext.createDataFrame( + datasetUnivariate = spark.createDataFrame( sc.parallelize(generateAFTInput( 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) - datasetMultivariate = sqlContext.createDataFrame( + datasetMultivariate = spark.createDataFrame( sc.parallelize(generateAFTInput( 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) - datasetUnivariateScaled = sqlContext.createDataFrame( + datasetUnivariateScaled = spark.createDataFrame( sc.parallelize(generateAFTInput( 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) @@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite test("should support all NumericType labels") { val aft = new AFTSurvivalRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( - aft, isClassification = false, sqlContext) { (expected, actual) => + aft, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } @@ -390,6 +390,18 @@ class AFTSurvivalRegressionSuite testEstimatorAndModelReadWrite(aft, datasetMultivariate, AFTSurvivalRegressionSuite.allParamSettings, checkModelData) } + + test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { + // This `dataset` will contain an empty partition because it has two rows but + // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s + // being merged incorrectly when it has an empty partition, running the codes below + // should not throw an exception. + val dataset = spark.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3)) + val trainer = new AFTSurvivalRegression() + trainer.fit(dataset) + } } object AFTSurvivalRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index e9fb2677b215b..9afb742406ec8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -38,7 +39,7 @@ class DecisionTreeRegressorSuite override def beforeAll() { super.beforeAll() categoricalDataPointsRDD = - sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML)) } ///////////////////////////////////////////////////////////////////////////// @@ -120,7 +121,7 @@ class DecisionTreeRegressorSuite test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( - dt, isClassification = false, sqlContext) { (expected, actual) => + dt, spark, isClassification = false) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } @@ -170,7 +171,7 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { categoricalFeatures: Map[Int, Int]): Unit = { val numFeatures = data.first().features.size val oldStrategy = dt.getOldStrategy(categoricalFeatures) - val oldTree = OldDecisionTree.train(data, oldStrategy) + val oldTree = OldDecisionTree.train(data.map(OldLabeledPoint.fromML), oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) // Use parent from newTree since this is not checked anyways. diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 216377959e090..7b5df8f31bb38 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -48,10 +49,13 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext override def beforeAll() { super.beforeAll() data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + .map(_.asML) trainData = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + .map(_.asML) validationData = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + .map(_.asML) } test("Regression with continuous features") { @@ -72,7 +76,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext } test("GBTRegressor behaves reasonably on toy data") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), @@ -99,7 +103,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val path = tempDir.toURI.toString sc.setCheckpointDir(path) - val df = sqlContext.createDataFrame(data) + val df = spark.createDataFrame(data) val gbt = new GBTRegressor() .setMaxDepth(2) .setMaxIter(5) @@ -115,7 +119,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext test("should support all NumericType labels and not support other types") { val gbt = new GBTRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor]( - gbt, isClassification = false, sqlContext) { (expected, actual) => + gbt, spark, isClassification = false) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } @@ -197,7 +201,7 @@ private object GBTRegressorSuite extends SparkFunSuite { val numFeatures = data.first().features.size val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt) - val oldModel = oldGBT.run(data) + val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML)) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) // Use parent from newTree since this is not checked anyways. diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index b854be2f1fbc1..9d102153f222b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -20,17 +20,18 @@ package org.apache.spark.ml.regression import scala.util.Random import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vector, Vectors} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.FloatType class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -52,19 +53,19 @@ class GeneralizedLinearRegressionSuite import GeneralizedLinearRegressionSuite._ - datasetGaussianIdentity = sqlContext.createDataFrame( + datasetGaussianIdentity = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gaussian", link = "identity"), 2)) - datasetGaussianLog = sqlContext.createDataFrame( + datasetGaussianLog = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gaussian", link = "log"), 2)) - datasetGaussianInverse = sqlContext.createDataFrame( + datasetGaussianInverse = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -80,40 +81,40 @@ class GeneralizedLinearRegressionSuite generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) - sqlContext.createDataFrame(sc.parallelize(testData, 2)) + spark.createDataFrame(sc.parallelize(testData, 2)) } - datasetPoissonLog = sqlContext.createDataFrame( + datasetPoissonLog = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "log"), 2)) - datasetPoissonIdentity = sqlContext.createDataFrame( + datasetPoissonIdentity = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "identity"), 2)) - datasetPoissonSqrt = sqlContext.createDataFrame( + datasetPoissonSqrt = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "sqrt"), 2)) - datasetGammaInverse = sqlContext.createDataFrame( + datasetGammaInverse = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gamma", link = "inverse"), 2)) - datasetGammaIdentity = sqlContext.createDataFrame( + datasetGammaIdentity = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gamma", link = "identity"), 2)) - datasetGammaLog = sqlContext.createDataFrame( + datasetGammaLog = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -540,7 +541,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), @@ -668,7 +669,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), @@ -782,7 +783,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), @@ -899,7 +900,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), @@ -1021,19 +1022,43 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[ GeneralizedLinearRegressionModel, GeneralizedLinearRegression]( - glr, isClassification = false, sqlContext) { (expected, actual) => + glr, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } } test("glm accepts Dataset[LabeledPoint]") { - val context = sqlContext + val context = spark import context.implicits._ new GeneralizedLinearRegression() .setFamily("gaussian") .fit(datasetGaussianIdentity.as[LabeledPoint]) } + + test("evaluate with labels that are not doubles") { + // Evaulate with a dataset that contains Labels not as doubles to verify correct casting + val dataset = spark.createDataFrame(sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 1.0, Vectors.dense(3.0, 13.0)) + ), 2)) + + val trainer = new GeneralizedLinearRegression() + .setMaxIter(1) + val model = trainer.fit(dataset) + assert(model.hasSummary) + val summary = model.summary + + val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType), + col(model.getFeaturesCol)) + val evalSummary = model.evaluate(longLabelDataset) + // The calculations below involve pattern matching with Label as a double + assert(evalSummary.nullDeviance === summary.nullDeviance) + assert(evalSummary.deviance === summary.deviance) + assert(evalSummary.aic === summary.aic) + } } object GeneralizedLinearRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 3a10ad7ed060a..14d8a4e4e3345 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -28,13 +28,13 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { - sqlContext.createDataFrame( + spark.createDataFrame( labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } ).toDF("label", "features", "weight") } private def generatePredictionInput(features: Seq[Double]): DataFrame = { - sqlContext.createDataFrame(features.map(Tuple1.apply)) + spark.createDataFrame(features.map(Tuple1.apply)) .toDF("features") } @@ -145,7 +145,7 @@ class IsotonicRegressionSuite } test("vector features column with feature index") { - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( (4.0, Vectors.dense(0.0, 1.0)), (3.0, Vectors.dense(0.0, 2.0)), (5.0, Vectors.sparse(2, Array(1), Array(3.0)))) @@ -184,7 +184,7 @@ class IsotonicRegressionSuite test("should support all NumericType labels and not support other types") { val ir = new IsotonicRegression() MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression]( - ir, isClassification = false, sqlContext) { (expected, actual) => + ir, spark, isClassification = false) { (expected, actual) => assert(expected.boundaries === actual.boundaries) assert(expected.predictions === actual.predictions) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index eb19d130939e4..df67a3a354610 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -21,12 +21,12 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} class LinearRegressionSuite @@ -42,29 +42,29 @@ class LinearRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - datasetWithDenseFeature = sqlContext.createDataFrame( + datasetWithDenseFeature = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2)) + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML)) /* datasetWithDenseFeatureWithoutIntercept is not needed for correctness testing but is useful for illustrating training model without intercept */ - datasetWithDenseFeatureWithoutIntercept = sqlContext.createDataFrame( + datasetWithDenseFeatureWithoutIntercept = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), - xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2)) + xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML)) val r = new Random(seed) // When feature size is larger than 4096, normal optimizer is choosed // as the solver of linear regression in the case of "auto" mode. val featureSize = 4100 - datasetWithSparseFeature = sqlContext.createDataFrame( + datasetWithSparseFeature = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray, xMean = Seq.fill(featureSize)(r.nextDouble()).toArray, xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200, - seed, eps = 0.1, sparsity = 0.7), 2)) + seed, eps = 0.1, sparsity = 0.7), 2).map(_.asML)) /* R code: @@ -74,7 +74,7 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - datasetWithWeight = sqlContext.createDataFrame( + datasetWithWeight = spark.createDataFrame( sc.parallelize(Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), @@ -90,14 +90,14 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df.const.label <- as.data.frame(cbind(A, b.const)) */ - datasetWithWeightConstantLabel = sqlContext.createDataFrame( + datasetWithWeightConstantLabel = spark.createDataFrame( sc.parallelize(Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2)) - datasetWithWeightZeroLabel = sqlContext.createDataFrame( + datasetWithWeightZeroLabel = spark.createDataFrame( sc.parallelize(Seq( Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), @@ -610,20 +610,31 @@ class LinearRegressionSuite val model1 = new LinearRegression() .setFitIntercept(fitIntercept) .setWeightCol("weight") + .setPredictionCol("myPrediction") .setSolver(solver) .fit(datasetWithWeightConstantLabel) val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0), model1.coefficients(1)) assert(actual1 ~== expected(idx) absTol 1e-4) + // Schema of summary.predictions should be a superset of the input dataset + assert((datasetWithWeightConstantLabel.schema.fieldNames.toSet + model1.getPredictionCol) + .subsetOf(model1.summary.predictions.schema.fieldNames.toSet)) + val model2 = new LinearRegression() .setFitIntercept(fitIntercept) .setWeightCol("weight") + .setPredictionCol("myPrediction") .setSolver(solver) .fit(datasetWithWeightZeroLabel) val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0), model2.coefficients(1)) assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4) + + // Schema of summary.predictions should be a superset of the input dataset + assert((datasetWithWeightZeroLabel.schema.fieldNames.toSet + model2.getPredictionCol) + .subsetOf(model2.summary.predictions.schema.fieldNames.toSet)) + idx += 1 } } @@ -672,7 +683,7 @@ class LinearRegressionSuite test("linear regression model training summary") { Seq("auto", "l-bfgs", "normal").foreach { solver => - val trainer = new LinearRegression().setSolver(solver) + val trainer = new LinearRegression().setSolver(solver).setPredictionCol("myPrediction") val model = trainer.fit(datasetWithDenseFeature) val trainerNoPredictionCol = trainer.setPredictionCol("") val modelNoPredictionCol = trainerNoPredictionCol.fit(datasetWithDenseFeature) @@ -682,7 +693,7 @@ class LinearRegressionSuite assert(modelNoPredictionCol.hasSummary) // Schema should be a superset of the input dataset - assert((datasetWithDenseFeature.schema.fieldNames.toSet + "prediction").subsetOf( + assert((datasetWithDenseFeature.schema.fieldNames.toSet + model.getPredictionCol).subsetOf( model.summary.predictions.schema.fieldNames.toSet)) // Validate that we re-insert a prediction column for evaluation val modelNoPredictionColFieldNames @@ -795,7 +806,7 @@ class LinearRegressionSuite Seq("auto", "l-bfgs", "normal").foreach { solver => val (data, weightedData) = { val activeData = LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1).map(_.asML) val rnd = new Random(8392) val signedData = activeData.map { case p: LabeledPoint => @@ -822,14 +833,14 @@ class LinearRegressionSuite } val noiseData = LinearDataGenerator.generateLinearInput( - 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1) + 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1).map(_.asML) val weightedNoiseData = noiseData.map { case LabeledPoint(label, features) => Instance(label, weight = 0, features) } val data2 = weightedSignedData ++ weightedNoiseData - (sqlContext.createDataFrame(sc.parallelize(data1, 4)), - sqlContext.createDataFrame(sc.parallelize(data2, 4))) + (spark.createDataFrame(sc.parallelize(data1, 4)), + spark.createDataFrame(sc.parallelize(data2, 4))) } val trainer1a = (new LinearRegression).setFitIntercept(true) @@ -1008,12 +1019,14 @@ class LinearRegressionSuite } test("should support all NumericType labels and not support other types") { - val lr = new LinearRegression().setMaxIter(1) - MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( - lr, isClassification = false, sqlContext) { (expected, actual) => + for (solver <- Seq("auto", "l-bfgs", "normal")) { + val lr = new LinearRegression().setMaxIter(1).setSolver(solver) + MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( + lr, spark, isClassification = false) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index ca400e1914518..c08335f9f84af 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -40,7 +41,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex override def beforeAll() { super.beforeAll() orderedLabeledPoints50_1000 = - sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + .map(_.asML)) } ///////////////////////////////////////////////////////////////////////////// @@ -98,7 +100,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex test("should support all NumericType labels and not support other types") { val rf = new RandomForestRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor]( - rf, isClassification = false, sqlContext) { (expected, actual) => + rf, spark, isClassification = false) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } @@ -139,8 +141,8 @@ private object RandomForestRegressorSuite extends SparkFunSuite { val numFeatures = data.first().features.size val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity) - val oldModel = OldRandomForest.trainRegressor( - data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy, + rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = rf.fit(newData) // Use parent from newTree since this is not checked anyways. diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index e52fbd74a7b41..2517de59fed63 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import com.google.common.io.Files import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.util.Utils @@ -56,7 +56,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select as sparse vector") { - val df = sqlContext.read.format("libsvm").load(path) + val df = spark.read.format("libsvm").load(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") val row1 = df.first() @@ -66,7 +66,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select as dense vector") { - val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense")) + val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense")) .load(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") @@ -78,7 +78,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select a vector with specifying the longer dimension") { - val df = sqlContext.read.option("numFeatures", "100").format("libsvm") + val df = spark.read.option("numFeatures", "100").format("libsvm") .load(path) val row1 = df.first() val v = row1.getAs[SparseVector](1) @@ -86,27 +86,28 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("write libsvm data and read it again") { - val df = sqlContext.read.format("libsvm").load(path) + val df = spark.read.format("libsvm").load(path) val tempDir2 = new File(tempDir, "read_write_test") val writepath = tempDir2.toURI.toString // TODO: Remove requirement to coalesce by supporting multiple reads. df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) - val df2 = sqlContext.read.format("libsvm").load(writepath) + val df2 = spark.read.format("libsvm").load(writepath) val row1 = df2.first() val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } test("write libsvm data failed due to invalid schema") { - val df = sqlContext.read.format("text").load(path) + val df = spark.read.format("text").load(path) intercept[SparkException] { df.write.format("libsvm").save(path + "_2") } } test("select features from libsvm relation") { - val df = sqlContext.read.format("libsvm").load(path) + val df = spark.read.format("libsvm").load(path) df.select("features").rdd.map { case Row(d: Vector) => d }.first + df.select("features").collect } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index fecf372c3d843..5c50a88c8314a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.mllib.tree.{GradientBoostedTreesSuite => OldGBTSuite} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ @@ -35,10 +35,10 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext test("runWithValidation stops early and performs better on a validation dataset") { // Set numIterations large enough so that it stops early. val numIterations = 20 - val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2) - val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2) - val trainDF = sqlContext.createDataFrame(trainRdd) - val validateDF = sqlContext.createDataFrame(validateRdd) + val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2).map(_.asML) + val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2).map(_.asML) + val trainDF = spark.createDataFrame(trainRdd) + val validateDF = spark.createDataFrame(validateRdd) val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 9739e6c05dcbc..dcc2f305df75a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -21,14 +21,14 @@ import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.tree._ -import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.collection.OpenHashMap /** @@ -43,7 +43,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ///////////////////////////////////////////////////////////////////////////// test("Binary classification with continuous features: split calculation") { - val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1() + val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML) assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100) @@ -55,7 +55,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Binary classification with binary (ordered) categorical features: split calculation") { - val arr = OldDTSuite.generateCategoricalDataPoints() + val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, @@ -72,7 +72,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification with 3-ary (ordered) categorical features," + " with no samples for one category: split calculation") { - val arr = OldDTSuite.generateCategoricalDataPoints() + val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, @@ -148,7 +148,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Multiclass classification with unordered categorical features: split calculations") { - val arr = OldDTSuite.generateCategoricalDataPoints() + val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy( @@ -189,7 +189,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Multiclass classification with ordered categorical features: split calculations") { - val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML) assert(arr.length === 3000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100, @@ -334,7 +334,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Second level node building with vs. without groups") { - val arr = OldDTSuite.generateOrderedLabeledPoints() + val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML) assert(arr.length === 1000) val rdd = sc.parallelize(arr) // For tree with 1 group @@ -378,7 +378,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) { val numFeatures = 50 val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) - val rdd = sc.parallelize(arr) + val rdd = sc.parallelize(arr).map(_.asML) // Select feature subset for top nodes. Return true if OK. def checkFeatureSubsetStrategy( diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index e3f09899d7693..d2fa8d0d6335d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -22,11 +22,11 @@ import scala.collection.JavaConverters._ import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.tree._ -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} private[ml] object TreeTests extends SparkFunSuite { @@ -42,8 +42,13 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val sqlContext = SQLContext.getOrCreate(data.sparkContext) - import sqlContext.implicits._ + val spark = SparkSession.builder() + .master("local[2]") + .appName("TreeTests") + .sparkContext(data.sparkContext) + .getOrCreate() + import spark.implicits._ + val df = data.toDF() val numFeatures = data.first().features.size val featuresAttributes = Range(0, numFeatures).map { feature => diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 3e734aabc5544..30bd390381e97 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -20,17 +20,17 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model, Pipeline} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.types.StructType class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -39,7 +39,7 @@ class CrossValidatorSuite override def beforeAll(): Unit = { super.beforeAll() - dataset = sqlContext.createDataFrame( + dataset = spark.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } @@ -67,9 +67,9 @@ class CrossValidatorSuite } test("cross validation with linear regression") { - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML)) val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() @@ -136,6 +136,7 @@ class CrossValidatorSuite assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] @@ -186,6 +187,7 @@ class CrossValidatorSuite assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) @@ -259,6 +261,7 @@ class CrossValidatorSuite assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index dbee47c8475d7..c1e9c2fc1dc11 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -20,21 +20,21 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("train validation with logistic regression") { - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) val lr = new LogisticRegression @@ -58,9 +58,9 @@ class TrainValidationSplitSuite } test("train validation with linear regression") { - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML)) val trainer = new LinearRegression().setSolver("l-bfgs") val lrParamMaps = new ParamGridBuilder() @@ -127,6 +127,7 @@ class TrainValidationSplitSuite val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) + assert(tvs.getSeed === tvs2.getSeed) } test("read/write: TrainValidationSplitModel") { @@ -149,6 +150,7 @@ class TrainValidationSplitSuite assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.validationMetrics === tvs2.validationMetrics) + assert(tvs.getSeed === tvs2.getSeed) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index d9e6fd5aae677..80b976914cbdf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -20,10 +20,11 @@ package org.apache.spark.ml.util import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -37,18 +38,18 @@ object MLTestingUtils extends SparkFunSuite { def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( estimator: T, - isClassification: Boolean, - sqlContext: SQLContext)(check: (M, M) => Unit): Unit = { + spark: SparkSession, + isClassification: Boolean = true)(check: (M, M) => Unit): Unit = { val dfs = if (isClassification) { - genClassifDFWithNumericLabelCol(sqlContext) + genClassifDFWithNumericLabelCol(spark) } else { - genRegressionDFWithNumericLabelCol(sqlContext) + genRegressionDFWithNumericLabelCol(spark) } val expected = estimator.fit(dfs(DoubleType)) val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) actuals.foreach(actual => check(expected, actual)) - val dfWithStringLabels = sqlContext.createDataFrame(Seq( + val dfWithStringLabels = spark.createDataFrame(Seq( ("0", Vectors.dense(0, 2, 3), 0.0) )).toDF("label", "features", "censor") val thrown = intercept[IllegalArgumentException] { @@ -58,13 +59,37 @@ object MLTestingUtils extends SparkFunSuite { "Column label must be of type NumericType but was actually of type StringType")) } - def checkNumericTypes[T <: Evaluator](evaluator: T, sqlContext: SQLContext): Unit = { - val dfs = genEvaluatorDFWithNumericLabelCol(sqlContext, "label", "prediction") + def checkNumericTypesALS( + estimator: ALS, + spark: SparkSession, + column: String, + baseType: NumericType) + (check: (ALSModel, ALSModel) => Unit) + (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = { + val dfs = genRatingsDFWithNumericCols(spark, column) + val expected = estimator.fit(dfs(baseType)) + val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) + actuals.foreach { case (_, actual) => check(expected, actual) } + actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } + + val baseDF = dfs(baseType) + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_)) + val cols = Seq(col(column).cast(StringType)) ++ others + val strDF = baseDF.select(cols: _*) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(strDF) + } + assert(thrown.getMessage.contains( + s"$column must be of type NumericType but was actually of type StringType")) + } + + def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { + val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) val actuals = dfs.keys.filter(_ != DoubleType).map(t => evaluator.evaluate(dfs(t))) actuals.foreach(actual => assert(expected === actual)) - val dfWithStringLabels = sqlContext.createDataFrame(Seq( + val dfWithStringLabels = spark.createDataFrame(Seq( ("0", 0d) )).toDF("label", "prediction") val thrown = intercept[IllegalArgumentException] { @@ -75,10 +100,10 @@ object MLTestingUtils extends SparkFunSuite { } def genClassifDFWithNumericLabelCol( - sqlContext: SQLContext, + spark: SparkSession, labelColName: String = "label", featuresColName: String = "features"): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, Vectors.dense(0, 2, 3)), (1, Vectors.dense(0, 3, 1)), (0, Vectors.dense(0, 2, 2)), @@ -95,11 +120,11 @@ object MLTestingUtils extends SparkFunSuite { } def genRegressionDFWithNumericLabelCol( - sqlContext: SQLContext, + spark: SparkSession, labelColName: String = "label", featuresColName: String = "features", censorColName: String = "censor"): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, Vectors.dense(0)), (1, Vectors.dense(1)), (2, Vectors.dense(2)), @@ -116,11 +141,31 @@ object MLTestingUtils extends SparkFunSuite { }.toMap } + def genRatingsDFWithNumericCols( + spark: SparkSession, + column: String): Map[NumericType, DataFrame] = { + val df = spark.createDataFrame(Seq( + (0, 10, 1.0), + (1, 20, 2.0), + (2, 30, 3.0), + (3, 40, 4.0), + (4, 50, 5.0) + )).toDF("user", "item", "rating") + + val others = df.columns.toSeq.diff(Seq(column)).map(col(_)) + val types: Seq[NumericType] = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types.map { t => + val cols = Seq(col(column).cast(t)) ++ others + t -> df.select(cols: _*) + }.toMap + } + def genEvaluatorDFWithNumericLabelCol( - sqlContext: SQLContext, + spark: SparkSession, labelColName: String = "label", predictionColName: String = "prediction"): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, 0d), (1, 1d), (2, 2d), diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala index 9e6bc7193c13b..141249a427a4c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala @@ -60,9 +60,9 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { test("DistributedStopwatch on executors") { val sw = new DistributedStopwatch(sc, "sw") val rdd = sc.parallelize(0 until 4, 4) - val acc = sc.accumulator(0L) + val acc = sc.longAccumulator rdd.foreach { i => - acc += checkStopwatch(sw) + acc.add(checkStopwatch(sw)) } assert(!sw.isRunning) val elapsed = sw.elapsed() @@ -88,12 +88,12 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext { assert(sw.toString === s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}") val rdd = sc.parallelize(0 until 4, 4) - val acc = sc.accumulator(0L) + val acc = sc.longAccumulator rdd.foreach { i => sw("local").start() val duration = checkStopwatch(sw("spark")) sw("local").stop() - acc += duration + acc.add(duration) } val localElapsed2 = sw("local").elapsed() assert(localElapsed2 === localElapsed) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 28fada7053d65..5cf4377768516 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -411,10 +411,10 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val testRDD1 = sc.parallelize(testData, 2) val testRDD2 = sc.parallelize( - testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E3))), 2) + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.asBreeze * 1.0E3))), 2) val testRDD3 = sc.parallelize( - testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E6))), 2) + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.asBreeze * 1.0E6))), 2) testRDD1.cache() testRDD2.cache() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index ab54cb06d5aab..0c0aefc52b9bf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -182,7 +182,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val piVector = new BDV(model.pi) // model.theta is row-major; treat it as col-major representation of transpose, and transpose: val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t - val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze) + val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.asBreeze) val classProbs = logClassProbs.toArray.map(math.exp) val classProbsSum = classProbs.sum classProbs.map(_ / classProbsSum) @@ -234,7 +234,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t - val testBreeze = testData.toBreeze + val testBreeze = testData.asBreeze val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze) val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index ea23196d2c801..eb050158d48fe 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -116,7 +116,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { case (docId, (topicDistribution, (indices, weights))) => assert(indices.length == 2) assert(weights.length == 2) - val bdvTopicDist = topicDistribution.toBreeze + val bdvTopicDist = topicDistribution.asBreeze val top2Indices = argtopk(bdvTopicDist, 2) assert(top2Indices.toArray === indices) assert(bdvTopicDist(top2Indices).toArray === weights) @@ -369,7 +369,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val actualPredictions = ldaModel.topicDistributions(docs).cache() val topTopics = actualPredictions.map { case (id, topics) => // convert results to expectedPredictions format, which only has highest probability topic - val topicsBz = topics.toBreeze.toDenseVector + val topicsBz = topics.asBreeze.toDenseVector (id, (argmax(topicsBz), max(topicsBz))) }.sortByKey() .values diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 65e37c64d404e..fdaa098345d13 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -67,7 +67,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { // estimated center from streaming should exactly match the arithmetic mean of all data points // because the decay factor is set to 1.0 val grandMean = - input.flatten.map(x => x.toBreeze).reduce(_ + _) / (numBatches * numPoints).toDouble + input.flatten.map(x => x.asBreeze).reduce(_ + _) / (numBatches * numPoints).toDouble assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index d55bc8c3ec09f..f316c67234f18 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -69,11 +69,12 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) - assert(math.abs(metrics.recall - + assert(math.abs(metrics.accuracy - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) - assert(math.abs(metrics.recall - metrics.precision) < delta) - assert(math.abs(metrics.recall - metrics.fMeasure) < delta) - assert(math.abs(metrics.recall - metrics.weightedRecall) < delta) + assert(math.abs(metrics.accuracy - metrics.precision) < delta) + assert(math.abs(metrics.accuracy - metrics.recall) < delta) + assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta) + assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) assert(math.abs(metrics.weightedFalsePositiveRate - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) assert(math.abs(metrics.weightedPrecision - diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 34122d6ed2e95..10f7bafd6cf5b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -51,10 +51,10 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(brzNorm(data1(0).toBreeze, 1) ~== 1.0 absTol 1E-5) - assert(brzNorm(data1(2).toBreeze, 1) ~== 1.0 absTol 1E-5) - assert(brzNorm(data1(3).toBreeze, 1) ~== 1.0 absTol 1E-5) - assert(brzNorm(data1(4).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(0).asBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(2).asBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(3).asBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(4).asBreeze, 1) ~== 1.0 absTol 1E-5) assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) @@ -78,10 +78,10 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(brzNorm(data2(0).toBreeze, 2) ~== 1.0 absTol 1E-5) - assert(brzNorm(data2(2).toBreeze, 2) ~== 1.0 absTol 1E-5) - assert(brzNorm(data2(3).toBreeze, 2) ~== 1.0 absTol 1E-5) - assert(brzNorm(data2(4).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(0).asBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(2).asBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(3).asBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(4).asBreeze, 2) ~== 1.0 absTol 1E-5) assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 6d699440f2f2e..f4fa216b8eba0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils @@ -68,6 +69,21 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(syms(1)._1 == "japan") } + test("findSynonyms doesn't reject similar word vectors when called with a vector") { + val num = 2 + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + val syms = model.findSynonyms(Vectors.dense(Array(0.52, 0.5, 0.5, 0.5)), num) + assert(syms.length == num) + assert(syms(0)._1 == "china") + assert(syms(1)._1 == "taiwan") + } + test("model load / save") { val word2VecMap = Map( @@ -92,10 +108,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } test("big model load / save") { - // create a model bigger than 32MB since 9000 * 1000 * 4 > 2^25 - val word2VecMap = Map((0 to 9000).map(i => s"$i" -> Array.fill(1000)(0.1f)): _*) + // backupping old values + val oldBufferConfValue = spark.conf.get("spark.kryoserializer.buffer.max", "64m") + val oldBufferMaxConfValue = spark.conf.get("spark.kryoserializer.buffer", "64k") + + // setting test values to trigger partitioning + spark.conf.set("spark.kryoserializer.buffer", "50b") + spark.conf.set("spark.kryoserializer.buffer.max", "50b") + + // create a model bigger than 50 Bytes + val word2VecMap = Map((0 to 10).map(i => s"$i" -> Array.fill(10)(0.1f)): _*) val model = new Word2VecModel(word2VecMap) + // est. size of this model, given the formula: + // (floatSize * vectorSize + 15) * numWords + // (4 * 10 + 15) * 10 = 550 + // therefore it should generate multiple partitions val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString @@ -103,9 +131,16 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { model.save(sc, path) val sameModel = Word2VecModel.load(sc, path) assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } + catch { + case t: Throwable => fail("exception thrown persisting a model " + + "that spans over multiple partitions", t) } finally { Utils.deleteRecursively(tempDir) + spark.conf.set("spark.kryoserializer.buffer", oldBufferConfValue) + spark.conf.set("spark.kryoserializer.buffer.max", oldBufferMaxConfValue) } + } test("test similarity for word vectors with large values is not Infinity or NaN") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 6d8c7b47d8373..4c2376376dd2a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.mllib.fpm +import scala.language.existentials + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index e331c75989187..a13e7f63a9296 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.impl -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} @@ -140,9 +140,11 @@ private object PeriodicGraphCheckpointerSuite { // Instead, we check for the presence of the checkpoint files. // This test should continue to work even after this graph.isCheckpointed issue // is fixed (though it can then be simplified and not look for the files). - val fs = FileSystem.get(graph.vertices.sparkContext.hadoopConfiguration) + val hadoopConf = graph.vertices.sparkContext.hadoopConfiguration graph.getCheckpointFiles.foreach { checkpointFile => - assert(!fs.exists(new Path(checkpointFile)), + val path = new Path(checkpointFile) + val fs = path.getFileSystem(hadoopConf) + assert(!fs.exists(path), "Graph checkpoint file should have been removed") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala index b2a459a68b5fa..14adf8c29fc6b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.impl -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -127,9 +127,11 @@ private object PeriodicRDDCheckpointerSuite { // Instead, we check for the presence of the checkpoint files. // This test should continue to work even after this rdd.isCheckpointed issue // is fixed (though it can then be simplified and not look for the files). - val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration) + val hadoopConf = rdd.sparkContext.hadoopConfiguration rdd.getCheckpointFile.foreach { checkpointFile => - assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed") + val path = new Path(checkpointFile) + val fs = path.getFileSystem(hadoopConf) + assert(!fs.exists(path), "RDD checkpoint file should have been removed") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 80da03cc2efeb..6e68c1c9d36c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -392,6 +392,23 @@ class BLASSuite extends SparkFunSuite { } } + val y17 = new DenseVector(Array(0.0, 0.0)) + val y18 = y17.copy + + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) + .transpose + val sA4 = + new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0)) + val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) + + val expected4 = new DenseVector(Array(5.0, 4.0)) + + gemv(1.0, sA3, sx3, 0.0, y17) + gemv(1.0, sA4, sx3, 0.0, y18) + + assert(y17 ~== expected4 absTol 1e-15) + assert(y18 ~== expected4 absTol 1e-15) + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index de2c3c13bd923..9e4735afdd59f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.SparkFunSuite class BreezeMatrixConversionSuite extends SparkFunSuite { test("dense matrix to breeze") { val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) - val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] + val breeze = mat.asBreeze.asInstanceOf[BDM[Double]] assert(breeze.rows === mat.numRows) assert(breeze.cols === mat.numCols) assert(breeze.data.eq(mat.asInstanceOf[DenseMatrix].values), "should not copy data") @@ -48,7 +48,7 @@ class BreezeMatrixConversionSuite extends SparkFunSuite { val colPtrs = Array(0, 2, 4) val rowIndices = Array(1, 2, 1, 2) val mat = Matrices.sparse(3, 2, colPtrs, rowIndices, values) - val breeze = mat.toBreeze.asInstanceOf[BSM[Double]] + val breeze = mat.asBreeze.asInstanceOf[BSM[Double]] assert(breeze.rows === mat.numRows) assert(breeze.cols === mat.numCols) assert(breeze.data.eq(mat.asInstanceOf[SparseMatrix].values), "should not copy data") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala index 3772c9235ad3a..996f621f18c80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala @@ -33,12 +33,12 @@ class BreezeVectorConversionSuite extends SparkFunSuite { test("dense to breeze") { val vec = Vectors.dense(arr) - assert(vec.toBreeze === new BDV[Double](arr)) + assert(vec.asBreeze === new BDV[Double](arr)) } test("sparse to breeze") { val vec = Vectors.sparse(n, indices, values) - assert(vec.toBreeze === new BSV[Double](indices, values, n)) + assert(vec.asBreeze === new BSV[Double](indices, values, n)) } test("dense breeze to vector") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index b7df02e6c098a..d0c4dd28e14ee 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -63,7 +63,7 @@ class MatricesSuite extends SparkFunSuite { (1, 2, 2.0), (2, 2, 2.0), (1, 2, 2.0), (0, 0, 0.0)) val mat2 = SparseMatrix.fromCOO(m, n, entries) - assert(mat.toBreeze === mat2.toBreeze) + assert(mat.asBreeze === mat2.asBreeze) assert(mat2.values.length == 4) } @@ -176,8 +176,8 @@ class MatricesSuite extends SparkFunSuite { val spMat2 = deMat1.toSparse val deMat2 = spMat1.toDense - assert(spMat1.toBreeze === spMat2.toBreeze) - assert(deMat1.toBreeze === deMat2.toBreeze) + assert(spMat1.asBreeze === spMat2.asBreeze) + assert(deMat1.asBreeze === deMat2.asBreeze) } test("map, update") { @@ -211,8 +211,8 @@ class MatricesSuite extends SparkFunSuite { val sATexpected = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT.toBreeze === dATexpected.toBreeze) - assert(sAT.toBreeze === sATexpected.toBreeze) + assert(dAT.asBreeze === dATexpected.asBreeze) + assert(sAT.asBreeze === sATexpected.asBreeze) assert(dA(1, 0) === dAT(0, 1)) assert(dA(2, 1) === dAT(1, 2)) assert(sA(1, 0) === sAT(0, 1)) @@ -221,8 +221,8 @@ class MatricesSuite extends SparkFunSuite { assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") - assert(dAT.toSparse.toBreeze === sATexpected.toBreeze) - assert(sAT.toDense.toBreeze === dATexpected.toBreeze) + assert(dAT.toSparse.asBreeze === sATexpected.asBreeze) + assert(sAT.toDense.asBreeze === dATexpected.asBreeze) } test("foreachActive") { @@ -560,4 +560,55 @@ class MatricesSuite extends SparkFunSuite { compare(oldSM0, newSM0) compare(oldDM0, newDM0) } + + test("implicit conversions between new local linalg and mllib linalg") { + + def mllibMatrixToTriple(m: Matrix): (Array[Double], Int, Int) = + (m.toArray, m.numCols, m.numRows) + + def mllibDenseMatrixToTriple(m: DenseMatrix): (Array[Double], Int, Int) = + (m.toArray, m.numCols, m.numRows) + + def mllibSparseMatrixToTriple(m: SparseMatrix): (Array[Double], Int, Int) = + (m.toArray, m.numCols, m.numRows) + + def mlMatrixToTriple(m: newlinalg.Matrix): (Array[Double], Int, Int) = + (m.toArray, m.numCols, m.numRows) + + def mlDenseMatrixToTriple(m: newlinalg.DenseMatrix): (Array[Double], Int, Int) = + (m.toArray, m.numCols, m.numRows) + + def mlSparseMatrixToTriple(m: newlinalg.SparseMatrix): (Array[Double], Int, Int) = + (m.toArray, m.numCols, m.numRows) + + def compare(m1: (Array[Double], Int, Int), m2: (Array[Double], Int, Int)): Unit = { + assert(m1._1 === m2._1) + assert(m1._2 === m2._2) + assert(m1._3 === m2._3) + } + + val dm: DenseMatrix = new DenseMatrix(3, 2, Array(0.0, 0.0, 1.0, 0.0, 2.0, 3.5)) + val sm: SparseMatrix = dm.toSparse + val sm0: Matrix = sm.asInstanceOf[Matrix] + val dm0: Matrix = dm.asInstanceOf[Matrix] + + val newSM: newlinalg.SparseMatrix = sm.asML + val newDM: newlinalg.DenseMatrix = dm.asML + val newSM0: newlinalg.Matrix = sm0.asML + val newDM0: newlinalg.Matrix = dm0.asML + + import org.apache.spark.mllib.linalg.MatrixImplicits._ + + compare(mllibMatrixToTriple(dm0), mllibMatrixToTriple(newDM0)) + compare(mllibMatrixToTriple(sm0), mllibMatrixToTriple(newSM0)) + + compare(mllibDenseMatrixToTriple(dm), mllibDenseMatrixToTriple(newDM)) + compare(mllibSparseMatrixToTriple(sm), mllibSparseMatrixToTriple(newSM)) + + compare(mlMatrixToTriple(dm0), mlMatrixToTriple(newDM)) + compare(mlMatrixToTriple(sm0), mlMatrixToTriple(newSM0)) + + compare(mlDenseMatrixToTriple(dm), mlDenseMatrixToTriple(newDM)) + compare(mlSparseMatrixToTriple(sm), mlSparseMatrixToTriple(newSM)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index be7110ad6bbf0..8b439e6b7a017 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -29,7 +29,7 @@ object UDTSerializationBenchmark { val iters = 1e2.toInt val numRows = 1e3.toInt - val encoder = ExpressionEncoder[Vector].defaultBinding + val encoder = ExpressionEncoder[Vector].resolveAndBind() val vectors = (1 to numRows).map { i => Vectors.dense(Array.fill(1e5.toInt)(1.0 * i)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index a7c1a076044ee..71a3ceac1b947 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -269,7 +269,7 @@ class VectorsSuite extends SparkFunSuite with Logging { val denseVector1 = Vectors.dense(sparseVector1.toArray) val denseVector2 = Vectors.dense(sparseVector2.toArray) - val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze) + val squaredDist = breezeSquaredDistance(sparseVector1.asBreeze, sparseVector2.asBreeze) // SparseVector vs. SparseVector assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) @@ -422,4 +422,43 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(oldSV0.toArray === newSV0.toArray) assert(oldDV0.toArray === newDV0.toArray) } + + test("implicit conversions between new local linalg and mllib linalg") { + + def mllibVectorToArray(v: Vector): Array[Double] = v.toArray + + def mllibDenseVectorToArray(v: DenseVector): Array[Double] = v.toArray + + def mllibSparseVectorToArray(v: SparseVector): Array[Double] = v.toArray + + def mlVectorToArray(v: newlinalg.Vector): Array[Double] = v.toArray + + def mlDenseVectorToArray(v: newlinalg.DenseVector): Array[Double] = v.toArray + + def mlSparseVectorToArray(v: newlinalg.SparseVector): Array[Double] = v.toArray + + val dv: DenseVector = new DenseVector(Array(1.0, 2.0, 3.5)) + val sv: SparseVector = new SparseVector(5, Array(1, 2, 4), Array(1.1, 2.2, 4.4)) + val sv0: Vector = sv.asInstanceOf[Vector] + val dv0: Vector = dv.asInstanceOf[Vector] + + val newSV: newlinalg.SparseVector = sv.asML + val newDV: newlinalg.DenseVector = dv.asML + val newSV0: newlinalg.Vector = sv0.asML + val newDV0: newlinalg.Vector = dv0.asML + + import org.apache.spark.mllib.linalg.VectorImplicits._ + + assert(mllibVectorToArray(dv0) === mllibVectorToArray(newDV0)) + assert(mllibVectorToArray(sv0) === mllibVectorToArray(newSV0)) + + assert(mllibDenseVectorToArray(dv) === mllibDenseVectorToArray(newDV)) + assert(mllibSparseVectorToArray(sv) === mllibSparseVectorToArray(newSV)) + + assert(mlVectorToArray(dv0) === mlVectorToArray(newDV0)) + assert(mlVectorToArray(sv0) === mlVectorToArray(newSV0)) + + assert(mlDenseVectorToArray(dv) === mlDenseVectorToArray(newDV)) + assert(mlSparseVectorToArray(sv) === mlSparseVectorToArray(newSV)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index f37eaf225ab88..61266f3c78dbc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -135,6 +135,11 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(rowMat.numCols() === n) assert(rowMat.toBreeze() === gridBasedMat.toBreeze()) + // SPARK-15922: BlockMatrix to IndexedRowMatrix throws an error" + val bmat = rowMat.toBlockMatrix + val imat = bmat.toIndexedRowMatrix + imat.rows.collect + val rows = 1 val cols = 10 @@ -152,7 +157,7 @@ class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val C = B.toIndexedRowMatrix.rows.collect - (C(0).vector.toBreeze, C(1).vector.toBreeze) match { + (C(0).vector.asBreeze, C(1).vector.asBreeze) match { case (denseVector: BDV[Double], sparseVector: BSV[Double]) => assert(denseVector.length === sparseVector.length) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 5b7ccb90158b0..99af5fa10d999 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -108,7 +108,7 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val C = A.multiply(B) val localA = A.toBreeze() val localC = C.toBreeze() - val expected = localA * B.toBreeze.asInstanceOf[BDM[Double]] + val expected = localA * B.asBreeze.asInstanceOf[BDM[Double]] assert(localC === expected) } @@ -119,7 +119,7 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { (90.0, 12.0, 24.0), (12.0, 17.0, 22.0), (24.0, 22.0, 30.0)) - assert(G.toBreeze === expected) + assert(G.asBreeze === expected) } test("svd") { @@ -128,8 +128,8 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(svd.U.isInstanceOf[IndexedRowMatrix]) val localA = A.toBreeze() val U = svd.U.toBreeze() - val s = svd.s.toBreeze.asInstanceOf[BDV[Double]] - val V = svd.V.toBreeze.asInstanceOf[BDM[Double]] + val s = svd.s.asBreeze.asInstanceOf[BDV[Double]] + val V = svd.V.asBreeze.asInstanceOf[BDM[Double]] assert(closeToZero(U.t * U - BDM.eye[Double](n))) assert(closeToZero(V.t * V - BDM.eye[Double](n))) assert(closeToZero(U * brzDiag(s) * V.t - localA)) @@ -155,7 +155,7 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { test("similar columns") { val A = new IndexedRowMatrix(indexedRows) - val gram = A.computeGramianMatrix().toBreeze.toDenseMatrix + val gram = A.computeGramianMatrix().asBreeze.toDenseMatrix val G = A.columnSimilarities().toBreeze() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 2dff52c601d81..7c9e14f8cee70 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -96,7 +97,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { Matrices.dense(n, n, Array(126.0, 54.0, 72.0, 54.0, 66.0, 78.0, 72.0, 78.0, 94.0)) for (mat <- Seq(denseMat, sparseMat)) { val G = mat.computeGramianMatrix() - assert(G.toBreeze === expected.toBreeze) + assert(G.asBreeze === expected.asBreeze) } } @@ -153,8 +154,8 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(V.numRows === n) assert(V.numCols === k) assertColumnEqualUpToSign(U.toBreeze(), localU, k) - assertColumnEqualUpToSign(V.toBreeze.asInstanceOf[BDM[Double]], localV, k) - assert(closeToZero(s.toBreeze.asInstanceOf[BDV[Double]] - localSigma(0 until k))) + assertColumnEqualUpToSign(V.asBreeze.asInstanceOf[BDM[Double]], localV, k) + assert(closeToZero(s.asBreeze.asInstanceOf[BDV[Double]] - localSigma(0 until k))) } } val svdWithoutU = mat.computeSVD(1, computeU = false, 1e-9, 300, 1e-10, mode) @@ -207,7 +208,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val (pc, expVariance) = mat.computePrincipalComponentsAndExplainedVariance(k) assert(pc.numRows === n) assert(pc.numCols === k) - assertColumnEqualUpToSign(pc.toBreeze.asInstanceOf[BDM[Double]], principalComponents, k) + assertColumnEqualUpToSign(pc.asBreeze.asInstanceOf[BDM[Double]], principalComponents, k) assert( closeToZero(BDV(expVariance.toArray) - BDV(Arrays.copyOfRange(explainedVariance.data, 0, k)))) @@ -256,12 +257,12 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val calcQ = result.Q val calcR = result.R assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze()))) - assert(closeToZero(abs(expected.r) - abs(calcR.toBreeze.asInstanceOf[BDM[Double]]))) + assert(closeToZero(abs(expected.r) - abs(calcR.asBreeze.asInstanceOf[BDM[Double]]))) assert(closeToZero(calcQ.multiply(calcR).toBreeze - mat.toBreeze())) // Decomposition without computing Q val rOnly = mat.tallSkinnyQR(computeQ = false) assert(rOnly.Q == null) - assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]]))) + assert(closeToZero(abs(expected.r) - abs(rOnly.R.asBreeze.asInstanceOf[BDM[Double]]))) } } @@ -269,7 +270,7 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { for (mat <- Seq(denseMat, sparseMat)) { val result = mat.computeCovariance() val expected = breeze.linalg.cov(mat.toBreeze()) - assert(closeToZero(abs(expected) - abs(result.toBreeze.asInstanceOf[BDM[Double]]))) + assert(closeToZero(abs(expected) - abs(result.asBreeze.asInstanceOf[BDM[Double]]))) } } @@ -281,6 +282,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(cov(i, j) === cov(j, i)) } } + + test("QR decomposition should aware of empty partition (SPARK-16369)") { + val mat: RowMatrix = new RowMatrix(sc.parallelize(denseData, 1)) + val qrResult = mat.tallSkinnyQR(true) + + val matWithEmptyPartition = new RowMatrix(sc.parallelize(denseData, 8)) + val qrResult2 = matWithEmptyPartition.tallSkinnyQR(true) + + assert(qrResult.Q.numCols() === qrResult2.Q.numCols(), "Q matrix ncol not match") + assert(qrResult.Q.numRows() === qrResult2.Q.numRows(), "Q matrix nrow not match") + qrResult.Q.rows.collect().zip(qrResult2.Q.rows.collect()) + .foreach(x => assert(x._1 ~== x._2 relTol 1E-8, "Q matrix not match")) + + qrResult.R.toArray.zip(qrResult2.R.toArray) + .foreach(x => assert(x._1 ~== x._2 relTol 1E-8, "R matrix not match")) + } } class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index f8d0af8820e64..252a068dcd72f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint} import org.apache.spark.mllib.linalg.Vectors class LabeledPointSuite extends SparkFunSuite { @@ -40,4 +41,16 @@ class LabeledPointSuite extends SparkFunSuite { val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) } + + test("conversions between new ml LabeledPoint and mllib LabeledPoint") { + val points: Seq[LabeledPoint] = Seq( + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0)))) + + val newPoints: Seq[NewLabeledPoint] = points.map(_.asML) + + points.zip(newPoints).foreach { case (p1, p2) => + assert(p1 === LabeledPoint.fromML(p2)) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index 700f803490c08..e32767edb17a8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -104,8 +104,8 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log (Double.NaN, Double.NaN, 1.00000000, Double.NaN), (0.40047142, 0.91359586, Double.NaN, 1.0000000)) // scalastyle:on - assert(matrixApproxEqual(defaultMat.toBreeze, expected)) - assert(matrixApproxEqual(pearsonMat.toBreeze, expected)) + assert(matrixApproxEqual(defaultMat.asBreeze, expected)) + assert(matrixApproxEqual(pearsonMat.asBreeze, expected)) } test("corr(X) spearman") { @@ -118,7 +118,7 @@ class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext with Log (Double.NaN, Double.NaN, 1.00000000, Double.NaN), (0.4000000, 0.9486833, Double.NaN, 1.0000000)) // scalastyle:on - assert(matrixApproxEqual(spearmanMat.toBreeze, expected)) + assert(matrixApproxEqual(spearmanMat.asBreeze, expected)) } test("method identification") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 49cb7e1f24e35..441d0f7614bf6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -73,7 +73,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -100,7 +100,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -116,7 +116,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -133,7 +133,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -150,7 +150,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -167,7 +167,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -183,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(strategy.isMulticlassClassification) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) @@ -240,7 +240,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 3, maxBins = maxBins, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) @@ -288,7 +288,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) @@ -310,7 +310,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 0c6aabf1926e9..6aa93c9076007 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -25,12 +25,14 @@ import scala.io.Source import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.io.Files -import org.apache.spark.SparkException -import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.MetadataBuilder import org.apache.spark.util.Utils class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -53,13 +55,13 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val norm2 = Vectors.norm(v2, 2.0) val v3 = Vectors.sparse(n, indices, indices.map(i => a(i) + 0.5)) val norm3 = Vectors.norm(v3, 2.0) - val squaredDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze) + val squaredDist = breezeSquaredDistance(v1.asBreeze, v2.asBreeze) val fastSquaredDist1 = fastSquaredDistance(v1, norm1, v2, norm2, precision) assert((fastSquaredDist1 - squaredDist) <= precision * squaredDist, s"failed with m = $m") val fastSquaredDist2 = fastSquaredDistance(v1, norm1, Vectors.dense(v2.toArray), norm2, precision) assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m") - val squaredDist2 = breezeSquaredDistance(v2.toBreeze, v3.toBreeze) + val squaredDist2 = breezeSquaredDistance(v2.asBreeze, v3.asBreeze) val fastSquaredDist3 = fastSquaredDistance(v2, norm2, v3, norm3, precision) assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m") @@ -67,7 +69,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val v4 = Vectors.sparse(n, indices.slice(0, m - 10), indices.map(i => a(i) + 0.5).slice(0, m - 10)) val norm4 = Vectors.norm(v4, 2.0) - val squaredDist = breezeSquaredDistance(v2.toBreeze, v4.toBreeze) + val squaredDist = breezeSquaredDistance(v2.asBreeze, v4.asBreeze) val fastSquaredDist = fastSquaredDistance(v2, norm2, v4, norm4, precision) assert((fastSquaredDist - squaredDist) <= precision * squaredDist, s"failed with m = $m") @@ -245,4 +247,112 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(log1pExp(-13.8) ~== math.log1p(math.exp(-13.8)) absTol 1E-10) assert(log1pExp(-238423789.865) ~== math.log1p(math.exp(-238423789.865)) absTol 1E-10) } + + test("convertVectorColumnsToML") { + val x = Vectors.sparse(2, Array(1), Array(1.0)) + val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build() + val y = Vectors.dense(2.0, 3.0) + val z = Vectors.dense(4.0) + val p = (5.0, z) + val w = Vectors.dense(6.0).asML + val df = spark.createDataFrame(Seq( + (0, x, y, p, w) + )).toDF("id", "x", "y", "p", "w") + .withColumn("x", col("x"), metadata) + val newDF1 = convertVectorColumnsToML(df) + assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") + val new1 = newDF1.first() + assert(new1 === Row(0, x.asML, y.asML, Row(5.0, z), w)) + val new2 = convertVectorColumnsToML(df, "x", "y").first() + assert(new2 === new1) + val new3 = convertVectorColumnsToML(df, "y", "w").first() + assert(new3 === Row(0, x, y.asML, Row(5.0, z), w)) + intercept[IllegalArgumentException] { + convertVectorColumnsToML(df, "p") + } + intercept[IllegalArgumentException] { + convertVectorColumnsToML(df, "p._2") + } + } + + test("convertVectorColumnsFromML") { + val x = Vectors.sparse(2, Array(1), Array(1.0)).asML + val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build() + val y = Vectors.dense(2.0, 3.0).asML + val z = Vectors.dense(4.0).asML + val p = (5.0, z) + val w = Vectors.dense(6.0) + val df = spark.createDataFrame(Seq( + (0, x, y, p, w) + )).toDF("id", "x", "y", "p", "w") + .withColumn("x", col("x"), metadata) + val newDF1 = convertVectorColumnsFromML(df) + assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") + val new1 = newDF1.first() + assert(new1 === Row(0, Vectors.fromML(x), Vectors.fromML(y), Row(5.0, z), w)) + val new2 = convertVectorColumnsFromML(df, "x", "y").first() + assert(new2 === new1) + val new3 = convertVectorColumnsFromML(df, "y", "w").first() + assert(new3 === Row(0, x, Vectors.fromML(y), Row(5.0, z), w)) + intercept[IllegalArgumentException] { + convertVectorColumnsFromML(df, "p") + } + intercept[IllegalArgumentException] { + convertVectorColumnsFromML(df, "p._2") + } + } + + test("convertMatrixColumnsToML") { + val x = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0)) + val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build() + val y = Matrices.dense(2, 1, Array(0.2, 1.3)) + val z = Matrices.ones(1, 1) + val p = (5.0, z) + val w = Matrices.dense(1, 1, Array(4.5)).asML + val df = spark.createDataFrame(Seq( + (0, x, y, p, w) + )).toDF("id", "x", "y", "p", "w") + .withColumn("x", col("x"), metadata) + val newDF1 = convertMatrixColumnsToML(df) + assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") + val new1 = newDF1.first() + assert(new1 === Row(0, x.asML, y.asML, Row(5.0, z), w)) + val new2 = convertMatrixColumnsToML(df, "x", "y").first() + assert(new2 === new1) + val new3 = convertMatrixColumnsToML(df, "y", "w").first() + assert(new3 === Row(0, x, y.asML, Row(5.0, z), w)) + intercept[IllegalArgumentException] { + convertMatrixColumnsToML(df, "p") + } + intercept[IllegalArgumentException] { + convertMatrixColumnsToML(df, "p._2") + } + } + + test("convertMatrixColumnsFromML") { + val x = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0)).asML + val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build() + val y = Matrices.dense(2, 1, Array(0.2, 1.3)).asML + val z = Matrices.ones(1, 1).asML + val p = (5.0, z) + val w = Matrices.dense(1, 1, Array(4.5)) + val df = spark.createDataFrame(Seq( + (0, x, y, p, w) + )).toDF("id", "x", "y", "p", "w") + .withColumn("x", col("x"), metadata) + val newDF1 = convertMatrixColumnsFromML(df) + assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.") + val new1 = newDF1.first() + assert(new1 === Row(0, Matrices.fromML(x), Matrices.fromML(y), Row(5.0, z), w)) + val new2 = convertMatrixColumnsFromML(df, "x", "y").first() + assert(new2 === new1) + val new3 = convertMatrixColumnsFromML(df, "y", "w").first() + assert(new3 === Row(0, x, Matrices.fromML(y), Row(5.0, z), w)) + intercept[IllegalArgumentException] { + convertMatrixColumnsFromML(df, "p") + } + intercept[IllegalArgumentException] { + convertMatrixColumnsFromML(df, "p._2") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 7f9e340f54b69..db56aff63102c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -21,25 +21,24 @@ import java.io.File import org.scalatest.Suite -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.ml.util.TempDirectory -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.util.Utils trait MLlibTestSparkContext extends TempDirectory { self: Suite => + @transient var spark: SparkSession = _ @transient var sc: SparkContext = _ - @transient var sqlContext: SQLContext = _ @transient var checkpointDir: String = _ override def beforeAll() { super.beforeAll() - val conf = new SparkConf() - .setMaster("local[2]") - .setAppName("MLlibUnitTest") - sc = new SparkContext(conf) - SQLContext.clearActive() - sqlContext = new SQLContext(sc) - SQLContext.setActive(sqlContext) + spark = SparkSession.builder + .master("local[2]") + .appName("MLlibUnitTest") + .getOrCreate() + sc = spark.sparkContext + checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString sc.setCheckpointDir(checkpointDir) } @@ -47,12 +46,11 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => override def afterAll() { try { Utils.deleteRecursively(new File(checkpointDir)) - sqlContext = null - SQLContext.clearActive() - if (sc != null) { - sc.stop() + SparkSession.clearActiveSession() + if (spark != null) { + spark.stop() } - sc = null + spark = null } finally { super.afterAll() } diff --git a/pom.xml b/pom.xml index 66f1d8ea902ac..d97f348f6d545 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -100,8 +100,6 @@ sql/catalyst sql/core sql/hive - sql/hivecontext-compatibility - external/docker-integration-tests assembly external/flume external/flume-sink @@ -109,8 +107,11 @@ examples repl launcher - external/kafka - external/kafka-assembly + external/kafka-0-8 + external/kafka-0-8-assembly + external/kafka-0-10 + external/kafka-0-10-assembly + external/kafka-0-10-sql @@ -134,11 +135,11 @@ 1.2.1.spark2 1.2.1 - 10.10.1.1 + 10.12.1.1 1.7.0 1.6.0 - 8.1.14.v20131031 - 3.0.0.v201112011016 + 9.2.16.v20160414 + 3.1.0 0.8.0 2.4.0 2.0.8 @@ -150,8 +151,8 @@ 0.10.2 - 4.3.2 - 4.3.2 + 4.5.2 + 4.4.4 3.1 3.4.1 @@ -159,31 +160,36 @@ 3.2.2 2.11.8 2.11 - ${scala.version} - org.scala-lang 1.9.13 - 2.5.3 - 1.1.2.4 + 2.6.5 + 1.1.2.6 1.1.2 1.2.0-incubating 1.10 + 2.4 2.6 3.3.2 3.2.10 2.7.8 - 1.9 + 2.22.2 2.9.3 3.5.2 1.3.9 0.9.2 - 4.5.2-1 + 4.5.3 1.1 + 2.52.0 + 2.8 + 1.8 ${java.home} + + true + org.spark_project @@ -285,9 +291,8 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} ${project.version} - test com.twitter @@ -329,6 +334,12 @@ ${jetty.version} provided + + org.eclipse.jetty + jetty-servlets + ${jetty.version} + provided + org.eclipse.jetty jetty-util @@ -360,7 +371,6 @@ provided - org.apache.commons commons-lang3 @@ -371,6 +381,11 @@ commons-lang ${commons-lang2.version} + + commons-io + commons-io + ${commons-io.version} + commons-codec commons-codec @@ -406,6 +421,11 @@ httpclient ${commons.httpclient.version} + + org.apache.httpcomponents + httpmime + ${commons.httpclient.version} + org.apache.httpcomponents httpcore @@ -414,7 +434,7 @@ org.seleniumhq.selenium selenium-java - 2.45.0 + ${selenium.version} test @@ -427,6 +447,12 @@ + + org.seleniumhq.selenium + selenium-htmlunit-driver + ${selenium.version} + test + xml-apis @@ -588,16 +614,44 @@ - com.sun.jersey + com.fasterxml.jackson.module + jackson-module-jaxb-annotations + ${fasterxml.jackson.version} + + + org.glassfish.jersey.core jersey-server ${jersey.version} - ${hadoop.deps.scope} - com.sun.jersey - jersey-core + org.glassfish.jersey.core + jersey-common ${jersey.version} - ${hadoop.deps.scope} + + + org.glassfish.jersey.core + jersey-client + ${jersey.version} + + + org.glassfish.jersey.containers + jersey-container-servlet + ${jersey.version} + + + org.glassfish.jersey.containers + jersey-container-servlet-core + ${jersey.version} + + + org.glassfish.jersey + jersey-client + ${jersey.version} + + + javax.ws.rs + javax.ws.rs-api + 2.0.1 org.scalanlp @@ -619,18 +673,7 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.10 - - - com.sun.jersey - jersey-json - ${jersey.version} - - - stax - stax-api - - + 3.2.11 org.scala-lang @@ -702,26 +745,13 @@ com.spotify docker-client - shaded - 3.6.6 + 5.0.2 test guava com.google.guava - - org.apache.httpcomponents - httpclient - - - org.apache.httpcomponents - httpcore - - - commons-logging - httpclient - commons-logging commons-logging @@ -814,6 +844,18 @@ junit junit + + com.sun.jersey + * + + + com.sun.jersey.jersey-test-framework + * + + + com.sun.jersey.contribs + * + @@ -926,6 +968,18 @@ commons-logging commons-logging + + com.sun.jersey + * + + + com.sun.jersey.jersey-test-framework + * + + + com.sun.jersey.contribs + * + @@ -954,6 +1008,18 @@ commons-logging commons-logging + + com.sun.jersey + * + + + com.sun.jersey.jersey-test-framework + * + + + com.sun.jersey.contribs + * + @@ -983,6 +1049,18 @@ commons-logging commons-logging + + com.sun.jersey + * + + + com.sun.jersey.jersey-test-framework + * + + + com.sun.jersey.contribs + * + @@ -1011,6 +1089,18 @@ commons-logging commons-logging + + com.sun.jersey + * + + + com.sun.jersey.jersey-test-framework + * + + + com.sun.jersey.contribs + * + @@ -1039,6 +1129,18 @@ commons-logging commons-logging + + com.sun.jersey + * + + + com.sun.jersey.jersey-test-framework + * + + + com.sun.jersey.contribs + * + @@ -1324,6 +1426,10 @@ org.codehaus.groovy groovy-all + + jline + jline + @@ -1355,14 +1461,6 @@ ${hive.group} hive-shims - - org.apache.httpcomponents - httpclient - - - org.apache.httpcomponents - httpcore - org.apache.curator curator-framework @@ -1714,14 +1812,6 @@ libthrift ${libthrift.version} - - org.apache.httpcomponents - httpclient - - - org.apache.httpcomponents - httpcore - org.slf4j slf4j-api @@ -1733,14 +1823,6 @@ libfb303 ${libthrift.version} - - org.apache.httpcomponents - httpclient - - - org.apache.httpcomponents - httpcore - org.slf4j slf4j-api @@ -1752,6 +1834,11 @@ antlr4-runtime ${antlr4.version} + + ${jline.groupid} + jline + ${jline.version} + @@ -1873,11 +1960,6 @@ - - org.antlr - antlr3-maven-plugin - 3.5.2 - org.antlr antlr4-maven-plugin @@ -1990,7 +2072,7 @@ org.apache.maven.plugins maven-antrun-plugin - 1.8 + ${maven-antrun.version} org.apache.maven.plugins @@ -2113,7 +2195,7 @@ org.apache.maven.plugins maven-antrun-plugin - [1.8,) + [${maven-antrun.version},) run @@ -2177,6 +2259,7 @@ org.eclipse.jetty:jetty-http org.eclipse.jetty:jetty-continuation org.eclipse.jetty:jetty-servlet + org.eclipse.jetty:jetty-servlets org.eclipse.jetty:jetty-plus org.eclipse.jetty:jetty-security org.eclipse.jetty:jetty-util @@ -2381,6 +2464,13 @@ + + docker-integration-tests + + external/docker-integration-tests + + + + + + scala-compile-first + + + -javabootclasspath + ${env.JAVA_7_HOME}/jre/lib/rt.jar + + + + + scala-test-compile-first + + + -javabootclasspath + ${env.JAVA_7_HOME}/jre/lib/rt.jar + + + + + + + + + + scala-2.11 @@ -2508,6 +2642,8 @@ 2.11.8 2.11 + 2.12.1 + jline diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 3dc1ceacde19a..77397eab81ede 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -88,8 +88,9 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.6.0" - val fullId = "spark-" + projectRef.project + "_2.11" + val previousSparkVersion = "2.0.0" + val project = projectRef.project + val fullId = "spark-" + project + "_2.11" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 33e0db606c2ff..ee6e31a0ea50f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -33,1456 +33,786 @@ import com.typesafe.tools.mima.core.ProblemFilters._ * For a new Spark version, please update MimaBuild.scala to reflect the previous version. */ object MimaExcludes { - def excludes(version: String) = version match { - case v if v.startsWith("2.0") => - Seq( - excludePackage("org.apache.spark.rpc"), - excludePackage("org.spark-project.jetty"), - excludePackage("org.apache.spark.unused"), - excludePackage("org.apache.spark.unsafe"), - excludePackage("org.apache.spark.util.collection.unsafe"), - excludePackage("org.apache.spark.sql.catalyst"), - excludePackage("org.apache.spark.sql.execution"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationAttemptInfo.$default$5"), - // SPARK-14042 Add custom coalescer support - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.coalesce"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rdd.PartitionCoalescer$LocationIterator"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.rdd.PartitionCoalescer"), - // SPARK-12600 Remove SQL deprecated methods - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.applySchema"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.parquetFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jdbc"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.dialectClassName"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.getSQLDialect"), - // SPARK-13664 Replace HadoopFsRelation with FileFormat - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.source.libsvm.LibSVMRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelationProvider"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelation$FileStatusCache") - ) ++ Seq( - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory"), - // SPARK-14358 SparkListener from trait to abstract class - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.addSparkListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.JavaSparkListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkFirehoseListener"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.scheduler.SparkListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.jobs.JobProgressListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.exec.ExecutorsListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.env.EnvironmentListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.storage.StorageListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.StorageStatusListener") - ) ++ - Seq( - // SPARK-3369 Fix Iterable/Iterator in Java API - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.function.FlatMapFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.function.FlatMapFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.function.FlatMapFunction2.call"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.function.FlatMapFunction2.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.function.PairFlatMapFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.function.PairFlatMapFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.function.CoGroupFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.function.CoGroupFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.function.MapPartitionsFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.function.MapPartitionsFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.function.FlatMapGroupsFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.function.FlatMapGroupsFunction.call") - ) ++ - Seq( - // [SPARK-6429] Implement hashCode and equals together - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Partition.org$apache$spark$Partition$$super=uals") - ) ++ - Seq( - // SPARK-4819 replace Guava Optional - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getSparkHome"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner") - ) ++ - Seq( - // SPARK-12481 Remove Hadoop 1.x - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil"), - // SPARK-12615 Remove deprecated APIs in core - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.$default$6"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.numericRDDToDoubleRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intToIntWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.writableWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToAsyncRDDActions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.boolToBoolWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longToLongWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToOrderedRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.booleanWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringToText"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleToDoubleWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToSequenceFileRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesToBytesWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatToFloatWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions$default$4"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addOnCompleteCallback"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.runningLocally"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.attemptId"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.defaultMinSplits"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.runJob"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.runJob"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.tachyonFolderName"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.initLocalProperties"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearJars"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearFiles"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith$default$2"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.toArray"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith$default$2"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.filterWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.foreachWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit$default$2"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.SequenceFileRDDFunctions.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.splits"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.toArray"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.externalBlockStoreFolderName"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockManager"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore") - ) ++ Seq( - // SPARK-12149 Added new fields to ExecutorSummary - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") - ) ++ - // SPARK-12665 Remove deprecated and unused classes - Seq( - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.graphx.GraphKryoRegistrator"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$") - ) ++ Seq( - // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge") - ) ++ Seq( - // SPARK-12510 Refactor ActorReceiver to support Java - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") - ) ++ Seq( - // SPARK-12895 Implement TaskMetrics using accumulators - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.internalMetricsToAccumulators"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectInternalAccumulators"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectAccumulators") - ) ++ Seq( - // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Accumulator.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.Accumulator.initialValue") - ) ++ Seq( - // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log__="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log__="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log__="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") - ) ++ Seq( - // SPARK-12689 Migrate DDL parsing to the newly absorbed parser - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser") - ) ++ Seq( - // SPARK-7799 Add "streaming-akka" project - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$6"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$5"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$4"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$3"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.streaming.zeromq.ZeroMQReceiver"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver$Supervisor") - ) ++ Seq( - // SPARK-12348 Remove deprecated Streaming APIs. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.dstream.DStream.foreach"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions$default$4"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.awaitTermination"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.networkStream"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.api.java.JavaStreamingContextFactory"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.awaitTermination"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.sc"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreach"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.getOrCreate") - ) ++ Seq( - // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") - ) ++ Seq( - // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") - ) ++ Seq( - // SPARK-6363 Make Scala 2.11 the default Scala version - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") - ) ++ Seq( - // SPARK-7889 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI"), - // SPARK-13296 - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.UDFRegistration.register"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$") - ) ++ Seq( - // SPARK-12995 Remove deprecated APIs in graphx - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.lib.SVDPlusPlus.runSVDPlusPlus"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets") - ) ++ Seq( - // SPARK-13426 Remove the support of SIMR - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX") - ) ++ Seq( - // SPARK-13413 Remove SparkContext.metricsSystem/schedulerBackend_ setter - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metricsSystem"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.schedulerBackend_=") - ) ++ Seq( - // SPARK-13220 Deprecate yarn-client and yarn-cluster mode - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") - ) ++ Seq( - // SPARK-13465 TaskContext. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") - ) ++ Seq ( - // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") - ) ++ Seq( - // SPARK-13526 Move SQLContext per-session states to new class - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.sql.UDFRegistration.this") - ) ++ Seq( - // [SPARK-13486][SQL] Move SQLConf into an internal package - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$") - ) ++ Seq( - //SPARK-11011 UserDefinedType serialization should be strongly typed - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), - // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition") - ) ++ Seq( - // [SPARK-13244][SQL] Migrates DataFrame to Dataset - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.tables"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.sql"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.baseRelationToDataFrame"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.table"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrame.apply"), - - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.localSeqToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.stringRddToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.rddToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"), - - // [SPARK-14451][SQL] Move encoder definition into Aggregator interface - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions") - ) ++ Seq( - // [SPARK-13686][MLLIB][STREAMING] Add a constructor parameter `reqParam` to (Streaming)LinearRegressionWithSGD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD.this") - ) ++ Seq( - // SPARK-13920: MIMA checks should apply to @Experimental and @DeveloperAPI APIs - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineCombinersByKey"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineValuesByKey"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.run"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.runJob"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.actorSystem"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.cacheManager"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getConfigurationFromJobContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTaskAttemptIDFromTaskAttemptContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.newConfiguration"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback_="), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.setBytesReadCallback"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.updateBytesRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decFetchWaitTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decLocalBlocksFetched"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRecordsRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBlocksFetched"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBytesRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleBytesWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleWriteTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleBytesWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleWriteTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.setShuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.PCAModel.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.taskMetrics"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.TaskInfo.attempt"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.ExperimentalMethods.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUDF"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUdf"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.cumeDist"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.denseRank"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.inputFileName"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.isNaN"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.percentRank"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.rowNumber"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.sparkPartitionId"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.externalBlockStoreSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsedByRdd"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatusListener.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.scheduler.BatchInfo.streamIdToNumRecords"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.storageStatusList"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.storage.StorageListener.storageStatusList"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.apply"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.copy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.InputMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.OutputMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transformImpl"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.extractLabeledPoints"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.GBTClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayes.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRest.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRestModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.RandomForestClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeans.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logLikelihood"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logPerplexity"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.RegressionEvaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Binarizer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Bucketizer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelector.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.HashingTF.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDF.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDFModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IndexToString.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Interaction.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.OneHotEncoder.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCA.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCAModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormula.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormulaModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.SQLTransformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StopWordsRemover.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorAssembler.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorSlicer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2Vec.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALS.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.GBTRegressor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.extractWeightedLabeledPoints"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.extractWeightedLabeledPoints"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegression.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionTrainingSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplit.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.RegressionMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameWriter.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.broadcast"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.callUDF"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.InsertableRelation.insert"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.fMeasureByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.pr"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.precisionByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.predictions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.recallByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.roc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.describeTopics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.getVectors"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.itemFactors"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.userFactors"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.predictions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.residuals"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.name"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.value"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.drop"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.fill"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.replace"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.jdbc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.json"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.load"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.orc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.parquet"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.table"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.text"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.crosstab"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.freqItems"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.sampleBy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.createExternalTable"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.emptyDataFrame"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.range"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.functions.udf"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.JobLogger"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorHelper"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics$"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics$"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics$"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.functions$"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Predictor.train"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListener.onOtherEvent"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.InsertableRelation.insert") - ) ++ Seq( - // [SPARK-13926] Automatically use Kryo serializer when shuffling RDDs with simple types - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ShuffleDependency.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ShuffleDependency.serializer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.Serializer$") - ) ++ Seq( - // SPARK-13927: add row/column iterator to local matrices - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.rowIter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.colIter") - ) ++ Seq( - // SPARK-13948: MiMa Check should catch if the visibility change to `private` - // TODO(josh): Some of these may be legitimate incompatibilities; we should follow up before the 2.0.0 release - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.toDS"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.askTimeout"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.lookupTimeout"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.UnaryTransformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.select"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.toDF"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Logging.initializeLogIfNecessary"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerEvent.logEvent"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance") - ) ++ Seq( - // [SPARK-14014] Replace existing analysis.Catalog with SessionCatalog - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.this") - ) ++ Seq( - // [SPARK-13928] Move org.apache.spark.Logging into org.apache.spark.internal.Logging - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Logging"), - (problem: Problem) => problem match { - case MissingTypesProblem(_, missing) - if missing.map(_.fullName).sameElements(Seq("org.apache.spark.Logging")) => false - case _ => true - } - ) ++ Seq( - // [SPARK-13990] Automatically pick serializer when caching RDDs - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.uploadBlock") - ) ++ Seq( - // [SPARK-14089][CORE][MLLIB] Remove methods that has been deprecated since 1.1, 1.2, 1.3, 1.4, and 1.5 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.getThreadLocal"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeReduce"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeAggregate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.defaultStategy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.saveLabeledData"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLabeledData"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.optimization.LBFGS.setMaxNumIterations"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.setScoreCol") - ) ++ Seq( - // [SPARK-14205][SQL] remove trait Queryable - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset") - ) ++ Seq( - // [SPARK-11262][ML] Unit test for gradient, loss layers, memory management - // for multilayer perceptron. - // This class is marked as `private`. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction") - ) ++ Seq( - // [SPARK-13674][SQL] Add wholestage codegen support to Sample - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this") - ) ++ Seq( - // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this") - ) ++ Seq( - // [SPARK-14437][Core] Use the address that NettyBlockTransferService listens to create BlockManagerId - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.this") - ) ++ Seq( - // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") - ) ++ Seq( - // [SPARK-14475] Propagate user-defined context from driver to executors - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), - // [SPARK-14617] Remove deprecated APIs in TaskMetrics - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$"), - // [SPARK-14628] Simplify task metrics by always tracking read/write metrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.readMethod"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.writeMethod") - ) ++ Seq( - // SPARK-14628: Always track input/output/shuffle metrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.totalBlocksFetched"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.inputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.outputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleWriteMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleReadMetrics"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.inputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.outputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleWriteMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleReadMetrics"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") - ) ++ Seq( - // SPARK-13643: Move functionality from SQLContext to SparkSession - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.getSchema") - ) ++ Seq( - // [SPARK-14407] Hides HadoopFsRelation related data source API into execution package - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriterFactory") - ) ++ Seq( - // SPARK-14734: Add conversions between mllib and ml Vector, Matrix types - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asML"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asML") - ) ++ Seq( - // SPARK-14704: Create accumulators in TaskMetrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.this") - ) ++ Seq( - // SPARK-14861: Replace internal usages of SQLContext with SparkSession - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.clustering.LocalLDAModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.clustering.DistributedLDAModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.clustering.LDAModel.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.ml.clustering.LDAModel.sqlContext"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.sql.Dataset.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.sql.DataFrameReader.this") - ) ++ Seq( - // [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable") - ) ++ Seq( - // [SPARK-14952][Core][ML] Remove methods deprecated in 1.6 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.input.PortableDataStream.close"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.weights"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionModel.weights") - ) ++ Seq( - // SPARK-14654: New accumulator API - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ExceptionFailure$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.metrics"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.remoteBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.totalBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.localBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.remoteBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.localBlocksFetched") - ) - case v if v.startsWith("1.6") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("network"), - MimaBuild.excludeSparkPackage("unsafe"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - // SQL execution is considered private. - excludePackage("org.apache.spark.sql.execution"), - // SQL columnar is considered private. - excludePackage("org.apache.spark.sql.columnar"), - // The shuffle package is considered private. - excludePackage("org.apache.spark.shuffle"), - // The collections utilities are considered private. - excludePackage("org.apache.spark.util.collection") - ) ++ - MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++ - MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++ - Seq( - // MiMa does not deal properly with sealed traits - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticRegressionSummary.featuresCol") - ) ++ Seq( - // SPARK-11530 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this") - ) ++ Seq( - // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. - // This class is marked as `private` but MiMa still seems to be confused by the change. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") - ) ++ Seq( - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup") - ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresAggregator.add"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.clearLastInstantiatedContext"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.setLastInstantiatedContext"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.SQLContext$SQLSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.detachSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.tlSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.defaultSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.currentSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.openSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.setSession"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.createSession") - ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.preferredNodeLocationData_="), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SparkSQLParser") - ) ++ Seq( - // SPARK-11485 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), - // SPARK-11541 mark various JDBC dialects as private - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productElement"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productArity"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.canEqual"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productIterator"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productPrefix"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.toString"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.hashCode"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.PostgresDialect$"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productElement"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productArity"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.canEqual"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productIterator"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productPrefix"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") - ) ++ Seq ( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationInfo.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.StageData.this") - ) ++ Seq( - // SPARK-11766 add toJson to Vector - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toJson") - ) ++ Seq( - // SPARK-9065 Support message handler in Kafka Python API - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createDirectStream"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper.createRDD") - ) ++ Seq( - // SPARK-4557 Changed foreachRDD to use VoidFunction - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD") - ) ++ Seq( - // SPARK-11996 Make the executor thread dump work again - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.ExecutorEndpoint$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.storage.BlockManagerMessages$GetRpcHostPortForExecutor$") - ) ++ Seq( - // SPARK-3580 Add getNumPartitions method to JavaRDD - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.getNumPartitions") - ) ++ Seq( - // SPARK-12149 Added new fields to ExecutorSummary - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") - ) ++ - // SPARK-11314: YARN backend moved to yarn sub-module and MiMA complains even though it's a - // private class. - MimaBuild.excludeSparkClass("scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") - case v if v.startsWith("1.5") => - Seq( - MimaBuild.excludeSparkPackage("network"), - MimaBuild.excludeSparkPackage("deploy"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - // JavaRDDLike is not meant to be extended by user programs - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.partitioner"), - // Modification of private static method - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), - // Mima false positive (was a private[spark] class) - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.util.collection.PairIterator"), - // Removing a testing method from a private class - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), - // While private MiMa is still not happy about the changes, - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresAggregator.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.regression.LeastSquaresCostFun.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.ml.classification.LogisticCostFun.this"), - // SQL execution is considered private. - excludePackage("org.apache.spark.sql.execution"), - // The old JSON RDD is removed in favor of streaming Jackson - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"), - // local function inside a method - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") - ) ++ Seq( - // SPARK-8479 Add numNonzeros and numActives to Matrix. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.numNonzeros"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.numActives") - ) ++ Seq( - // SPARK-8914 Remove RDDApi - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.RDDApi") - ) ++ Seq( - // SPARK-7292 Provide operator to truncate lineage cheaply - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.rdd.RDDCheckpointData"), - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.rdd.CheckpointRDD") - ) ++ Seq( - // SPARK-8701 Add input metadata in the batch page. - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.streaming.scheduler.InputInfo$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.streaming.scheduler.InputInfo") - ) ++ Seq( - // SPARK-6797 Support YARN modes for SparkR - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.PairwiseRRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.RRDD.createRWorker"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.RRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.StringRRDD.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.r.BaseRRDD.this") - ) ++ Seq( - // SPARK-7422 add argmax for sparse vectors - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.argmax") - ) ++ Seq( - // SPARK-8906 Move all internal data source classes into execution.datasources - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.BaseWriterContainer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.ResolvedDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), - // SPARK-9763 Minimize exposure of internal SQL classes - excludePackage("org.apache.spark.sql.parquet"), - excludePackage("org.apache.spark.sql.json"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$DecimalConversion"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartition$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$JDBCConversion"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$DriverWrapper"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCPartitioningInfo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JdbcUtils"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DefaultSource"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.package$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.JDBCRelation") - ) ++ Seq( - // SPARK-4751 Dynamic allocation for standalone mode - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.supportDynamicAllocation") - ) ++ Seq( - // SPARK-9580: Remove SQL test singletons - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.LocalSQLContext$SQLSession"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.LocalSQLContext"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.TestSQLContext"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.test.TestSQLContext$") - ) ++ Seq( - // SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.mllib.linalg.VectorUDT.serialize") - ) ++ Seq( - // SPARK-10381 Fix types / units in private AskPermissionToCommitOutput RPC message. - // This class is marked as `private` but MiMa still seems to be confused by the change. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.task"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.taskAttempt"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.copy$default$3"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply") - ) + // Exclude rules for 2.1.x + lazy val v21excludes = v20excludes - case v if v.startsWith("1.4") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("ml"), - // SPARK-7910 Adding a method to get the partitioner to JavaRDD, - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), - // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in 1.3. - excludePackage("org.spark-project.jetty"), - MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.rdd.JdbcRDD.compute"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorEndpoint") - ) ++ Seq( - // SPARK-4655 - Making Stage an Abstract class broke binary compatibility even though - // the stage class is defined as private[spark] - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.scheduler.Stage") - ) ++ Seq( - // SPARK-6510 Add a Graph#minus method acting as Set#difference - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.minus") - ) ++ Seq( - // SPARK-6492 Fix deadlock in SparkContext.stop() - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.org$" + - "apache$spark$SparkContext$$SPARK_CONTEXT_CONSTRUCTOR_LOCK") - )++ Seq( - // SPARK-6693 add tostring with max lines and width for matrix - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.toString") - )++ Seq( - // SPARK-6703 Add getOrCreate method to SparkContext - ProblemFilters.exclude[IncompatibleResultTypeProblem] - ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") - )++ Seq( - // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.mllib.clustering.LDA$EMOptimizer") - ) ++ Seq( - // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.compressed"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toDense"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numNonzeros"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.toSparse"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numActives"), - // SPARK-7681 add SparseVector support for gemv - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.multiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.SparseMatrix.multiply") - ) ++ Seq( - // Execution should never be included as its always internal. - MimaBuild.excludeSparkPackage("sql.execution"), - // This `protected[sql]` method was removed in 1.3.1 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.SQLContext.checkAnalysis"), - // These `private[sql]` class were removed in 1.4.0: - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.execution.AddExchange"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.execution.AddExchange$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.PartitionSpec"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.PartitionSpec$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.Partition"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.Partition$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), - // These test support classes were moved out of src/main and into src/test: - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTestData"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTestData$"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.TestGroupWriteSupport"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), - // TODO: Remove the following rule once ParquetTest has been moved to src/test. - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.ParquetTest") - ) ++ Seq( - // SPARK-7530 Added StreamingContext.getState() - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.StreamingContext.state_=") - ) ++ Seq( - // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some - // unnecessary type bounds in order to fix some compiler warnings that occurred when - // implementing this interface in Java. Note that ShuffleWriter is private[spark]. - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.shuffle.ShuffleWriter") - ) ++ Seq( - // SPARK-6888 make jdbc driver handling user definable - // This patch renames some classes to API friendly names. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") - ) + // Exclude rules for 2.0.x + lazy val v20excludes = { + Seq( + excludePackage("org.apache.spark.rpc"), + excludePackage("org.spark-project.jetty"), + excludePackage("org.spark_project.jetty"), + excludePackage("org.apache.spark.internal"), + excludePackage("org.apache.spark.unused"), + excludePackage("org.apache.spark.unsafe"), + excludePackage("org.apache.spark.memory"), + excludePackage("org.apache.spark.util.collection.unsafe"), + excludePackage("org.apache.spark.sql.catalyst"), + excludePackage("org.apache.spark.sql.execution"), + excludePackage("org.apache.spark.sql.internal"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationAttemptInfo.$default$5"), + // SPARK-14042 Add custom coalescer support + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.coalesce"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rdd.PartitionCoalescer$LocationIterator"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.rdd.PartitionCoalescer"), + // SPARK-15532 Remove isRootContext flag from SQLContext. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.isRootContext"), + // SPARK-12600 Remove SQL deprecated methods + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.applySchema"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.parquetFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jdbc"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.dialectClassName"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.getSQLDialect"), + // SPARK-13664 Replace HadoopFsRelation with FileFormat + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.source.libsvm.LibSVMRelation"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelationProvider"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelation$FileStatusCache"), + // SPARK-15543 Rename DefaultSources to make them more self-describing + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.source.libsvm.DefaultSource") + ) ++ Seq( + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory"), + // SPARK-14358 SparkListener from trait to abstract class + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.addSparkListener"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.JavaSparkListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkFirehoseListener"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.scheduler.SparkListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.jobs.JobProgressListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.exec.ExecutorsListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.env.EnvironmentListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.storage.StorageListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.StorageStatusListener") + ) ++ + Seq( + // SPARK-3369 Fix Iterable/Iterator in Java API + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.FlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.FlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.FlatMapFunction2.call"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.FlatMapFunction2.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.PairFlatMapFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.PairFlatMapFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.CoGroupFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.CoGroupFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.MapPartitionsFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.MapPartitionsFunction.call"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.FlatMapGroupsFunction.call"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.FlatMapGroupsFunction.call") + ) ++ + Seq( + // [SPARK-6429] Implement hashCode and equals together + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Partition.org$apache$spark$Partition$$super=uals") + ) ++ + Seq( + // SPARK-4819 replace Guava Optional + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getSparkHome"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner") + ) ++ + Seq( + // SPARK-12481 Remove Hadoop 1.x + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil"), + // SPARK-12615 Remove deprecated APIs in core + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.$default$6"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.numericRDDToDoubleRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intToIntWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.writableWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToAsyncRDDActions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.boolToBoolWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longToLongWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToOrderedRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.booleanWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringToText"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleToDoubleWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToSequenceFileRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesToBytesWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatToFloatWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addOnCompleteCallback"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.runningLocally"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.attemptId"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.defaultMinSplits"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.runJob"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.runJob"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.tachyonFolderName"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.initLocalProperties"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearJars"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearFiles"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.toArray"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.filterWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.foreachWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.SequenceFileRDDFunctions.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.splits"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.toArray"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.externalBlockStoreFolderName"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockManager"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore") + ) ++ Seq( + // SPARK-12149 Added new fields to ExecutorSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") + ) ++ + // SPARK-12665 Remove deprecated and unused classes + Seq( + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.graphx.GraphKryoRegistrator"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$") + ) ++ Seq( + // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge") + ) ++ Seq( + // SPARK-12510 Refactor ActorReceiver to support Java + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") + ) ++ Seq( + // SPARK-12895 Implement TaskMetrics using accumulators + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.internalMetricsToAccumulators"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectInternalAccumulators"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectAccumulators") + ) ++ Seq( + // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Accumulator.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.Accumulator.initialValue") + ) ++ Seq( + // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log__="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log__="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log__="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") + ) ++ Seq( + // SPARK-12689 Migrate DDL parsing to the newly absorbed parser + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser") + ) ++ Seq( + // SPARK-7799 Add "streaming-akka" project + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$6"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$5"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$3"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.streaming.zeromq.ZeroMQReceiver"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver$Supervisor") + ) ++ Seq( + // SPARK-12348 Remove deprecated Streaming APIs. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.dstream.DStream.foreach"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.awaitTermination"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.networkStream"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.api.java.JavaStreamingContextFactory"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.awaitTermination"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.sc"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreach"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.getOrCreate") + ) ++ Seq( + // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") + ) ++ Seq( + // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") + ) ++ Seq( + // SPARK-6363 Make Scala 2.11 the default Scala version + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") + ) ++ Seq( + // SPARK-7889 + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI"), + // SPARK-13296 + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.UDFRegistration.register"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$") + ) ++ Seq( + // SPARK-12995 Remove deprecated APIs in graphx + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.lib.SVDPlusPlus.runSVDPlusPlus"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets") + ) ++ Seq( + // SPARK-13426 Remove the support of SIMR + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX") + ) ++ Seq( + // SPARK-13413 Remove SparkContext.metricsSystem/schedulerBackend_ setter + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metricsSystem"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.schedulerBackend_=") + ) ++ Seq( + // SPARK-13220 Deprecate yarn-client and yarn-cluster mode + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-13465 TaskContext. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") + ) ++ Seq ( + // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") + ) ++ Seq( + // SPARK-13526 Move SQLContext per-session states to new class + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.sql.UDFRegistration.this") + ) ++ Seq( + // [SPARK-13486][SQL] Move SQLConf into an internal package + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$") + ) ++ Seq( + //SPARK-11011 UserDefinedType serialization should be strongly typed + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), + // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition") + ) ++ Seq( + // [SPARK-13244][SQL] Migrates DataFrame to Dataset + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.tables"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.sql"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.baseRelationToDataFrame"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.table"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrame.apply"), - case v if v.startsWith("1.3") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("ml"), - // These are needed if checking against the sbt build, since they are part of - // the maven-generated artifacts in the 1.2 build. - MimaBuild.excludeSparkPackage("unused"), - ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") - ) ++ Seq( - // SPARK-2321 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkStageInfoImpl.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkStageInfo.submissionTime") - ) ++ Seq( - // SPARK-4614 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrices.randn"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrices.rand") - ) ++ Seq( - // SPARK-5321 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.transpose"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + - "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.isTransposed"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Matrix.foreachActive") - ) ++ Seq( - // SPARK-5540 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"), - // SPARK-5536 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock") - ) ++ Seq( - // SPARK-3325 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.print"), - // SPARK-2757 - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." + - "removeAndGetProcessor") - ) ++ Seq( - // SPARK-5123 (SparkSQL data type change) - alpha component only - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.ml.feature.HashingTF.outputDataType"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.ml.feature.Tokenizer.outputDataType"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.feature.Tokenizer.validateInputType"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema") - ) ++ Seq( - // SPARK-4014 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.TaskContext.taskAttemptId"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.TaskContext.attemptNumber") - ) ++ Seq( - // SPARK-5166 Spark SQL API stabilization - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") - ) ++ Seq( - // SPARK-5270 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.isEmpty") - ) ++ Seq( - // SPARK-5430 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.treeReduce"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.treeAggregate") - ) ++ Seq( - // SPARK-5297 Java FileStream do not work with custom key/values - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream") - ) ++ Seq( - // SPARK-5315 Spark Streaming Java API returns Scala DStream - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow") - ) ++ Seq( - // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.graphx.Graph.getCheckpointFiles"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.graphx.Graph.isCheckpointed") - ) ++ Seq( - // SPARK-4789 Standardize ML Prediction APIs - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType") - ) ++ Seq( - // SPARK-5814 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$wrapDoubleArray"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$fillFullMatrix"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$iterations"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeOutLinkBlock"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$computeYtY"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeLinkRDDs"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$alpha"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$randomFactor"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$makeInLinkBlock"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$dspr"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$lambda"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$implicitPrefs"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$rank") - ) ++ Seq( - // SPARK-4682 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock") - ) ++ Seq( - // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff") - ) + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.localSeqToDataFrameHolder"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.stringRddToDataFrameHolder"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.rddToDataFrameHolder"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"), - case v if v.startsWith("1.2") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ - MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ - Seq( - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.TaskLocation"), - // Added normL1 and normL2 to trait MultivariateStatisticalSummary - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), - // MapStatus should be private[spark] - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.MapStatus"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.network.netty.PathResolver"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.network.netty.client.BlockClientListener"), + // [SPARK-14451][SQL] Move encoder definition into Aggregator interface + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"), - // TaskContext was promoted to Abstract class - ProblemFilters.exclude[AbstractClassProblem]( - "org.apache.spark.TaskContext"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.util.collection.SortDataFormat") - ) ++ Seq( - // Adding new methods to the JavaRDDLike trait: - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.takeAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.collectAsync") - ) ++ Seq( - // SPARK-3822 - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") - ) ++ Seq( - // SPARK-1209 - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), - ProblemFilters.exclude[MissingClassProblem]( - "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), - ProblemFilters.exclude[MissingTypesProblem]( - "org.apache.spark.rdd.PairRDDFunctions") - ) ++ Seq( - // SPARK-4062 - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") - ) + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions") + ) ++ Seq( + // [SPARK-13686][MLLIB][STREAMING] Add a constructor parameter `reqParam` to (Streaming)LinearRegressionWithSGD + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD.this") + ) ++ Seq( + // SPARK-15250 Remove deprecated json API in DataFrameReader + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameReader.json") + ) ++ Seq( + // SPARK-13920: MIMA checks should apply to @Experimental and @DeveloperAPI APIs + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineCombinersByKey"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineValuesByKey"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.run"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.runJob"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.actorSystem"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.cacheManager"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getConfigurationFromJobContext"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTaskAttemptIDFromTaskAttemptContext"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.newConfiguration"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback_="), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.setBytesReadCallback"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.updateBytesRead"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decFetchWaitTime"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decLocalBlocksFetched"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRecordsRead"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBlocksFetched"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBytesRead"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleBytesWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleRecordsWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleWriteTime"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleBytesWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleRecordsWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleWriteTime"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.setShuffleRecordsWritten"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.PCAModel.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithContext"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.taskMetrics"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.TaskInfo.attempt"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.ExperimentalMethods.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUDF"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUdf"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.cumeDist"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.denseRank"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.inputFileName"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.isNaN"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.percentRank"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.rowNumber"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.sparkPartitionId"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.externalBlockStoreSize"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsed"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsedByRdd"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatusListener.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.scheduler.BatchInfo.streamIdToNumRecords"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.storageStatusList"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.storage.StorageListener.storageStatusList"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.apply"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.copy"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.InputMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.OutputMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transformImpl"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.extractLabeledPoints"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.GBTClassifier.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassifier.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayes.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRest.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRestModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.RandomForestClassifier.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeans.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logLikelihood"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logPerplexity"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.RegressionEvaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Binarizer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Bucketizer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelector.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizer.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.HashingTF.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDF.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDFModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IndexToString.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Interaction.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.OneHotEncoder.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCA.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCAModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormula.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormulaModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.SQLTransformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StopWordsRemover.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexer.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorAssembler.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexer.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorSlicer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2Vec.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALS.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.GBTRegressor.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.extractWeightedLabeledPoints"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.extractWeightedLabeledPoints"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegression.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionTrainingSummary.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressor.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplit.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.RegressionMetrics.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameWriter.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.broadcast"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.callUDF"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.InsertableRelation.insert"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.fMeasureByThreshold"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.pr"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.precisionByThreshold"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.predictions"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.recallByThreshold"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.roc"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.describeTopics"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.getVectors"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.itemFactors"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.userFactors"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.predictions"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.residuals"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.name"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.value"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.drop"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.fill"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.replace"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.jdbc"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.json"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.load"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.orc"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.parquet"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.table"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.text"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.crosstab"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.freqItems"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.sampleBy"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.createExternalTable"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.emptyDataFrame"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.range"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.functions.udf"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.JobLogger"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorHelper"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.functions$"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Predictor.train"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListener.onOtherEvent"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.InsertableRelation.insert") + ) ++ Seq( + // [SPARK-13926] Automatically use Kryo serializer when shuffling RDDs with simple types + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ShuffleDependency.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ShuffleDependency.serializer"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.Serializer$") + ) ++ Seq( + // SPARK-13927: add row/column iterator to local matrices + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.rowIter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.colIter") + ) ++ Seq( + // SPARK-13948: MiMa Check should catch if the visibility change to `private` + // TODO(josh): Some of these may be legitimate incompatibilities; we should follow up before the 2.0.0 release + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.toDS"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.askTimeout"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.lookupTimeout"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.UnaryTransformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassifier.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressor.train"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.select"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.toDF"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Logging.initializeLogIfNecessary"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerEvent.logEvent"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance") + ) ++ Seq( + // [SPARK-14014] Replace existing analysis.Catalog with SessionCatalog + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.this") + ) ++ Seq( + // [SPARK-13928] Move org.apache.spark.Logging into org.apache.spark.internal.Logging + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Logging"), + (problem: Problem) => problem match { + case MissingTypesProblem(_, missing) + if missing.map(_.fullName).sameElements(Seq("org.apache.spark.Logging")) => false + case _ => true + } + ) ++ Seq( + // [SPARK-13990] Automatically pick serializer when caching RDDs + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.uploadBlock") + ) ++ Seq( + // [SPARK-14089][CORE][MLLIB] Remove methods that has been deprecated since 1.1, 1.2, 1.3, 1.4, and 1.5 + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.getThreadLocal"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeReduce"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeAggregate"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.defaultStategy"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.saveLabeledData"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLabeledData"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.optimization.LBFGS.setMaxNumIterations"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.setScoreCol") + ) ++ Seq( + // [SPARK-14205][SQL] remove trait Queryable + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset") + ) ++ Seq( + // [SPARK-11262][ML] Unit test for gradient, loss layers, memory management + // for multilayer perceptron. + // This class is marked as `private`. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction") + ) ++ Seq( + // [SPARK-13674][SQL] Add wholestage codegen support to Sample + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this") + ) ++ Seq( + // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this") + ) ++ Seq( + // [SPARK-14437][Core] Use the address that NettyBlockTransferService listens to create BlockManagerId + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.this") + ) ++ Seq( + // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") + ) ++ Seq( + // [SPARK-14475] Propagate user-defined context from driver to executors + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), + // [SPARK-14617] Remove deprecated APIs in TaskMetrics + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$"), + // [SPARK-14628] Simplify task metrics by always tracking read/write metrics + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.readMethod"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.writeMethod") + ) ++ Seq( + // SPARK-14628: Always track input/output/shuffle metrics + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.totalBlocksFetched"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.inputMetrics"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.outputMetrics"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleWriteMetrics"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleReadMetrics"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.inputMetrics"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.outputMetrics"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleWriteMetrics"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleReadMetrics"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") + ) ++ Seq( + // SPARK-13643: Move functionality from SQLContext to SparkSession + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.getSchema") + ) ++ Seq( + // [SPARK-14407] Hides HadoopFsRelation related data source API into execution package + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriterFactory") + ) ++ Seq( + // SPARK-14734: Add conversions between mllib and ml Vector, Matrix types + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asML"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asML") + ) ++ Seq( + // SPARK-14704: Create accumulators in TaskMetrics + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.this") + ) ++ Seq( + // SPARK-14861: Replace internal usages of SQLContext with SparkSession + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.clustering.LocalLDAModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.clustering.DistributedLDAModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.ml.clustering.LDAModel.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.ml.clustering.LDAModel.sqlContext"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.sql.Dataset.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.sql.DataFrameReader.this") + ) ++ Seq( + // SPARK-14542 configurable buffer size for pipe RDD + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.pipe"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.pipe") + ) ++ Seq( + // [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable") + ) ++ Seq( + // [SPARK-14952][Core][ML] Remove methods deprecated in 1.6 + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.input.PortableDataStream.close"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.weights"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionModel.weights") + ) ++ Seq( + // [SPARK-10653] [Core] Remove unnecessary things from SparkEnv + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.sparkFilesDir"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.blockTransferService") + ) ++ Seq( + // SPARK-14654: New accumulator API + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ExceptionFailure$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.metrics"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.remoteBlocksFetched"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.totalBlocksFetched"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.localBlocksFetched"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.remoteBlocksFetched"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.localBlocksFetched") + ) ++ Seq( + // [SPARK-14615][ML] Use the new ML Vector and Matrix in the ML pipeline based algorithms + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.getOldDocConcentration"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.estimatedDocConcentration"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.topicsMatrix"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.clusterCenters"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LabelConverter.decodeLabel"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LabelConverter.encodeLabeledPoint"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.weights"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.predict"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.predictRaw"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.raw2probabilityInPlace"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.theta"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.pi"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.probability2prediction"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.predictRaw"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.raw2prediction"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.raw2probabilityInPlace"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.predict"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.coefficients"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.raw2prediction"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.predictRaw"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.ClassificationModel.predictRaw"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.ElementwiseProduct.getScalingVec"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ElementwiseProduct.setScalingVec"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.PCAModel.pc"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.originalMax"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.originalMin"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.IDFModel.idf"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.mean"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.std"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.predict"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.coefficients"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.predictQuantiles"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.predictions"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.boundaries"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.predict"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.coefficients"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.this") + ) ++ Seq( + // [SPARK-15290] Move annotations, like @Since / @DeveloperApi, into spark-tags + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.package$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.package"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.Private"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.AlphaComponent"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.Experimental"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.DeveloperApi") + ) ++ Seq( + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asBreeze"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asBreeze") + ) ++ Seq( + // [SPARK-15914] Binary compatibility is broken since consolidation of Dataset and DataFrame + // in Spark 2.0. However, source level compatibility is still maintained. + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.load"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jsonFile"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jdbc"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.parquetFile"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.applySchema") + ) ++ Seq( + // SPARK-17096: Improve exception string reported through the StreamingQueryListener + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.stackTrace"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.this") + ) ++ Seq( + // SPARK-16240: ML persistence backward compatibility for LDA + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$") + ) ++ Seq( + // [SPARK-17731][SQL][Streaming] Metrics for structured streaming + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryInfo"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.queryInfo"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.queryInfo"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.queryInfo"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.SourceStatus.offsetDesc"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SourceStatus.this"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.status"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryStarted"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryStarted"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryStarted"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryProgress"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryProgress"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryTerminated"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryTerminated") + ) + } - case v if v.startsWith("1.1") => - Seq( - MimaBuild.excludeSparkPackage("deploy"), - MimaBuild.excludeSparkPackage("graphx") - ) ++ - Seq( - // Adding new method to JavaRDLike trait - we should probably mark this as a developer API. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitions"), - // Should probably mark this as Experimental - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), - // We made a mistake earlier (ed06500d3) in the Java API to use default parameter values - // for countApproxDistinct* functions, which does not work in Java. We later removed - // them, and use the following to tell Mima to not care about them. - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaPairRDD.countApproxDistinctByKey$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.DiskStore.getValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.Entry") - ) ++ - Seq( - // Serializer interface change. See SPARK-3045. - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.DeserializationStream"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.Serializer"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.SerializationStream"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.serializer.SerializerInstance") - )++ - Seq( - // Renamed putValues -> putArray + putIterator - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.MemoryStore.putValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.DiskStore.putValues"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.storage.TachyonStore.putValues") - ) ++ - Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.streaming.flume.FlumeReceiver.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaUtils.createStream"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.streaming.kafka.KafkaReceiver.this") - ) ++ - Seq( // Ignore some private methods in ALS. - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"), - ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments. - "org.apache.spark.mllib.recommendation.ALS.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures") - ) ++ - MimaBuild.excludeSparkClass("mllib.linalg.distributed.ColumnStatisticsAggregator") ++ - MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ - MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ - MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ - MimaBuild.excludeSparkClass("storage.Values") ++ - MimaBuild.excludeSparkClass("storage.Entry") ++ - MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ - // Class was missing "@DeveloperApi" annotation in 1.0. - MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ - Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Gini.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.mllib.tree.impurity.Variance.calculate") - ) ++ - Seq( // Package-private classes removed in SPARK-2341 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") - ) ++ - Seq( // package-private classes removed in MLlib - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") - ) ++ - Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") - ) ++ - Seq( // synthetic methods generated in LabeledPoint - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") - ) ++ - Seq ( // Scala 2.11 compatibility fix - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") - ) - case v if v.startsWith("1.0") => - Seq( - MimaBuild.excludeSparkPackage("api.java"), - MimaBuild.excludeSparkPackage("mllib"), - MimaBuild.excludeSparkPackage("streaming") - ) ++ - MimaBuild.excludeSparkClass("rdd.ClassTags") ++ - MimaBuild.excludeSparkClass("util.XORShiftRandom") ++ - MimaBuild.excludeSparkClass("graphx.EdgeRDD") ++ - MimaBuild.excludeSparkClass("graphx.VertexRDD") ++ - MimaBuild.excludeSparkClass("graphx.impl.GraphImpl") ++ - MimaBuild.excludeSparkClass("graphx.impl.RoutingTable") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.PrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") ++ - MimaBuild.excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - MimaBuild.excludeSparkClass("mllib.optimization.SquaredGradient") ++ - MimaBuild.excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LassoWithSGD") ++ - MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + def excludes(version: String) = version match { + case v if v.startsWith("2.1") => v21excludes + case v if v.startsWith("2.0") => v20excludes case _ => Seq() } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f7781e0ef12b4..98f1e23a9fd77 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -39,21 +39,21 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, hiveCompatibility) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "hivecontext-compatibility" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, sqlKafka010) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq( - streaming, streamingFlumeSink, streamingFlume, streamingKafka + streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingKafka010 ) = Seq( - "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka" + "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8", "streaming-kafka-0-10" ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( - core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* + core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, _* ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", - "test-tags", "sketch" + "tags", "sketch" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl, @@ -61,8 +61,8 @@ object BuildCommons { Seq("yarn", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = - Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") + val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = + Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val copyJarsProjects@Seq(assembly, examples) = Seq("assembly", "examples") @@ -277,12 +277,24 @@ object SparkBuild extends PomBuild { // additional discussion and explanation. javacOptions in (Compile, compile) ++= Seq( "-target", javacJVMVersion.value - ), + ) ++ sys.env.get("JAVA_7_HOME").toSeq.flatMap { jdk7 => + if (javacJVMVersion.value == "1.7") { + Seq("-bootclasspath", s"$jdk7/jre/lib/rt.jar") + } else { + Nil + } + }, scalacOptions in Compile ++= Seq( s"-target:jvm-${scalacJVMVersion.value}", "-sourcepath", (baseDirectory in ThisBuild).value.getAbsolutePath // Required for relative source links in scaladoc - ), + ) ++ sys.env.get("JAVA_7_HOME").toSeq.flatMap { jdk7 => + if (javacJVMVersion.value == "1.7") { + Seq("-javabootclasspath", s"$jdk7/jre/lib/rt.jar") + } else { + Nil + } + }, // Implements -Xfatal-warnings, ignoring deprecation warnings. // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410. @@ -339,8 +351,8 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( - spark, hive, hiveThriftServer, hiveCompatibility, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, testTags, sketch, mllibLocal + spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, + unsafe, tags, sqlKafka010 ).contains(x) } @@ -348,6 +360,9 @@ object SparkBuild extends PomBuild { enable(MimaBuild.mimaSettings(sparkHome, x))(x) } + /* Generate and pick the spark build info from extra-resources */ + enable(Core.settings)(core) + /* Unsafe settings */ enable(Unsafe.settings)(unsafe) @@ -382,7 +397,8 @@ object SparkBuild extends PomBuild { enable(Java8TestSettings.settings)(java8Tests) - enable(DockerIntegrationTests.settings)(dockerIntegrationTests) + // SPARK-14738 - Remove docker tests from main Spark build + // enable(DockerIntegrationTests.settings)(dockerIntegrationTests) /** * Adds the ability to run the spark shell directly from SBT without building an assembly @@ -435,7 +451,19 @@ object SparkBuild extends PomBuild { else x.settings(Seq[Setting[_]](): _*) } ++ Seq[Project](OldDeps.project) } +} +object Core { + lazy val settings = Seq( + resourceGenerators in Compile += Def.task { + val buildScript = baseDirectory.value + "/../build/spark-build-info" + val targetDir = baseDirectory.value + "/target/extra-resources/" + val command = Seq("bash", buildScript, targetDir, version.value) + Process(command).!! + val propsFile = baseDirectory.value / "target" / "extra-resources" / "spark-version-info.properties" + Seq(propsFile) + }.taskValue + ) } object Unsafe { @@ -453,7 +481,7 @@ object DockerIntegrationTests { // This serves to override the override specified in DependencyOverrides: lazy val settings = Seq( dependencyOverrides += "com.google.guava" % "guava" % "18.0", - resolvers ++= Seq("DB2" at "https://app.camunda.com/nexus/content/repositories/public/"), + resolvers += "DB2" at "https://app.camunda.com/nexus/content/repositories/public/", libraryDependencies += "com.oracle" % "ojdbc6" % "11.2.0.1.0" from "https://app.camunda.com/nexus/content/repositories/public/com/oracle/ojdbc6/11.2.0.1.0/ojdbc6-11.2.0.1.0.jar" // scalastyle:ignore ) } @@ -467,9 +495,9 @@ object DependencyOverrides { } /** - This excludes library dependencies in sbt, which are specified in maven but are - not needed by sbt build. - */ + * This excludes library dependencies in sbt, which are specified in maven but are + * not needed by sbt build. + */ object ExcludedDependencies { lazy val settings = Seq( libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } @@ -580,8 +608,8 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { - // This must match the same name used in maven (see external/kafka-assembly/pom.xml) + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-0-8-assembly") || mName.contains("streaming-kafka-0-10-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { + // This must match the same name used in maven (see external/kafka-0-8-assembly/pom.xml) s"${mName}-${v}.jar" } else { s"${mName}-${v}-hadoop${hv}.jar" @@ -656,11 +684,6 @@ object Unidoc { import sbtunidoc.Plugin._ import UnidocKeys._ - // for easier specification of JavaDoc package groups - private def packageList(names: String*): String = { - names.map(s => "org.apache.spark." + s).mkString(":") - } - private def ignoreUndocumentedPackages(packages: Seq[Seq[File]]): Seq[Seq[File]] = { packages .map(_.filterNot(_.getName.contains("$"))) @@ -678,15 +701,29 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test"))) } + private def ignoreClasspaths(classpaths: Seq[Classpath]): Seq[Classpath] = { + classpaths + .map(_.filterNot(_.data.getCanonicalPath.matches(""".*kafka-clients-0\.10.*"""))) + .map(_.filterNot(_.data.getCanonicalPath.matches(""".*kafka_2\..*-0\.10.*"""))) + } + val unidocSourceBase = settingKey[String]("Base URL of source links in Scaladoc.") lazy val settings = scalaJavaUnidocSettings ++ Seq ( publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010, sqlKafka010), + + unidocAllClasspaths in (ScalaUnidoc, unidoc) := { + ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value) + }, + + unidocAllClasspaths in (JavaUnidoc, unidoc) := { + ignoreClasspaths((unidocAllClasspaths in (JavaUnidoc, unidoc)).value) + }, // Skip actual catalyst, but include the subproject. // Catalyst is not public API and contains quasiquotes which break scaladoc. @@ -697,27 +734,12 @@ object Unidoc { // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { ignoreUndocumentedPackages((unidocAllSources in (JavaUnidoc, unidoc)).value) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/hadoop"))) }, - // Javadoc options: create a window title, and group key packages on index page - javacOptions in doc := Seq( + javacOptions in (JavaUnidoc, unidoc) := Seq( "-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc", "-public", - "-group", "Core Java API", packageList("api.java", "api.java.function"), - "-group", "Spark Streaming", packageList( - "streaming.api.java", "streaming.flume", "streaming.kafka", "streaming.kinesis" - ), - "-group", "MLlib", packageList( - "mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg", - "mllib.linalg.distributed", "mllib.optimization", "mllib.rdd", "mllib.recommendation", - "mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration", - "mllib.tree.impurity", "mllib.tree.model", "mllib.util", - "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation", - "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss", - "ml", "ml.attribute", "ml.classification", "ml.clustering", "ml.evaluation", "ml.feature", - "ml.param", "ml.recommendation", "ml.regression", "ml.tuning" - ), - "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" ), @@ -725,7 +747,8 @@ object Unidoc { unidocSourceBase := s"https://github.com/apache/spark/tree/v${version.value}", scalacOptions in (ScalaUnidoc, unidoc) ++= Seq( - "-groups" // Group similar methods together based on the @group annotation. + "-groups", // Group similar methods together based on the @group annotation. + "-skip-packages", "org.apache.hadoop" ) ++ ( // Add links to sources when generating Scaladoc for a non-snapshot release if (!isSnapshot.value) { @@ -805,7 +828,7 @@ object TestSettings { javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", - javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", + javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=false", javaOptions in Test += "-Dderby.system.durability=test", javaOptions in Test ++= System.getProperties.asScala.filter(_._1.startsWith("spark")) .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/project/plugins.sbt b/project/plugins.sbt index 44ec3a12ae709..76597d27292ea 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -6,7 +6,7 @@ addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") -addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.9") +addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.11") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") @@ -20,7 +20,13 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" -// TODO I am not sure we want such a dep. -resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases" +addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.11") -addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.10") +// Spark uses a custom fork of the sbt-pom-reader plugin which contains a patch to fix issues +// related to test-jar dependencies (https://github.com/sbt/sbt-pom-reader/pull/14). The source for +// this fork is published at https://github.com/JoshRosen/sbt-pom-reader/tree/v1.0.0-spark +// and corresponds to commit b160317fcb0b9d1009635a7c5aa05d0f3be61936 in that repository. +// In the long run, we should try to merge our patch upstream and switch to an upstream version of +// the plugin; this is tracked at SPARK-14401. + +addSbtPlugin("org.spark-project" % "sbt-pom-reader" % "1.0.0-spark") diff --git a/python/docs/Makefile b/python/docs/Makefile index 905e0215c20c2..de86e97d862f0 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9.2-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.3-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/docs/conf.py b/python/docs/conf.py index d35bf73c30510..50fb3175a7dc4 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -32,6 +32,7 @@ 'sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'epytext', + 'sphinx.ext.mathjax', ] # Add any paths that contain templates here, relative to this directory. diff --git a/python/docs/index.rst b/python/docs/index.rst index 306ffdb0e0f13..421c8de86a3cc 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -50,4 +50,3 @@ Indices and tables ================== * :ref:`search` - diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 86d4186a2c798..26f7415e1a423 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -41,6 +41,14 @@ pyspark.ml.clustering module :undoc-members: :inherited-members: +pyspark.ml.linalg module +---------------------------- + +.. automodule:: pyspark.ml.linalg + :members: + :undoc-members: + :inherited-members: + pyspark.ml.recommendation module -------------------------------- diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index 6259379ed05b7..09848b880194d 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -8,16 +8,20 @@ Module Context :members: :undoc-members: - pyspark.sql.types module ------------------------ .. automodule:: pyspark.sql.types :members: :undoc-members: - pyspark.sql.functions module ---------------------------- .. automodule:: pyspark.sql.functions :members: :undoc-members: + +pyspark.sql.streaming module +---------------------------- +.. automodule:: pyspark.sql.streaming + :members: + :undoc-members: diff --git a/python/lib/py4j-0.10.3-src.zip b/python/lib/py4j-0.10.3-src.zip new file mode 100644 index 0000000000000..bc54f33af1515 Binary files /dev/null and b/python/lib/py4j-0.10.3-src.zip differ diff --git a/python/lib/py4j-0.9.2-src.zip b/python/lib/py4j-0.9.2-src.zip deleted file mode 100644 index 881bb759d7823..0000000000000 Binary files a/python/lib/py4j-0.9.2-src.zip and /dev/null differ diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index e56e22a9b920e..822ae46e45111 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -169,7 +169,12 @@ def save_function(self, obj, name=None): if name is None: name = obj.__name__ - modname = pickle.whichmodule(obj, name) + try: + # whichmodule() could fail, see + # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling + modname = pickle.whichmodule(obj, name) + except Exception: + modname = None # print('which gives %s %s %s' % (modname, obj, name)) try: themodule = sys.modules[modname] @@ -326,7 +331,12 @@ def save_global(self, obj, name=None, pack=struct.pack): modname = getattr(obj, "__module__", None) if modname is None: - modname = pickle.whichmodule(obj, name) + try: + # whichmodule() could fail, see + # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling + modname = pickle.whichmodule(obj, name) + except Exception: + modname = '__main__' if modname == '__main__': themodule = None diff --git a/python/pyspark/context.py b/python/pyspark/context.py index cb15b4b91f913..2744bb9ec04e5 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -155,10 +155,6 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.appName = self._conf.get("spark.app.name") self.sparkHome = self._conf.get("spark.home", None) - # Let YARN know it's a pyspark app, so it distributes needed libraries. - if self.master == "yarn-client": - self._conf.set("spark.yarn.isPython", "true") - for (k, v) in self._conf.getAll(): if k.startswith("spark.executorEnv."): varName = k[len("spark.executorEnv."):] @@ -170,6 +166,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # Create the Java SparkContext through Py4J self._jsc = jsc or self._initialize_context(self._conf._jconf) + # Reset the SparkConf to the one actually used by the SparkContext in JVM. + self._conf = SparkConf(_jconf=self._jsc.sc().conf()) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server @@ -789,14 +787,6 @@ def addFile(self, path): """ self._jsc.sc().addFile(path) - def clearFiles(self): - """ - Clear the job's list of files added by L{addFile} or L{addPyFile} so - that they do not get downloaded to any new nodes. - """ - # TODO: remove added .py or .zip files from the PYTHONPATH? - self._jsc.sc().clearFiles() - def addPyFile(self, path): """ Add a .py or .zip dependency for all tasks to be executed on this @@ -952,6 +942,11 @@ def dump_profiles(self, path): """ self.profiler_collector.dump_profiles(path) + def getConf(self): + conf = SparkConf() + conf.setAll(self._conf.getAll()) + return conf + def _test(): import atexit diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index cd4c55f79f18c..527ca82d31f1b 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -116,6 +116,7 @@ def killChild(): java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 25cfac02f383d..1d42d49a8816b 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -15,6 +15,10 @@ # limitations under the License. # +""" +DataFrame-based machine learning APIs to let users quickly assemble and configure practical +machine learning pipelines. +""" from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.pipeline import Pipeline, PipelineModel diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index a7a58e17a43ed..339e5d6af52a7 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -19,7 +19,7 @@ from pyspark import since from pyspark.ml.param import Params -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc @inherit_doc diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4331f73b73253..3c4af90acac85 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -21,12 +21,12 @@ from pyspark import since, keyword_only from pyspark.ml import Estimator, Model from pyspark.ml.param.shared import * -from pyspark.ml.regression import ( - RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) +from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \ + RandomForestParams, TreeEnsembleModels, TreeEnsembleParams from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.wrapper import JavaWrapper -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc from pyspark.sql import DataFrame from pyspark.sql.functions import udf, when from pyspark.sql.types import ArrayType, DoubleType @@ -53,7 +53,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti Currently, this class only supports binary classification. >>> from pyspark.sql import Row - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> df = sc.parallelize([ ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() @@ -96,7 +96,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti threshold = Param(Params._dummy(), "threshold", "Threshold in binary classification prediction, in range [0, 1]." + - " If threshold and thresholds are both set, they must match.", + " If threshold and thresholds are both set, they must match." + + "e.g. if threshold is p, then thresholds must be equal to [1-p, p].", typeConverter=TypeConverters.toFloat) @keyword_only @@ -154,7 +155,12 @@ def setThreshold(self, value): @since("1.4.0") def getThreshold(self): """ - Gets the value of threshold or its default value. + Get threshold for binary classification. + + If :py:attr:`thresholds` is set with length 2 (i.e., binary classification), + this returns the equivalent threshold: + :math:`\\frac{1}{1 + \\frac{thresholds(0)}{thresholds(1)}}`. + Otherwise, returns :py:attr:`threshold` if set or its default value if unset. """ self._checkThresholdConsistency() if self.isSet(self.thresholds): @@ -214,7 +220,7 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ @property - @since("1.6.0") + @since("2.0.0") def coefficients(self): """ Model coefficients. @@ -267,6 +273,8 @@ def evaluate(self, dataset): class LogisticRegressionSummary(JavaWrapper): """ + .. note:: Experimental + Abstraction for Logistic Regression Results for a given model. .. versionadded:: 2.0.0 @@ -311,6 +319,8 @@ def featuresCol(self): @inherit_doc class LogisticRegressionTrainingSummary(LogisticRegressionSummary): """ + .. note:: Experimental + Abstraction for multinomial Logistic Regression Training results. Currently, the training summary ignores the training weights except for the objective trace. @@ -351,9 +361,11 @@ class BinaryLogisticRegressionSummary(LogisticRegressionSummary): def roc(self): """ Returns the receiver operating characteristic (ROC) curve, - which is an Dataframe having two fields (FPR, TPR) with + which is a Dataframe having two fields (FPR, TPR) with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. - Reference: http://en.wikipedia.org/wiki/Receiver_operating_characteristic + + .. seealso:: `Wikipedia reference \ + `_ Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. This will change in later Spark @@ -378,7 +390,7 @@ def areaUnderROC(self): @since("2.0.0") def pr(self): """ - Returns the precision-recall curve, which is an Dataframe + Returns the precision-recall curve, which is a Dataframe containing two fields recall, precision with (0.0, 1.0) prepended to it. @@ -464,8 +476,7 @@ def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. """ - self._set(impurity=value) - return self + return self._set(impurity=value) @since("1.6.0") def getImpurity(self): @@ -490,14 +501,14 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable): """ - `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` + `Decision tree `_ learning algorithm for classification. It supports both binary and multiclass labels, as well as both continuous and categorical features. - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer - >>> df = sqlContext.createDataFrame([ + >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") @@ -511,7 +522,9 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 1 >>> model.featureImportances SparseVector(1, {0: 1.0}) - >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> print(model.toDebugString) + DecisionTreeClassificationModel (uid=...) of depth 1 with 3 nodes... + >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> result = model.transform(test0).head() >>> result.prediction 0.0 @@ -519,7 +532,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred DenseVector([1.0, 0.0]) >>> result.rawPrediction DenseVector([1.0, 0.0]) - >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -617,16 +630,16 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred RandomForestParams, TreeClassifierParams, HasCheckpointInterval, JavaMLWritable, JavaMLReadable): """ - `http://en.wikipedia.org/wiki/Random_forest Random Forest` + `Random Forest `_ learning algorithm for classification. It supports both binary and multiclass labels, as well as both continuous and categorical features. >>> import numpy >>> from numpy import allclose - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer - >>> df = sqlContext.createDataFrame([ + >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") @@ -638,7 +651,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True - >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> result = model.transform(test0).head() >>> result.prediction 0.0 @@ -646,9 +659,11 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 0 >>> numpy.argmax(result.rawPrediction) 0 - >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> model.trees + [DecisionTreeClassificationModel (uid=...) of depth..., DecisionTreeClassificationModel...] >>> rfc_path = temp_path + "/rfc" >>> rf.save(rfc_path) >>> rf2 = RandomForestClassifier.load(rfc_path) @@ -680,7 +695,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.RandomForestClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", numTrees=20, featureSubsetStrategy="auto") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -729,21 +744,37 @@ def featureImportances(self): """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, JavaMLReadable): """ - `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` + `Gradient-Boosted Trees (GBTs) `_ learning algorithm for classification. It supports binary labels, as well as both continuous and categorical features. Note: Multiclass labels are not currently supported. + The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. + + Notes on Gradient Boosting vs. TreeBoost: + - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + - Both algorithms learn tree ensembles by minimizing loss functions. + - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + based on the loss function, whereas the original gradient boosting method does not. + - We expect to implement TreeBoost in the future: + `SPARK-4240 `_ + >>> from numpy import allclose - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer - >>> df = sqlContext.createDataFrame([ + >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") @@ -755,12 +786,16 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True - >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 - >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> model.totalNumNodes + 15 + >>> print(model.toDebugString) + GBTClassificationModel (uid=...)...with 5 trees... >>> gbtc_path = temp_path + "gbtc" >>> gbt.save(gbtc_path) >>> gbt2 = GBTClassifier.load(gbtc_path) @@ -773,6 +808,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol True >>> model.treeWeights == model2.treeWeights True + >>> model.trees + [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] .. versionadded:: 1.4.0 """ @@ -798,7 +835,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "org.apache.spark.ml.classification.GBTClassifier", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - lossType="logistic", maxIter=20, stepSize=0.1, seed=None) + lossType="logistic", maxIter=20, stepSize=0.1) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -826,8 +863,7 @@ def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. """ - self._set(lossType=value) - return self + return self._set(lossType=value) @since("1.4.0") def getLossType(self): @@ -859,23 +895,29 @@ def featureImportances(self): """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, JavaMLWritable, JavaMLReadable): + HasRawPredictionCol, HasThresholds, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. - It supports both Multinomial and Bernoulli NB. Multinomial NB - (`http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html`) + It supports both Multinomial and Bernoulli NB. `Multinomial NB + `_ can handle finitely supported discrete data. For example, by converting documents into TF-IDF vectors, it can be used for document classification. By making every vector a - binary (0/1) data, it can also be used as Bernoulli NB - (`http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html`). + binary (0/1) data, it can also be used as `Bernoulli NB + `_. The input feature values must be nonnegative. >>> from pyspark.sql import Row - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... Row(label=0.0, features=Vectors.dense([0.0, 0.0])), ... Row(label=0.0, features=Vectors.dense([0.0, 1.0])), ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))]) @@ -908,6 +950,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H True >>> model.theta == model2.theta True + >>> nb = nb.setThresholds([0.01, 10.00]) + >>> model3 = nb.fit(df) + >>> result = model3.transform(test0).head() + >>> result.prediction + 0.0 .. versionadded:: 1.5.0 """ @@ -921,11 +968,11 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial"): + modelType="multinomial", thresholds=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial") + modelType="multinomial", thresholds=None) """ super(NaiveBayes, self).__init__() self._java_obj = self._new_java_obj( @@ -938,11 +985,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, - modelType="multinomial"): + modelType="multinomial", thresholds=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, \ - modelType="multinomial") + modelType="multinomial", thresholds=None) Sets params for Naive Bayes. """ kwargs = self.setParams._input_kwargs @@ -956,8 +1003,7 @@ def setSmoothing(self, value): """ Sets the value of :py:attr:`smoothing`. """ - self._set(smoothing=value) - return self + return self._set(smoothing=value) @since("1.5.0") def getSmoothing(self): @@ -971,8 +1017,7 @@ def setModelType(self, value): """ Sets the value of :py:attr:`modelType`. """ - self._set(modelType=value) - return self + return self._set(modelType=value) @since("1.5.0") def getModelType(self): @@ -990,7 +1035,7 @@ class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable): """ @property - @since("1.5.0") + @since("2.0.0") def pi(self): """ log of class priors. @@ -998,7 +1043,7 @@ def pi(self): return self._call_java("pi") @property - @since("1.5.0") + @since("2.0.0") def theta(self): """ log of class conditional probabilities. @@ -1008,26 +1053,29 @@ def theta(self): @inherit_doc class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasMaxIter, HasTol, HasSeed, JavaMLWritable, JavaMLReadable): + HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable, + JavaMLReadable): """ + .. note:: Experimental + Classifier trainer based on the Multilayer Perceptron. Each layer has sigmoid activation function, output layer has softmax. Number of inputs has to be equal to the size of feature vectors. Number of outputs has to be equal to the total number of labels. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... (0.0, Vectors.dense([0.0, 0.0])), ... (1.0, Vectors.dense([0.0, 1.0])), ... (1.0, Vectors.dense([1.0, 0.0])), ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"]) - >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=123) + >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 2, 2], blockSize=1, seed=123) >>> model = mlp.fit(df) >>> model.layers - [2, 5, 2] + [2, 2, 2] >>> model.weights.size - 27 - >>> testDF = sqlContext.createDataFrame([ + 12 + >>> testDF = spark.createDataFrame([ ... (Vectors.dense([1.0, 0.0]),), ... (Vectors.dense([0.0, 0.0]),)], ["features"]) >>> model.transform(testDF).show() @@ -1050,6 +1098,12 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, True >>> model.weights == model2.weights True + >>> mlp2 = mlp2.setInitialWeights(list(range(0, 12))) + >>> model3 = mlp2.fit(df) + >>> model3.weights != model2.weights + True + >>> model3.layers == model.layers + True .. versionadded:: 1.6.0 """ @@ -1063,28 +1117,36 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, "remaining data in a partition then it is adjusted to the size of this " + "data. Recommended size is between 10 and 1000, default is 128.", typeConverter=TypeConverters.toInt) + solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " + + "options: l-bfgs, gd.", typeConverter=TypeConverters.toString) + initialWeights = Param(Params._dummy(), "initialWeights", "The initial weights of the model.", + typeConverter=TypeConverters.toVector) @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128, stepSize=0.03, + solver="l-bfgs", initialWeights=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128) + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128, stepSize=0.03, \ + solver="l-bfgs", initialWeights=None) """ super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) - self._setDefault(maxIter=100, tol=1E-4, blockSize=128) + self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128, stepSize=0.03, + solver="l-bfgs", initialWeights=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128) + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128, stepSize=0.03, \ + solver="l-bfgs", initialWeights=None) Sets params for MultilayerPerceptronClassifier. """ kwargs = self.setParams._input_kwargs @@ -1098,8 +1160,7 @@ def setLayers(self, value): """ Sets the value of :py:attr:`layers`. """ - self._set(layers=value) - return self + return self._set(layers=value) @since("1.6.0") def getLayers(self): @@ -1113,8 +1174,7 @@ def setBlockSize(self, value): """ Sets the value of :py:attr:`blockSize`. """ - self._set(blockSize=value) - return self + return self._set(blockSize=value) @since("1.6.0") def getBlockSize(self): @@ -1123,9 +1183,53 @@ def getBlockSize(self): """ return self.getOrDefault(self.blockSize) + @since("2.0.0") + def setStepSize(self, value): + """ + Sets the value of :py:attr:`stepSize`. + """ + return self._set(stepSize=value) + + @since("2.0.0") + def getStepSize(self): + """ + Gets the value of stepSize or its default value. + """ + return self.getOrDefault(self.stepSize) + + @since("2.0.0") + def setSolver(self, value): + """ + Sets the value of :py:attr:`solver`. + """ + return self._set(solver=value) + + @since("2.0.0") + def getSolver(self): + """ + Gets the value of solver or its default value. + """ + return self.getOrDefault(self.solver) + + @since("2.0.0") + def setInitialWeights(self, value): + """ + Sets the value of :py:attr:`initialWeights`. + """ + return self._set(initialWeights=value) + + @since("2.0.0") + def getInitialWeights(self): + """ + Gets the value of initialWeights or its default value. + """ + return self.getOrDefault(self.initialWeights) + class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable): """ + .. note:: Experimental + Model fitted by MultilayerPerceptronClassifier. .. versionadded:: 1.6.0 @@ -1140,7 +1244,7 @@ def layers(self): return self._call_java("javaLayers") @property - @since("1.6.0") + @since("2.0.0") def weights(self): """ vector of initial weights for the model that consists of the weights of layers. @@ -1162,8 +1266,7 @@ def setClassifier(self, value): .. note:: Only LogisticRegression and NaiveBayes are supported now. """ - self._set(classifier=value) - return self + return self._set(classifier=value) @since("2.0.0") def getClassifier(self): @@ -1176,6 +1279,8 @@ def getClassifier(self): @inherit_doc class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): """ + .. note:: Experimental + Reduction of Multiclass Classification to Binary Classification. Performs reduction using one against all strategy. For a multiclass classification with k classes, train k models (one per class). @@ -1183,7 +1288,7 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): is picked to label the example. >>> from pyspark.sql import Row - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> df = sc.parallelize([ ... Row(label=0.0, features=Vectors.dense(1.0, 0.8)), ... Row(label=1.0, features=Vectors.sparse(2, [], [])), @@ -1330,6 +1435,8 @@ def _to_java(self): class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable): """ + .. note:: Experimental + Model fitted by OneVsRest. This stores the models resulting from training k binary classifiers: one for each class. Each example is scored against all k models, and the model with the highest score @@ -1457,21 +1564,23 @@ def _to_java(self): if __name__ == "__main__": import doctest import pyspark.ml.classification - from pyspark.context import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import SparkSession globs = pyspark.ml.classification.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - sc = SparkContext("local[2]", "ml.classification tests") - sqlContext = SQLContext(sc) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.classification tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = sqlContext + globs['spark'] = spark import tempfile temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + spark.stop() finally: from shutil import rmtree try: diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 50ebf4fde1cf5..4dab83362a0a4 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -19,7 +19,7 @@ from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'KMeans', 'KMeansModel', @@ -64,8 +64,23 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte .. note:: Experimental GaussianMixture clustering. + This class performs expectation maximization for multivariate Gaussian + Mixture Models (GMMs). A GMM represents a composite distribution of + independent Gaussian distributions with associated "mixing" weights + specifying each's contribution to the composite. - >>> from pyspark.mllib.linalg import Vectors + Given a set of sample points, this class will maximize the log-likelihood + for a mixture of k Gaussians, iterating until the log-likelihood changes by + less than convergenceTol, or until it has reached the max number of iterations. + While this process is generally guaranteed to converge, it is not guaranteed + to find a global optimum. + + Note: For high-dimensional data (with many features), this algorithm may perform poorly. + This is due to high-dimensional data (a) making it difficult to cluster at all + (based on statistical/theoretical arguments) and (b) numerical issues with + Gaussian distributions. + + >>> from pyspark.ml.linalg import Vectors >>> data = [(Vectors.dense([-0.1, -0.05 ]),), ... (Vectors.dense([-0.01, -0.1]),), @@ -73,7 +88,7 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte ... (Vectors.dense([0.75, 0.935]),), ... (Vectors.dense([-0.83, -0.68]),), ... (Vectors.dense([-0.91, -0.76]),)] - >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> df = spark.createDataFrame(data, ["features"]) >>> gm = GaussianMixture(k=3, tol=0.0001, ... maxIter=10, seed=10) >>> model = gm.fit(df) @@ -84,9 +99,9 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte +--------------------+--------------------+ | mean| cov| +--------------------+--------------------+ - |[-0.0550000000000...|0.002025000000000...| - |[0.82499999999999...|0.005625000000000...| - |[-0.87,-0.7200000...|0.001600000000000...| + |[0.82500000140229...|0.005625000000006...| + |[-0.4777098016092...|0.167969502720916...| + |[-0.4472625243352...|0.167304119758233...| +--------------------+--------------------+ ... >>> transformed = model.transform(df).select("features", "prediction") @@ -109,17 +124,17 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte +--------------------+--------------------+ | mean| cov| +--------------------+--------------------+ - |[-0.0550000000000...|0.002025000000000...| - |[0.82499999999999...|0.005625000000000...| - |[-0.87,-0.7200000...|0.001600000000000...| + |[0.82500000140229...|0.005625000000006...| + |[-0.4777098016092...|0.167969502720916...| + |[-0.4472625243352...|0.167304119758233...| +--------------------+--------------------+ ... .. versionadded:: 2.0.0 """ - k = Param(Params._dummy(), "k", "number of clusters to create", - typeConverter=TypeConverters.toInt) + k = Param(Params._dummy(), "k", "Number of independent Gaussians in the mixture model. " + + "Must be > 1.", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -156,8 +171,7 @@ def setK(self, value): """ Sets the value of :py:attr:`k`. """ - self._set(k=value) - return self + return self._set(k=value) @since("2.0.0") def getK(self): @@ -195,10 +209,10 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol K-means clustering with a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] - >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> df = spark.createDataFrame(data, ["features"]) >>> kmeans = KMeans(k=2, seed=1) >>> model = kmeans.fit(df) >>> centers = model.clusterCenters() @@ -228,15 +242,15 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol .. versionadded:: 1.5.0 """ - k = Param(Params._dummy(), "k", "number of clusters to create", + k = Param(Params._dummy(), "k", "The number of clusters to create. Must be > 1.", typeConverter=TypeConverters.toInt) initMode = Param(Params._dummy(), "initMode", - "the initialization algorithm. This can be either \"random\" to " + + "The initialization algorithm. This can be either \"random\" to " + "choose random points as initial cluster centers, or \"k-means||\" " + "to use a parallel variant of k-means++", typeConverter=TypeConverters.toString) - initSteps = Param(Params._dummy(), "initSteps", "steps for k-means initialization mode", - typeConverter=TypeConverters.toInt) + initSteps = Param(Params._dummy(), "initSteps", "The number of steps for k-means|| " + + "initialization mode. Must be > 0.", typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, featuresCol="features", predictionCol="prediction", k=2, @@ -272,8 +286,7 @@ def setK(self, value): """ Sets the value of :py:attr:`k`. """ - self._set(k=value) - return self + return self._set(k=value) @since("1.5.0") def getK(self): @@ -287,8 +300,7 @@ def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. """ - self._set(initMode=value) - return self + return self._set(initMode=value) @since("1.5.0") def getInitMode(self): @@ -302,8 +314,7 @@ def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. """ - self._set(initSteps=value) - return self + return self._set(initSteps=value) @since("1.5.0") def getInitSteps(self): @@ -351,10 +362,10 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, larger clusters get higher priority. - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), ... (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] - >>> df = sqlContext.createDataFrame(data, ["features"]) + >>> df = spark.createDataFrame(data, ["features"]) >>> bkm = BisectingKMeans(k=2, minDivisibleClusterSize=1.0) >>> model = bkm.fit(df) >>> centers = model.clusterCenters() @@ -384,11 +395,11 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte .. versionadded:: 2.0.0 """ - k = Param(Params._dummy(), "k", "number of clusters to create", + k = Param(Params._dummy(), "k", "The desired number of leaf clusters. Must be > 1.", typeConverter=TypeConverters.toInt) minDivisibleClusterSize = Param(Params._dummy(), "minDivisibleClusterSize", - "the minimum number of points (if >= 1.0) " + - "or the minimum proportion", + "The minimum number of points (if >= 1.0) or the minimum " + + "proportion of points (if < 1.0) of a divisible cluster.", typeConverter=TypeConverters.toFloat) @keyword_only @@ -422,8 +433,7 @@ def setK(self, value): """ Sets the value of :py:attr:`k`. """ - self._set(k=value) - return self + return self._set(k=value) @since("2.0.0") def getK(self): @@ -437,8 +447,7 @@ def setMinDivisibleClusterSize(self, value): """ Sets the value of :py:attr:`minDivisibleClusterSize`. """ - self._set(minDivisibleClusterSize=value) - return self + return self._set(minDivisibleClusterSize=value) @since("2.0.0") def getMinDivisibleClusterSize(self): @@ -523,7 +532,7 @@ def describeTopics(self, maxTermsPerTopic=10): def estimatedDocConcentration(self): """ Value for :py:attr:`LDA.docConcentration` estimated from data. - If Online LDA was used and :py:attr::`LDA.optimizeDocConcentration` was set to false, + If Online LDA was used and :py:attr:`LDA.optimizeDocConcentration` was set to false, then this returns the fixed (given) value for the :py:attr:`LDA.docConcentration` parameter. """ return self._call_java("estimatedDocConcentration") @@ -631,9 +640,9 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter :py:class:`pyspark.ml.feature.Tokenizer` and :py:class:`pyspark.ml.feature.CountVectorizer` can be useful for converting text to word count vectors. - >>> from pyspark.mllib.linalg import Vectors, SparseVector + >>> from pyspark.ml.linalg import Vectors, SparseVector >>> from pyspark.ml.clustering import LDA - >>> df = sqlContext.createDataFrame([[1, Vectors.dense([0.0, 1.0])], + >>> df = spark.createDataFrame([[1, Vectors.dense([0.0, 1.0])], ... [2, SparseVector(2, {0: 1.0})],], ["id", "features"]) >>> lda = LDA(k=2, seed=1, optimizer="em") >>> model = lda.fit(df) @@ -667,7 +676,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter .. versionadded:: 2.0.0 """ - k = Param(Params._dummy(), "k", "number of topics (clusters) to infer", + k = Param(Params._dummy(), "k", "The number of topics (clusters) to infer. Must be > 1.", typeConverter=TypeConverters.toInt) optimizer = Param(Params._dummy(), "optimizer", "Optimizer or inference algorithm used to estimate the LDA model. " @@ -939,21 +948,23 @@ def getKeepLastCheckpoint(self): if __name__ == "__main__": import doctest import pyspark.ml.clustering - from pyspark.context import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import SparkSession globs = pyspark.ml.clustering.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - sc = SparkContext("local[2]", "ml.clustering tests") - sqlContext = SQLContext(sc) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.clustering tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = sqlContext + globs['spark'] = spark import tempfile temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + spark.stop() finally: from shutil import rmtree try: diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py new file mode 100644 index 0000000000000..7d449aaccb44f --- /dev/null +++ b/python/pyspark/ml/common.py @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version >= '3': + long = int + unicode = str + +import py4j.protocol +from py4j.protocol import Py4JJavaError +from py4j.java_gateway import JavaObject +from py4j.java_collections import ListConverter, JavaArray, JavaList + +from pyspark import RDD, SparkContext +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql import DataFrame, SQLContext + +# Hack for support float('inf') in Py4j +_old_smart_decode = py4j.protocol.smart_decode + +_float_str_mapping = { + 'nan': 'NaN', + 'inf': 'Infinity', + '-inf': '-Infinity', +} + + +def _new_smart_decode(obj): + if isinstance(obj, float): + s = str(obj) + return _float_str_mapping.get(s, s) + return _old_smart_decode(obj) + +py4j.protocol.smart_decode = _new_smart_decode + + +_picklable_classes = [ + 'SparseVector', + 'DenseVector', + 'DenseMatrix', +] + + +# this will call the ML version of pythonToJava() +def _to_java_object_rdd(rdd): + """ Return an JavaRDD of Object by unpickling + + It will convert each Python object into Java object by Pyrolite, whenever the + RDD is serialized in batch or not. + """ + rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) + return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True) + + +def _py2java(sc, obj): + """ Convert Python object into Java """ + if isinstance(obj, RDD): + obj = _to_java_object_rdd(obj) + elif isinstance(obj, DataFrame): + obj = obj._jdf + elif isinstance(obj, SparkContext): + obj = obj._jsc + elif isinstance(obj, list): + obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) + elif isinstance(obj, JavaObject): + pass + elif isinstance(obj, (int, long, float, bool, bytes, unicode)): + pass + else: + data = bytearray(PickleSerializer().dumps(obj)) + obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data) + return obj + + +def _java2py(sc, r, encoding="bytes"): + if isinstance(r, JavaObject): + clsName = r.getClass().getSimpleName() + # convert RDD into JavaRDD + if clsName != 'JavaRDD' and clsName.endswith("RDD"): + r = r.toJavaRDD() + clsName = 'JavaRDD' + + if clsName == 'JavaRDD': + jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r) + return RDD(jrdd, sc) + + if clsName == 'Dataset': + return DataFrame(r, SQLContext.getOrCreate(sc)) + + if clsName in _picklable_classes: + r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) + elif isinstance(r, (JavaArray, JavaList)): + try: + r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) + except Py4JJavaError: + pass # not pickable + + if isinstance(r, (bytearray, bytes)): + r = PickleSerializer().loads(bytes(r), encoding=encoding) + return r + + +def callJavaFunc(sc, func, *args): + """ Call Java Function """ + args = [_py2java(sc, a) for a in args] + return _java2py(sc, func(*args)) + + +def inherit_doc(cls): + """ + A decorator that makes a class inherit documentation from its parents. + """ + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 455795f9a083a..1fe8772da772a 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -21,7 +21,7 @@ from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator', 'MulticlassClassificationEvaluator'] @@ -105,14 +105,16 @@ def isLargerBetter(self): @inherit_doc class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): """ + .. note:: Experimental + Evaluator for binary classification, which expects two input columns: rawPrediction and label. The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities). - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]), ... [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)]) - >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) + >>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"]) ... >>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw") >>> evaluator.evaluate(dataset) @@ -147,8 +149,7 @@ def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. """ - self._set(metricName=value) - return self + return self._set(metricName=value) @since("1.4.0") def getMetricName(self): @@ -173,12 +174,14 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", @inherit_doc class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): """ + .. note:: Experimental + Evaluator for Regression, which expects two input columns: prediction and label. >>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5), ... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)] - >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) + >>> dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"]) ... >>> evaluator = RegressionEvaluator(predictionCol="raw") >>> evaluator.evaluate(dataset) @@ -190,11 +193,12 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): .. versionadded:: 1.4.0 """ - # Because we will maximize evaluation value (ref: `CrossValidator`), - # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), - # we take and output the negative of this metric. metricName = Param(Params._dummy(), "metricName", - "metric name in evaluation (mse|rmse|r2|mae)", + """metric name in evaluation - one of: + rmse - root mean squared error (default) + mse - mean squared error + r2 - r^2 metric + mae - mean absolute error.""", typeConverter=TypeConverters.toString) @keyword_only @@ -217,8 +221,7 @@ def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. """ - self._set(metricName=value) - return self + return self._set(metricName=value) @since("1.4.0") def getMetricName(self): @@ -243,25 +246,26 @@ def setParams(self, predictionCol="prediction", labelCol="label", @inherit_doc class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): """ + .. note:: Experimental + Evaluator for Multiclass Classification, which expects two input columns: prediction and label. + >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)] - >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"]) + >>> dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"]) ... >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction") >>> evaluator.evaluate(dataset) 0.66... - >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"}) - 0.66... - >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) + >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"}) 0.66... .. versionadded:: 1.5.0 """ metricName = Param(Params._dummy(), "metricName", "metric name in evaluation " - "(f1|precision|recall|weightedPrecision|weightedRecall)", + "(f1|weightedPrecision|weightedRecall|accuracy)", typeConverter=TypeConverters.toString) @keyword_only @@ -284,8 +288,7 @@ def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. """ - self._set(metricName=value) - return self + return self._set(metricName=value) @since("1.5.0") def getMetricName(self): @@ -308,17 +311,19 @@ def setParams(self, predictionCol="prediction", labelCol="label", if __name__ == "__main__": import doctest - from pyspark.context import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import SparkSession globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - sc = SparkContext("local[2]", "ml.evaluation tests") - sqlContext = SQLContext(sc) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.evaluation tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = sqlContext + globs['spark'] = spark (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py old mode 100644 new mode 100755 index b95d288198b53..2881380152c8d --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,15 +19,13 @@ if sys.version > '3': basestring = str -from py4j.java_collections import JavaArray - from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix +from pyspark.ml.linalg import _convert_to_vector from pyspark.ml.param.shared import * from pyspark.ml.util import JavaMLReadable, JavaMLWritable from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm -from pyspark.mllib.common import inherit_doc -from pyspark.mllib.linalg import _convert_to_vector +from pyspark.ml.common import inherit_doc __all__ = ['Binarizer', 'Bucketizer', @@ -62,11 +60,9 @@ @inherit_doc class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Binarize a column of continuous features given a threshold. - >>> df = sqlContext.createDataFrame([(0.5,)], ["values"]) + >>> df = spark.createDataFrame([(0.5,)], ["values"]) >>> binarizer = Binarizer(threshold=1.0, inputCol="values", outputCol="features") >>> binarizer.transform(df).head().features 0.0 @@ -114,8 +110,7 @@ def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. """ - self._set(threshold=value) - return self + return self._set(threshold=value) @since("1.4.0") def getThreshold(self): @@ -128,11 +123,9 @@ def getThreshold(self): @inherit_doc class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Maps a column of continuous features to a column of feature buckets. - >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) + >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], ... inputCol="values", outputCol="buckets") >>> bucketed = bucketizer.transform(df).collect() @@ -152,7 +145,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav >>> loadedBucketizer.getSplits() == bucketizer.getSplits() True - .. versionadded:: 1.3.0 + .. versionadded:: 1.4.0 """ splits = \ @@ -160,9 +153,9 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav "Split points for mapping continuous features into buckets. With n+1 splits, " + "there are n buckets. A bucket defined by splits x,y holds values in the " + "range [x,y) except the last bucket, which also includes y. The splits " + - "should be strictly increasing. Values at -inf, inf must be explicitly " + - "provided to cover all Double values; otherwise, values outside the splits " + - "specified will be treated as errors.", + "should be of length >= 3 and strictly increasing. Values at -inf, inf must be " + + "explicitly provided to cover all Double values; otherwise, values outside the " + + "splits specified will be treated as errors.", typeConverter=TypeConverters.toListFloat) @keyword_only @@ -190,8 +183,7 @@ def setSplits(self, value): """ Sets the value of :py:attr:`splits`. """ - self._set(splits=value) - return self + return self._set(splits=value) @since("1.4.0") def getSplits(self): @@ -204,11 +196,9 @@ def getSplits(self): @inherit_doc class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. - >>> df = sqlContext.createDataFrame( + >>> df = spark.createDataFrame( ... [(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])], ... ["label", "raw"]) >>> cv = CountVectorizer(inputCol="raw", outputCol="vectors") @@ -295,8 +285,7 @@ def setMinTF(self, value): """ Sets the value of :py:attr:`minTF`. """ - self._set(minTF=value) - return self + return self._set(minTF=value) @since("1.6.0") def getMinTF(self): @@ -310,8 +299,7 @@ def setMinDF(self, value): """ Sets the value of :py:attr:`minDF`. """ - self._set(minDF=value) - return self + return self._set(minDF=value) @since("1.6.0") def getMinDF(self): @@ -325,8 +313,7 @@ def setVocabSize(self, value): """ Sets the value of :py:attr:`vocabSize`. """ - self._set(vocabSize=value) - return self + return self._set(vocabSize=value) @since("1.6.0") def getVocabSize(self): @@ -340,8 +327,7 @@ def setBinary(self, value): """ Sets the value of :py:attr:`binary`. """ - self._set(binary=value) - return self + return self._set(binary=value) @since("2.0.0") def getBinary(self): @@ -356,9 +342,7 @@ def _create_model(self, java_model): class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - - Model fitted by CountVectorizer. + Model fitted by :py:class:`CountVectorizer`. .. versionadded:: 1.6.0 """ @@ -375,19 +359,17 @@ def vocabulary(self): @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero padding is performed on the input vector. It returns a real vector of the same length representing the DCT. The return vector is scaled such that the transform matrix is unitary (aka scaled DCT-II). - More information on - `https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia`. + .. seealso:: `More information on Wikipedia \ + `_. - >>> from pyspark.mllib.linalg import Vectors - >>> df1 = sqlContext.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"]) + >>> from pyspark.ml.linalg import Vectors + >>> df1 = spark.createDataFrame([(Vectors.dense([5.0, 8.0, 6.0]),)], ["vec"]) >>> dct = DCT(inverse=False, inputCol="vec", outputCol="resultVec") >>> df2 = dct.transform(df1) >>> df2.head().resultVec @@ -433,8 +415,7 @@ def setInverse(self, value): """ Sets the value of :py:attr:`inverse`. """ - self._set(inverse=value) - return self + return self._set(inverse=value) @since("1.6.0") def getInverse(self): @@ -448,14 +429,12 @@ def getInverse(self): class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a provided "weight" vector. In other words, it scales each column of the dataset by a scalar multiplier. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"]) + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(Vectors.dense([2.0, 1.0, 3.0]),)], ["values"]) >>> ep = ElementwiseProduct(scalingVec=Vectors.dense([1.0, 2.0, 3.0]), ... inputCol="values", outputCol="eprod") >>> ep.transform(df).head().eprod @@ -495,15 +474,14 @@ def setParams(self, scalingVec=None, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - @since("1.5.0") + @since("2.0.0") def setScalingVec(self, value): """ Sets the value of :py:attr:`scalingVec`. """ - self._set(scalingVec=value) - return self + return self._set(scalingVec=value) - @since("1.5.0") + @since("2.0.0") def getScalingVec(self): """ Gets the value of scalingVec or its default value. @@ -515,8 +493,6 @@ def getScalingVec(self): class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Maps a sequence of terms to their term frequencies using the hashing trick. Currently we use Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32) to calculate the hash code value for the term object. @@ -524,7 +500,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java it is advisable to use a power of two as the numFeatures parameter; otherwise the features will not be mapped evenly to the columns. - >>> df = sqlContext.createDataFrame([(["a", "b", "c"],)], ["words"]) + >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["words"]) >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") >>> hashingTF.transform(df).head().features SparseVector(10, {0: 1.0, 1: 1.0, 2: 1.0}) @@ -573,8 +549,7 @@ def setBinary(self, value): """ Sets the value of :py:attr:`binary`. """ - self._set(binary=value) - return self + return self._set(binary=value) @since("2.0.0") def getBinary(self): @@ -587,15 +562,15 @@ def getBinary(self): @inherit_doc class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Compute the Inverse Document Frequency (IDF) given a collection of documents. - >>> from pyspark.mllib.linalg import DenseVector - >>> df = sqlContext.createDataFrame([(DenseVector([1.0, 2.0]),), + >>> from pyspark.ml.linalg import DenseVector + >>> df = spark.createDataFrame([(DenseVector([1.0, 2.0]),), ... (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"]) >>> idf = IDF(minDocFreq=3, inputCol="tf", outputCol="idf") >>> model = idf.fit(df) + >>> model.idf + DenseVector([0.0, 0.0]) >>> model.transform(df).head().idf DenseVector([0.0, 0.0]) >>> idf.setParams(outputCol="freqs").fit(df).transform(df).collect()[1].freqs @@ -618,7 +593,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab """ minDocFreq = Param(Params._dummy(), "minDocFreq", - "minimum of documents in which a term should appear for filtering", + "minimum number of documents in which a term should appear for filtering", typeConverter=TypeConverters.toInt) @keyword_only @@ -647,8 +622,7 @@ def setMinDocFreq(self, value): """ Sets the value of :py:attr:`minDocFreq`. """ - self._set(minDocFreq=value) - return self + return self._set(minDocFreq=value) @since("1.4.0") def getMinDocFreq(self): @@ -663,13 +637,19 @@ def _create_model(self, java_model): class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - - Model fitted by IDF. + Model fitted by :py:class:`IDF`. .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def idf(self): + """ + Returns the IDF vector. + """ + return self._call_java("idf") + @inherit_doc class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): @@ -680,8 +660,8 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav absolute value in each feature. It does not shift/center the data, and thus does not destroy any sparsity. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"]) >>> maScaler = MaxAbsScaler(inputCol="a", outputCol="scaled") >>> model = maScaler.fit(df) >>> model.transform(df).show() @@ -754,8 +734,6 @@ def maxAbs(self): @inherit_doc class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Rescale each feature individually to a common range [min, max] linearly using column summary statistics, which is also known as min-max normalization or Rescaling. The rescaled value for feature E is calculated as, @@ -767,8 +745,8 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav Note that since zero values will probably be transformed to non-zero values, output of the transformer will be DenseVector even for sparse input. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) >>> mmScaler = MinMaxScaler(inputCol="a", outputCol="scaled") >>> model = mmScaler.fit(df) >>> model.originalMin @@ -832,8 +810,7 @@ def setMin(self, value): """ Sets the value of :py:attr:`min`. """ - self._set(min=value) - return self + return self._set(min=value) @since("1.6.0") def getMin(self): @@ -847,8 +824,7 @@ def setMax(self, value): """ Sets the value of :py:attr:`max`. """ - self._set(max=value) - return self + return self._set(max=value) @since("1.6.0") def getMax(self): @@ -863,8 +839,6 @@ def _create_model(self, java_model): class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Model fitted by :py:class:`MinMaxScaler`. .. versionadded:: 1.6.0 @@ -891,8 +865,6 @@ def originalMax(self): @ignore_unicode_prefix class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - A feature transformer that converts the input array of strings into an array of n-grams. Null values in the input array are ignored. It returns an array of n-grams where each n-gram is represented by a space-separated string of @@ -901,7 +873,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr When the input array length is less than n (number of elements per n-gram), no n-grams are returned. - >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])]) + >>> df = spark.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])]) >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams") >>> ngram.transform(df).head() Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e']) @@ -956,8 +928,7 @@ def setN(self, value): """ Sets the value of :py:attr:`n`. """ - self._set(n=value) - return self + return self._set(n=value) @since("1.5.0") def getN(self): @@ -970,13 +941,11 @@ def getN(self): @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Normalize a vector to have unit norm using the given p-norm. - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> svec = Vectors.sparse(4, {1: 4.0, 3: 3.0}) - >>> df = sqlContext.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], ["dense", "sparse"]) + >>> df = spark.createDataFrame([(Vectors.dense([3.0, -4.0]), svec)], ["dense", "sparse"]) >>> normalizer = Normalizer(p=2.0, inputCol="dense", outputCol="features") >>> normalizer.transform(df).head().features DenseVector([0.6, -0.8]) @@ -1023,8 +992,7 @@ def setP(self, value): """ Sets the value of :py:attr:`p`. """ - self._set(p=value) - return self + return self._set(p=value) @since("1.4.0") def getP(self): @@ -1037,8 +1005,6 @@ def getP(self): @inherit_doc class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - A one-hot encoder that maps a column of category indices to a column of binary vectors, with at most a single one-value per row that indicates the input category index. @@ -1106,8 +1072,7 @@ def setDropLast(self, value): """ Sets the value of :py:attr:`dropLast`. """ - self._set(dropLast=value) - return self + return self._set(dropLast=value) @since("1.4.0") def getDropLast(self): @@ -1121,16 +1086,14 @@ def getDropLast(self): class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - - Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, - which is available at `http://en.wikipedia.org/wiki/Polynomial_expansion`, "In mathematics, an + Perform feature expansion in a polynomial space. As said in `wikipedia of Polynomial Expansion + `_, "In mathematics, an expansion of a product of sums expresses it as a sum of products by using the fact that multiplication distributes over addition". Take a 2-variable feature vector as an example: `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"]) + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"]) >>> px = PolynomialExpansion(degree=2, inputCol="dense", outputCol="expanded") >>> px.transform(df).head().expanded DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) @@ -1175,8 +1138,7 @@ def setDegree(self, value): """ Sets the value of :py:attr:`degree`. """ - self._set(degree=value) - return self + return self._set(degree=value) @since("1.4.0") def getDegree(self): @@ -1187,22 +1149,23 @@ def getDegree(self): @inherit_doc -class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, - JavaMLWritable): +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned - categorical features. The bin ranges are chosen by taking a sample of the data and dividing it - into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, - covering all real values. This attempts to find numBuckets partitions based on a sample of data, - but it may find fewer depending on the data sample values. - - >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) + categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. + The bin ranges are chosen using an approximate algorithm (see the documentation for + :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). + The precision of the approximation can be controlled with the + :py:attr:`relativeError` parameter. + The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values. + + >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) >>> qds = QuantileDiscretizer(numBuckets=2, - ... inputCol="values", outputCol="buckets", seed=123) - >>> qds.getSeed() - 123 + ... inputCol="values", outputCol="buckets", relativeError=0.01) + >>> qds.getRelativeError() + 0.01 >>> bucketizer = qds.fit(df) >>> splits = bucketizer.getSplits() >>> splits[0] @@ -1221,32 +1184,33 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, Jav .. versionadded:: 2.0.0 """ - # a placeholder to make it appear in the generated doc numBuckets = Param(Params._dummy(), "numBuckets", "Maximum number of buckets (quantiles, or " + - "categories) into which data points are grouped. Must be >= 2. Default 2.", + "categories) into which data points are grouped. Must be >= 2.", typeConverter=TypeConverters.toInt) + relativeError = Param(Params._dummy(), "relativeError", "The relative target precision for " + + "the approximate quantile algorithm used to generate buckets. " + + "Must be in the range [0, 1].", + typeConverter=TypeConverters.toFloat) + @keyword_only - def __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None): + def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001): """ - __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None) + __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001) """ super(QuantileDiscretizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", self.uid) - self.numBuckets = Param(self, "numBuckets", - "Maximum number of buckets (quantiles, or " + - "categories) into which data points are grouped. Must be >= 2.") - self._setDefault(numBuckets=2) + self._setDefault(numBuckets=2, relativeError=0.001) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("2.0.0") - def setParams(self, numBuckets=2, inputCol=None, outputCol=None, seed=None): + def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001): """ - setParams(self, numBuckets=2, inputCol=None, outputCol=None, seed=None) + setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001) Set the params for the QuantileDiscretizer """ kwargs = self.setParams._input_kwargs @@ -1257,8 +1221,7 @@ def setNumBuckets(self, value): """ Sets the value of :py:attr:`numBuckets`. """ - self._set(numBuckets=value) - return self + return self._set(numBuckets=value) @since("2.0.0") def getNumBuckets(self): @@ -1267,6 +1230,20 @@ def getNumBuckets(self): """ return self.getOrDefault(self.numBuckets) + @since("2.0.0") + def setRelativeError(self, value): + """ + Sets the value of :py:attr:`relativeError`. + """ + return self._set(relativeError=value) + + @since("2.0.0") + def getRelativeError(self): + """ + Gets the value of relativeError or its default value. + """ + return self.getOrDefault(self.relativeError) + def _create_model(self, java_model): """ Private method to convert the java_model to a Python model. @@ -1280,8 +1257,6 @@ def _create_model(self, java_model): @ignore_unicode_prefix class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to split the text (default) or repeatedly matching the regex (if gaps is false). @@ -1289,7 +1264,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, length. It returns an array of strings that can be empty. - >>> df = sqlContext.createDataFrame([("A B c",)], ["text"]) + >>> df = spark.createDataFrame([("A B c",)], ["text"]) >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words") >>> reTokenizer.transform(df).head() Row(text=u'A B c', words=[u'a', u'b', u'c']) @@ -1319,7 +1294,8 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)", typeConverter=TypeConverters.toInt) - gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens") + gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens " + + "(False)") pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing", typeConverter=TypeConverters.toString) toLowercase = Param(Params._dummy(), "toLowercase", "whether to convert all characters to " + @@ -1355,8 +1331,7 @@ def setMinTokenLength(self, value): """ Sets the value of :py:attr:`minTokenLength`. """ - self._set(minTokenLength=value) - return self + return self._set(minTokenLength=value) @since("1.4.0") def getMinTokenLength(self): @@ -1370,8 +1345,7 @@ def setGaps(self, value): """ Sets the value of :py:attr:`gaps`. """ - self._set(gaps=value) - return self + return self._set(gaps=value) @since("1.4.0") def getGaps(self): @@ -1385,8 +1359,7 @@ def setPattern(self, value): """ Sets the value of :py:attr:`pattern`. """ - self._set(pattern=value) - return self + return self._set(pattern=value) @since("1.4.0") def getPattern(self): @@ -1400,8 +1373,7 @@ def setToLowercase(self, value): """ Sets the value of :py:attr:`toLowercase`. """ - self._set(toLowercase=value) - return self + return self._set(toLowercase=value) @since("2.0.0") def getToLowercase(self): @@ -1414,13 +1386,11 @@ def getToLowercase(self): @inherit_doc class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Implements the transforms which are defined by SQL statement. Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' where '__THIS__' represents the underlying table of the input dataset. - >>> df = sqlContext.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"]) + >>> df = spark.createDataFrame([(0, 1.0, 3.0), (2, 2.0, 5.0)], ["id", "v1", "v2"]) >>> sqlTrans = SQLTransformer( ... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") >>> sqlTrans.transform(df).head() @@ -1462,8 +1432,7 @@ def setStatement(self, value): """ Sets the value of :py:attr:`statement`. """ - self._set(statement=value) - return self + return self._set(statement=value) @since("1.6.0") def getStatement(self): @@ -1476,13 +1445,15 @@ def getStatement(self): @inherit_doc class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) + The "unit std" is computed using the `corrected sample standard deviation \ + `_, + which is computed as the square root of the unbiased sample variance. + + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(Vectors.dense([0.0]),), (Vectors.dense([2.0]),)], ["a"]) >>> standardScaler = StandardScaler(inputCol="a", outputCol="scaled") >>> model = standardScaler.fit(df) >>> model.mean @@ -1540,8 +1511,7 @@ def setWithMean(self, value): """ Sets the value of :py:attr:`withMean`. """ - self._set(withMean=value) - return self + return self._set(withMean=value) @since("1.4.0") def getWithMean(self): @@ -1555,8 +1525,7 @@ def setWithStd(self, value): """ Sets the value of :py:attr:`withStd`. """ - self._set(withStd=value) - return self + return self._set(withStd=value) @since("1.4.0") def getWithStd(self): @@ -1571,15 +1540,13 @@ def _create_model(self, java_model): class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - - Model fitted by StandardScaler. + Model fitted by :py:class:`StandardScaler`. .. versionadded:: 1.4.0 """ @property - @since("1.5.0") + @since("2.0.0") def std(self): """ Standard deviation of the StandardScalerModel. @@ -1587,7 +1554,7 @@ def std(self): return self._call_java("std") @property - @since("1.5.0") + @since("2.0.0") def mean(self): """ Mean of the StandardScalerModel. @@ -1599,8 +1566,6 @@ def mean(self): class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. The indices are in [0, numLabels), ordered by label frequencies. @@ -1663,9 +1628,7 @@ def _create_model(self, java_model): class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - - Model fitted by StringIndexer. + Model fitted by :py:class:`StringIndexer`. .. versionadded:: 1.4.0 """ @@ -1682,8 +1645,6 @@ def labels(self): @inherit_doc class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - A :py:class:`Transformer` that maps a column of indices back to a new column of corresponding string values. The index-string mapping is either from the ML attributes of the input column, @@ -1724,8 +1685,7 @@ def setLabels(self, value): """ Sets the value of :py:attr:`labels`. """ - self._set(labels=value) - return self + return self._set(labels=value) @since("1.6.0") def getLabels(self): @@ -1737,12 +1697,10 @@ def getLabels(self): class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - A feature transformer that filters out stop words from input. Note: null values from input array are preserved unless adding null to stopWords explicitly. - >>> df = sqlContext.createDataFrame([(["a", "b", "c"],)], ["text"]) + >>> df = spark.createDataFrame([(["a", "b", "c"],)], ["text"]) >>> remover = StopWordsRemover(inputCol="text", outputCol="words", stopWords=["b"]) >>> remover.transform(df).head().words == ['a', 'c'] True @@ -1763,28 +1721,23 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl "comparison over the stop words", typeConverter=TypeConverters.toBoolean) @keyword_only - def __init__(self, inputCol=None, outputCol=None, stopWords=None, - caseSensitive=False): + def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): """ - __init__(self, inputCol=None, outputCol=None, stopWords=None,\ - caseSensitive=false) + __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) - stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords - defaultStopWords = list(stopWordsObj.English()) - self._setDefault(stopWords=defaultStopWords, caseSensitive=False) + self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, inputCol=None, outputCol=None, stopWords=None, - caseSensitive=False): + def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): """ - setParams(self, inputCol="input", outputCol="output", stopWords=None,\ - caseSensitive=false) + setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) Sets params for this StopWordRemover. """ kwargs = self.setParams._input_kwargs @@ -1793,44 +1746,51 @@ def setParams(self, inputCol=None, outputCol=None, stopWords=None, @since("1.6.0") def setStopWords(self, value): """ - Specify the stopwords to be filtered. + Sets the value of :py:attr:`stopWords`. """ - self._set(stopWords=value) - return self + return self._set(stopWords=value) @since("1.6.0") def getStopWords(self): """ - Get the stopwords. + Gets the value of :py:attr:`stopWords` or its default value. """ return self.getOrDefault(self.stopWords) @since("1.6.0") def setCaseSensitive(self, value): """ - Set whether to do a case sensitive comparison over the stop words + Sets the value of :py:attr:`caseSensitive`. """ - self._set(caseSensitive=value) - return self + return self._set(caseSensitive=value) @since("1.6.0") def getCaseSensitive(self): """ - Get whether to do a case sensitive comparison over the stop words. + Gets the value of :py:attr:`caseSensitive` or its default value. """ return self.getOrDefault(self.caseSensitive) + @staticmethod + @since("2.0.0") + def loadDefaultStopWords(language): + """ + Loads the default stop words for the given language. + Supported languages: danish, dutch, english, finnish, french, german, hungarian, + italian, norwegian, portuguese, russian, spanish, swedish, turkish + """ + stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWordsRemover + return list(stopWordsObj.loadDefaultStopWords(language)) + @inherit_doc @ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - A tokenizer that converts the input string to lowercase and then splits it by white spaces. - >>> df = sqlContext.createDataFrame([("a b c",)], ["text"]) + >>> df = spark.createDataFrame([("a b c",)], ["text"]) >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") >>> tokenizer.transform(df).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) @@ -1870,7 +1830,7 @@ def __init__(self, inputCol=None, outputCol=None): @since("1.3.0") def setParams(self, inputCol=None, outputCol=None): """ - setParams(self, inputCol="input", outputCol="output") + setParams(self, inputCol=None, outputCol=None) Sets params for this Tokenizer. """ kwargs = self.setParams._input_kwargs @@ -1880,11 +1840,9 @@ def setParams(self, inputCol=None, outputCol=None): @inherit_doc class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - A feature transformer that merges multiple columns into a vector column. - >>> df = sqlContext.createDataFrame([(1, 0, 3)], ["a", "b", "c"]) + >>> df = spark.createDataFrame([(1, 0, 3)], ["a", "b", "c"]) >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features") >>> vecAssembler.transform(df).head().features DenseVector([1.0, 0.0, 3.0]) @@ -1926,9 +1884,7 @@ def setParams(self, inputCols=None, outputCol=None): @inherit_doc class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - - Class for indexing categorical feature columns in a dataset of [[Vector]]. + Class for indexing categorical feature columns in a dataset of `Vector`. This has 2 usage modes: - Automatically identify categorical features (default behavior) @@ -1963,8 +1919,8 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja - Add warning if a categorical feature has only 1 category. - Add option for allowing unknown categories. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([(Vectors.dense([-1.0, 0.0]),), + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),), ... (Vectors.dense([0.0, 1.0]),), (Vectors.dense([0.0, 2.0]),)], ["a"]) >>> indexer = VectorIndexer(maxCategories=2, inputCol="a", outputCol="indexed") >>> model = indexer.fit(df) @@ -2027,8 +1983,7 @@ def setMaxCategories(self, value): """ Sets the value of :py:attr:`maxCategories`. """ - self._set(maxCategories=value) - return self + return self._set(maxCategories=value) @since("1.4.0") def getMaxCategories(self): @@ -2043,9 +1998,17 @@ def _create_model(self, java_model): class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental + Model fitted by :py:class:`VectorIndexer`. + + Transform categorical features to use 0-based indices instead of their original values. + - Categorical features are mapped to indices. + - Continuous features (columns) are left unchanged. + + This also appends metadata to the output column, marking features as Numeric (continuous), + Nominal (categorical), or Binary (either continuous or categorical). + Non-ML metadata is not carried over from the input to the output column. - Model fitted by VectorIndexer. + This maintains vector sparsity. .. versionadded:: 1.4.0 """ @@ -2072,8 +2035,6 @@ def categoryMaps(self): @inherit_doc class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - This class takes a feature vector and outputs a new feature vector with a subarray of the original features. @@ -2084,8 +2045,8 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J The output vector will order features with the selected indices first (in the order given), followed by the selected names (in the order given). - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... (Vectors.dense([-2.0, 2.3, 0.0, 0.0, 1.0]),), ... (Vectors.dense([0.0, 0.0, 0.0, 0.0, 0.0]),), ... (Vectors.dense([0.6, -1.1, -3.0, 4.5, 3.3]),)], ["features"]) @@ -2137,8 +2098,7 @@ def setIndices(self, value): """ Sets the value of :py:attr:`indices`. """ - self._set(indices=value) - return self + return self._set(indices=value) @since("1.6.0") def getIndices(self): @@ -2152,8 +2112,7 @@ def setNames(self, value): """ Sets the value of :py:attr:`names`. """ - self._set(names=value) - return self + return self._set(names=value) @since("1.6.0") def getNames(self): @@ -2168,13 +2127,11 @@ def getNames(self): class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further natural language processing or machine learning process. >>> sent = ("a b " * 100 + "a c " * 10).split(" ") - >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"]) + >>> doc = spark.createDataFrame([(sent,), (sent,)], ["sentence"]) >>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model") >>> model = word2Vec.fit(doc) >>> model.getVectors().show() @@ -2229,28 +2186,33 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has windowSize = Param(Params._dummy(), "windowSize", "the window size (context words from [-window, window]). Default value is 5", typeConverter=TypeConverters.toInt) + maxSentenceLength = Param(Params._dummy(), "maxSentenceLength", + "Maximum length (in words) of each sentence in the input data. " + + "Any sentence longer than this threshold will " + + "be divided into chunks up to the size.", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=None, inputCol=None, outputCol=None, windowSize=5): + seed=None, inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000): """ __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \ - seed=None, inputCol=None, outputCol=None, windowSize=5) + seed=None, inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000) """ super(Word2Vec, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=None, windowSize=5) + windowSize=5, maxSentenceLength=1000) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=None, inputCol=None, outputCol=None, windowSize=5): + seed=None, inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000): """ setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, \ - inputCol=None, outputCol=None, windowSize=5) + inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000) Sets params for this Word2Vec. """ kwargs = self.setParams._input_kwargs @@ -2261,8 +2223,7 @@ def setVectorSize(self, value): """ Sets the value of :py:attr:`vectorSize`. """ - self._set(vectorSize=value) - return self + return self._set(vectorSize=value) @since("1.4.0") def getVectorSize(self): @@ -2276,8 +2237,7 @@ def setNumPartitions(self, value): """ Sets the value of :py:attr:`numPartitions`. """ - self._set(numPartitions=value) - return self + return self._set(numPartitions=value) @since("1.4.0") def getNumPartitions(self): @@ -2291,8 +2251,7 @@ def setMinCount(self, value): """ Sets the value of :py:attr:`minCount`. """ - self._set(minCount=value) - return self + return self._set(minCount=value) @since("1.4.0") def getMinCount(self): @@ -2306,8 +2265,7 @@ def setWindowSize(self, value): """ Sets the value of :py:attr:`windowSize`. """ - self._set(windowSize=value) - return self + return self._set(windowSize=value) @since("2.0.0") def getWindowSize(self): @@ -2316,15 +2274,27 @@ def getWindowSize(self): """ return self.getOrDefault(self.windowSize) + @since("2.0.0") + def setMaxSentenceLength(self, value): + """ + Sets the value of :py:attr:`maxSentenceLength`. + """ + return self._set(maxSentenceLength=value) + + @since("2.0.0") + def getMaxSentenceLength(self): + """ + Gets the value of maxSentenceLength or its default value. + """ + return self.getOrDefault(self.maxSentenceLength) + def _create_model(self, java_model): return Word2VecModel(java_model) class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - - Model fitted by Word2Vec. + Model fitted by :py:class:`Word2Vec`. .. versionadded:: 1.4.0 """ @@ -2353,15 +2323,14 @@ def findSynonyms(self, word, num): @inherit_doc class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - - PCA trains a model to project vectors to a low-dimensional space using PCA. + PCA trains a model to project vectors to a lower dimensional space of the + top :py:attr:`k` principal components. - >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.linalg import Vectors >>> data = [(Vectors.sparse(5, [(1, 1.0), (3, 7.0)]),), ... (Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]),), ... (Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0]),)] - >>> df = sqlContext.createDataFrame(data,["features"]) + >>> df = spark.createDataFrame(data,["features"]) >>> pca = PCA(k=2, inputCol="features", outputCol="pca_features") >>> model = pca.fit(df) >>> model.transform(df).collect()[0].pca_features @@ -2412,8 +2381,7 @@ def setK(self, value): """ Sets the value of :py:attr:`k`. """ - self._set(k=value) - return self + return self._set(k=value) @since("1.5.0") def getK(self): @@ -2428,9 +2396,7 @@ def _create_model(self, java_model): class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable): """ - .. note:: Experimental - - Model fitted by PCA. + Model fitted by :py:class:`PCA`. Transforms vectors to a lower dimensional space. .. versionadded:: 1.5.0 """ @@ -2461,11 +2427,10 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM Implements the transforms required for fitting a dataset against an R model formula. Currently we support a limited subset of the R - operators, including '~', '.', ':', '+', and '-'. Also see the R formula - docs: - http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + operators, including '~', '.', ':', '+', and '-'. Also see the `R formula docs + `_. - >>> df = sqlContext.createDataFrame([ + >>> df = spark.createDataFrame([ ... (1.0, 1.0, "a"), ... (0.0, 2.0, "b"), ... (0.0, 0.0, "a") @@ -2499,6 +2464,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM True >>> loadedRF.getLabelCol() == rf.getLabelCol() True + >>> str(loadedRF) + 'RFormula(y ~ x + s) (uid=...)' >>> modelPath = temp_path + "/rFormulaModel" >>> model.save(modelPath) >>> loadedModel = RFormulaModel.load(modelPath) @@ -2513,6 +2480,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM |0.0|0.0| a|[0.0,1.0]| 0.0| +---+---+---+---------+-----+ ... + >>> str(loadedModel) + 'RFormulaModel(ResolvedRFormula(label=y, terms=[x,s], hasIntercept=true)) (uid=...)' .. versionadded:: 1.5.0 """ @@ -2545,8 +2514,7 @@ def setFormula(self, value): """ Sets the value of :py:attr:`formula`. """ - self._set(formula=value) - return self + return self._set(formula=value) @since("1.5.0") def getFormula(self): @@ -2558,16 +2526,25 @@ def getFormula(self): def _create_model(self, java_model): return RFormulaModel(java_model) + def __str__(self): + formulaStr = self.getFormula() if self.isDefined(self.formula) else "" + return "RFormula(%s) (uid=%s)" % (formulaStr, self.uid) + class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental - Model fitted by :py:class:`RFormula`. + Model fitted by :py:class:`RFormula`. Fitting is required to determine the + factor levels of formula terms. .. versionadded:: 1.5.0 """ + def __str__(self): + resolvedFormula = self._call_java("resolvedFormula") + return "RFormulaModel(%s) (uid=%s)" % (resolvedFormula, self.uid) + @inherit_doc class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable, @@ -2578,8 +2555,8 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja Chi-Squared feature selection, which selects categorical features to use for predicting a categorical label. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame( + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame( ... [(Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0), ... (Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0), ... (Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0)], @@ -2638,8 +2615,7 @@ def setNumTopFeatures(self, value): """ Sets the value of :py:attr:`numTopFeatures`. """ - self._set(numTopFeatures=value) - return self + return self._set(numTopFeatures=value) @since("2.0.0") def getNumTopFeatures(self): @@ -2656,7 +2632,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental - Model fitted by ChiSqSelector. + Model fitted by :py:class:`ChiSqSelector`. .. versionadded:: 2.0.0 """ @@ -2675,8 +2651,7 @@ def selectedFeatures(self): import tempfile import pyspark.ml.feature - from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext + from pyspark.sql import Row, SparkSession globs = globals().copy() features = pyspark.ml.feature.__dict__.copy() @@ -2684,19 +2659,22 @@ def selectedFeatures(self): # The small batch size here ensures that we see multiple batches, # even in these small test examples: - sc = SparkContext("local[2]", "ml.feature tests") - sqlContext = SQLContext(sc) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.feature tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = sqlContext + globs['spark'] = spark testData = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="b"), Row(id=2, label="c"), Row(id=3, label="a"), Row(id=4, label="a"), Row(id=5, label="c")], 2) - globs['stringIndDf'] = sqlContext.createDataFrame(testData) + globs['stringIndDf'] = spark.createDataFrame(testData) temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + spark.stop() finally: from shutil import rmtree try: diff --git a/python/pyspark/ml/linalg/__init__.py b/python/pyspark/ml/linalg/__init__.py new file mode 100644 index 0000000000000..bd0e186e7fcb1 --- /dev/null +++ b/python/pyspark/ml/linalg/__init__.py @@ -0,0 +1,1145 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +MLlib utilities for linear algebra. For dense vectors, MLlib +uses the NumPy C{array} type, so you can simply pass NumPy arrays +around. For sparse vectors, users can construct a L{SparseVector} +object from MLlib or pass SciPy C{scipy.sparse} column vectors if +SciPy is available in their environment. +""" + +import sys +import array +import struct + +if sys.version >= '3': + basestring = str + xrange = range + import copyreg as copy_reg + long = int +else: + from itertools import izip as zip + import copy_reg + +import numpy as np + +from pyspark import since +from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType, BooleanType + + +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', + 'Matrix', 'DenseMatrix', 'SparseMatrix', 'Matrices'] + + +if sys.version_info[:2] == (2, 7): + # speed up pickling array in Python 2.7 + def fast_pickle_array(ar): + return array.array, (ar.typecode, ar.tostring()) + copy_reg.pickle(array.array, fast_pickle_array) + + +# Check whether we have SciPy. MLlib works without it too, but if we have it, some methods, +# such as _dot and _serialize_double_vector, start to support scipy.sparse matrices. + +try: + import scipy.sparse + _have_scipy = True +except: + # No SciPy in environment, but that's okay + _have_scipy = False + + +def _convert_to_vector(l): + if isinstance(l, Vector): + return l + elif type(l) in (array.array, np.array, np.ndarray, list, tuple, xrange): + return DenseVector(l) + elif _have_scipy and scipy.sparse.issparse(l): + assert l.shape[1] == 1, "Expected column vector" + csc = l.tocsc() + return SparseVector(l.shape[0], csc.indices, csc.data) + else: + raise TypeError("Cannot convert type %s into Vector" % type(l)) + + +def _vector_size(v): + """ + Returns the size of the vector. + + >>> _vector_size([1., 2., 3.]) + 3 + >>> _vector_size((1., 2., 3.)) + 3 + >>> _vector_size(array.array('d', [1., 2., 3.])) + 3 + >>> _vector_size(np.zeros(3)) + 3 + >>> _vector_size(np.zeros((3, 1))) + 3 + >>> _vector_size(np.zeros((1, 3))) + Traceback (most recent call last): + ... + ValueError: Cannot treat an ndarray of shape (1, 3) as a vector + """ + if isinstance(v, Vector): + return len(v) + elif type(v) in (array.array, list, tuple, xrange): + return len(v) + elif type(v) == np.ndarray: + if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1): + return len(v) + else: + raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape)) + elif _have_scipy and scipy.sparse.issparse(v): + assert v.shape[1] == 1, "Expected column vector" + return v.shape[0] + else: + raise TypeError("Cannot treat type %s as a vector" % type(v)) + + +def _format_float(f, digits=4): + s = str(round(f, digits)) + if '.' in s: + s = s[:s.index('.') + 1 + digits] + return s + + +def _format_float_list(l): + return [_format_float(x) for x in l] + + +def _double_to_long_bits(value): + if np.isnan(value): + value = float('nan') + # pack double into 64 bits, then unpack as long int + return struct.unpack('Q', struct.pack('d', value))[0] + + +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.ml.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.ml.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise TypeError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + def simpleString(self): + return "vector" + + +class MatrixUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Matrix. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("numRows", IntegerType(), False), + StructField("numCols", IntegerType(), False), + StructField("colPtrs", ArrayType(IntegerType(), False), True), + StructField("rowIndices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True), + StructField("isTransposed", BooleanType(), False)]) + + @classmethod + def module(cls): + return "pyspark.ml.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.ml.linalg.MatrixUDT" + + def serialize(self, obj): + if isinstance(obj, SparseMatrix): + colPtrs = [int(i) for i in obj.colPtrs] + rowIndices = [int(i) for i in obj.rowIndices] + values = [float(v) for v in obj.values] + return (0, obj.numRows, obj.numCols, colPtrs, + rowIndices, values, bool(obj.isTransposed)) + elif isinstance(obj, DenseMatrix): + values = [float(v) for v in obj.values] + return (1, obj.numRows, obj.numCols, None, None, values, + bool(obj.isTransposed)) + else: + raise TypeError("cannot serialize type %r" % (type(obj))) + + def deserialize(self, datum): + assert len(datum) == 7, \ + "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseMatrix(*datum[1:]) + elif tpe == 1: + return DenseMatrix(datum[1], datum[2], datum[5], datum[6]) + else: + raise ValueError("do not recognize type %r" % tpe) + + def simpleString(self): + return "matrix" + + +class Vector(object): + + __UDT__ = VectorUDT() + + """ + Abstract class for DenseVector and SparseVector + """ + def toArray(self): + """ + Convert the vector into an numpy.ndarray + + :return: numpy.ndarray + """ + raise NotImplementedError + + +class DenseVector(Vector): + """ + A dense vector represented by a value array. We use numpy array for + storage and arithmetics will be delegated to the underlying numpy + array. + + >>> v = Vectors.dense([1.0, 2.0]) + >>> u = Vectors.dense([3.0, 4.0]) + >>> v + u + DenseVector([4.0, 6.0]) + >>> 2 - v + DenseVector([1.0, 0.0]) + >>> v / 2 + DenseVector([0.5, 1.0]) + >>> v * u + DenseVector([3.0, 8.0]) + >>> u / v + DenseVector([3.0, 2.0]) + >>> u % 2 + DenseVector([1.0, 0.0]) + """ + def __init__(self, ar): + if isinstance(ar, bytes): + ar = np.frombuffer(ar, dtype=np.float64) + elif not isinstance(ar, np.ndarray): + ar = np.array(ar, dtype=np.float64) + if ar.dtype != np.float64: + ar = ar.astype(np.float64) + self.array = ar + + def __reduce__(self): + return DenseVector, (self.array.tostring(),) + + def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros + """ + return np.count_nonzero(self.array) + + def norm(self, p): + """ + Calculates the norm of a DenseVector. + + >>> a = DenseVector([0, -1, 2, -3]) + >>> a.norm(2) + 3.7... + >>> a.norm(1) + 6.0 + """ + return np.linalg.norm(self.array, p) + + def dot(self, other): + """ + Compute the dot product of two Vectors. We support + (Numpy array, list, SparseVector, or SciPy sparse) + and a target NumPy array that is either 1- or 2-dimensional. + Equivalent to calling numpy.dot of the two vectors. + + >>> dense = DenseVector(array.array('d', [1., 2.])) + >>> dense.dot(dense) + 5.0 + >>> dense.dot(SparseVector(2, [0, 1], [2., 1.])) + 4.0 + >>> dense.dot(range(1, 3)) + 5.0 + >>> dense.dot(np.array(range(1, 3))) + 5.0 + >>> dense.dot([1.,]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> dense.dot(np.reshape([1., 2., 3., 4.], (2, 2), order='F')) + array([ 5., 11.]) + >>> dense.dot(np.reshape([1., 2., 3.], (3, 1), order='F')) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + """ + if type(other) == np.ndarray: + if other.ndim > 1: + assert len(self) == other.shape[0], "dimension mismatch" + return np.dot(self.array, other) + elif _have_scipy and scipy.sparse.issparse(other): + assert len(self) == other.shape[0], "dimension mismatch" + return other.transpose().dot(self.toArray()) + else: + assert len(self) == _vector_size(other), "dimension mismatch" + if isinstance(other, SparseVector): + return other.dot(self) + elif isinstance(other, Vector): + return np.dot(self.toArray(), other.toArray()) + else: + return np.dot(self.toArray(), other) + + def squared_distance(self, other): + """ + Squared distance of two Vectors. + + >>> dense1 = DenseVector(array.array('d', [1., 2.])) + >>> dense1.squared_distance(dense1) + 0.0 + >>> dense2 = np.array([2., 1.]) + >>> dense1.squared_distance(dense2) + 2.0 + >>> dense3 = [2., 1.] + >>> dense1.squared_distance(dense3) + 2.0 + >>> sparse1 = SparseVector(2, [0, 1], [2., 1.]) + >>> dense1.squared_distance(sparse1) + 2.0 + >>> dense1.squared_distance([1.,]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> dense1.squared_distance(SparseVector(1, [0,], [1.,])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + """ + assert len(self) == _vector_size(other), "dimension mismatch" + if isinstance(other, SparseVector): + return other.squared_distance(self) + elif _have_scipy and scipy.sparse.issparse(other): + return _convert_to_vector(other).squared_distance(self) + + if isinstance(other, Vector): + other = other.toArray() + elif not isinstance(other, np.ndarray): + other = np.array(other) + diff = self.toArray() - other + return np.dot(diff, diff) + + def toArray(self): + """ + Returns an numpy.ndarray + """ + return self.array + + @property + def values(self): + """ + Returns a list of values + """ + return self.array + + def __getitem__(self, item): + return self.array[item] + + def __len__(self): + return len(self.array) + + def __str__(self): + return "[" + ",".join([str(v) for v in self.array]) + "]" + + def __repr__(self): + return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) + + def __eq__(self, other): + if isinstance(other, DenseVector): + return np.array_equal(self.array, other.array) + elif isinstance(other, SparseVector): + if len(self) != other.size: + return False + return Vectors._equals(list(xrange(len(self))), self.array, other.indices, other.values) + return False + + def __ne__(self, other): + return not self == other + + def __hash__(self): + size = len(self) + result = 31 + size + nnz = 0 + i = 0 + while i < size and nnz < 128: + if self.array[i] != 0: + result = 31 * result + i + bits = _double_to_long_bits(self.array[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + + def __getattr__(self, item): + return getattr(self.array, item) + + def _delegate(op): + def func(self, other): + if isinstance(other, DenseVector): + other = other.array + return DenseVector(getattr(self.array, op)(other)) + return func + + __neg__ = _delegate("__neg__") + __add__ = _delegate("__add__") + __sub__ = _delegate("__sub__") + __mul__ = _delegate("__mul__") + __div__ = _delegate("__div__") + __truediv__ = _delegate("__truediv__") + __mod__ = _delegate("__mod__") + __radd__ = _delegate("__radd__") + __rsub__ = _delegate("__rsub__") + __rmul__ = _delegate("__rmul__") + __rdiv__ = _delegate("__rdiv__") + __rtruediv__ = _delegate("__rtruediv__") + __rmod__ = _delegate("__rmod__") + + +class SparseVector(Vector): + """ + A simple sparse vector class for passing data to MLlib. Users may + alternatively pass SciPy's {scipy.sparse} data types. + """ + def __init__(self, size, *args): + """ + Create a sparse vector, using either a dictionary, a list of + (index, value) pairs, or two separate arrays of indices and + values (sorted by index). + + :param size: Size of the vector. + :param args: Active entries, as a dictionary {index: value, ...}, + a list of tuples [(index, value), ...], or a list of strictly + increasing indices and a list of corresponding values [index, ...], + [value, ...]. Inactive entries are treated as zeros. + + >>> SparseVector(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) + """ + self.size = int(size) + """ Size of the vector. """ + assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" + if len(args) == 1: + pairs = args[0] + if type(pairs) == dict: + pairs = pairs.items() + pairs = sorted(pairs) + self.indices = np.array([p[0] for p in pairs], dtype=np.int32) + """ A list of indices corresponding to active entries. """ + self.values = np.array([p[1] for p in pairs], dtype=np.float64) + """ A list of values corresponding to active entries. """ + else: + if isinstance(args[0], bytes): + assert isinstance(args[1], bytes), "values should be string too" + if args[0]: + self.indices = np.frombuffer(args[0], np.int32) + self.values = np.frombuffer(args[1], np.float64) + else: + # np.frombuffer() doesn't work well with empty string in older version + self.indices = np.array([], dtype=np.int32) + self.values = np.array([], dtype=np.float64) + else: + self.indices = np.array(args[0], dtype=np.int32) + self.values = np.array(args[1], dtype=np.float64) + assert len(self.indices) == len(self.values), "index and value arrays not same length" + for i in xrange(len(self.indices) - 1): + if self.indices[i] >= self.indices[i + 1]: + raise TypeError( + "Indices %s and %s are not strictly increasing" + % (self.indices[i], self.indices[i + 1])) + + def numNonzeros(self): + """ + Number of nonzero elements. This scans all active values and count non zeros. + """ + return np.count_nonzero(self.values) + + def norm(self, p): + """ + Calculates the norm of a SparseVector. + + >>> a = SparseVector(4, [0, 1], [3., -4.]) + >>> a.norm(1) + 7.0 + >>> a.norm(2) + 5.0 + """ + return np.linalg.norm(self.values, p) + + def __reduce__(self): + return ( + SparseVector, + (self.size, self.indices.tostring(), self.values.tostring())) + + def dot(self, other): + """ + Dot product with a SparseVector or 1- or 2-dimensional Numpy array. + + >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) + >>> a.dot(a) + 25.0 + >>> a.dot(array.array('d', [1., 2., 3., 4.])) + 22.0 + >>> b = SparseVector(4, [2], [1.0]) + >>> a.dot(b) + 0.0 + >>> a.dot(np.array([[1, 1], [2, 2], [3, 3], [4, 4]])) + array([ 22., 22.]) + >>> a.dot([1., 2., 3.]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> a.dot(np.array([1., 2.])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> a.dot(DenseVector([1., 2.])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> a.dot(np.zeros((3, 2))) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + """ + + if isinstance(other, np.ndarray): + if other.ndim not in [2, 1]: + raise ValueError("Cannot call dot with %d-dimensional array" % other.ndim) + assert len(self) == other.shape[0], "dimension mismatch" + return np.dot(self.values, other[self.indices]) + + assert len(self) == _vector_size(other), "dimension mismatch" + + if isinstance(other, DenseVector): + return np.dot(other.array[self.indices], self.values) + + elif isinstance(other, SparseVector): + # Find out common indices. + self_cmind = np.in1d(self.indices, other.indices, assume_unique=True) + self_values = self.values[self_cmind] + if self_values.size == 0: + return 0.0 + else: + other_cmind = np.in1d(other.indices, self.indices, assume_unique=True) + return np.dot(self_values, other.values[other_cmind]) + + else: + return self.dot(_convert_to_vector(other)) + + def squared_distance(self, other): + """ + Squared distance from a SparseVector or 1-dimensional NumPy array. + + >>> a = SparseVector(4, [1, 3], [3.0, 4.0]) + >>> a.squared_distance(a) + 0.0 + >>> a.squared_distance(array.array('d', [1., 2., 3., 4.])) + 11.0 + >>> a.squared_distance(np.array([1., 2., 3., 4.])) + 11.0 + >>> b = SparseVector(4, [2], [1.0]) + >>> a.squared_distance(b) + 26.0 + >>> b.squared_distance(a) + 26.0 + >>> b.squared_distance([1., 2.]) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + >>> b.squared_distance(SparseVector(3, [1,], [1.0,])) + Traceback (most recent call last): + ... + AssertionError: dimension mismatch + """ + assert len(self) == _vector_size(other), "dimension mismatch" + + if isinstance(other, np.ndarray) or isinstance(other, DenseVector): + if isinstance(other, np.ndarray) and other.ndim != 1: + raise Exception("Cannot call squared_distance with %d-dimensional array" % + other.ndim) + if isinstance(other, DenseVector): + other = other.array + sparse_ind = np.zeros(other.size, dtype=bool) + sparse_ind[self.indices] = True + dist = other[sparse_ind] - self.values + result = np.dot(dist, dist) + + other_ind = other[~sparse_ind] + result += np.dot(other_ind, other_ind) + return result + + elif isinstance(other, SparseVector): + result = 0.0 + i, j = 0, 0 + while i < len(self.indices) and j < len(other.indices): + if self.indices[i] == other.indices[j]: + diff = self.values[i] - other.values[j] + result += diff * diff + i += 1 + j += 1 + elif self.indices[i] < other.indices[j]: + result += self.values[i] * self.values[i] + i += 1 + else: + result += other.values[j] * other.values[j] + j += 1 + while i < len(self.indices): + result += self.values[i] * self.values[i] + i += 1 + while j < len(other.indices): + result += other.values[j] * other.values[j] + j += 1 + return result + else: + return self.squared_distance(_convert_to_vector(other)) + + def toArray(self): + """ + Returns a copy of this SparseVector as a 1-dimensional NumPy array. + """ + arr = np.zeros((self.size,), dtype=np.float64) + arr[self.indices] = self.values + return arr + + def __len__(self): + return self.size + + def __str__(self): + inds = "[" + ",".join([str(i) for i in self.indices]) + "]" + vals = "[" + ",".join([str(v) for v in self.values]) + "]" + return "(" + ",".join((str(self.size), inds, vals)) + ")" + + def __repr__(self): + inds = self.indices + vals = self.values + entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i])) + for i in xrange(len(inds))]) + return "SparseVector({0}, {{{1}}})".format(self.size, entries) + + def __eq__(self, other): + if isinstance(other, SparseVector): + return other.size == self.size and np.array_equal(other.indices, self.indices) \ + and np.array_equal(other.values, self.values) + elif isinstance(other, DenseVector): + if self.size != len(other): + return False + return Vectors._equals(self.indices, self.values, list(xrange(len(other))), other.array) + return False + + def __getitem__(self, index): + inds = self.indices + vals = self.values + if not isinstance(index, int): + raise TypeError( + "Indices must be of type integer, got type %s" % type(index)) + + if index >= self.size or index < -self.size: + raise IndexError("Index %d out of bounds." % index) + if index < 0: + index += self.size + + if (inds.size == 0) or (index > inds.item(-1)): + return 0. + + insert_index = np.searchsorted(inds, index) + row_ind = inds[insert_index] + if row_ind == index: + return vals[insert_index] + return 0. + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + result = 31 + self.size + nnz = 0 + i = 0 + while i < len(self.values) and nnz < 128: + if self.values[i] != 0: + result = 31 * result + int(self.indices[i]) + bits = _double_to_long_bits(self.values[i]) + result = 31 * result + (bits ^ (bits >> 32)) + nnz += 1 + i += 1 + return result + + +class Vectors(object): + + """ + Factory methods for working with vectors. Note that dense vectors + are simply represented as NumPy array objects, so there is no need + to covert them for use in MLlib. For sparse vectors, the factory + methods in this class create an MLlib-compatible type, or users + can pass in SciPy's C{scipy.sparse} column vectors. + """ + + @staticmethod + def sparse(size, *args): + """ + Create a sparse vector, using either a dictionary, a list of + (index, value) pairs, or two separate arrays of indices and + values (sorted by index). + + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tuples, + or two sorted lists containing indices and values. + + >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) + """ + return SparseVector(size, *args) + + @staticmethod + def dense(*elements): + """ + Create a dense vector of 64-bit floats from a Python list or numbers. + + >>> Vectors.dense([1, 2, 3]) + DenseVector([1.0, 2.0, 3.0]) + >>> Vectors.dense(1.0, 2.0) + DenseVector([1.0, 2.0]) + """ + if len(elements) == 1 and not isinstance(elements[0], (float, int, long)): + # it's list, numpy.array or other iterable object. + elements = elements[0] + return DenseVector(elements) + + @staticmethod + def squared_distance(v1, v2): + """ + Squared distance between two vectors. + a and b can be of type SparseVector, DenseVector, np.ndarray + or array.array. + + >>> a = Vectors.sparse(4, [(0, 1), (3, 4)]) + >>> b = Vectors.dense([2, 5, 4, 1]) + >>> a.squared_distance(b) + 51.0 + """ + v1, v2 = _convert_to_vector(v1), _convert_to_vector(v2) + return v1.squared_distance(v2) + + @staticmethod + def norm(vector, p): + """ + Find norm of the given vector. + """ + return _convert_to_vector(vector).norm(p) + + @staticmethod + def zeros(size): + return DenseVector(np.zeros(size)) + + @staticmethod + def _equals(v1_indices, v1_values, v2_indices, v2_values): + """ + Check equality between sparse/dense vectors, + v1_indices and v2_indices assume to be strictly increasing. + """ + v1_size = len(v1_values) + v2_size = len(v2_values) + k1 = 0 + k2 = 0 + all_equal = True + while all_equal: + while k1 < v1_size and v1_values[k1] == 0: + k1 += 1 + while k2 < v2_size and v2_values[k2] == 0: + k2 += 1 + + if k1 >= v1_size or k2 >= v2_size: + return k1 >= v1_size and k2 >= v2_size + + all_equal = v1_indices[k1] == v2_indices[k2] and v1_values[k1] == v2_values[k2] + k1 += 1 + k2 += 1 + return all_equal + + +class Matrix(object): + + __UDT__ = MatrixUDT() + + """ + Represents a local matrix. + """ + def __init__(self, numRows, numCols, isTransposed=False): + self.numRows = numRows + self.numCols = numCols + self.isTransposed = isTransposed + + def toArray(self): + """ + Returns its elements in a NumPy ndarray. + """ + raise NotImplementedError + + @staticmethod + def _convert_to_array(array_like, dtype): + """ + Convert Matrix attributes which are array-like or buffer to array. + """ + if isinstance(array_like, bytes): + return np.frombuffer(array_like, dtype=dtype) + return np.asarray(array_like, dtype=dtype) + + +class DenseMatrix(Matrix): + """ + Column-major dense matrix. + """ + def __init__(self, numRows, numCols, values, isTransposed=False): + Matrix.__init__(self, numRows, numCols, isTransposed) + values = self._convert_to_array(values, np.float64) + assert len(values) == numRows * numCols + self.values = values + + def __reduce__(self): + return DenseMatrix, ( + self.numRows, self.numCols, self.values.tostring(), + int(self.isTransposed)) + + def __str__(self): + """ + Pretty printing of a DenseMatrix + + >>> dm = DenseMatrix(2, 2, range(4)) + >>> print(dm) + DenseMatrix([[ 0., 2.], + [ 1., 3.]]) + >>> dm = DenseMatrix(2, 2, range(4), isTransposed=True) + >>> print(dm) + DenseMatrix([[ 0., 1.], + [ 2., 3.]]) + """ + # Inspired by __repr__ in scipy matrices. + array_lines = repr(self.toArray()).splitlines() + + # We need to adjust six spaces which is the difference in number + # of letters between "DenseMatrix" and "array" + x = '\n'.join([(" " * 6 + line) for line in array_lines[1:]]) + return array_lines[0].replace("array", "DenseMatrix") + "\n" + x + + def __repr__(self): + """ + Representation of a DenseMatrix + + >>> dm = DenseMatrix(2, 2, range(4)) + >>> dm + DenseMatrix(2, 2, [0.0, 1.0, 2.0, 3.0], False) + """ + # If the number of values are less than seventeen then return as it is. + # Else return first eight values and last eight values. + if len(self.values) < 17: + entries = _format_float_list(self.values) + else: + entries = ( + _format_float_list(self.values[:8]) + + ["..."] + + _format_float_list(self.values[-8:]) + ) + + entries = ", ".join(entries) + return "DenseMatrix({0}, {1}, [{2}], {3})".format( + self.numRows, self.numCols, entries, self.isTransposed) + + def toArray(self): + """ + Return an numpy.ndarray + + >>> m = DenseMatrix(2, 2, range(4)) + >>> m.toArray() + array([[ 0., 2.], + [ 1., 3.]]) + """ + if self.isTransposed: + return np.asfortranarray( + self.values.reshape((self.numRows, self.numCols))) + else: + return self.values.reshape((self.numRows, self.numCols), order='F') + + def toSparse(self): + """Convert to SparseMatrix""" + if self.isTransposed: + values = np.ravel(self.toArray(), order='F') + else: + values = self.values + indices = np.nonzero(values)[0] + colCounts = np.bincount(indices // self.numRows) + colPtrs = np.cumsum(np.hstack( + (0, colCounts, np.zeros(self.numCols - colCounts.size)))) + values = values[indices] + rowIndices = indices % self.numRows + + return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise IndexError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j >= self.numCols or j < 0: + raise IndexError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + + if self.isTransposed: + return self.values[i * self.numCols + j] + else: + return self.values[i + j * self.numRows] + + def __eq__(self, other): + if (not isinstance(other, DenseMatrix) or + self.numRows != other.numRows or + self.numCols != other.numCols): + return False + + self_values = np.ravel(self.toArray(), order='F') + other_values = np.ravel(other.toArray(), order='F') + return all(self_values == other_values) + + +class SparseMatrix(Matrix): + """Sparse Matrix stored in CSC format.""" + def __init__(self, numRows, numCols, colPtrs, rowIndices, values, + isTransposed=False): + Matrix.__init__(self, numRows, numCols, isTransposed) + self.colPtrs = self._convert_to_array(colPtrs, np.int32) + self.rowIndices = self._convert_to_array(rowIndices, np.int32) + self.values = self._convert_to_array(values, np.float64) + + if self.isTransposed: + if self.colPtrs.size != numRows + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numRows + 1, self.colPtrs.size)) + else: + if self.colPtrs.size != numCols + 1: + raise ValueError("Expected colPtrs of size %d, got %d." + % (numCols + 1, self.colPtrs.size)) + if self.rowIndices.size != self.values.size: + raise ValueError("Expected rowIndices of length %d, got %d." + % (self.rowIndices.size, self.values.size)) + + def __str__(self): + """ + Pretty printing of a SparseMatrix + + >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + >>> print(sm1) + 2 X 2 CSCMatrix + (0,0) 2.0 + (1,0) 3.0 + (1,1) 4.0 + >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + >>> print(sm1) + 2 X 2 CSRMatrix + (0,0) 2.0 + (0,1) 3.0 + (1,1) 4.0 + """ + spstr = "{0} X {1} ".format(self.numRows, self.numCols) + if self.isTransposed: + spstr += "CSRMatrix\n" + else: + spstr += "CSCMatrix\n" + + cur_col = 0 + smlist = [] + + # Display first 16 values. + if len(self.values) <= 16: + zipindval = zip(self.rowIndices, self.values) + else: + zipindval = zip(self.rowIndices[:16], self.values[:16]) + for i, (rowInd, value) in enumerate(zipindval): + if self.colPtrs[cur_col + 1] <= i: + cur_col += 1 + if self.isTransposed: + smlist.append('({0},{1}) {2}'.format( + cur_col, rowInd, _format_float(value))) + else: + smlist.append('({0},{1}) {2}'.format( + rowInd, cur_col, _format_float(value))) + spstr += "\n".join(smlist) + + if len(self.values) > 16: + spstr += "\n.." * 2 + return spstr + + def __repr__(self): + """ + Representation of a SparseMatrix + + >>> sm1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + >>> sm1 + SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2.0, 3.0, 4.0], False) + """ + rowIndices = list(self.rowIndices) + colPtrs = list(self.colPtrs) + + if len(self.values) <= 16: + values = _format_float_list(self.values) + + else: + values = ( + _format_float_list(self.values[:8]) + + ["..."] + + _format_float_list(self.values[-8:]) + ) + rowIndices = rowIndices[:8] + ["..."] + rowIndices[-8:] + + if len(self.colPtrs) > 16: + colPtrs = colPtrs[:8] + ["..."] + colPtrs[-8:] + + values = ", ".join(values) + rowIndices = ", ".join([str(ind) for ind in rowIndices]) + colPtrs = ", ".join([str(ptr) for ptr in colPtrs]) + return "SparseMatrix({0}, {1}, [{2}], [{3}], [{4}], {5})".format( + self.numRows, self.numCols, colPtrs, rowIndices, + values, self.isTransposed) + + def __reduce__(self): + return SparseMatrix, ( + self.numRows, self.numCols, self.colPtrs.tostring(), + self.rowIndices.tostring(), self.values.tostring(), + int(self.isTransposed)) + + def __getitem__(self, indices): + i, j = indices + if i < 0 or i >= self.numRows: + raise IndexError("Row index %d is out of range [0, %d)" + % (i, self.numRows)) + if j < 0 or j >= self.numCols: + raise IndexError("Column index %d is out of range [0, %d)" + % (j, self.numCols)) + + # If a CSR matrix is given, then the row index should be searched + # for in ColPtrs, and the column index should be searched for in the + # corresponding slice obtained from rowIndices. + if self.isTransposed: + j, i = i, j + + colStart = self.colPtrs[j] + colEnd = self.colPtrs[j + 1] + nz = self.rowIndices[colStart: colEnd] + ind = np.searchsorted(nz, i) + colStart + if ind < colEnd and self.rowIndices[ind] == i: + return self.values[ind] + else: + return 0.0 + + def toArray(self): + """ + Return an numpy.ndarray + """ + A = np.zeros((self.numRows, self.numCols), dtype=np.float64, order='F') + for k in xrange(self.colPtrs.size - 1): + startptr = self.colPtrs[k] + endptr = self.colPtrs[k + 1] + if self.isTransposed: + A[k, self.rowIndices[startptr:endptr]] = self.values[startptr:endptr] + else: + A[self.rowIndices[startptr:endptr], k] = self.values[startptr:endptr] + return A + + def toDense(self): + densevals = np.ravel(self.toArray(), order='F') + return DenseMatrix(self.numRows, self.numCols, densevals) + + # TODO: More efficient implementation: + def __eq__(self, other): + return np.all(self.toArray() == other.toArray()) + + +class Matrices(object): + @staticmethod + def dense(numRows, numCols, values): + """ + Create a DenseMatrix + """ + return DenseMatrix(numRows, numCols, values) + + @staticmethod + def sparse(numRows, numCols, colPtrs, rowIndices, values): + """ + Create a SparseMatrix + """ + return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) + + +def _test(): + import doctest + (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index d9513ca5b273d..ade4864e1d785 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -29,8 +29,8 @@ from py4j.java_gateway import JavaObject from pyspark import since +from pyspark.ml.linalg import DenseVector, Vector from pyspark.ml.util import Identifiable -from pyspark.mllib.linalg import DenseVector, Vector __all__ = ['Param', 'Params', 'TypeConverters'] diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index a2acf956bc2ad..c32dcc467d492 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -85,8 +85,7 @@ def set$Name(self, value): """ Sets the value of :py:attr:`$name`. """ - self._set($name=value) - return self + return self._set($name=value) def get$Name(self): """ @@ -125,12 +124,12 @@ def get$Name(self): "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, "TypeConverters.toInt"), ("seed", "random seed.", "hash(type(self).__name__)", "TypeConverters.toInt"), - ("tol", "the convergence tolerance for iterative algorithms.", None, + ("tol", "the convergence tolerance for iterative algorithms (>= 0).", None, "TypeConverters.toFloat"), - ("stepSize", "Step size to be used for each iteration of optimization.", None, + ("stepSize", "Step size to be used for each iteration of optimization (>= 0).", None, "TypeConverters.toFloat"), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + - "out rows with bad values), or error (which will throw an errror). More options may be " + + "out rows with bad values), or error (which will throw an error). More options may be " + "added later.", None, "TypeConverters.toString"), ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 538c0b718ad9b..c5ccf81540d58 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -34,8 +34,7 @@ def setMaxIter(self, value): """ Sets the value of :py:attr:`maxIter`. """ - self._set(maxIter=value) - return self + return self._set(maxIter=value) def getMaxIter(self): """ @@ -58,8 +57,7 @@ def setRegParam(self, value): """ Sets the value of :py:attr:`regParam`. """ - self._set(regParam=value) - return self + return self._set(regParam=value) def getRegParam(self): """ @@ -83,8 +81,7 @@ def setFeaturesCol(self, value): """ Sets the value of :py:attr:`featuresCol`. """ - self._set(featuresCol=value) - return self + return self._set(featuresCol=value) def getFeaturesCol(self): """ @@ -108,8 +105,7 @@ def setLabelCol(self, value): """ Sets the value of :py:attr:`labelCol`. """ - self._set(labelCol=value) - return self + return self._set(labelCol=value) def getLabelCol(self): """ @@ -133,8 +129,7 @@ def setPredictionCol(self, value): """ Sets the value of :py:attr:`predictionCol`. """ - self._set(predictionCol=value) - return self + return self._set(predictionCol=value) def getPredictionCol(self): """ @@ -158,8 +153,7 @@ def setProbabilityCol(self, value): """ Sets the value of :py:attr:`probabilityCol`. """ - self._set(probabilityCol=value) - return self + return self._set(probabilityCol=value) def getProbabilityCol(self): """ @@ -183,8 +177,7 @@ def setRawPredictionCol(self, value): """ Sets the value of :py:attr:`rawPredictionCol`. """ - self._set(rawPredictionCol=value) - return self + return self._set(rawPredictionCol=value) def getRawPredictionCol(self): """ @@ -207,8 +200,7 @@ def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ - self._set(inputCol=value) - return self + return self._set(inputCol=value) def getInputCol(self): """ @@ -231,8 +223,7 @@ def setInputCols(self, value): """ Sets the value of :py:attr:`inputCols`. """ - self._set(inputCols=value) - return self + return self._set(inputCols=value) def getInputCols(self): """ @@ -256,8 +247,7 @@ def setOutputCol(self, value): """ Sets the value of :py:attr:`outputCol`. """ - self._set(outputCol=value) - return self + return self._set(outputCol=value) def getOutputCol(self): """ @@ -280,8 +270,7 @@ def setNumFeatures(self, value): """ Sets the value of :py:attr:`numFeatures`. """ - self._set(numFeatures=value) - return self + return self._set(numFeatures=value) def getNumFeatures(self): """ @@ -304,8 +293,7 @@ def setCheckpointInterval(self, value): """ Sets the value of :py:attr:`checkpointInterval`. """ - self._set(checkpointInterval=value) - return self + return self._set(checkpointInterval=value) def getCheckpointInterval(self): """ @@ -329,8 +317,7 @@ def setSeed(self, value): """ Sets the value of :py:attr:`seed`. """ - self._set(seed=value) - return self + return self._set(seed=value) def getSeed(self): """ @@ -341,10 +328,10 @@ def getSeed(self): class HasTol(Params): """ - Mixin for param tol: the convergence tolerance for iterative algorithms. + Mixin for param tol: the convergence tolerance for iterative algorithms (>= 0). """ - tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", typeConverter=TypeConverters.toFloat) + tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms (>= 0).", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasTol, self).__init__() @@ -353,8 +340,7 @@ def setTol(self, value): """ Sets the value of :py:attr:`tol`. """ - self._set(tol=value) - return self + return self._set(tol=value) def getTol(self): """ @@ -365,10 +351,10 @@ def getTol(self): class HasStepSize(Params): """ - Mixin for param stepSize: Step size to be used for each iteration of optimization. + Mixin for param stepSize: Step size to be used for each iteration of optimization (>= 0). """ - stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", typeConverter=TypeConverters.toFloat) + stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization (>= 0).", typeConverter=TypeConverters.toFloat) def __init__(self): super(HasStepSize, self).__init__() @@ -377,8 +363,7 @@ def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. """ - self._set(stepSize=value) - return self + return self._set(stepSize=value) def getStepSize(self): """ @@ -389,10 +374,10 @@ def getStepSize(self): class HasHandleInvalid(Params): """ - Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. + Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later. """ - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toString) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later.", typeConverter=TypeConverters.toString) def __init__(self): super(HasHandleInvalid, self).__init__() @@ -401,8 +386,7 @@ def setHandleInvalid(self, value): """ Sets the value of :py:attr:`handleInvalid`. """ - self._set(handleInvalid=value) - return self + return self._set(handleInvalid=value) def getHandleInvalid(self): """ @@ -426,8 +410,7 @@ def setElasticNetParam(self, value): """ Sets the value of :py:attr:`elasticNetParam`. """ - self._set(elasticNetParam=value) - return self + return self._set(elasticNetParam=value) def getElasticNetParam(self): """ @@ -451,8 +434,7 @@ def setFitIntercept(self, value): """ Sets the value of :py:attr:`fitIntercept`. """ - self._set(fitIntercept=value) - return self + return self._set(fitIntercept=value) def getFitIntercept(self): """ @@ -476,8 +458,7 @@ def setStandardization(self, value): """ Sets the value of :py:attr:`standardization`. """ - self._set(standardization=value) - return self + return self._set(standardization=value) def getStandardization(self): """ @@ -500,8 +481,7 @@ def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. """ - self._set(thresholds=value) - return self + return self._set(thresholds=value) def getThresholds(self): """ @@ -524,8 +504,7 @@ def setWeightCol(self, value): """ Sets the value of :py:attr:`weightCol`. """ - self._set(weightCol=value) - return self + return self._set(weightCol=value) def getWeightCol(self): """ @@ -549,8 +528,7 @@ def setSolver(self, value): """ Sets the value of :py:attr:`solver`. """ - self._set(solver=value) - return self + return self._set(solver=value) def getSolver(self): """ @@ -573,8 +551,7 @@ def setVarianceCol(self, value): """ Sets the value of :py:attr:`varianceCol`. """ - self._set(varianceCol=value) - return self + return self._set(varianceCol=value) def getVarianceCol(self): """ @@ -603,8 +580,7 @@ def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. """ - self._set(maxDepth=value) - return self + return self._set(maxDepth=value) def getMaxDepth(self): """ @@ -616,8 +592,7 @@ def setMaxBins(self, value): """ Sets the value of :py:attr:`maxBins`. """ - self._set(maxBins=value) - return self + return self._set(maxBins=value) def getMaxBins(self): """ @@ -629,8 +604,7 @@ def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. """ - self._set(minInstancesPerNode=value) - return self + return self._set(minInstancesPerNode=value) def getMinInstancesPerNode(self): """ @@ -642,8 +616,7 @@ def setMinInfoGain(self, value): """ Sets the value of :py:attr:`minInfoGain`. """ - self._set(minInfoGain=value) - return self + return self._set(minInfoGain=value) def getMinInfoGain(self): """ @@ -655,8 +628,7 @@ def setMaxMemoryInMB(self, value): """ Sets the value of :py:attr:`maxMemoryInMB`. """ - self._set(maxMemoryInMB=value) - return self + return self._set(maxMemoryInMB=value) def getMaxMemoryInMB(self): """ @@ -668,8 +640,7 @@ def setCacheNodeIds(self, value): """ Sets the value of :py:attr:`cacheNodeIds`. """ - self._set(cacheNodeIds=value) - return self + return self._set(cacheNodeIds=value) def getCacheNodeIds(self): """ diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 146e403a8f97b..a48f4bb2ad1ba 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -25,7 +25,7 @@ from pyspark.ml.param import Param, Params from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable from pyspark.ml.wrapper import JavaParams -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc @inherit_doc @@ -42,7 +42,7 @@ class Pipeline(Estimator, MLReadable, MLWritable): stage. If a stage is a :py:class:`Transformer`, its :py:meth:`Transformer.transform` method will be called to produce the dataset for the next stage. The fitted model from a - :py:class:`Pipeline` is an :py:class:`PipelineModel`, which + :py:class:`Pipeline` is a :py:class:`PipelineModel`, which consists of fitted models and transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as an identity transformer. @@ -71,8 +71,7 @@ def setStages(self, value): :param value: a list of transformers or estimators :return: the pipeline instance """ - self._set(stages=value) - return self + return self._set(stages=value) @since("1.3.0") def getStages(self): diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 08770d9981277..e28d38bd19f80 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -19,7 +19,7 @@ from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc __all__ = ['ALS', 'ALSModel'] @@ -54,8 +54,8 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha and update the products based on these messages. For implicit preference data, the algorithm used is based on - "Collaborative Filtering for Implicit Feedback Datasets", available - at `http://dx.doi.org/10.1109/ICDM.2008.22`, adapted for the blocked + `"Collaborative Filtering for Implicit Feedback Datasets", + `_, adapted for the blocked approach used here. Essentially instead of finding the low-rank approximations to the @@ -65,16 +65,16 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha indicated user preferences rather than explicit ratings given to items. - >>> df = sqlContext.createDataFrame( + >>> df = spark.createDataFrame( ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], ... ["user", "item", "rating"]) - >>> als = ALS(rank=10, maxIter=5) + >>> als = ALS(rank=10, maxIter=5, seed=0) >>> model = als.fit(df) >>> model.rank 10 >>> model.userFactors.orderBy("id").collect() [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)] - >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) + >>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] Row(user=0, item=2, prediction=-0.13807615637779236) @@ -110,22 +110,20 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha typeConverter=TypeConverters.toBoolean) alpha = Param(Params._dummy(), "alpha", "alpha for implicit preference", typeConverter=TypeConverters.toFloat) - userCol = Param(Params._dummy(), "userCol", "column name for user ids", - typeConverter=TypeConverters.toString) - itemCol = Param(Params._dummy(), "itemCol", "column name for item ids", - typeConverter=TypeConverters.toString) + userCol = Param(Params._dummy(), "userCol", "column name for user ids. Ids must be within " + + "the integer value range.", typeConverter=TypeConverters.toString) + itemCol = Param(Params._dummy(), "itemCol", "column name for item ids. Ids must be within " + + "the integer value range.", typeConverter=TypeConverters.toString) ratingCol = Param(Params._dummy(), "ratingCol", "column name for ratings", typeConverter=TypeConverters.toString) nonnegative = Param(Params._dummy(), "nonnegative", "whether to use nonnegative constraint for least squares", typeConverter=TypeConverters.toBoolean) intermediateStorageLevel = Param(Params._dummy(), "intermediateStorageLevel", - "StorageLevel for intermediate datasets. Cannot be 'NONE'. " + - "Default: 'MEMORY_AND_DISK'.", + "StorageLevel for intermediate datasets. Cannot be 'NONE'.", typeConverter=TypeConverters.toString) finalStorageLevel = Param(Params._dummy(), "finalStorageLevel", - "StorageLevel for ALS model factors. " + - "Default: 'MEMORY_AND_DISK'.", + "StorageLevel for ALS model factors.", typeConverter=TypeConverters.toString) @keyword_only @@ -144,7 +142,7 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB super(ALS, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", finalStorageLevel="MEMORY_AND_DISK") @@ -177,8 +175,7 @@ def setRank(self, value): """ Sets the value of :py:attr:`rank`. """ - self._set(rank=value) - return self + return self._set(rank=value) @since("1.4.0") def getRank(self): @@ -192,8 +189,7 @@ def setNumUserBlocks(self, value): """ Sets the value of :py:attr:`numUserBlocks`. """ - self._set(numUserBlocks=value) - return self + return self._set(numUserBlocks=value) @since("1.4.0") def getNumUserBlocks(self): @@ -207,8 +203,7 @@ def setNumItemBlocks(self, value): """ Sets the value of :py:attr:`numItemBlocks`. """ - self._set(numItemBlocks=value) - return self + return self._set(numItemBlocks=value) @since("1.4.0") def getNumItemBlocks(self): @@ -223,15 +218,14 @@ def setNumBlocks(self, value): Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. """ self._set(numUserBlocks=value) - self._set(numItemBlocks=value) + return self._set(numItemBlocks=value) @since("1.4.0") def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. """ - self._set(implicitPrefs=value) - return self + return self._set(implicitPrefs=value) @since("1.4.0") def getImplicitPrefs(self): @@ -245,8 +239,7 @@ def setAlpha(self, value): """ Sets the value of :py:attr:`alpha`. """ - self._set(alpha=value) - return self + return self._set(alpha=value) @since("1.4.0") def getAlpha(self): @@ -260,8 +253,7 @@ def setUserCol(self, value): """ Sets the value of :py:attr:`userCol`. """ - self._set(userCol=value) - return self + return self._set(userCol=value) @since("1.4.0") def getUserCol(self): @@ -275,8 +267,7 @@ def setItemCol(self, value): """ Sets the value of :py:attr:`itemCol`. """ - self._set(itemCol=value) - return self + return self._set(itemCol=value) @since("1.4.0") def getItemCol(self): @@ -290,8 +281,7 @@ def setRatingCol(self, value): """ Sets the value of :py:attr:`ratingCol`. """ - self._set(ratingCol=value) - return self + return self._set(ratingCol=value) @since("1.4.0") def getRatingCol(self): @@ -305,8 +295,7 @@ def setNonnegative(self, value): """ Sets the value of :py:attr:`nonnegative`. """ - self._set(nonnegative=value) - return self + return self._set(nonnegative=value) @since("1.4.0") def getNonnegative(self): @@ -320,8 +309,7 @@ def setIntermediateStorageLevel(self, value): """ Sets the value of :py:attr:`intermediateStorageLevel`. """ - self._set(intermediateStorageLevel=value) - return self + return self._set(intermediateStorageLevel=value) @since("2.0.0") def getIntermediateStorageLevel(self): @@ -335,8 +323,7 @@ def setFinalStorageLevel(self, value): """ Sets the value of :py:attr:`finalStorageLevel`. """ - self._set(finalStorageLevel=value) - return self + return self._set(finalStorageLevel=value) @since("2.0.0") def getFinalStorageLevel(self): @@ -381,21 +368,23 @@ def itemFactors(self): if __name__ == "__main__": import doctest import pyspark.ml.recommendation - from pyspark.context import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import SparkSession globs = pyspark.ml.recommendation.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - sc = SparkContext("local[2]", "ml.recommendation tests") - sqlContext = SQLContext(sc) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.recommendation tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = sqlContext + globs['spark'] = spark import tempfile temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + spark.stop() finally: from shutil import rmtree try: diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0f08f9b9737e2..d88dc75353598 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -21,7 +21,7 @@ from pyspark.ml.param.shared import * from pyspark.ml.util import * from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc from pyspark.sql import DataFrame @@ -29,6 +29,7 @@ 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', 'GBTRegressionModel', 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel', + 'GeneralizedLinearRegressionSummary', 'GeneralizedLinearRegressionTrainingSummary', 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', @@ -45,26 +46,30 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction The learning objective is to minimize the squared error, with regularization. The specific squared error loss function used is: L = 1/2n ||A coefficients - y||^2^ - This support multiple types of regularization: - - none (a.k.a. ordinary least squares) - - L2 (ridge regression) - - L1 (Lasso) - - L2 + L1 (elastic net) + This supports multiple types of regularization: - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + * none (a.k.a. ordinary least squares) + + * L2 (ridge regression) + + * L1 (Lasso) + + * L2 + L1 (elastic net) + + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... (1.0, 2.0, Vectors.dense(1.0)), ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight") >>> model = lr.fit(df) - >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 True >>> abs(model.coefficients[0] - 1.0) < 0.001 True >>> abs(model.intercept - 0.0) < 0.001 True - >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> abs(model.transform(test1).head().prediction - 1.0) < 0.001 True >>> lr.setParams("vector") @@ -123,13 +128,13 @@ def _create_model(self, java_model): class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ - Model fitted by LinearRegression. + Model fitted by :class:`LinearRegression`. .. versionadded:: 1.4.0 """ @property - @since("1.6.0") + @since("2.0.0") def coefficients(self): """ Model coefficients. @@ -229,7 +234,9 @@ def explainedVariance(self): """ Returns the explained variance regression score. explainedVariance = 1 - variance(y - \hat{y}) / variance(y) - Reference: http://en.wikipedia.org/wiki/Explained_variation + + .. seealso:: `Wikipedia explain variation \ + `_ Note: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -283,7 +290,9 @@ def rootMeanSquaredError(self): def r2(self): """ Returns R^2^, the coefficient of determination. - Reference: http://en.wikipedia.org/wiki/Coefficient_of_determination + + .. seealso:: `Wikipedia coefficient of determination \ + ` Note: This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. This will change in later Spark @@ -398,18 +407,16 @@ def totalIterations(self): class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasWeightCol, JavaMLWritable, JavaMLReadable): """ - .. note:: Experimental - Currently implemented using parallelized pool adjacent violators algorithm. Only univariate (single feature) algorithm supported. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> ir = IsotonicRegression() >>> model = ir.fit(df) - >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 >>> model.boundaries @@ -426,6 +433,8 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti True >>> model.predictions == model2.predictions True + + .. versionadded:: 1.6.0 """ isotonic = \ @@ -469,8 +478,7 @@ def setIsotonic(self, value): """ Sets the value of :py:attr:`isotonic`. """ - self._set(isotonic=value) - return self + return self._set(isotonic=value) def getIsotonic(self): """ @@ -482,8 +490,7 @@ def setFeatureIndex(self, value): """ Sets the value of :py:attr:`featureIndex`. """ - self._set(featureIndex=value) - return self + return self._set(featureIndex=value) def getFeatureIndex(self): """ @@ -494,19 +501,21 @@ def getFeatureIndex(self): class IsotonicRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ - .. note:: Experimental + Model fitted by :class:`IsotonicRegression`. - Model fitted by IsotonicRegression. + .. versionadded:: 1.6.0 """ @property + @since("1.6.0") def boundaries(self): """ - Model boundaries. + Boundaries in increasing order for which predictions are known. """ return self._call_java("boundaries") @property + @since("1.6.0") def predictions(self): """ Predictions associated with the boundaries at the same index, monotone because of isotonic @@ -532,8 +541,7 @@ def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. """ - self._set(subsamplingRate=value) - return self + return self._set(subsamplingRate=value) @since("1.4.0") def getSubsamplingRate(self): @@ -562,8 +570,7 @@ def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. """ - self._set(impurity=value) - return self + return self._set(impurity=value) @since("1.4.0") def getImpurity(self): @@ -584,7 +591,7 @@ class RandomForestParams(TreeEnsembleParams): featureSubsetStrategy = \ Param(Params._dummy(), "featureSubsetStrategy", "The number of features to consider for splits at each tree node. Supported " + - "options: " + ", ".join(supportedFeatureSubsetStrategies), + "options: " + ", ".join(supportedFeatureSubsetStrategies) + " (0.0-1.0], [1-n].", typeConverter=TypeConverters.toString) def __init__(self): @@ -595,8 +602,7 @@ def setNumTrees(self, value): """ Sets the value of :py:attr:`numTrees`. """ - self._set(numTrees=value) - return self + return self._set(numTrees=value) @since("1.4.0") def getNumTrees(self): @@ -610,8 +616,7 @@ def setFeatureSubsetStrategy(self, value): """ Sets the value of :py:attr:`featureSubsetStrategy`. """ - self._set(featureSubsetStrategy=value) - return self + return self._set(featureSubsetStrategy=value) @since("1.4.0") def getFeatureSubsetStrategy(self): @@ -633,12 +638,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval, HasSeed, JavaMLWritable, JavaMLReadable, HasVarianceCol): """ - `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` + `Decision tree `_ learning algorithm for regression. It supports both continuous and categorical features. - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance") @@ -649,10 +654,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi 3 >>> model.featureImportances SparseVector(1, {0: 1.0}) - >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 - >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 >>> dtr_path = temp_path + "/dtr" @@ -715,7 +720,8 @@ def _create_model(self, java_model): @inherit_doc class DecisionTreeModel(JavaModel): - """Abstraction for Decision Tree models. + """ + Abstraction for Decision Tree models. .. versionadded:: 1.5.0 """ @@ -732,23 +738,54 @@ def depth(self): """Return depth of the decision tree.""" return self._call_java("depth") + @property + @since("2.0.0") + def toDebugString(self): + """Full description of model.""" + return self._call_java("toDebugString") + def __repr__(self): return self._call_java("toString") @inherit_doc class TreeEnsembleModels(JavaModel): - """Represents a tree ensemble model. + """ + (private abstraction) - .. versionadded:: 1.5.0 + Represents a tree ensemble model. """ + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeModel(m) for m in list(self._call_java("trees"))] + + @property + @since("2.0.0") + def getNumTrees(self): + """Number of trees in ensemble.""" + return self._call_java("getNumTrees") + @property @since("1.5.0") def treeWeights(self): """Return the weights for each tree""" return list(self._call_java("javaTreeWeights")) + @property + @since("2.0.0") + def totalNumNodes(self): + """Total number of nodes, summed over all trees in the ensemble.""" + return self._call_java("totalNumNodes") + + @property + @since("2.0.0") + def toDebugString(self): + """Full description of model.""" + return self._call_java("toDebugString") + def __repr__(self): return self._call_java("toString") @@ -756,7 +793,7 @@ def __repr__(self): @inherit_doc class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable): """ - Model fitted by DecisionTreeRegressor. + Model fitted by :class:`DecisionTreeRegressor`. .. versionadded:: 1.4.0 """ @@ -788,13 +825,13 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi RandomForestParams, TreeRegressorParams, HasCheckpointInterval, JavaMLWritable, JavaMLReadable): """ - `http://en.wikipedia.org/wiki/Random_forest Random Forest` + `Random Forest `_ learning algorithm for regression. It supports both continuous and categorical features. >>> from numpy import allclose - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42) @@ -803,10 +840,14 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 1.0]) True - >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 - >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.trees + [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> model.getNumTrees + 2 + >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 >>> rfr_path = temp_path + "/rfr" @@ -841,7 +882,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "org.apache.spark.ml.regression.RandomForestRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="variance", subsamplingRate=1.0, seed=None, numTrees=20, + impurity="variance", subsamplingRate=1.0, numTrees=20, featureSubsetStrategy="auto") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -870,11 +911,17 @@ def _create_model(self, java_model): class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ - Model fitted by RandomForestRegressor. + Model fitted by :class:`RandomForestRegressor`. .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @property @since("2.0.0") def featureImportances(self): @@ -894,27 +941,29 @@ def featureImportances(self): @inherit_doc class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable, - JavaMLReadable): + JavaMLReadable, TreeRegressorParams): """ - `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` + `Gradient-Boosted Trees (GBTs) `_ learning algorithm for regression. It supports both continuous and categorical features. >>> from numpy import allclose - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) + >>> print(gbt.getImpurity()) + variance >>> model = gbt.fit(df) >>> model.featureImportances SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True - >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 - >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 >>> gbtr_path = temp_path + "gbtr" @@ -929,6 +978,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, True >>> model.treeWeights == model2.treeWeights True + >>> model.trees + [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] .. versionadded:: 1.4.0 """ @@ -942,19 +993,21 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None): + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, + impurity="variance"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None) + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ + impurity="variance") """ super(GBTRegressor, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, - seed=None) + impurity="variance") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -963,12 +1016,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None): + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, + impuriy="variance"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \ - checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None) + checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \ + impurity="variance") Sets params for Gradient Boosted Tree Regression. """ kwargs = self.setParams._input_kwargs @@ -982,8 +1037,7 @@ def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. """ - self._set(lossType=value) - return self + return self._set(lossType=value) @since("1.4.0") def getLossType(self): @@ -995,7 +1049,7 @@ def getLossType(self): class GBTRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ - Model fitted by GBTRegressor. + Model fitted by :class:`GBTRegressor`. .. versionadded:: 1.4.0 """ @@ -1015,11 +1069,19 @@ def featureImportances(self): """ return self._call_java("featureImportances") + @property + @since("2.0.0") + def trees(self): + """Trees in this ensemble. Warning: These have null parent Estimators.""" + return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable): """ + .. note:: Experimental + Accelerated Failure Time (AFT) Model Survival Regression Fit a parametric AFT survival regression model based on the Weibull distribution @@ -1027,8 +1089,8 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi .. seealso:: `AFT Model `_ - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(1.0), 1.0), ... (0.0, Vectors.sparse(1, [], []), 0.0)], ["label", "features", "censor"]) >>> aftsr = AFTSurvivalRegression() @@ -1120,8 +1182,7 @@ def setCensorCol(self, value): """ Sets the value of :py:attr:`censorCol`. """ - self._set(censorCol=value) - return self + return self._set(censorCol=value) @since("1.6.0") def getCensorCol(self): @@ -1135,8 +1196,7 @@ def setQuantileProbabilities(self, value): """ Sets the value of :py:attr:`quantileProbabilities`. """ - self._set(quantileProbabilities=value) - return self + return self._set(quantileProbabilities=value) @since("1.6.0") def getQuantileProbabilities(self): @@ -1150,8 +1210,7 @@ def setQuantilesCol(self, value): """ Sets the value of :py:attr:`quantilesCol`. """ - self._set(quantilesCol=value) - return self + return self._set(quantilesCol=value) @since("1.6.0") def getQuantilesCol(self): @@ -1163,13 +1222,15 @@ def getQuantilesCol(self): class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ - Model fitted by AFTSurvivalRegression. + .. note:: Experimental + + Model fitted by :class:`AFTSurvivalRegression`. .. versionadded:: 1.6.0 """ @property - @since("1.6.0") + @since("2.0.0") def coefficients(self): """ Model coefficients. @@ -1192,12 +1253,14 @@ def scale(self): """ return self._call_java("scale") + @since("2.0.0") def predictQuantiles(self, features): """ Predicted Quantiles """ return self._call_java("predictQuantiles", features) + @since("2.0.0") def predict(self, features): """ Predicted value @@ -1210,28 +1273,37 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol, HasSolver, JavaMLWritable, JavaMLReadable): """ + .. note:: Experimental + Generalized Linear Regression. Fit a Generalized Linear Model specified by giving a symbolic description of the linear predictor (link function) and a description of the error distribution (family). It supports "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family is listed below. The first link function of each family is the default one. - - "gaussian" -> "identity", "log", "inverse" - - "binomial" -> "logit", "probit", "cloglog" - - "poisson" -> "log", "identity", "sqrt" - - "gamma" -> "inverse", "identity", "log" + + * "gaussian" -> "identity", "log", "inverse" + + * "binomial" -> "logit", "probit", "cloglog" + + * "poisson" -> "log", "identity", "sqrt" + + * "gamma" -> "inverse", "identity", "log" .. seealso:: `GLM `_ - >>> from pyspark.mllib.linalg import Vectors - >>> df = sqlContext.createDataFrame([ + >>> from pyspark.ml.linalg import Vectors + >>> df = spark.createDataFrame([ ... (1.0, Vectors.dense(0.0, 0.0)), ... (1.0, Vectors.dense(1.0, 2.0)), ... (2.0, Vectors.dense(0.0, 0.0)), ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"]) - >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity") + >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p") >>> model = glr.fit(df) - >>> abs(model.transform(df).head().prediction - 1.5) < 0.001 + >>> transformed = model.transform(df) + >>> abs(transformed.head().prediction - 1.5) < 0.001 + True + >>> abs(transformed.head().p - 1.5) < 0.001 True >>> model.coefficients DenseVector([1.5..., -1.0...]) @@ -1255,21 +1327,23 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha family = Param(Params._dummy(), "family", "The name of family which is a description of " + "the error distribution to be used in the model. Supported options: " + - "gaussian(default), binomial, poisson and gamma.", + "gaussian (default), binomial, poisson and gamma.", typeConverter=TypeConverters.toString) link = Param(Params._dummy(), "link", "The name of link function which provides the " + "relationship between the linear predictor and the mean of the distribution " + "function. Supported options: identity, log, inverse, logit, probit, cloglog " + "and sqrt.", typeConverter=TypeConverters.toString) + linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " + + "predictor) column name", typeConverter=TypeConverters.toString) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls"): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): """ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls") + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) """ super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -1282,11 +1356,11 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred @since("2.0.0") def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls"): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): """ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls") + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) Sets params for generalized linear regression. """ kwargs = self.setParams._input_kwargs @@ -1300,8 +1374,7 @@ def setFamily(self, value): """ Sets the value of :py:attr:`family`. """ - self._set(family=value) - return self + return self._set(family=value) @since("2.0.0") def getFamily(self): @@ -1310,13 +1383,26 @@ def getFamily(self): """ return self.getOrDefault(self.family) + @since("2.0.0") + def setLinkPredictionCol(self, value): + """ + Sets the value of :py:attr:`linkPredictionCol`. + """ + return self._set(linkPredictionCol=value) + + @since("2.0.0") + def getLinkPredictionCol(self): + """ + Gets the value of linkPredictionCol or its default value. + """ + return self.getOrDefault(self.linkPredictionCol) + @since("2.0.0") def setLink(self, value): """ Sets the value of :py:attr:`link`. """ - self._set(link=value) - return self + return self._set(link=value) @since("2.0.0") def getLink(self): @@ -1328,7 +1414,9 @@ def getLink(self): class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ - Model fitted by GeneralizedLinearRegression. + .. note:: Experimental + + Model fitted by :class:`GeneralizedLinearRegression`. .. versionadded:: 2.0.0 """ @@ -1349,25 +1437,225 @@ def intercept(self): """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, deviance, pValues) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_glrt_summary = self._call_java("summary") + return GeneralizedLinearRegressionTrainingSummary(java_glrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_glr_summary = self._call_java("evaluate", dataset) + return GeneralizedLinearRegressionSummary(java_glr_summary) + + +class GeneralizedLinearRegressionSummary(JavaWrapper): + """ + .. note:: Experimental + + Generalized linear regression results evaluated on a dataset. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Predictions output by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def predictionCol(self): + """ + Field in :py:attr:`predictions` which gives the predicted value of each instance. + This is set to a new column name if the original model's `predictionCol` is not set. + """ + return self._call_java("predictionCol") + + @property + @since("2.0.0") + def rank(self): + """ + The numeric rank of the fitted linear model. + """ + return self._call_java("rank") + + @property + @since("2.0.0") + def degreesOfFreedom(self): + """ + Degrees of freedom. + """ + return self._call_java("degreesOfFreedom") + + @property + @since("2.0.0") + def residualDegreeOfFreedom(self): + """ + The residual degrees of freedom. + """ + return self._call_java("residualDegreeOfFreedom") + + @property + @since("2.0.0") + def residualDegreeOfFreedomNull(self): + """ + The residual degrees of freedom for the null model. + """ + return self._call_java("residualDegreeOfFreedomNull") + + @since("2.0.0") + def residuals(self, residualsType="deviance"): + """ + Get the residuals of the fitted model by type. + + :param residualsType: The type of residuals which should be returned. + Supported options: deviance (default), pearson, working, and response. + """ + return self._call_java("residuals", residualsType) + + @property + @since("2.0.0") + def nullDeviance(self): + """ + The deviance for the null model. + """ + return self._call_java("nullDeviance") + + @property + @since("2.0.0") + def deviance(self): + """ + The deviance for the fitted model. + """ + return self._call_java("deviance") + + @property + @since("2.0.0") + def dispersion(self): + """ + The dispersion of the fitted model. + It is taken as 1.0 for the "binomial" and "poisson" families, and otherwise + estimated by the residual Pearson's Chi-Squared statistic (which is defined as + sum of the squares of the Pearson residuals) divided by the residual degrees of freedom. + """ + return self._call_java("dispersion") + + @property + @since("2.0.0") + def aic(self): + """ + Akaike's "An Information Criterion"(AIC) for the fitted model. + """ + return self._call_java("aic") + + +@inherit_doc +class GeneralizedLinearRegressionTrainingSummary(GeneralizedLinearRegressionSummary): + """ + .. note:: Experimental + + Generalized linear regression training results. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def numIterations(self): + """ + Number of training iterations. + """ + return self._call_java("numIterations") + + @property + @since("2.0.0") + def solver(self): + """ + The numeric solver used for training. + """ + return self._call_java("solver") + + @property + @since("2.0.0") + def coefficientStandardErrors(self): + """ + Standard error of estimated coefficients and intercept. + + If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + """ + return self._call_java("coefficientStandardErrors") + + @property + @since("2.0.0") + def tValues(self): + """ + T-statistic of estimated coefficients and intercept. + + If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + """ + return self._call_java("tValues") + + @property + @since("2.0.0") + def pValues(self): + """ + Two-sided p-value of estimated coefficients and intercept. + + If :py:attr:`GeneralizedLinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + """ + return self._call_java("pValues") + if __name__ == "__main__": import doctest import pyspark.ml.regression - from pyspark.context import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import SparkSession globs = pyspark.ml.regression.__dict__.copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - sc = SparkContext("local[2]", "ml.regression tests") - sqlContext = SQLContext(sc) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.regression tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = sqlContext + globs['spark'] = spark import tempfile temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + spark.stop() finally: from shutil import rmtree try: diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py old mode 100644 new mode 100755 index 78ec96af8aa90..de95a47a2b8aa --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -16,9 +16,8 @@ # """ -Unit tests for Spark ML Python APIs. +Unit tests for MLlib Python DataFrame-based APIs. """ -import array import sys if sys.version > '3': xrange = range @@ -40,29 +39,59 @@ from shutil import rmtree import tempfile +import array as pyarray import numpy as np +from numpy import ( + array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) +from numpy import sum as array_sum import inspect -from pyspark import keyword_only +from pyspark import keyword_only, SparkContext from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import * from pyspark.ml.clustering import * from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * +from pyspark.ml.linalg import Vector, SparseVector, DenseVector, VectorUDT,\ + DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT, _convert_to_vector from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.recommendation import ALS -from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor +from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ + GeneralizedLinearRegression from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams -from pyspark.mllib.common import _java2py -from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector -from pyspark.sql import DataFrame, SQLContext, Row +from pyspark.ml.common import _java2py +from pyspark.serializers import PickleSerializer +from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand from pyspark.sql.utils import IllegalArgumentException from pyspark.storagelevel import * from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase +ser = PickleSerializer() + + +class MLlibTestCase(unittest.TestCase): + def setUp(self): + self.sc = SparkContext('local[4]', "MLlib tests") + self.spark = SparkSession(self.sc) + + def tearDown(self): + self.spark.stop() + + +class SparkSessionTestCase(PySparkTestCase): + @classmethod + def setUpClass(cls): + PySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + PySparkTestCase.tearDownClass() + cls.spark.stop() + class MockDataset(DataFrame): @@ -136,8 +165,8 @@ def test_vector(self): def test_list(self): l = [0, 1] - for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), range(len(l)), l), - array.array('l', l), xrange(2), tuple(l)]: + for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), + range(len(l)), l), pyarray.array('l', l), xrange(2), tuple(l)]: converted = TypeConverters.toList(lst_like) self.assertEqual(type(converted), list) self.assertListEqual(converted, l) @@ -145,7 +174,7 @@ def test_list(self): def test_list_int(self): for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]), SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0), - array.array('d', [1.0, 2.0])]: + pyarray.array('d', [1.0, 2.0])]: vs = VectorSlicer(indices=indices) self.assertListEqual(vs.getIndices(), [1, 2]) self.assertTrue(all([type(v) == int for v in vs.getIndices()])) @@ -350,7 +379,7 @@ def test_word2vec_param(self): self.assertEqual(model.getWindowSize(), 6) -class FeatureTests(PySparkTestCase): +class FeatureTests(SparkSessionTestCase): def test_binarizer(self): b0 = Binarizer() @@ -376,8 +405,7 @@ def test_binarizer(self): self.assertEqual(b1.getOutputCol(), "output") def test_idf(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (DenseVector([1.0, 2.0]),), (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"]) @@ -390,8 +418,7 @@ def test_idf(self): self.assertIsNotNone(output.head().idf) def test_ngram(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ Row(input=["a", "b", "c", "d", "e"])]) ngram0 = NGram(n=4, inputCol="input", outputCol="output") self.assertEqual(ngram0.getN(), 4) @@ -401,8 +428,7 @@ def test_ngram(self): self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) def test_stopwordsremover(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) + dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") # Default self.assertEqual(stopWordRemover.getInputCol(), "input") @@ -417,10 +443,16 @@ def test_stopwordsremover(self): self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, ["a"]) + # with language selection + stopwords = StopWordsRemover.loadDefaultStopWords("turkish") + dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) + stopWordRemover.setStopWords(stopwords) + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) def test_count_vectorizer_with_binary(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), @@ -468,11 +500,10 @@ def _fit(self, dataset): return model -class CrossValidatorTests(PySparkTestCase): +class CrossValidatorTests(SparkSessionTestCase): def test_copy(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -496,8 +527,7 @@ def test_copy(self): < 0.0001) def test_fit_minimize_metric(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -520,8 +550,7 @@ def test_fit_minimize_metric(self): self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") def test_fit_maximize_metric(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -547,8 +576,7 @@ def test_save_load(self): # This tests saving and loading the trained model only. # Save/load for CrossValidator will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame( + dataset = self.spark.createDataFrame( [(Vectors.dense([0.0]), 0.0), (Vectors.dense([0.4]), 1.0), (Vectors.dense([0.5]), 0.0), @@ -569,11 +597,10 @@ def test_save_load(self): self.assertEqual(loadedLrModel.intercept, lrModel.intercept) -class TrainValidationSplitTests(PySparkTestCase): +class TrainValidationSplitTests(SparkSessionTestCase): def test_fit_minimize_metric(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -583,21 +610,24 @@ def test_fit_minimize_metric(self): iee = InducedErrorEstimator() evaluator = RegressionEvaluator(metricName="rmse") - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) tvsModel = tvs.fit(dataset) bestModel = tvsModel.bestModel bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + validationMetrics = tvsModel.validationMetrics self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), "Best model should have zero induced error") self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") + self.assertEqual(len(grid), len(validationMetrics), + "validationMetrics has the same size of grid parameter") + self.assertEqual(0.0, min(validationMetrics)) def test_fit_maximize_metric(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -607,24 +637,27 @@ def test_fit_maximize_metric(self): iee = InducedErrorEstimator() evaluator = RegressionEvaluator(metricName="r2") - grid = (ParamGridBuilder() - .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) - .build()) + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) tvsModel = tvs.fit(dataset) bestModel = tvsModel.bestModel bestModelMetric = evaluator.evaluate(bestModel.transform(dataset)) + validationMetrics = tvsModel.validationMetrics self.assertEqual(0.0, bestModel.getOrDefault('inducedError'), "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + self.assertEqual(len(grid), len(validationMetrics), + "validationMetrics has the same size of grid parameter") + self.assertEqual(1.0, max(validationMetrics)) def test_save_load(self): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame( + dataset = self.spark.createDataFrame( [(Vectors.dense([0.0]), 0.0), (Vectors.dense([0.4]), 1.0), (Vectors.dense([0.5]), 0.0), @@ -644,8 +677,38 @@ def test_save_load(self): self.assertEqual(loadedLrModel.uid, lrModel.uid) self.assertEqual(loadedLrModel.intercept, lrModel.intercept) + def test_copy(self): + dataset = self.spark.createDataFrame([ + (10, 10.0), + (50, 50.0), + (100, 100.0), + (500, 500.0)] * 10, + ["feature", "label"]) -class PersistenceTest(PySparkTestCase): + iee = InducedErrorEstimator() + evaluator = RegressionEvaluator(metricName="r2") + + grid = ParamGridBuilder() \ + .addGrid(iee.inducedError, [100.0, 0.0, 10000.0]) \ + .build() + tvs = TrainValidationSplit(estimator=iee, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsCopied = tvs.copy() + tvsModelCopied = tvsModel.copy() + + self.assertEqual(tvs.getEstimator().uid, tvsCopied.getEstimator().uid, + "Copied TrainValidationSplit has the same uid of Estimator") + + self.assertEqual(tvsModel.bestModel.uid, tvsModelCopied.bestModel.uid) + self.assertEqual(len(tvsModel.validationMetrics), + len(tvsModelCopied.validationMetrics), + "Copied validationMetrics has the same size of the original") + for index in range(len(tvsModel.validationMetrics)): + self.assertEqual(tvsModel.validationMetrics[index], + tvsModelCopied.validationMetrics[index]) + + +class PersistenceTest(SparkSessionTestCase): def test_linear_regression(self): lr = LinearRegression(maxIter=1) @@ -684,12 +747,32 @@ def test_logistic_regression(self): except OSError: pass + def _compare_params(self, m1, m2, param): + """ + Compare 2 ML Params instances for the given param, and assert both have the same param value + and parent. The param must be a parameter of m1. + """ + # Prevent key not found error in case of some param in neither paramMap nor defaultParamMap. + if m1.isDefined(param): + paramValue1 = m1.getOrDefault(param) + paramValue2 = m2.getOrDefault(m2.getParam(param.name)) + if isinstance(paramValue1, Params): + self._compare_pipelines(paramValue1, paramValue2) + else: + self.assertEqual(paramValue1, paramValue2) # for general types param + # Assert parents are equal + self.assertEqual(param.parent, m2.getParam(param.name).parent) + else: + # If m1 is not defined param, then m2 should not, too. See SPARK-14931. + self.assertFalse(m2.isDefined(m2.getParam(param.name))) + def _compare_pipelines(self, m1, m2): """ Compare 2 ML types, asserting that they are equivalent. This currently supports: - basic types - Pipeline, PipelineModel + - OneVsRest, OneVsRestModel This checks: - uid - type @@ -700,8 +783,7 @@ def _compare_pipelines(self, m1, m2): if isinstance(m1, JavaParams): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: - self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) - self.assertEqual(p.parent, m2.getParam(p.name).parent) + self._compare_params(m1, m2, p) elif isinstance(m1, Pipeline): self.assertEqual(len(m1.getStages()), len(m2.getStages())) for s1, s2 in zip(m1.getStages(), m2.getStages()): @@ -710,6 +792,13 @@ def _compare_pipelines(self, m1, m2): self.assertEqual(len(m1.stages), len(m2.stages)) for s1, s2 in zip(m1.stages, m2.stages): self._compare_pipelines(s1, s2) + elif isinstance(m1, OneVsRest) or isinstance(m1, OneVsRestModel): + for p in m1.params: + self._compare_params(m1, m2, p) + if isinstance(m1, OneVsRestModel): + self.assertEqual(len(m1.models), len(m2.models)) + for x, y in zip(m1.models, m2.models): + self._compare_pipelines(x, y) else: raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) @@ -717,11 +806,10 @@ def test_pipeline_persistence(self): """ Pipeline[HashingTF, PCA] """ - sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() try: - df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") pca = PCA(k=2, inputCol="features", outputCol="pca_features") pl = Pipeline(stages=[tf, pca]) @@ -746,11 +834,10 @@ def test_nested_pipeline_persistence(self): """ Pipeline[HashingTF, Pipeline[PCA]] """ - sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() try: - df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") pca = PCA(k=2, inputCol="features", outputCol="pca_features") p0 = Pipeline(stages=[pca]) @@ -772,6 +859,24 @@ def test_nested_pipeline_persistence(self): except OSError: pass + def test_onevsrest(self): + temp_path = tempfile.mkdtemp() + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + (1.0, Vectors.sparse(2, [], [])), + (2.0, Vectors.dense(0.5, 0.5))] * 10, + ["label", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr) + model = ovr.fit(df) + ovrPath = temp_path + "/ovr" + ovr.save(ovrPath) + loadedOvr = OneVsRest.load(ovrPath) + self._compare_pipelines(ovr, loadedOvr) + modelPath = temp_path + "/ovrModel" + model.save(modelPath) + loadedModel = OneVsRestModel.load(modelPath) + self._compare_pipelines(model, loadedModel) + def test_decisiontree_classifier(self): dt = DecisionTreeClassifier(maxDepth=1) path = tempfile.mkdtemp() @@ -809,7 +914,7 @@ def test_decisiontree_regressor(self): pass -class LDATest(PySparkTestCase): +class LDATest(SparkSessionTestCase): def _compare(self, m1, m2): """ @@ -829,8 +934,7 @@ def _compare(self, m1, m2): def test_persistence(self): # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([ + df = self.spark.createDataFrame([ [1, Vectors.dense([0.0, 1.0])], [2, Vectors.sparse(2, {0: 1.0})], ], ["id", "features"]) @@ -864,12 +968,10 @@ def test_persistence(self): pass -class TrainingSummaryTest(PySparkTestCase): +class TrainingSummaryTest(SparkSessionTestCase): def test_linear_regression_summary(self): - from pyspark.mllib.linalg import Vectors - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", @@ -905,10 +1007,44 @@ def test_linear_regression_summary(self): sameSummary = model.evaluate(df) self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) + def test_glr_summary(self): + from pyspark.ml.linalg import Vectors + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + glr = GeneralizedLinearRegression(family="gaussian", link="identity", weightCol="weight", + fitIntercept=False) + model = glr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertEqual(s.numIterations, 1) # this should default to a single iteration of WLS + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.residuals(), DataFrame)) + self.assertTrue(isinstance(s.residuals("pearson"), DataFrame)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + self.assertEqual(s.degreesOfFreedom, 1) + self.assertEqual(s.residualDegreeOfFreedom, 1) + self.assertEqual(s.residualDegreeOfFreedomNull, 2) + self.assertEqual(s.rank, 1) + self.assertTrue(isinstance(s.solver, basestring)) + self.assertTrue(isinstance(s.aic, float)) + self.assertTrue(isinstance(s.deviance, float)) + self.assertTrue(isinstance(s.nullDeviance, float)) + self.assertTrue(isinstance(s.dispersion, float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.deviance, s.deviance) + def test_logistic_regression_summary(self): - from pyspark.mllib.linalg import Vectors - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) @@ -935,11 +1071,10 @@ def test_logistic_regression_summary(self): self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) -class OneVsRestTests(PySparkTestCase): +class OneVsRestTests(SparkSessionTestCase): def test_copy(self): - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), (1.0, Vectors.sparse(2, [], [])), (2.0, Vectors.dense(0.5, 0.5))], ["label", "features"]) @@ -953,8 +1088,7 @@ def test_copy(self): self.assertEqual(model1.getPredictionCol(), "indexed") def test_output_columns(self): - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), (1.0, Vectors.sparse(2, [], [])), (2.0, Vectors.dense(0.5, 0.5))], ["label", "features"]) @@ -964,35 +1098,12 @@ def test_output_columns(self): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) - def test_save_load(self): - temp_path = tempfile.mkdtemp() - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), - (1.0, Vectors.sparse(2, [], [])), - (2.0, Vectors.dense(0.5, 0.5))], - ["label", "features"]) - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr) - model = ovr.fit(df) - ovrPath = temp_path + "/ovr" - ovr.save(ovrPath) - loadedOvr = OneVsRest.load(ovrPath) - self.assertEqual(loadedOvr.getFeaturesCol(), ovr.getFeaturesCol()) - self.assertEqual(loadedOvr.getLabelCol(), ovr.getLabelCol()) - self.assertEqual(loadedOvr.getClassifier().uid, ovr.getClassifier().uid) - modelPath = temp_path + "/ovrModel" - model.save(modelPath) - loadedModel = OneVsRestModel.load(modelPath) - for m, n in zip(model.models, loadedModel.models): - self.assertEqual(m.uid, n.uid) - -class HashingTFTest(PySparkTestCase): +class HashingTFTest(SparkSessionTestCase): def test_apply_binary_term_freqs(self): - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) + df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) n = 10 hashingTF = HashingTF() hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) @@ -1004,11 +1115,10 @@ def test_apply_binary_term_freqs(self): ": expected " + str(expected[i]) + ", got " + str(features[i])) -class ALSTest(PySparkTestCase): +class ALSTest(SparkSessionTestCase): def test_storage_levels(self): - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame( + df = self.spark.createDataFrame( [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"]) als = ALS().setMaxIter(1).setRank(1) @@ -1074,6 +1184,350 @@ def test_java_params(self): self.check_params(cls()) +def _squared_distance(a, b): + if isinstance(a, Vector): + return a.squared_distance(b) + else: + return b.squared_distance(a) + + +class VectorTests(MLlibTestCase): + + def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) + jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) + self.assertEqual(v, nv) + vs = [v] * 100 + jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs))) + self.assertEqual(vs, nvs) + + def test_serialize(self): + self._test_serialize(DenseVector(range(10))) + self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) + self._test_serialize(DenseVector(pyarray.array('d', range(10)))) + self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self._test_serialize(sm1) + + def test_dot(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([1, 2, 3, 4]) + mat = array([[1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.], + [1., 2., 3., 4.]]) + arr = pyarray.array('d', [0, 1, 2, 3]) + self.assertEqual(10.0, sv.dot(dv)) + self.assertTrue(array_equal(array([3., 6., 9., 12.]), sv.dot(mat))) + self.assertEqual(30.0, dv.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), dv.dot(mat))) + self.assertEqual(30.0, lst.dot(dv)) + self.assertTrue(array_equal(array([10., 20., 30., 40.]), lst.dot(mat))) + self.assertEqual(7.0, sv.dot(arr)) + + def test_squared_distance(self): + sv = SparseVector(4, {1: 1, 3: 2}) + dv = DenseVector(array([1., 2., 3., 4.])) + lst = DenseVector([4, 3, 2, 1]) + lst1 = [4, 3, 2, 1] + arr = pyarray.array('d', [0, 2, 1, 3]) + narr = array([0, 2, 1, 3]) + self.assertEqual(15.0, _squared_distance(sv, dv)) + self.assertEqual(25.0, _squared_distance(sv, lst)) + self.assertEqual(20.0, _squared_distance(dv, lst)) + self.assertEqual(15.0, _squared_distance(dv, sv)) + self.assertEqual(25.0, _squared_distance(lst, sv)) + self.assertEqual(20.0, _squared_distance(lst, dv)) + self.assertEqual(0.0, _squared_distance(sv, sv)) + self.assertEqual(0.0, _squared_distance(dv, dv)) + self.assertEqual(0.0, _squared_distance(lst, lst)) + self.assertEqual(25.0, _squared_distance(sv, lst1)) + self.assertEqual(3.0, _squared_distance(sv, arr)) + self.assertEqual(3.0, _squared_distance(sv, narr)) + + def test_hash(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(hash(v1), hash(v2)) + self.assertEqual(hash(v1), hash(v3)) + self.assertEqual(hash(v2), hash(v3)) + self.assertFalse(hash(v1) == hash(v4)) + self.assertFalse(hash(v2) == hash(v4)) + + def test_eq(self): + v1 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v2 = SparseVector(4, [(1, 1.0), (3, 5.5)]) + v3 = DenseVector([0.0, 1.0, 0.0, 5.5]) + v4 = SparseVector(6, [(1, 1.0), (3, 5.5)]) + v5 = DenseVector([0.0, 1.0, 0.0, 2.5]) + v6 = SparseVector(4, [(1, 1.0), (3, 2.5)]) + self.assertEqual(v1, v2) + self.assertEqual(v1, v3) + self.assertFalse(v2 == v4) + self.assertFalse(v1 == v5) + self.assertFalse(v1 == v6) + + def test_equals(self): + indices = [1, 2, 4] + values = [1., 3., 2.] + self.assertTrue(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 1., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 3., 0., 2.])) + self.assertFalse(Vectors._equals(indices, values, list(range(5)), [0., 1., 3., 2., 2.])) + + def test_conversion(self): + # numpy arrays should be automatically upcast to float64 + # tests for fix of [SPARK-5089] + v = array([1, 2, 3, 4], dtype='float64') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + v = array([1, 2, 3, 4], dtype='float32') + dv = DenseVector(v) + self.assertTrue(dv.array.dtype == 'float64') + + def test_sparse_vector_indexing(self): + sv = SparseVector(5, {1: 1, 3: 2}) + self.assertEqual(sv[0], 0.) + self.assertEqual(sv[3], 2.) + self.assertEqual(sv[1], 1.) + self.assertEqual(sv[2], 0.) + self.assertEqual(sv[4], 0.) + self.assertEqual(sv[-1], 0.) + self.assertEqual(sv[-2], 2.) + self.assertEqual(sv[-3], 0.) + self.assertEqual(sv[-5], 0.) + for ind in [5, -6]: + self.assertRaises(IndexError, sv.__getitem__, ind) + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) + + zeros = SparseVector(4, {}) + self.assertEqual(zeros[0], 0.0) + self.assertEqual(zeros[3], 0.0) + for ind in [4, -5]: + self.assertRaises(IndexError, zeros.__getitem__, ind) + + empty = SparseVector(0, {}) + for ind in [-1, 0, 1]: + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) + + def test_matrix_indexing(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + expected = [[0, 6], [1, 8], [4, 10]] + for i in range(3): + for j in range(2): + self.assertEqual(mat[i, j], expected[i][j]) + + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + + def test_repr_dense_matrix(self): + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10], True) + self.assertTrue( + repr(mat), + 'DenseMatrix(3, 2, [0.0, 1.0, 4.0, 6.0, 8.0, 10.0], False)') + + mat = DenseMatrix(6, 3, zeros(18)) + self.assertTrue( + repr(mat), + 'DenseMatrix(6, 3, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..., \ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], False)') + + def test_repr_sparse_matrix(self): + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertTrue( + repr(sm1t), + 'SparseMatrix(3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], True)') + + indices = tile(arange(6), 3) + values = ones(18) + sm = SparseMatrix(6, 3, [0, 6, 12, 18], indices, values) + self.assertTrue( + repr(sm), "SparseMatrix(6, 3, [0, 6, 12, 18], \ + [0, 1, 2, 3, 4, 5, 0, 1, ..., 4, 5, 0, 1, 2, 3, 4, 5], \ + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..., \ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], False)") + + self.assertTrue( + str(sm), + "6 X 3 CSCMatrix\n\ + (0,0) 1.0\n(1,0) 1.0\n(2,0) 1.0\n(3,0) 1.0\n(4,0) 1.0\n(5,0) 1.0\n\ + (0,1) 1.0\n(1,1) 1.0\n(2,1) 1.0\n(3,1) 1.0\n(4,1) 1.0\n(5,1) 1.0\n\ + (0,2) 1.0\n(1,2) 1.0\n(2,2) 1.0\n(3,2) 1.0\n..\n..") + + sm = SparseMatrix(1, 18, zeros(19), [], []) + self.assertTrue( + repr(sm), + 'SparseMatrix(1, 18, \ + [0, 0, 0, 0, 0, 0, 0, 0, ..., 0, 0, 0, 0, 0, 0, 0, 0], [], [], False)') + + def test_sparse_matrix(self): + # Test sparse matrix creation. + sm1 = SparseMatrix( + 3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0]) + self.assertEqual(sm1.numRows, 3) + self.assertEqual(sm1.numCols, 4) + self.assertEqual(sm1.colPtrs.tolist(), [0, 2, 2, 4, 4]) + self.assertEqual(sm1.rowIndices.tolist(), [1, 2, 1, 2]) + self.assertEqual(sm1.values.tolist(), [1.0, 2.0, 4.0, 5.0]) + self.assertTrue( + repr(sm1), + 'SparseMatrix(3, 4, [0, 2, 2, 4, 4], [1, 2, 1, 2], [1.0, 2.0, 4.0, 5.0], False)') + + # Test indexing + expected = [ + [0, 0, 0, 0], + [1, 0, 4, 0], + [2, 0, 5, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1[i, j]) + self.assertTrue(array_equal(sm1.toArray(), expected)) + + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + + # Test conversion to dense and sparse. + smnew = sm1.toDense().toSparse() + self.assertEqual(sm1.numRows, smnew.numRows) + self.assertEqual(sm1.numCols, smnew.numCols) + self.assertTrue(array_equal(sm1.colPtrs, smnew.colPtrs)) + self.assertTrue(array_equal(sm1.rowIndices, smnew.rowIndices)) + self.assertTrue(array_equal(sm1.values, smnew.values)) + + sm1t = SparseMatrix( + 3, 4, [0, 2, 3, 5], [0, 1, 2, 0, 2], [3.0, 2.0, 4.0, 9.0, 8.0], + isTransposed=True) + self.assertEqual(sm1t.numRows, 3) + self.assertEqual(sm1t.numCols, 4) + self.assertEqual(sm1t.colPtrs.tolist(), [0, 2, 3, 5]) + self.assertEqual(sm1t.rowIndices.tolist(), [0, 1, 2, 0, 2]) + self.assertEqual(sm1t.values.tolist(), [3.0, 2.0, 4.0, 9.0, 8.0]) + + expected = [ + [3, 2, 0, 0], + [0, 0, 4, 0], + [9, 0, 8, 0]] + + for i in range(3): + for j in range(4): + self.assertEqual(expected[i][j], sm1t[i, j]) + self.assertTrue(array_equal(sm1t.toArray(), expected)) + + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEqual(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEqual(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + + def test_norms(self): + a = DenseVector([0, 2, 3, -1]) + self.assertAlmostEqual(a.norm(2), 3.742, 3) + self.assertTrue(a.norm(1), 6) + self.assertTrue(a.norm(inf), 3) + a = SparseVector(4, [0, 2], [3, -4]) + self.assertAlmostEqual(a.norm(2), 5) + self.assertTrue(a.norm(1), 7) + self.assertTrue(a.norm(inf), 4) + + tmp = SparseVector(4, [0, 2], [3, 0]) + self.assertEqual(tmp.numNonzeros(), 1) + + +class VectorUDTTests(MLlibTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([Row(label=1.0, features=self.dv1), + Row(label=0.0, features=self.sv1)]) + df = rdd.toDF() + schema = df.schema + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = df.rdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) + + +class MatrixUDTTests(MLlibTestCase): + + dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10]) + dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True) + sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0]) + sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True) + udt = MatrixUDT() + + def test_json_schema(self): + self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for m in [self.dm1, self.dm2, self.sm1, self.sm2]: + self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) + + def test_infer_schema(self): + rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) + df = rdd.toDF() + schema = df.schema + self.assertTrue(schema.fields[1].dataType, self.udt) + matrices = df.rdd.map(lambda x: x._2).collect() + self.assertEqual(len(matrices), 2) + for m in matrices: + if isinstance(m, DenseMatrix): + self.assertTrue(m, self.dm1) + elif isinstance(m, SparseMatrix): + self.assertTrue(m, self.sm1) + else: + raise ValueError("Expected a matrix but got type %r" % type(m)) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index eb1f029ebb4e3..e17d13d4d2bd2 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -25,7 +25,7 @@ from pyspark.ml.param.shared import HasSeed from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand -from pyspark.mllib.common import inherit_doc, _py2java +from pyspark.ml.common import inherit_doc, _py2java __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', 'TrainValidationSplitModel'] @@ -147,8 +147,8 @@ class CrossValidator(Estimator, ValidatorParams): >>> from pyspark.ml.classification import LogisticRegression >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator - >>> from pyspark.mllib.linalg import Vectors - >>> dataset = sqlContext.createDataFrame( + >>> from pyspark.ml.linalg import Vectors + >>> dataset = spark.createDataFrame( ... [(Vectors.dense([0.0]), 0.0), ... (Vectors.dense([0.4]), 1.0), ... (Vectors.dense([0.5]), 0.0), @@ -160,6 +160,8 @@ class CrossValidator(Estimator, ValidatorParams): >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) >>> cvModel = cv.fit(dataset) + >>> cvModel.avgMetrics[0] + 0.5 >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... @@ -198,8 +200,7 @@ def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. """ - self._set(numFolds=value) - return self + return self._set(numFolds=value) @since("1.4.0") def getNumFolds(self): @@ -229,7 +230,7 @@ def _fit(self, dataset): model = est.fit(train, epm[j]) # TODO: duplicate evaluator to take extra params from input metric = eva.evaluate(model.transform(validation, epm[j])) - metrics[j] += metric + metrics[j] += metric/nFolds if eva.isLargerBetter(): bestIndex = np.argmax(metrics) @@ -270,6 +271,8 @@ def __init__(self, bestModel, avgMetrics=[]): super(CrossValidatorModel, self).__init__() #: best model from cross validation self.bestModel = bestModel + #: Average cross-validation metrics for each paramMap in + #: CrossValidator.estimatorParamMaps, in the corresponding order. self.avgMetrics = avgMetrics def _transform(self, dataset): @@ -295,12 +298,16 @@ def copy(self, extra=None): class TrainValidationSplit(Estimator, ValidatorParams): """ - Train-Validation-Split. + .. note:: Experimental + + Validation for hyper-parameter tuning. Randomly splits the input dataset into train and + validation sets, and uses evaluation metric on the validation set to select the best model. + Similar to :class:`CrossValidator`, but only splits the set once. >>> from pyspark.ml.classification import LogisticRegression >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator - >>> from pyspark.mllib.linalg import Vectors - >>> dataset = sqlContext.createDataFrame( + >>> from pyspark.ml.linalg import Vectors + >>> dataset = spark.createDataFrame( ... [(Vectors.dense([0.0]), 0.0), ... (Vectors.dense([0.4]), 1.0), ... (Vectors.dense([0.5]), 0.0), @@ -350,8 +357,7 @@ def setTrainRatio(self, value): """ Sets the value of :py:attr:`trainRatio`. """ - self._set(trainRatio=value) - return self + return self._set(trainRatio=value) @since("2.0.0") def getTrainRatio(self): @@ -369,7 +375,7 @@ def _fit(self, dataset): seed = self.getOrDefault(self.seed) randCol = self.uid + "_rand" df = dataset.select("*", rand(seed).alias(randCol)) - metrics = np.zeros(numModels) + metrics = [0.0] * numModels condition = (df[randCol] >= tRatio) validation = df.filter(condition) train = df.filter(~condition) @@ -382,7 +388,7 @@ def _fit(self, dataset): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return self._copyValues(TrainValidationSplitModel(bestModel)) + return self._copyValues(TrainValidationSplitModel(bestModel, metrics)) @since("2.0.0") def copy(self, extra=None): @@ -407,15 +413,19 @@ def copy(self, extra=None): class TrainValidationSplitModel(Model, ValidatorParams): """ + .. note:: Experimental + Model from train validation split. .. versionadded:: 2.0.0 """ - def __init__(self, bestModel): + def __init__(self, bestModel, validationMetrics=[]): super(TrainValidationSplitModel, self).__init__() #: best model from cross validation self.bestModel = bestModel + #: evaluated validation metrics + self.validationMetrics = validationMetrics def _transform(self, dataset): return self.bestModel.transform(dataset) @@ -427,29 +437,34 @@ def copy(self, extra=None): and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. + And, this creates a shallow copy of the validationMetrics. :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ if extra is None: extra = dict() - return TrainValidationSplitModel(self.bestModel.copy(extra)) + bestModel = self.bestModel.copy(extra) + validationMetrics = list(self.validationMetrics) + return TrainValidationSplitModel(bestModel, validationMetrics) if __name__ == "__main__": import doctest - from pyspark.context import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import SparkSession globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - sc = SparkContext("local[2]", "ml.tuning tests") - sqlContext = SQLContext(sc) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("ml.tuning tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = sqlContext + globs['spark'] = spark (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9d28823196216..4a31a298096fc 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -23,7 +23,7 @@ unicode = str from pyspark import SparkContext, since -from pyspark.mllib.common import inherit_doc +from pyspark.ml.common import inherit_doc def _jvm(): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index fef0040faf86e..25c44b7533c77 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -22,7 +22,7 @@ from pyspark.ml import Estimator, Transformer, Model from pyspark.ml.param import Params from pyspark.ml.util import _jvm -from pyspark.mllib.common import inherit_doc, _java2py, _py2java +from pyspark.ml.common import inherit_doc, _java2py, _py2java class JavaWrapper(object): diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index acba3a717d21a..ae26521ea96bf 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -16,7 +16,10 @@ # """ -Python bindings for MLlib. +RDD-based machine learning APIs for Python (in maintenance mode). + +The `pyspark.mllib` package is in maintenance mode as of the Spark 2.0.0 release to encourage +migration to the DataFrame-based APIs under the `pyspark.ml` package. """ from __future__ import absolute_import diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index fe5b6844bfceb..9f53ed098202b 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -48,11 +48,9 @@ def __init__(self, weights, intercept): @since('1.4.0') def setThreshold(self, value): """ - .. note:: Experimental - Sets the threshold that separates positive predictions from negative predictions. An example with prediction score greater - than or equal to this threshold is identified as an positive, + than or equal to this threshold is identified as a positive, and negative otherwise. It is used for binary classification only. """ @@ -62,8 +60,6 @@ def setThreshold(self, value): @since('1.4.0') def threshold(self): """ - .. note:: Experimental - Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. It is used for binary classification only. @@ -73,8 +69,6 @@ def threshold(self): @since('1.4.0') def clearThreshold(self): """ - .. note:: Experimental - Clears the threshold so that `predict` will output raw prediction scores. It is used for binary classification only. """ @@ -756,12 +750,16 @@ def update(rdd): def _test(): import doctest - from pyspark import SparkContext + from pyspark.sql import SparkSession import pyspark.mllib.classification globs = pyspark.mllib.classification.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("mllib.classification tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 95f7278dc64ce..29aa615125770 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -47,8 +47,6 @@ @inherit_doc class BisectingKMeansModel(JavaModelWrapper): """ - .. note:: Experimental - A clustering model derived from the bisecting k-means method. >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2) @@ -120,8 +118,6 @@ def computeCost(self, x): class BisectingKMeans(object): """ - .. note:: Experimental - A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" by Steinbach, Karypis, and Kumar, with modification to fit Spark. @@ -366,8 +362,6 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ - .. note:: Experimental - A clustering model derived from the Gaussian Mixture Model method. >>> from pyspark.mllib.linalg import Vectors, DenseMatrix @@ -422,7 +416,7 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): ... 4.5605, 5.2043, 6.2734]) >>> clusterdata_2 = sc.parallelize(data.reshape(5,3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, - ... maxIterations=150, seed=10) + ... maxIterations=150, seed=4) >>> labels = model.predict(clusterdata_2).collect() >>> labels[0]==labels[1] True @@ -507,14 +501,12 @@ def load(cls, sc, path): Path to where the model is stored. """ model = cls._load_java(sc, path) - wrapper = sc._jvm.GaussianMixtureModelWrapper(model) + wrapper = sc._jvm.org.apache.spark.mllib.api.python.GaussianMixtureModelWrapper(model) return cls(wrapper) class GaussianMixture(object): """ - .. note:: Experimental - Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. .. versionadded:: 1.3.0 @@ -565,20 +557,18 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ - .. note:: Experimental - Model produced by [[PowerIterationClustering]]. >>> import math >>> def genCircle(r, n): - ... points = [] - ... for i in range(0, n): - ... theta = 2.0 * math.pi * i / n - ... points.append((r * math.cos(theta), r * math.sin(theta))) - ... return points + ... points = [] + ... for i in range(0, n): + ... theta = 2.0 * math.pi * i / n + ... points.append((r * math.cos(theta), r * math.sin(theta))) + ... return points >>> def sim(x, y): - ... dist2 = (x[0] - y[0]) * (x[0] - y[0]) + (x[1] - y[1]) * (x[1] - y[1]) - ... return math.exp(-dist2 / 2.0) + ... dist2 = (x[0] - y[0]) * (x[0] - y[0]) + (x[1] - y[1]) * (x[1] - y[1]) + ... return math.exp(-dist2 / 2.0) >>> r1 = 1.0 >>> n1 = 10 >>> r2 = 4.0 @@ -638,14 +628,13 @@ def load(cls, sc, path): Load a model from the given path. """ model = cls._load_java(sc, path) - wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model) + wrapper =\ + sc._jvm.org.apache.spark.mllib.api.python.PowerIterationClusteringModelWrapper(model) return PowerIterationClusteringModel(wrapper) class PowerIterationClustering(object): """ - .. note:: Experimental - Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. From the abstract: PIC finds a very low-dimensional embedding of a @@ -692,8 +681,6 @@ class Assignment(namedtuple("Assignment", ["id", "cluster"])): class StreamingKMeansModel(KMeansModel): """ - .. note:: Experimental - Clustering model which can perform an online update of the centroids. The update formula for each centroid is given by @@ -793,8 +780,6 @@ def update(self, data, decayFactor, timeUnit): class StreamingKMeans(object): """ - .. note:: Experimental - Provides methods to set k, decayFactor, timeUnit to configure the KMeans algorithm for fitting and predicting on incoming dstreams. More details on how the centroids are updated are provided under the diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 6bc2b1e64651e..21f0e09ea7742 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -60,13 +60,13 @@ def _new_smart_decode(obj): # this will call the MLlib version of pythonToJava() def _to_java_object_rdd(rdd): - """ Return an JavaRDD of Object by unpickling + """ Return a JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) - return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) + return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True) def _py2java(sc, obj): @@ -85,7 +85,7 @@ def _py2java(sc, obj): pass else: data = bytearray(PickleSerializer().dumps(obj)) - obj = sc._jvm.SerDe.loads(data) + obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data) return obj @@ -98,17 +98,17 @@ def _java2py(sc, r, encoding="bytes"): clsName = 'JavaRDD' if clsName == 'JavaRDD': - jrdd = sc._jvm.SerDe.javaToPython(r) + jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r) return RDD(jrdd, sc) if clsName == 'Dataset': return DataFrame(r, SQLContext.getOrCreate(sc)) if clsName in _picklable_classes: - r = sc._jvm.SerDe.dumps(r) + r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList)): try: - r = sc._jvm.SerDe.dumps(r) + r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r) except Py4JJavaError: pass # not pickable diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 22e68ea5b4511..fc2a0b3b5038a 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -15,6 +15,8 @@ # limitations under the License. # +import warnings + from pyspark import since from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc from pyspark.sql import SQLContext @@ -177,9 +179,7 @@ class MulticlassMetrics(JavaModelWrapper): 1.0... >>> metrics.fMeasure(0.0, 2.0) 0.52... - >>> metrics.precision() - 0.66... - >>> metrics.recall() + >>> metrics.accuracy 0.66... >>> metrics.weightedFalsePositiveRate 0.19... @@ -233,6 +233,8 @@ def precision(self, label=None): Returns precision or precision for a given label (category) if specified. """ if label is None: + # note:: Deprecated in 2.0.0. Use accuracy. + warnings.warn("Deprecated in 2.0.0. Use accuracy.") return self.call("precision") else: return self.call("precision", float(label)) @@ -243,6 +245,8 @@ def recall(self, label=None): Returns recall or recall for a given label (category) if specified. """ if label is None: + # note:: Deprecated in 2.0.0. Use accuracy. + warnings.warn("Deprecated in 2.0.0. Use accuracy.") return self.call("recall") else: return self.call("recall", float(label)) @@ -254,6 +258,8 @@ def fMeasure(self, label=None, beta=None): """ if beta is None: if label is None: + # note:: Deprecated in 2.0.0. Use accuracy. + warnings.warn("Deprecated in 2.0.0. Use accuracy.") return self.call("fMeasure") else: return self.call("fMeasure", label) @@ -263,6 +269,15 @@ def fMeasure(self, label=None, beta=None): else: return self.call("fMeasure", label, beta) + @property + @since('2.0.0') + def accuracy(self): + """ + Returns accuracy (equals to the total number of correctly classified instances + out of the total number of instances). + """ + return self.call("accuracy") + @property @since('1.4.0') def weightedTruePositiveRate(self): @@ -516,12 +531,16 @@ def accuracy(self): def _test(): import doctest - from pyspark import SparkContext + from pyspark.sql import SparkSession import pyspark.mllib.evaluation globs = pyspark.mllib.evaluation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest') + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("mllib.evaluation tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 90559f6cfbe43..929531862d18f 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -60,8 +60,6 @@ def transform(self, vector): class Normalizer(VectorTransformer): """ - .. note:: Experimental - Normalizes samples individually to unit L\ :sup:`p`\ norm For any 1 <= `p` < float('inf'), normalizes samples using @@ -131,8 +129,6 @@ def transform(self, vector): class StandardScalerModel(JavaVectorTransformer): """ - .. note:: Experimental - Represents a StandardScaler model that can transform vectors. .. versionadded:: 1.2.0 @@ -207,8 +203,6 @@ def mean(self): class StandardScaler(object): """ - .. note:: Experimental - Standardizes features by removing the mean and scaling to unit variance using column summary statistics on the samples in the training set. @@ -262,8 +256,6 @@ def fit(self, dataset): class ChiSqSelectorModel(JavaVectorTransformer): """ - .. note:: Experimental - Represents a Chi Squared selector model. .. versionadded:: 1.4.0 @@ -282,8 +274,6 @@ def transform(self, vector): class ChiSqSelector(object): """ - .. note:: Experimental - Creates a ChiSquared feature selector. :param numTopFeatures: number of features that selector will select. @@ -361,8 +351,6 @@ def fit(self, data): class HashingTF(object): """ - .. note:: Experimental - Maps a sequence of terms to their term frequencies using the hashing trick. @@ -448,8 +436,6 @@ def idf(self): class IDF(object): """ - .. note:: Experimental - Inverse document frequency (IDF). The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, @@ -553,14 +539,13 @@ def load(cls, sc, path): """ jmodel = sc._jvm.org.apache.spark.mllib.feature \ .Word2VecModel.load(sc._jsc.sc(), path) - model = sc._jvm.Word2VecModelWrapper(jmodel) + model = sc._jvm.org.apache.spark.mllib.api.python.Word2VecModelWrapper(jmodel) return Word2VecModel(model) @ignore_unicode_prefix class Word2Vec(object): - """ - Word2Vec creates vector representation of words in a text corpus. + """Word2Vec creates vector representation of words in a text corpus. The algorithm first constructs a vocabulary from the corpus and then learns vector representation of words in the vocabulary. The vector representation can be used as features in @@ -582,13 +567,19 @@ class Word2Vec(object): >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) >>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc) + Querying for synonyms of a word will not return that word: + >>> syms = model.findSynonyms("a", 2) >>> [s[0] for s in syms] [u'b', u'c'] + + But querying for synonyms of a vector may return the word whose + representation is that vector: + >>> vec = model.transform("a") >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] - [u'b', u'c'] + [u'a', u'b'] >>> import os, tempfile >>> path = tempfile.mkdtemp() @@ -606,6 +597,7 @@ class Word2Vec(object): ... pass .. versionadded:: 1.2.0 + """ def __init__(self): """ @@ -697,8 +689,6 @@ def fit(self, data): class ElementwiseProduct(VectorTransformer): """ - .. note:: Experimental - Scales each column of the vector, with the supplied weight vector. i.e the elementwise product. @@ -732,11 +722,15 @@ def transform(self, vector): def _test(): import doctest - from pyspark import SparkContext + from pyspark.sql import SparkSession globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("mllib.feature tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index f339e50891166..f58ea5dfb0874 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -31,8 +31,6 @@ @ignore_unicode_prefix class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ - .. note:: Experimental - A FP-Growth model for mining frequent itemsets using the Parallel FP-Growth algorithm. @@ -64,14 +62,12 @@ def load(cls, sc, path): Load a model from the given path. """ model = cls._load_java(sc, path) - wrapper = sc._jvm.FPGrowthModelWrapper(model) + wrapper = sc._jvm.org.apache.spark.mllib.api.python.FPGrowthModelWrapper(model) return FPGrowthModel(wrapper) class FPGrowth(object): """ - .. note:: Experimental - A Parallel FP-growth algorithm to mine frequent itemsets. .. versionadded:: 1.4.0 @@ -108,8 +104,6 @@ class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): @ignore_unicode_prefix class PrefixSpanModel(JavaModelWrapper): """ - .. note:: Experimental - Model fitted by PrefixSpan >>> data = [ @@ -133,8 +127,6 @@ def freqSequences(self): class PrefixSpan(object): """ - .. note:: Experimental - A parallel PrefixSpan algorithm to mine frequent sequential patterns. The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns Efficiently by Prefix-Projected Pattern Growth @@ -183,16 +175,21 @@ class FreqSequence(namedtuple("FreqSequence", ["sequence", "freq"])): def _test(): import doctest + from pyspark.sql import SparkSession import pyspark.mllib.fpm globs = pyspark.mllib.fpm.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest') + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("mllib.fpm tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext import tempfile temp_path = tempfile.mkdtemp() globs['temp_path'] = temp_path try: (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() finally: from shutil import rmtree try: diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index 70509a6d9bece..d37e715c8d8ec 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -39,6 +39,7 @@ import numpy as np from pyspark import since +from pyspark.ml import linalg as newlinalg from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ IntegerType, ByteType, BooleanType @@ -247,6 +248,15 @@ def toArray(self): """ raise NotImplementedError + def asML(self): + """ + Convert this vector to the new mllib-local representation. + This does NOT copy the data; it copies references. + + :return: :py:class:`pyspark.ml.linalg.Vector` + """ + raise NotImplementedError + class DenseVector(Vector): """ @@ -408,6 +418,17 @@ def toArray(self): """ return self.array + def asML(self): + """ + Convert this vector to the new mllib-local representation. + This does NOT copy the data; it copies references. + + :return: :py:class:`pyspark.ml.linalg.DenseVector` + + .. versionadded:: 2.0.0 + """ + return newlinalg.DenseVector(self.array) + @property def values(self): """ @@ -569,7 +590,7 @@ def parse(s): if start == -1: raise ValueError("Tuple should start with '('") end = s.find(')') - if start == -1: + if end == -1: raise ValueError("Tuple should end with ')'") s = s[start + 1: end].strip() @@ -737,6 +758,17 @@ def toArray(self): arr[self.indices] = self.values return arr + def asML(self): + """ + Convert this vector to the new mllib-local representation. + This does NOT copy the data; it copies references. + + :return: :py:class:`pyspark.ml.linalg.SparseVector` + + .. versionadded:: 2.0.0 + """ + return newlinalg.SparseVector(self.size, self.indices, self.values) + def __len__(self): return self.size @@ -770,7 +802,7 @@ def __getitem__(self, index): "Indices must be of type integer, got type %s" % type(index)) if index >= self.size or index < -self.size: - raise ValueError("Index %d out of bounds." % index) + raise IndexError("Index %d out of bounds." % index) if index < 0: index += self.size @@ -845,6 +877,24 @@ def dense(*elements): elements = elements[0] return DenseVector(elements) + @staticmethod + def fromML(vec): + """ + Convert a vector from the new mllib-local representation. + This does NOT copy the data; it copies references. + + :param vec: a :py:class:`pyspark.ml.linalg.Vector` + :return: a :py:class:`pyspark.mllib.linalg.Vector` + + .. versionadded:: 2.0.0 + """ + if isinstance(vec, newlinalg.DenseVector): + return DenseVector(vec.array) + elif isinstance(vec, newlinalg.SparseVector): + return SparseVector(vec.size, vec.indices, vec.values) + else: + raise TypeError("Unsupported vector type %s" % type(vec)) + @staticmethod def stringify(vector): """ @@ -945,6 +995,13 @@ def toArray(self): """ raise NotImplementedError + def asML(self): + """ + Convert this matrix to the new mllib-local representation. + This does NOT copy the data; it copies references. + """ + raise NotImplementedError + @staticmethod def _convert_to_array(array_like, dtype): """ @@ -1044,13 +1101,24 @@ def toSparse(self): return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) + def asML(self): + """ + Convert this matrix to the new mllib-local representation. + This does NOT copy the data; it copies references. + + :return: :py:class:`pyspark.ml.linalg.DenseMatrix` + + .. versionadded:: 2.0.0 + """ + return newlinalg.DenseMatrix(self.numRows, self.numCols, self.values, self.isTransposed) + def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j >= self.numCols or j < 0: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) if self.isTransposed: @@ -1177,10 +1245,10 @@ def __reduce__(self): def __getitem__(self, indices): i, j = indices if i < 0 or i >= self.numRows: - raise ValueError("Row index %d is out of range [0, %d)" + raise IndexError("Row index %d is out of range [0, %d)" % (i, self.numRows)) if j < 0 or j >= self.numCols: - raise ValueError("Column index %d is out of range [0, %d)" + raise IndexError("Column index %d is out of range [0, %d)" % (j, self.numCols)) # If a CSR matrix is given, then the row index should be searched @@ -1216,6 +1284,18 @@ def toDense(self): densevals = np.ravel(self.toArray(), order='F') return DenseMatrix(self.numRows, self.numCols, densevals) + def asML(self): + """ + Convert this matrix to the new mllib-local representation. + This does NOT copy the data; it copies references. + + :return: :py:class:`pyspark.ml.linalg.SparseMatrix` + + .. versionadded:: 2.0.0 + """ + return newlinalg.SparseMatrix(self.numRows, self.numCols, self.colPtrs, self.rowIndices, + self.values, self.isTransposed) + # TODO: More efficient implementation: def __eq__(self, other): return np.all(self.toArray() == other.toArray()) @@ -1236,11 +1316,28 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values): """ return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values) + @staticmethod + def fromML(mat): + """ + Convert a matrix from the new mllib-local representation. + This does NOT copy the data; it copies references. + + :param mat: a :py:class:`pyspark.ml.linalg.Matrix` + :return: a :py:class:`pyspark.mllib.linalg.Matrix` + + .. versionadded:: 2.0.0 + """ + if isinstance(mat, newlinalg.DenseMatrix): + return DenseMatrix(mat.numRows, mat.numCols, mat.values, mat.isTransposed) + elif isinstance(mat, newlinalg.SparseMatrix): + return SparseMatrix(mat.numRows, mat.numCols, mat.colPtrs, mat.rowIndices, + mat.values, mat.isTransposed) + else: + raise TypeError("Unsupported matrix type %s" % type(mat)) + class QRDecomposition(object): """ - .. note:: Experimental - Represents QR factors. """ def __init__(self, Q, R): diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index af34ce346b0ca..538cada7d163d 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -40,8 +40,6 @@ class DistributedMatrix(object): """ - .. note:: Experimental - Represents a distributively stored matrix backed by one or more RDDs. @@ -57,8 +55,6 @@ def numCols(self): class RowMatrix(DistributedMatrix): """ - .. note:: Experimental - Represents a row-oriented distributed Matrix with no meaningful row indices. @@ -306,8 +302,6 @@ def tallSkinnyQR(self, computeQ=False): class IndexedRow(object): """ - .. note:: Experimental - Represents a row of an IndexedRowMatrix. Just a wrapper over a (long, vector) tuple. @@ -334,8 +328,6 @@ def _convert_to_indexed_row(row): class IndexedRowMatrix(DistributedMatrix): """ - .. note:: Experimental - Represents a row-oriented distributed Matrix with indexed rows. :param rows: An RDD of IndexedRows or (long, vector) tuples. @@ -536,8 +528,6 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): class MatrixEntry(object): """ - .. note:: Experimental - Represents an entry of a CoordinateMatrix. Just a wrapper over a (long, long, float) tuple. @@ -566,8 +556,6 @@ def _convert_to_matrix_entry(entry): class CoordinateMatrix(DistributedMatrix): """ - .. note:: Experimental - Represents a matrix in coordinate format. :param entries: An RDD of MatrixEntry inputs or @@ -795,8 +783,6 @@ def _convert_to_matrix_block_tuple(block): class BlockMatrix(DistributedMatrix): """ - .. note:: Experimental - Represents a distributed matrix in blocks of local matrices. :param blocks: An RDD of sub-matrix blocks @@ -1184,16 +1170,18 @@ def toCoordinateMatrix(self): def _test(): import doctest - from pyspark import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import SparkSession from pyspark.mllib.linalg import Matrices import pyspark.mllib.linalg.distributed globs = pyspark.mllib.linalg.distributed.__dict__.copy() - globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) - globs['sqlContext'] = SQLContext(globs['sc']) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("mllib.linalg.distributed tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext globs['Matrices'] = Matrices (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 6a3c643b66417..61213ddf62e8b 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -409,13 +409,17 @@ def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed= def _test(): import doctest - from pyspark.context import SparkContext + from pyspark.sql import SparkSession globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("mllib.random tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 7e60255d43ead..732300ee9c2c9 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -207,7 +207,7 @@ def rank(self): def load(cls, sc, path): """Load a model from the given path""" model = cls._load_java(sc, path) - wrapper = sc._jvm.MatrixFactorizationModelWrapper(model) + wrapper = sc._jvm.org.apache.spark.mllib.api.python.MatrixFactorizationModelWrapper(model) return MatrixFactorizationModel(wrapper) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 639c5eabaa23b..705022934e41b 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -648,7 +648,7 @@ def predict(self, x): @since("1.4.0") def save(self, sc, path): - """Save a IsotonicRegressionModel.""" + """Save an IsotonicRegressionModel.""" java_boundaries = _py2java(sc, self.boundaries.tolist()) java_predictions = _py2java(sc, self.predictions.tolist()) java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel( @@ -658,7 +658,7 @@ def save(self, sc, path): @classmethod @since("1.4.0") def load(cls, sc, path): - """Load a IsotonicRegressionModel.""" + """Load an IsotonicRegressionModel.""" java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel.load( sc._jsc.sc(), path) py_boundaries = _java2py(sc, java_model.boundaryVector()).toArray() @@ -694,7 +694,7 @@ class IsotonicRegression(object): @since("1.4.0") def train(cls, data, isotonic=True): """ - Train a isotonic regression model on the given data. + Train an isotonic regression model on the given data. :param data: RDD of (label, feature, weight) tuples. @@ -824,12 +824,16 @@ def update(rdd): def _test(): import doctest - from pyspark import SparkContext + from pyspark.sql import SparkSession import pyspark.mllib.regression globs = pyspark.mllib.regression.__dict__.copy() - globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("mllib.regression tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py index 7da921976d4d2..3b1c5519bd87e 100644 --- a/python/pyspark/mllib/stat/KernelDensity.py +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -28,8 +28,6 @@ class KernelDensity(object): """ - .. note:: Experimental - Estimate probability density at required points given a RDD of samples from the population. diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 36c8f48a4a882..67d5f0e44f41c 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -160,8 +160,6 @@ def corr(x, y=None, method=None): @ignore_unicode_prefix def chiSqTest(observed, expected=None): """ - .. note:: Experimental - If `observed` is Vector, conduct Pearson's chi-squared goodness of fit test of the observed data against the expected distribution, or againt the uniform distribution (by default), with each category @@ -246,8 +244,6 @@ def chiSqTest(observed, expected=None): @ignore_unicode_prefix def kolmogorovSmirnovTest(data, distName="norm", *params): """ - .. note:: Experimental - Performs the Kolmogorov-Smirnov (KS) test for data sampled from a continuous distribution. It tests the null hypothesis that the data is generated from a particular distribution. @@ -306,11 +302,15 @@ def kolmogorovSmirnovTest(data, distName="norm", *params): def _test(): import doctest - from pyspark import SparkContext + from pyspark.sql import SparkSession globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("mllib.stat.statistics tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 53a1d2c59cb2f..c519883cdd73b 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -49,6 +49,7 @@ import unittest from pyspark import SparkContext +import pyspark.ml.linalg as newlinalg from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ @@ -66,7 +67,8 @@ from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession +from pyspark.sql.utils import IllegalArgumentException from pyspark.streaming import StreamingContext _have_scipy = False @@ -83,9 +85,10 @@ class MLlibTestCase(unittest.TestCase): def setUp(self): self.sc = SparkContext('local[4]', "MLlib tests") + self.spark = SparkSession(self.sc) def tearDown(self): - self.sc.stop() + self.spark.stop() class MLLibStreamingTestCase(unittest.TestCase): @@ -147,12 +150,12 @@ class VectorTests(MLlibTestCase): def _test_serialize(self, v): self.assertEqual(v, ser.loads(ser.dumps(v))) - jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec))) + jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) + nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) self.assertEqual(v, nv) vs = [v] * 100 - jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs))) + jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs))) + nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs))) self.assertEqual(vs, nvs) def test_serialize(self): @@ -257,7 +260,7 @@ def test_sparse_vector_indexing(self): self.assertEqual(sv[-3], 0.) self.assertEqual(sv[-5], 0.) for ind in [5, -6]: - self.assertRaises(ValueError, sv.__getitem__, ind) + self.assertRaises(IndexError, sv.__getitem__, ind) for ind in [7.8, '1']: self.assertRaises(TypeError, sv.__getitem__, ind) @@ -265,11 +268,15 @@ def test_sparse_vector_indexing(self): self.assertEqual(zeros[0], 0.0) self.assertEqual(zeros[3], 0.0) for ind in [4, -5]: - self.assertRaises(ValueError, zeros.__getitem__, ind) + self.assertRaises(IndexError, zeros.__getitem__, ind) empty = SparseVector(0, {}) for ind in [-1, 0, 1]: - self.assertRaises(ValueError, empty.__getitem__, ind) + self.assertRaises(IndexError, empty.__getitem__, ind) + + def test_sparse_vector_iteration(self): + self.assertListEqual(list(SparseVector(3, [], [])), [0.0, 0.0, 0.0]) + self.assertListEqual(list(SparseVector(5, [0, 3], [1.0, 2.0])), [1.0, 0.0, 0.0, 2.0, 0.0]) def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -278,6 +285,9 @@ def test_matrix_indexing(self): for j in range(2): self.assertEqual(mat[i, j], expected[i][j]) + for i, j in [(-1, 0), (4, 1), (3, 4)]: + self.assertRaises(IndexError, mat.__getitem__, (i, j)) + def test_repr_dense_matrix(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) self.assertTrue( @@ -349,6 +359,9 @@ def test_sparse_matrix(self): self.assertEqual(expected[i][j], sm1[i, j]) self.assertTrue(array_equal(sm1.toArray(), expected)) + for i, j in [(-1, 1), (4, 3), (3, 5)]: + self.assertRaises(IndexError, sm1.__getitem__, (i, j)) + # Test conversion to dense and sparse. smnew = sm1.toDense().toSparse() self.assertEqual(sm1.numRows, smnew.numRows) @@ -421,6 +434,74 @@ def test_norms(self): tmp = SparseVector(4, [0, 2], [3, 0]) self.assertEqual(tmp.numNonzeros(), 1) + def test_ml_mllib_vector_conversion(self): + # to ml + # dense + mllibDV = Vectors.dense([1, 2, 3]) + mlDV1 = newlinalg.Vectors.dense([1, 2, 3]) + mlDV2 = mllibDV.asML() + self.assertEqual(mlDV2, mlDV1) + # sparse + mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV2 = mllibSV.asML() + self.assertEqual(mlSV2, mlSV1) + # from ml + # dense + mllibDV1 = Vectors.dense([1, 2, 3]) + mlDV = newlinalg.Vectors.dense([1, 2, 3]) + mllibDV2 = Vectors.fromML(mlDV) + self.assertEqual(mllibDV1, mllibDV2) + # sparse + mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5}) + mllibSV2 = Vectors.fromML(mlSV) + self.assertEqual(mllibSV1, mllibSV2) + + def test_ml_mllib_matrix_conversion(self): + # to ml + # dense + mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3]) + mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3]) + mlDM2 = mllibDM.asML() + self.assertEqual(mlDM2, mlDM1) + # transposed + mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True) + mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True) + mlDMt2 = mllibDMt.asML() + self.assertEqual(mlDMt2, mlDMt1) + # sparse + mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM2 = mllibSM.asML() + self.assertEqual(mlSM2, mlSM1) + # transposed + mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt2 = mllibSMt.asML() + self.assertEqual(mlSMt2, mlSMt1) + # from ml + # dense + mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4]) + mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4]) + mllibDM2 = Matrices.fromML(mlDM) + self.assertEqual(mllibDM1, mllibDM2) + # transposed + mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True) + mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True) + mllibDMt2 = Matrices.fromML(mlDMt) + self.assertEqual(mllibDMt1, mllibDMt2) + # sparse + mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]) + mllibSM2 = Matrices.fromML(mlSM) + self.assertEqual(mllibSM1, mllibSM2) + # transposed + mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True) + mllibSMt2 = Matrices.fromML(mlSMt) + self.assertEqual(mllibSMt1, mllibSMt2) + class ListTests(MLlibTestCase): @@ -479,7 +560,7 @@ def test_gmm(self): [-6, -7], ]) clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=10, seed=56) + maxIterations=10, seed=1) labels = clusters.predict(data).collect() self.assertEqual(labels[0], labels[1]) self.assertEqual(labels[2], labels[3]) @@ -698,7 +779,6 @@ def test_serialization(self): self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) def test_infer_schema(self): - sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) df = rdd.toDF() schema = df.schema @@ -731,7 +811,6 @@ def test_serialization(self): self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) def test_infer_schema(self): - sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) df = rdd.toDF() schema = df.schema @@ -919,7 +998,7 @@ def test_goodness_of_fit(self): # Negative counts in observed neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_obs, expected1) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1) # Count = 0.0 in expected but not observed zero_expected = Vectors.dense([1.0, 0.0, 3.0]) @@ -930,7 +1009,8 @@ def test_goodness_of_fit(self): # 0.0 in expected and observed simultaneously zero_observed = Vectors.dense([2.0, 0.0, 1.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, zero_observed, zero_expected) + self.assertRaises( + IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected) def test_matrix_independence(self): data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] @@ -944,15 +1024,15 @@ def test_matrix_independence(self): # Negative counts neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_counts) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts) # Row sum = 0.0 row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, row_zero) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero) # Column sum = 0.0 col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, col_zero) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero) def test_chi_sq_pearson(self): data = [ @@ -1580,8 +1660,8 @@ class ALSTests(MLlibTestCase): def test_als_ratings_serialize(self): r = Rating(7, 1123, 3.14) - jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r))) - nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr))) + jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) + nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) self.assertEqual(r.user, nr.user) self.assertEqual(r.product, nr.product) self.assertAlmostEqual(r.rating, nr.rating, 2) @@ -1589,7 +1669,8 @@ def test_als_ratings_serialize(self): def test_als_ratings_id_long_error(self): r = Rating(1205640308657491975, 50233468418, 1.0) # rating user id exceeds max int value, should fail when pickled - self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r))) + self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads, + bytearray(ser.dumps(r))) class HashingTFTest(MLlibTestCase): diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index f7ea466b43291..b3011d42e56af 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -76,8 +76,6 @@ def toDebugString(self): class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ - .. note:: Experimental - A decision tree model for classification or regression. .. versionadded:: 1.1.0 @@ -130,8 +128,6 @@ def _java_loader_class(cls): class DecisionTree(object): """ - .. note:: Experimental - Learning algorithm for a decision tree model for classification or regression. @@ -283,8 +279,6 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, @inherit_doc class RandomForestModel(TreeEnsembleModel, JavaLoader): """ - .. note:: Experimental - Represents a random forest model. .. versionadded:: 1.2.0 @@ -297,8 +291,6 @@ def _java_loader_class(cls): class RandomForest(object): """ - .. note:: Experimental - Learning algorithm for a random forest model for classification or regression. @@ -486,8 +478,6 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt @inherit_doc class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): """ - .. note:: Experimental - Represents a gradient-boosted tree model. .. versionadded:: 1.3.0 @@ -500,8 +490,6 @@ def _java_loader_class(cls): class GradientBoostedTrees(object): """ - .. note:: Experimental - Learning algorithm for a gradient boosted trees model for classification or regression. @@ -657,9 +645,14 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, def _test(): import doctest globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + from pyspark.sql import SparkSession + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("mllib.tree tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 39bc6586dd582..ed6fd4bca4c54 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -26,6 +26,7 @@ from pyspark import SparkContext, since from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector +from pyspark.sql import DataFrame class MLUtils(object): @@ -139,8 +140,8 @@ def saveAsLibSVMFile(data, dir): >>> from pyspark.mllib.regression import LabeledPoint >>> from glob import glob >>> from pyspark.mllib.util import MLUtils - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> MLUtils.saveAsLibSVMFile(sc.parallelize(examples), tempFile.name) @@ -165,8 +166,8 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.regression import LabeledPoint - >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ - LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] + >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), + ... LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name) @@ -200,6 +201,166 @@ def loadVectors(sc, path): """ return callMLlibFunc("loadVectors", sc, path) + @staticmethod + @since("2.0.0") + def convertVectorColumnsToML(dataset, *cols): + """ + Converts vector columns in an input DataFrame from the + :py:class:`pyspark.mllib.linalg.Vector` type to the new + :py:class:`pyspark.ml.linalg.Vector` type under the `spark.ml` + package. + + :param dataset: + input dataset + :param cols: + a list of vector columns to be converted. + New vector columns will be ignored. If unspecified, all old + vector columns will be converted excepted nested ones. + :return: + the input dataset with old vector columns converted to the + new vector type + + >>> import pyspark + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.mllib.util import MLUtils + >>> df = spark.createDataFrame( + ... [(0, Vectors.sparse(2, [1], [1.0]), Vectors.dense(2.0, 3.0))], + ... ["id", "x", "y"]) + >>> r1 = MLUtils.convertVectorColumnsToML(df).first() + >>> isinstance(r1.x, pyspark.ml.linalg.SparseVector) + True + >>> isinstance(r1.y, pyspark.ml.linalg.DenseVector) + True + >>> r2 = MLUtils.convertVectorColumnsToML(df, "x").first() + >>> isinstance(r2.x, pyspark.ml.linalg.SparseVector) + True + >>> isinstance(r2.y, pyspark.mllib.linalg.DenseVector) + True + """ + if not isinstance(dataset, DataFrame): + raise TypeError("Input dataset must be a DataFrame but got {}.".format(type(dataset))) + return callMLlibFunc("convertVectorColumnsToML", dataset, list(cols)) + + @staticmethod + @since("2.0.0") + def convertVectorColumnsFromML(dataset, *cols): + """ + Converts vector columns in an input DataFrame to the + :py:class:`pyspark.mllib.linalg.Vector` type from the new + :py:class:`pyspark.ml.linalg.Vector` type under the `spark.ml` + package. + + :param dataset: + input dataset + :param cols: + a list of vector columns to be converted. + Old vector columns will be ignored. If unspecified, all new + vector columns will be converted except nested ones. + :return: + the input dataset with new vector columns converted to the + old vector type + + >>> import pyspark + >>> from pyspark.ml.linalg import Vectors + >>> from pyspark.mllib.util import MLUtils + >>> df = spark.createDataFrame( + ... [(0, Vectors.sparse(2, [1], [1.0]), Vectors.dense(2.0, 3.0))], + ... ["id", "x", "y"]) + >>> r1 = MLUtils.convertVectorColumnsFromML(df).first() + >>> isinstance(r1.x, pyspark.mllib.linalg.SparseVector) + True + >>> isinstance(r1.y, pyspark.mllib.linalg.DenseVector) + True + >>> r2 = MLUtils.convertVectorColumnsFromML(df, "x").first() + >>> isinstance(r2.x, pyspark.mllib.linalg.SparseVector) + True + >>> isinstance(r2.y, pyspark.ml.linalg.DenseVector) + True + """ + if not isinstance(dataset, DataFrame): + raise TypeError("Input dataset must be a DataFrame but got {}.".format(type(dataset))) + return callMLlibFunc("convertVectorColumnsFromML", dataset, list(cols)) + + @staticmethod + @since("2.0.0") + def convertMatrixColumnsToML(dataset, *cols): + """ + Converts matrix columns in an input DataFrame from the + :py:class:`pyspark.mllib.linalg.Matrix` type to the new + :py:class:`pyspark.ml.linalg.Matrix` type under the `spark.ml` + package. + + :param dataset: + input dataset + :param cols: + a list of matrix columns to be converted. + New matrix columns will be ignored. If unspecified, all old + matrix columns will be converted excepted nested ones. + :return: + the input dataset with old matrix columns converted to the + new matrix type + + >>> import pyspark + >>> from pyspark.mllib.linalg import Matrices + >>> from pyspark.mllib.util import MLUtils + >>> df = spark.createDataFrame( + ... [(0, Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]), + ... Matrices.dense(2, 2, range(4)))], ["id", "x", "y"]) + >>> r1 = MLUtils.convertMatrixColumnsToML(df).first() + >>> isinstance(r1.x, pyspark.ml.linalg.SparseMatrix) + True + >>> isinstance(r1.y, pyspark.ml.linalg.DenseMatrix) + True + >>> r2 = MLUtils.convertMatrixColumnsToML(df, "x").first() + >>> isinstance(r2.x, pyspark.ml.linalg.SparseMatrix) + True + >>> isinstance(r2.y, pyspark.mllib.linalg.DenseMatrix) + True + """ + if not isinstance(dataset, DataFrame): + raise TypeError("Input dataset must be a DataFrame but got {}.".format(type(dataset))) + return callMLlibFunc("convertMatrixColumnsToML", dataset, list(cols)) + + @staticmethod + @since("2.0.0") + def convertMatrixColumnsFromML(dataset, *cols): + """ + Converts matrix columns in an input DataFrame to the + :py:class:`pyspark.mllib.linalg.Matrix` type from the new + :py:class:`pyspark.ml.linalg.Matrix` type under the `spark.ml` + package. + + :param dataset: + input dataset + :param cols: + a list of matrix columns to be converted. + Old matrix columns will be ignored. If unspecified, all new + matrix columns will be converted except nested ones. + :return: + the input dataset with new matrix columns converted to the + old matrix type + + >>> import pyspark + >>> from pyspark.ml.linalg import Matrices + >>> from pyspark.mllib.util import MLUtils + >>> df = spark.createDataFrame( + ... [(0, Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4]), + ... Matrices.dense(2, 2, range(4)))], ["id", "x", "y"]) + >>> r1 = MLUtils.convertMatrixColumnsFromML(df).first() + >>> isinstance(r1.x, pyspark.mllib.linalg.SparseMatrix) + True + >>> isinstance(r1.y, pyspark.mllib.linalg.DenseMatrix) + True + >>> r2 = MLUtils.convertMatrixColumnsFromML(df, "x").first() + >>> isinstance(r2.x, pyspark.mllib.linalg.SparseMatrix) + True + >>> isinstance(r2.y, pyspark.ml.linalg.DenseMatrix) + True + """ + if not isinstance(dataset, DataFrame): + raise TypeError("Input dataset must be a DataFrame but got {}.".format(type(dataset))) + return callMLlibFunc("convertMatrixColumnsFromML", dataset, list(cols)) + class Saveable(object): """ @@ -347,13 +508,18 @@ def generateLinearRDD(sc, nexamples, nfeatures, eps, def _test(): import doctest - from pyspark.context import SparkContext + from pyspark.sql import SparkSession globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + spark = SparkSession.builder\ + .master("local[2]")\ + .appName("mllib.util tests")\ + .getOrCreate() + globs['spark'] = spark + globs['sc'] = spark.sparkContext (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8978f028c5928..5fb10f86f4692 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -754,8 +754,8 @@ def foreachPartition(self, f): Applies a function to each partition of this RDD. >>> def f(iterator): - ... for x in iterator: - ... print(x) + ... for x in iterator: + ... print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ def func(it): @@ -1027,20 +1027,20 @@ def histogram(self, buckets): If your histogram is evenly spaced (e.g. [0, 10, 20, 30]), this can be switched from an O(log n) inseration to O(1) per - element(where n = # buckets). + element (where n is the number of buckets). - Buckets must be sorted and not contain any duplicates, must be + Buckets must be sorted, not contain any duplicates, and have at least two elements. - If `buckets` is a number, it will generates buckets which are + If `buckets` is a number, it will generate buckets which are evenly spaced between the minimum and maximum of the RDD. For - example, if the min value is 0 and the max is 100, given buckets - as 2, the resulting buckets will be [0,50) [50,100]. buckets must - be at least 1 If the RDD contains infinity, NaN throws an exception - If the elements in RDD do not vary (max == min) always returns - a single bucket. + example, if the min value is 0 and the max is 100, given `buckets` + as 2, the resulting buckets will be [0,50) [50,100]. `buckets` must + be at least 1. An exception is raised if the RDD contains infinity. + If the elements in the RDD do not vary (max == min), a single bucket + will be used. - It will return an tuple of buckets and histogram. + The return value is a tuple of buckets and histogram. >>> rdd = sc.parallelize(range(51)) >>> rdd.histogram(2) @@ -2211,7 +2211,7 @@ def lookup(self, key): return values.collect() def _to_java_object_rdd(self): - """ Return an JavaRDD of Object by unpickling + """ Return a JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. @@ -2274,9 +2274,9 @@ def countApproxDistinct(self, relativeSD=0.05): Return approximate number of distinct elements in the RDD. The algorithm used is based on streamlib's implementation of - "HyperLogLog in Practice: Algorithmic Engineering of a State - of The Art Cardinality Estimation Algorithm", available - here. + `"HyperLogLog in Practice: Algorithmic Engineering of a State + of The Art Cardinality Estimation Algorithm", available here + `_. :param relativeSD: Relative accuracy. Smaller values create counters that require more space. diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index c6b0eda996168..ac5ce87a3f0fd 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -35,17 +35,21 @@ if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) -sc = SparkContext() -atexit.register(lambda: sc.stop()) +SparkContext._ensure_initialized() try: # Try to access HiveConf, it will raise exception if Hive is not added - sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - spark = SparkSession.withHiveSupport(sc) + SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf() + spark = SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() except py4j.protocol.Py4JError: - spark = SparkSession(sc) + spark = SparkSession.builder.getOrCreate() except TypeError: - spark = SparkSession(sc) + spark = SparkSession.builder.getOrCreate() + +sc = spark.sparkContext +atexit.register(lambda: sc.stop()) # for compatibility sqlContext = spark._wrapped diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index cff73ff192e51..22ec416f6c584 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -18,7 +18,7 @@ """ Important classes of Spark SQL and DataFrames: - - :class:`pyspark.sql.SQLContext` + - :class:`pyspark.sql.SparkSession` Main entry point for :class:`DataFrame` and SQL functionality. - :class:`pyspark.sql.DataFrame` A distributed collection of data grouped into named columns. @@ -26,8 +26,6 @@ A column expression in a :class:`DataFrame`. - :class:`pyspark.sql.Row` A row of data in a :class:`DataFrame`. - - :class:`pyspark.sql.HiveContext` - Main entry point for accessing data stored in Apache Hive. - :class:`pyspark.sql.GroupedData` Aggregation methods, returned by :func:`DataFrame.groupBy`. - :class:`pyspark.sql.DataFrameNaFunctions` @@ -45,7 +43,7 @@ from pyspark.sql.types import Row -from pyspark.sql.context import SQLContext, HiveContext +from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration from pyspark.sql.session import SparkSession from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStatFunctions @@ -55,7 +53,8 @@ __all__ = [ - 'SparkSession', 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', - 'Row', 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', + 'SparkSession', 'SQLContext', 'HiveContext', 'UDFRegistration', + 'DataFrame', 'GroupedData', 'Column', 'Row', + 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 9cfdd0a99f241..3c5030722f307 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -166,34 +166,20 @@ def createExternalTable(self, tableName, path=None, source=None, schema=None, ** return DataFrame(df, self._sparkSession._wrapped) @since(2.0) - def dropTempTable(self, tableName): - """Drops the temporary table with the given table name in the catalog. - If the table has been cached before, then it will also be uncached. + def dropTempView(self, viewName): + """Drops the temporary view with the given view name in the catalog. + If the view has been cached before, then it will also be uncached. - >>> spark.createDataFrame([(1, 1)]).registerTempTable("my_table") + >>> spark.createDataFrame([(1, 1)]).createTempView("my_table") >>> spark.table("my_table").collect() [Row(_1=1, _2=1)] - >>> spark.catalog.dropTempTable("my_table") + >>> spark.catalog.dropTempView("my_table") >>> spark.table("my_table") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... AnalysisException: ... """ - self._jcatalog.dropTempTable(tableName) - - @since(2.0) - def registerTable(self, df, tableName): - """Registers the given :class:`DataFrame` as a temporary table in the catalog. - - >>> df = spark.createDataFrame([(2, 1), (3, 1)]) - >>> spark.catalog.registerTable(df, "my_cool_table") - >>> spark.table("my_cool_table").collect() - [Row(_1=2, _2=1), Row(_1=3, _2=1)] - """ - if isinstance(df, DataFrame): - self._jsparkSession.registerTable(df._jdf, tableName) - else: - raise ValueError("Can only register DataFrame as table") + self._jcatalog.dropTempView(viewName) @ignore_unicode_prefix @since(2.0) @@ -207,7 +193,7 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function - :param returnType: a :class:`DataType` object + :param returnType: a :class:`pyspark.sql.types.DataType` object >>> spark.catalog.registerFunction("stringLengthString", lambda x: len(x)) >>> spark.sql("SELECT stringLengthString('test')").collect() @@ -246,6 +232,11 @@ def clearCache(self): """Removes all cached tables from the in-memory cache.""" self._jcatalog.clearCache() + @since(2.0) + def refreshTable(self, tableName): + """Invalidate and refresh all the cached metadata of the given table.""" + self._jcatalog.refreshTable(tableName) + def _reset(self): """(Internal use only) Drop all existing databases (except "default"), tables, partitions and functions, and set the current database to "default". @@ -258,21 +249,23 @@ def _reset(self): def _test(): import os import doctest - from pyspark.context import SparkContext - from pyspark.sql.session import SparkSession + from pyspark.sql import SparkSession import pyspark.sql.catalog os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.catalog.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['spark'] = SparkSession(sc) + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.catalog tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext + globs['spark'] = spark (failure_count, test_count) = doctest.testmod( pyspark.sql.catalog, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 43e9baece2de9..4b99f3058b75c 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -139,8 +139,6 @@ class Column(object): df.colName + 1 1 / df.colName - .. note:: Experimental - .. versionadded:: 1.3 """ @@ -418,8 +416,6 @@ def over(self, window): >>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1) >>> from pyspark.sql.functions import rank, min >>> # df.select(rank().over(window), min('age').over(window)) - - .. note:: Window functions is only supported with HiveContext in 1.4 """ from pyspark.sql.window import WindowSpec if not isinstance(window, WindowSpec): @@ -438,13 +434,15 @@ def __repr__(self): def _test(): import doctest - from pyspark.context import SparkContext - from pyspark.sql import SQLContext + from pyspark.sql import SparkSession import pyspark.sql.column globs = pyspark.sql.column.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.column tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) @@ -452,7 +450,7 @@ def _test(): (failure_count, test_count) = doctest.testmod( pyspark.sql.column, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 7428c919915f3..792c420ca6386 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -23,7 +23,6 @@ class RuntimeConfig(object): """User-facing configuration API, accessible through `SparkSession.conf`. Options set here are automatically propagated to the Hadoop configuration during I/O. - This a thin wrapper around its Scala implementation org.apache.spark.sql.RuntimeConfig. """ def __init__(self, jconf): @@ -72,11 +71,14 @@ def _test(): os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.conf.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') - globs['sc'] = sc - globs['spark'] = SparkSession(sc) + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.conf tests")\ + .getOrCreate() + globs['sc'] = spark.sparkContext + globs['spark'] = spark (failure_count, test_count) = doctest.testmod(pyspark.sql.conf, globs=globs) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 2096236d7f36f..8cdf37188e665 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -17,6 +17,7 @@ from __future__ import print_function import sys +import warnings if sys.version >= '3': basestring = unicode = str @@ -26,6 +27,7 @@ from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import Row, StringType from pyspark.sql.utils import install_exception_handler @@ -33,7 +35,10 @@ class SQLContext(object): - """Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality. + """The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x. + + As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class + here for backward compatibility. A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. @@ -56,7 +61,7 @@ def __init__(self, sparkContext, sparkSession=None, jsqlContext=None): ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> df = allTypes.toDF() - >>> df.registerTempTable("allTypes") + >>> df.createOrReplaceTempView("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \ @@ -105,7 +110,7 @@ def getOrCreate(cls, sc): def newSession(self): """ Returns a new SQLContext as new session, that has separate SQLConf, - registered temporary tables and UDFs, but shared SparkContext and + registered temporary views and UDFs, but shared SparkContext and table cache. """ return self.__class__(self._sc, self.sparkSession.newSession()) @@ -147,9 +152,9 @@ def udf(self): @since(1.4) def range(self, start, end=None, step=1, numPartitions=None): """ - Create a :class:`DataFrame` with single LongType column named `id`, - containing elements in a range from `start` to `end` (exclusive) with - step value `step`. + Create a :class:`DataFrame` with single :class:`pyspark.sql.types.LongType` column named + ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with + step value ``step``. :param start: the start value :param end: the end value (exclusive) @@ -179,7 +184,7 @@ def registerFunction(self, name, f, returnType=StringType()): :param name: name of the UDF :param f: python function - :param returnType: a :class:`DataType` object + :param returnType: a :class:`pyspark.sql.types.DataType` object >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() @@ -204,13 +209,13 @@ def _inferSchema(self, rdd, samplingRatio=None): :param rdd: an RDD of Row or tuple :param samplingRatio: sampling ratio, or no sampling (default) - :return: StructType + :return: :class:`pyspark.sql.types.StructType` """ return self.sparkSession._inferSchema(rdd, samplingRatio) @since(1.3) @ignore_unicode_prefix - def createDataFrame(self, data, schema=None, samplingRatio=None): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -221,28 +226,38 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): from ``data``, which should be an RDD of :class:`Row`, or :class:`namedtuple`, or :class:`dict`. - When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or - exception will be thrown at runtime. If the given schema is not StructType, it will be - wrapped into a StructType as its only field, and the field name will be "value", each record - will also be wrapped into a tuple, which can be converted to row later. + When ``schema`` is :class:`pyspark.sql.types.DataType` or + :class:`pyspark.sql.types.StringType`, it must match the + real data, or an exception will be thrown at runtime. If the given schema is not + :class:`pyspark.sql.types.StructType`, it will be wrapped into a + :class:`pyspark.sql.types.StructType` as its only field, and the field name will be "value", + each record will also be wrapped into a tuple, which can be converted to row later. If schema inference is needed, ``samplingRatio`` is used to determined the ratio of rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. - :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, - etc.), or :class:`list`, or :class:`pandas.DataFrame`. - :param schema: a :class:`DataType` or a datatype string or a list of column names, default - is None. The data type string format equals to `DataType.simpleString`, except that - top level struct type can omit the `struct<>` and atomic types use `typeName()` as - their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int` - as a short name for IntegerType. + :param data: an RDD of any kind of SQL data representation(e.g. :class:`Row`, + :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, or + :class:`pandas.DataFrame`. + :param schema: a :class:`pyspark.sql.types.DataType` or a + :class:`pyspark.sql.types.StringType` or a list of + column names, default is None. The data type string format equals to + :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can + omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use + ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. + We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`. :param samplingRatio: the sample ratio of rows used for inferring + :param verifySchema: verify data types of every row against schema. :return: :class:`DataFrame` .. versionchanged:: 2.0 - The schema parameter can be a DataType or a datatype string after 2.0. If it's not a - StructType, it will be wrapped into a StructType and each record will also be wrapped - into a tuple. + The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a + :class:`pyspark.sql.types.StringType` after 2.0. + If it's not a :class:`pyspark.sql.types.StructType`, it will be wrapped into a + :class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple. + + .. versionchanged:: 2.0.1 + Added verifySchema. >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() @@ -291,7 +306,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): ... Py4JJavaError: ... """ - return self.sparkSession.createDataFrame(data, schema, samplingRatio) + return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema) @since(1.3) def registerDataFrameAsTable(self, df, tableName): @@ -301,7 +316,7 @@ def registerDataFrameAsTable(self, df, tableName): >>> sqlContext.registerDataFrameAsTable(df, "table1") """ - self.sparkSession.catalog.registerTable(df, tableName) + df.createOrReplaceTempView(tableName) @since(1.6) def dropTempTable(self, tableName): @@ -310,7 +325,7 @@ def dropTempTable(self, tableName): >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> sqlContext.dropTempTable("table1") """ - self.sparkSession.catalog.dropTempTable(tableName) + self.sparkSession.catalog.dropTempView(tableName) @since(1.3) def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): @@ -424,17 +439,35 @@ def read(self): """ return DataFrameReader(self) + @property + @since(2.0) + def readStream(self): + """ + Returns a :class:`DataStreamReader` that can be used to read data streams + as a streaming :class:`DataFrame`. + + .. note:: Experimental. + + :return: :class:`DataStreamReader` + + >>> text_sdf = sqlContext.readStream.text(tempfile.mkdtemp()) + >>> text_sdf.isStreaming + True + """ + return DataStreamReader(self) + @property @since(2.0) def streams(self): - """Returns a :class:`ContinuousQueryManager` that allows managing all the - :class:`ContinuousQuery` ContinuousQueries active on `this` context. + """Returns a :class:`StreamingQueryManager` that allows managing all the + :class:`StreamingQuery` StreamingQueries active on `this` context. + + .. note:: Experimental. """ - from pyspark.sql.streaming import ContinuousQueryManager - return ContinuousQueryManager(self._ssql_ctx.streams()) + from pyspark.sql.streaming import StreamingQueryManager + return StreamingQueryManager(self._ssql_ctx.streams()) -# TODO(andrew): deprecate this class HiveContext(SQLContext): """A variant of Spark SQL that integrates with data stored in Hive. @@ -444,11 +477,18 @@ class HiveContext(SQLContext): :param sparkContext: The SparkContext to wrap. :param jhiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new :class:`HiveContext` in the JVM, instead we make all calls to this object. + + .. note:: Deprecated in 2.0.0. Use SparkSession.builder.enableHiveSupport().getOrCreate(). """ + warnings.warn( + "HiveContext is deprecated in Spark 2.0.0. Please use " + + "SparkSession.builder.enableHiveSupport().getOrCreate() instead.", + DeprecationWarning) + def __init__(self, sparkContext, jhiveContext=None): if jhiveContext is None: - sparkSession = SparkSession.withHiveSupport(sparkContext) + sparkSession = SparkSession.builder.enableHiveSupport().getOrCreate() else: sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession()) SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext) @@ -462,7 +502,7 @@ def _createForTesting(cls, sparkContext): confusing error messages. """ jsc = sparkContext._jsc.sc() - jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc) + jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc, False) return cls(sparkContext, jtestHive) def refreshTable(self, tableName): @@ -490,6 +530,7 @@ def register(self, name, f, returnType=StringType()): def _test(): import os import doctest + import tempfile from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context @@ -498,6 +539,8 @@ def _test(): globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') + globs['tempfile'] = tempfile + globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) globs['rdd'] = rdd = sc.parallelize( diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index bbe15f5f900da..64d7a20757432 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -34,6 +34,7 @@ from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter +from pyspark.sql.streaming import DataStreamWriter from pyspark.sql.types import * __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -60,11 +61,9 @@ class DataFrame(object): people = sqlContext.read.parquet("...") department = sqlContext.read.parquet("...") - people.filter(people.age > 30).join(department, people.deptId == department.id)\ + people.filter(people.age > 30).join(department, people.deptId == department.id) \\ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) - .. note:: Experimental - .. versionadded:: 1.3 """ @@ -121,26 +120,83 @@ def registerTempTable(self, name): that was used to create this :class:`DataFrame`. >>> df.registerTempTable("people") - >>> df2 = sqlContext.sql("select * from people") + >>> df2 = spark.sql("select * from people") + >>> sorted(df.collect()) == sorted(df2.collect()) + True + >>> spark.catalog.dropTempView("people") + + .. note:: Deprecated in 2.0, use createOrReplaceTempView instead. + """ + self._jdf.createOrReplaceTempView(name) + + @since(2.0) + def createTempView(self, name): + """Creates a temporary view with this DataFrame. + + The lifetime of this temporary table is tied to the :class:`SparkSession` + that was used to create this :class:`DataFrame`. + throws :class:`TempTableAlreadyExistsException`, if the view name already exists in the + catalog. + + >>> df.createTempView("people") + >>> df2 = spark.sql("select * from people") >>> sorted(df.collect()) == sorted(df2.collect()) True + >>> df.createTempView("people") # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + AnalysisException: u"Temporary table 'people' already exists;" + >>> spark.catalog.dropTempView("people") + + """ + self._jdf.createTempView(name) + + @since(2.0) + def createOrReplaceTempView(self, name): + """Creates or replaces a temporary view with this DataFrame. + + The lifetime of this temporary table is tied to the :class:`SparkSession` + that was used to create this :class:`DataFrame`. + + >>> df.createOrReplaceTempView("people") + >>> df2 = df.filter(df.age > 3) + >>> df2.createOrReplaceTempView("people") + >>> df3 = spark.sql("select * from people") + >>> sorted(df3.collect()) == sorted(df2.collect()) + True + >>> spark.catalog.dropTempView("people") + """ - self._jdf.registerTempTable(name) + self._jdf.createOrReplaceTempView(name) @property @since(1.4) def write(self): """ - Interface for saving the content of the :class:`DataFrame` out into external storage. + Interface for saving the content of the non-streaming :class:`DataFrame` out into external + storage. :return: :class:`DataFrameWriter` """ return DataFrameWriter(self) + @property + @since(2.0) + def writeStream(self): + """ + Interface for saving the content of the streaming :class:`DataFrame` out into external + storage. + + .. note:: Experimental. + + :return: :class:`DataStreamWriter` + """ + return DataStreamWriter(self) + @property @since(1.3) def schema(self): - """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. + """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`. >>> df.schema StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) @@ -202,10 +258,12 @@ def isLocal(self): def isStreaming(self): """Returns true if this :class:`Dataset` contains one or more sources that continuously return data as it arrives. A :class:`Dataset` that reads data from a streaming source - must be executed as a :class:`ContinuousQuery` using the :func:`startStream` method in - :class:`DataFrameWriter`. Methods that return a single answer, (e.g., :func:`count` or + must be executed as a :class:`StreamingQuery` using the :func:`start` method in + :class:`DataStreamWriter`. Methods that return a single answer, (e.g., :func:`count` or :func:`collect`) will throw an :class:`AnalysisException` when there is a streaming source present. + + .. note:: Experimental """ return self._jdf.isStreaming() @@ -287,10 +345,7 @@ def take(self, num): >>> df.take(2) [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ - with SCCallSiteSync(self._sc) as css: - port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe( - self._jdf, num) - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + return self.limit(num).collect() @since(1.3) def foreach(self, f): @@ -363,15 +418,6 @@ def coalesce(self, numPartitions): """ return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx) - @since(1.3) - def repartition(self, numPartitions): - """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. - - >>> df.repartition(10).rdd.getNumPartitions() - 10 - """ - return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) - @since(1.3) def repartition(self, numPartitions, *cols): """ @@ -542,7 +588,7 @@ def alias(self, alias): >>> df_as2 = df.alias("df_as2") >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') >>> joined_df.select("df_as1.name", "df_as2.name", "df_as2.age").collect() - [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] + [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)] """ assert isinstance(alias, basestring), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) @@ -552,16 +598,16 @@ def alias(self, alias): def join(self, other, on=None, how=None): """Joins with another :class:`DataFrame`, using the given join expression. - The following performs a full outer join between ``df1`` and ``df2``. - :param other: Right side of the join - :param on: a string for join column name, a list of column names, - , a join expression (Column) or a list of Columns. - If `on` is a string or a list of string indicating the name of the join column(s), + :param on: a string for the join column name, a list of column names, + a join expression (Column), or a list of Columns. + If `on` is a string or a list of strings indicating the name of the join column(s), the column(s) must exist on both sides, and this performs an equi-join. :param how: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. + The following performs a full outer join between ``df1`` and ``df2``. + >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] @@ -984,10 +1030,10 @@ def dropDuplicates(self, subset=None): :func:`drop_duplicates` is an alias for :func:`dropDuplicates`. >>> from pyspark.sql import Row - >>> df = sc.parallelize([ \ - Row(name='Alice', age=5, height=80), \ - Row(name='Alice', age=5, height=80), \ - Row(name='Alice', age=10, height=80)]).toDF() + >>> df = sc.parallelize([ \\ + ... Row(name='Alice', age=5, height=80), \\ + ... Row(name='Alice', age=5, height=80), \\ + ... Row(name='Alice', age=10, height=80)]).toDF() >>> df.dropDuplicates().show() +---+------+-----+ |age|height| name| @@ -1327,6 +1373,7 @@ def withColumn(self, colName, col): @since(1.3) def withColumnRenamed(self, existing, new): """Returns a new :class:`DataFrame` by renaming an existing column. + This is a no-op if schema doesn't contain the given column name. :param existing: string, name of the existing column to rename. :param col: string, new name of the column. @@ -1340,6 +1387,7 @@ def withColumnRenamed(self, existing, new): @ignore_unicode_prefix def drop(self, col): """Returns a new :class:`DataFrame` that drops the specified column. + This is a no-op if schema doesn't contain the given column name(s). :param col: a string name of the column to drop, or a :class:`Column` to drop. @@ -1488,12 +1536,13 @@ def sampleBy(self, col, fractions, seed=None): def _test(): import doctest from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext + from pyspark.sql import Row, SQLContext, SparkSession import pyspark.sql.dataframe globs = pyspark.sql.dataframe.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) + globs['spark'] = SparkSession(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dac842c0ce8c0..4ea83e24bbc9a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -142,7 +142,7 @@ def _(): _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + 'polar coordinates (r, theta).', - 'hypot': 'Computes `sqrt(a^2 + b^2)` without intermediate overflow or underflow.', + 'hypot': 'Computes ``sqrt(a^2 + b^2)`` without intermediate overflow or underflow.', 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } @@ -212,7 +212,7 @@ def broadcast(df): def coalesce(*cols): """Returns the first column that is not null. - >>> cDf = sqlContext.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b")) + >>> cDf = spark.createDataFrame([(None, None), (1, None), (None, 2)], ("a", "b")) >>> cDf.show() +----+----+ | a| b| @@ -252,7 +252,7 @@ def corr(col1, col2): >>> a = range(20) >>> b = [2 * x for x in range(20)] - >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(corr("a", "b").alias('c')).collect() [Row(c=1.0)] """ @@ -267,7 +267,7 @@ def covar_pop(col1, col2): >>> a = [1] * 10 >>> b = [1] * 10 - >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(covar_pop("a", "b").alias('c')).collect() [Row(c=0.0)] """ @@ -282,7 +282,7 @@ def covar_samp(col1, col2): >>> a = [1] * 10 >>> b = [1] * 10 - >>> df = sqlContext.createDataFrame(zip(a, b), ["a", "b"]) + >>> df = spark.createDataFrame(zip(a, b), ["a", "b"]) >>> df.agg(covar_samp("a", "b").alias('c')).collect() [Row(c=0.0)] """ @@ -373,7 +373,7 @@ def input_file_name(): def isnan(col): """An expression that returns true iff the column is NaN. - >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) >>> df.select(isnan("a").alias("r1"), isnan(df.a).alias("r2")).collect() [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ @@ -385,7 +385,7 @@ def isnan(col): def isnull(col): """An expression that returns true iff the column is null. - >>> df = sqlContext.createDataFrame([(1, None), (None, 2)], ("a", "b")) + >>> df = spark.createDataFrame([(1, None), (None, 2)], ("a", "b")) >>> df.select(isnull("a").alias("r1"), isnull(df.a).alias("r2")).collect() [Row(r1=False, r2=False), Row(r1=True, r2=True)] """ @@ -432,7 +432,7 @@ def nanvl(col1, col2): Both inputs should be floating point columns (DoubleType or FloatType). - >>> df = sqlContext.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) + >>> df = spark.createDataFrame([(1.0, float('nan')), (float('nan'), 2.0)], ("a", "b")) >>> df.select(nanvl("a", "b").alias("r1"), nanvl(df.a, df.b).alias("r2")).collect() [Row(r1=1.0, r2=1.0), Row(r1=2.0, r2=2.0)] """ @@ -470,7 +470,7 @@ def round(col, scale=0): Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0 or at integral part when `scale` < 0. - >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect() [Row(r=3.0)] """ sc = SparkContext._active_spark_context @@ -483,7 +483,7 @@ def bround(col, scale=0): Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0 or at integral part when `scale` < 0. - >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() + >>> spark.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect() [Row(r=2.0)] """ sc = SparkContext._active_spark_context @@ -494,7 +494,7 @@ def bround(col, scale=0): def shiftLeft(col, numBits): """Shift the given value numBits left. - >>> sqlContext.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() + >>> spark.createDataFrame([(21,)], ['a']).select(shiftLeft('a', 1).alias('r')).collect() [Row(r=42)] """ sc = SparkContext._active_spark_context @@ -505,7 +505,7 @@ def shiftLeft(col, numBits): def shiftRight(col, numBits): """Shift the given value numBits right. - >>> sqlContext.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() + >>> spark.createDataFrame([(42,)], ['a']).select(shiftRight('a', 1).alias('r')).collect() [Row(r=21)] """ sc = SparkContext._active_spark_context @@ -517,7 +517,7 @@ def shiftRight(col, numBits): def shiftRightUnsigned(col, numBits): """Unsigned shift the given value numBits right. - >>> df = sqlContext.createDataFrame([(-42,)], ['a']) + >>> df = spark.createDataFrame([(-42,)], ['a']) >>> df.select(shiftRightUnsigned('a', 1).alias('r')).collect() [Row(r=9223372036854775787)] """ @@ -575,7 +575,7 @@ def greatest(*cols): Returns the greatest value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null iff all parameters are null. - >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) >>> df.select(greatest(df.a, df.b, df.c).alias("greatest")).collect() [Row(greatest=4)] """ @@ -591,7 +591,7 @@ def least(*cols): Returns the least value of the list of column names, skipping null values. This function takes at least 2 parameters. It will return null iff all parameters are null. - >>> df = sqlContext.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) + >>> df = spark.createDataFrame([(1, 4, 3)], ['a', 'b', 'c']) >>> df.select(least(df.a, df.b, df.c).alias("least")).collect() [Row(least=1)] """ @@ -647,7 +647,7 @@ def log(arg1, arg2=None): def log2(col): """Returns the base-2 logarithm of the argument. - >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() + >>> spark.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() [Row(log2=2.0)] """ sc = SparkContext._active_spark_context @@ -660,7 +660,7 @@ def conv(col, fromBase, toBase): """ Convert a number in a string column from one base to another. - >>> df = sqlContext.createDataFrame([("010101",)], ['n']) + >>> df = spark.createDataFrame([("010101",)], ['n']) >>> df.select(conv(df.n, 2, 16).alias('hex')).collect() [Row(hex=u'15')] """ @@ -673,7 +673,7 @@ def factorial(col): """ Computes the factorial of the given value. - >>> df = sqlContext.createDataFrame([(5,)], ['n']) + >>> df = spark.createDataFrame([(5,)], ['n']) >>> df.select(factorial(df.n).alias('f')).collect() [Row(f=120)] """ @@ -765,7 +765,7 @@ def date_format(date, format): NOTE: Use when ever possible specialized functions like `year`. These benefit from a specialized implementation. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(date_format('a', 'MM/dd/yyy').alias('date')).collect() [Row(date=u'04/08/2015')] """ @@ -778,7 +778,7 @@ def year(col): """ Extract the year of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(year('a').alias('year')).collect() [Row(year=2015)] """ @@ -791,7 +791,7 @@ def quarter(col): """ Extract the quarter of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(quarter('a').alias('quarter')).collect() [Row(quarter=2)] """ @@ -804,7 +804,7 @@ def month(col): """ Extract the month of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(month('a').alias('month')).collect() [Row(month=4)] """ @@ -817,7 +817,7 @@ def dayofmonth(col): """ Extract the day of the month of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(dayofmonth('a').alias('day')).collect() [Row(day=8)] """ @@ -830,7 +830,7 @@ def dayofyear(col): """ Extract the day of the year of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(dayofyear('a').alias('day')).collect() [Row(day=98)] """ @@ -843,7 +843,7 @@ def hour(col): """ Extract the hours of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) >>> df.select(hour('a').alias('hour')).collect() [Row(hour=13)] """ @@ -856,7 +856,7 @@ def minute(col): """ Extract the minutes of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) >>> df.select(minute('a').alias('minute')).collect() [Row(minute=8)] """ @@ -869,7 +869,7 @@ def second(col): """ Extract the seconds of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08 13:08:15',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08 13:08:15',)], ['a']) >>> df.select(second('a').alias('second')).collect() [Row(second=15)] """ @@ -882,7 +882,7 @@ def weekofyear(col): """ Extract the week number of a given date as integer. - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['a']) >>> df.select(weekofyear(df.a).alias('week')).collect() [Row(week=15)] """ @@ -895,7 +895,7 @@ def date_add(start, days): """ Returns the date that is `days` days after `start` - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) >>> df.select(date_add(df.d, 1).alias('d')).collect() [Row(d=datetime.date(2015, 4, 9))] """ @@ -908,7 +908,7 @@ def date_sub(start, days): """ Returns the date that is `days` days before `start` - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) >>> df.select(date_sub(df.d, 1).alias('d')).collect() [Row(d=datetime.date(2015, 4, 7))] """ @@ -921,7 +921,7 @@ def datediff(end, start): """ Returns the number of days from `start` to `end`. - >>> df = sqlContext.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) + >>> df = spark.createDataFrame([('2015-04-08','2015-05-10')], ['d1', 'd2']) >>> df.select(datediff(df.d2, df.d1).alias('diff')).collect() [Row(diff=32)] """ @@ -934,7 +934,7 @@ def add_months(start, months): """ Returns the date that is `months` months after `start` - >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df = spark.createDataFrame([('2015-04-08',)], ['d']) >>> df.select(add_months(df.d, 1).alias('d')).collect() [Row(d=datetime.date(2015, 5, 8))] """ @@ -947,7 +947,7 @@ def months_between(date1, date2): """ Returns the number of months between date1 and date2. - >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) + >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) >>> df.select(months_between(df.t, df.d).alias('months')).collect() [Row(months=3.9495967...)] """ @@ -958,9 +958,10 @@ def months_between(date1, date2): @since(1.5) def to_date(col): """ - Converts the column of StringType or TimestampType into DateType. + Converts the column of :class:`pyspark.sql.types.StringType` or + :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType`. - >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(to_date(df.t).alias('date')).collect() [Row(date=datetime.date(1997, 2, 28))] """ @@ -975,7 +976,7 @@ def trunc(date, format): :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' - >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d']) + >>> df = spark.createDataFrame([('1997-02-28',)], ['d']) >>> df.select(trunc(df.d, 'year').alias('year')).collect() [Row(year=datetime.date(1997, 1, 1))] >>> df.select(trunc(df.d, 'mon').alias('month')).collect() @@ -993,7 +994,7 @@ def next_day(date, dayOfWeek): Day of the week parameter is case insensitive, and accepts: "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". - >>> df = sqlContext.createDataFrame([('2015-07-27',)], ['d']) + >>> df = spark.createDataFrame([('2015-07-27',)], ['d']) >>> df.select(next_day(df.d, 'Sun').alias('date')).collect() [Row(date=datetime.date(2015, 8, 2))] """ @@ -1006,7 +1007,7 @@ def last_day(date): """ Returns the last day of the month which the given date belongs to. - >>> df = sqlContext.createDataFrame([('1997-02-10',)], ['d']) + >>> df = spark.createDataFrame([('1997-02-10',)], ['d']) >>> df.select(last_day(df.d).alias('date')).collect() [Row(date=datetime.date(1997, 2, 28))] """ @@ -1045,7 +1046,7 @@ def from_utc_timestamp(timestamp, tz): """ Assumes given timestamp is UTC and converts to given timezone. - >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(from_utc_timestamp(df.t, "PST").alias('t')).collect() [Row(t=datetime.datetime(1997, 2, 28, 2, 30))] """ @@ -1058,7 +1059,7 @@ def to_utc_timestamp(timestamp, tz): """ Assumes given timestamp is in given timezone and converts to UTC. - >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) >>> df.select(to_utc_timestamp(df.t, "PST").alias('t')).collect() [Row(t=datetime.datetime(1997, 2, 28, 18, 30))] """ @@ -1074,20 +1075,20 @@ def window(timeColumn, windowDuration, slideDuration=None, startTime=None): [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in the order of months are not supported. - The time column must be of TimestampType. + The time column must be of :class:`pyspark.sql.types.TimestampType`. Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. - If the `slideDuration` is not provided, the windows will be tumbling windows. + If the ``slideDuration`` is not provided, the windows will be tumbling windows. The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start window intervals. For example, in order to have hourly tumbling windows that start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. The output column will be a struct called 'window' by default with the nested columns 'start' - and 'end', where 'start' and 'end' will be of `TimestampType`. + and 'end', where 'start' and 'end' will be of :class:`pyspark.sql.types.TimestampType`. - >>> df = sqlContext.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> df = spark.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) >>> w.select(w.window.start.cast("string").alias("start"), ... w.window.end.cast("string").alias("end"), "sum").collect() @@ -1124,7 +1125,7 @@ def crc32(col): Calculates the cyclic redundancy check value (CRC32) of a binary column and returns the value as a bigint. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(crc32('a').alias('crc32')).collect() [Row(crc32=2743272264)] """ sc = SparkContext._active_spark_context @@ -1136,7 +1137,7 @@ def crc32(col): def md5(col): """Calculates the MD5 digest and returns the value as a 32 character hex string. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] """ sc = SparkContext._active_spark_context @@ -1149,7 +1150,7 @@ def md5(col): def sha1(col): """Returns the hex string result of SHA-1. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] """ sc = SparkContext._active_spark_context @@ -1177,9 +1178,9 @@ def sha2(col, numBits): @since(2.0) def hash(*cols): - """Calculates the hash code of given columns, and returns the result as a int column. + """Calculates the hash code of given columns, and returns the result as an int column. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() [Row(hash=-757602832)] """ sc = SparkContext._active_spark_context @@ -1215,7 +1216,7 @@ def concat(*cols): """ Concatenates multiple input string columns together into a single string column. - >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) >>> df.select(concat(df.s, df.d).alias('s')).collect() [Row(s=u'abcd123')] """ @@ -1230,7 +1231,7 @@ def concat_ws(sep, *cols): Concatenates multiple input string columns together into a single string column, using the given separator. - >>> df = sqlContext.createDataFrame([('abcd','123')], ['s', 'd']) + >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) >>> df.select(concat_ws('-', df.s, df.d).alias('s')).collect() [Row(s=u'abcd-123')] """ @@ -1268,7 +1269,7 @@ def format_number(col, d): :param col: the column name of the numeric value to be formatted :param d: the N decimal places - >>> sqlContext.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() + >>> spark.createDataFrame([(5,)], ['a']).select(format_number('a', 4).alias('v')).collect() [Row(v=u'5.0000')] """ sc = SparkContext._active_spark_context @@ -1284,7 +1285,7 @@ def format_string(format, *cols): :param col: the column name of the numeric value to be formatted :param d: the N decimal places - >>> df = sqlContext.createDataFrame([(5, "hello")], ['a', 'b']) + >>> df = spark.createDataFrame([(5, "hello")], ['a', 'b']) >>> df.select(format_string('%d %s', df.a, df.b).alias('v')).collect() [Row(v=u'5 hello')] """ @@ -1301,7 +1302,7 @@ def instr(str, substr): NOTE: The position is not zero based, but 1 based index, returns 0 if substr could not be found in str. - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(instr(df.s, 'b').alias('s')).collect() [Row(s=2)] """ @@ -1317,7 +1318,7 @@ def substring(str, pos, len): returns the slice of byte array that starts at `pos` in byte and is of length `len` when str is Binary type - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(substring(df.s, 1, 2).alias('s')).collect() [Row(s=u'ab')] """ @@ -1334,7 +1335,7 @@ def substring_index(str, delim, count): returned. If count is negative, every to the right of the final delimiter (counting from the right) is returned. substring_index performs a case-sensitive match when searching for delim. - >>> df = sqlContext.createDataFrame([('a.b.c.d',)], ['s']) + >>> df = spark.createDataFrame([('a.b.c.d',)], ['s']) >>> df.select(substring_index(df.s, '.', 2).alias('s')).collect() [Row(s=u'a.b')] >>> df.select(substring_index(df.s, '.', -3).alias('s')).collect() @@ -1349,7 +1350,7 @@ def substring_index(str, delim, count): def levenshtein(left, right): """Computes the Levenshtein distance of the two given strings. - >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) + >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) >>> df0.select(levenshtein('l', 'r').alias('d')).collect() [Row(d=3)] """ @@ -1359,7 +1360,7 @@ def levenshtein(left, right): @since(1.5) -def locate(substr, str, pos=0): +def locate(substr, str, pos=1): """ Locate the position of the first occurrence of substr in a string column, after position pos. @@ -1367,10 +1368,10 @@ def locate(substr, str, pos=0): could not be found in str. :param substr: a string - :param str: a Column of StringType + :param str: a Column of :class:`pyspark.sql.types.StringType` :param pos: start position (zero based) - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(locate('b', df.s, 1).alias('s')).collect() [Row(s=2)] """ @@ -1384,7 +1385,7 @@ def lpad(col, len, pad): """ Left-pad the string column to width `len` with `pad`. - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(lpad(df.s, 6, '#').alias('s')).collect() [Row(s=u'##abcd')] """ @@ -1398,7 +1399,7 @@ def rpad(col, len, pad): """ Right-pad the string column to width `len` with `pad`. - >>> df = sqlContext.createDataFrame([('abcd',)], ['s',]) + >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(rpad(df.s, 6, '#').alias('s')).collect() [Row(s=u'abcd##')] """ @@ -1412,7 +1413,7 @@ def repeat(col, n): """ Repeats a string column n times, and returns it as a new string column. - >>> df = sqlContext.createDataFrame([('ab',)], ['s',]) + >>> df = spark.createDataFrame([('ab',)], ['s',]) >>> df.select(repeat(df.s, 3).alias('s')).collect() [Row(s=u'ababab')] """ @@ -1428,7 +1429,7 @@ def split(str, pattern): NOTE: pattern is a string represent the regular expression. - >>> df = sqlContext.createDataFrame([('ab12cd',)], ['s',]) + >>> df = spark.createDataFrame([('ab12cd',)], ['s',]) >>> df.select(split(df.s, '[0-9]+').alias('s')).collect() [Row(s=[u'ab', u'cd'])] """ @@ -1439,11 +1440,18 @@ def split(str, pattern): @ignore_unicode_prefix @since(1.5) def regexp_extract(str, pattern, idx): - """Extract a specific(idx) group identified by a java regex, from the specified string column. + """Extract a specific group matched by a Java regex, from the specified string column. + If the regex did not match, or the specified group did not match, an empty string is returned. - >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df = spark.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() [Row(d=u'100')] + >>> df = spark.createDataFrame([('foo',)], ['str']) + >>> df.select(regexp_extract('str', '(\d+)', 1).alias('d')).collect() + [Row(d=u'')] + >>> df = spark.createDataFrame([('aaaac',)], ['str']) + >>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect() + [Row(d=u'')] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) @@ -1455,7 +1463,7 @@ def regexp_extract(str, pattern, idx): def regexp_replace(str, pattern, replacement): """Replace all substrings of the specified string value that match regexp with rep. - >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df = spark.createDataFrame([('100-200',)], ['str']) >>> df.select(regexp_replace('str', '(\\d+)', '--').alias('d')).collect() [Row(d=u'-----')] """ @@ -1469,7 +1477,7 @@ def regexp_replace(str, pattern, replacement): def initcap(col): """Translate the first letter of each word to upper case in the sentence. - >>> sqlContext.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() + >>> spark.createDataFrame([('ab cd',)], ['a']).select(initcap("a").alias('v')).collect() [Row(v=u'Ab Cd')] """ sc = SparkContext._active_spark_context @@ -1482,7 +1490,7 @@ def soundex(col): """ Returns the SoundEx encoding for a string - >>> df = sqlContext.createDataFrame([("Peters",),("Uhrbach",)], ['name']) + >>> df = spark.createDataFrame([("Peters",),("Uhrbach",)], ['name']) >>> df.select(soundex(df.name).alias("soundex")).collect() [Row(soundex=u'P362'), Row(soundex=u'U612')] """ @@ -1506,10 +1514,11 @@ def bin(col): @ignore_unicode_prefix @since(1.5) def hex(col): - """Computes hex value of the given column, which could be StringType, - BinaryType, IntegerType or LongType. + """Computes hex value of the given column, which could be :class:`pyspark.sql.types.StringType`, + :class:`pyspark.sql.types.BinaryType`, :class:`pyspark.sql.types.IntegerType` or + :class:`pyspark.sql.types.LongType`. - >>> sqlContext.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() + >>> spark.createDataFrame([('ABC', 3)], ['a', 'b']).select(hex('a'), hex('b')).collect() [Row(hex(a)=u'414243', hex(b)=u'3')] """ sc = SparkContext._active_spark_context @@ -1523,7 +1532,7 @@ def unhex(col): """Inverse of hex. Interprets each pair of characters as a hexadecimal number and converts to the byte representation of number. - >>> sqlContext.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() + >>> spark.createDataFrame([('414243',)], ['a']).select(unhex('a')).collect() [Row(unhex(a)=bytearray(b'ABC'))] """ sc = SparkContext._active_spark_context @@ -1535,7 +1544,7 @@ def unhex(col): def length(col): """Calculates the length of a string or binary expression. - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() + >>> spark.createDataFrame([('ABC',)], ['a']).select(length('a').alias('length')).collect() [Row(length=3)] """ sc = SparkContext._active_spark_context @@ -1550,8 +1559,8 @@ def translate(srcCol, matching, replace): The translate will happen when any character in the string matching with the character in the `matching`. - >>> sqlContext.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\ - .alias('r')).collect() + >>> spark.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123") \\ + ... .alias('r')).collect() [Row(r=u'1a2s3ae')] """ sc = SparkContext._active_spark_context @@ -1608,7 +1617,7 @@ def array_contains(col, value): :param col: name of column containing array :param value: value to check for in array - >>> df = sqlContext.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) + >>> df = spark.createDataFrame([(["a", "b", "c"],), ([],)], ['data']) >>> df.select(array_contains(df.data, "a")).collect() [Row(array_contains(data, a)=True), Row(array_contains(data, a)=False)] """ @@ -1621,7 +1630,7 @@ def explode(col): """Returns a new row for each element in the given array or map. >>> from pyspark.sql import Row - >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() [Row(anInt=1), Row(anInt=2), Row(anInt=3)] @@ -1637,6 +1646,27 @@ def explode(col): return Column(jc) +@since(2.1) +def posexplode(col): + """Returns a new row for each element with position in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(posexplode(eDF.intlist)).collect() + [Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)] + + >>> eDF.select(posexplode(eDF.mapfield)).show() + +---+---+-----+ + |pos|key|value| + +---+---+-----+ + | 0| a| b| + +---+---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.posexplode(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.6) def get_json_object(col, path): @@ -1648,9 +1678,9 @@ def get_json_object(col, path): :param path: path to the json object to extract >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] - >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) - >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \ - get_json_object(df.jstring, '$.f2').alias("c1") ).collect() + >>> df = spark.createDataFrame(data, ("key", "jstring")) + >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \\ + ... get_json_object(df.jstring, '$.f2').alias("c1") ).collect() [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] """ sc = SparkContext._active_spark_context @@ -1667,7 +1697,7 @@ def json_tuple(col, *fields): :param fields: list of fields to extract >>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')] - >>> df = sqlContext.createDataFrame(data, ("key", "jstring")) + >>> df = spark.createDataFrame(data, ("key", "jstring")) >>> df.select(df.key, json_tuple(df.jstring, 'f1', 'f2')).collect() [Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)] """ @@ -1683,7 +1713,7 @@ def size(col): :param col: name of column or expression - >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) >>> df.select(size(df.data)).collect() [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)] """ @@ -1698,7 +1728,7 @@ def sort_array(col, asc=True): :param col: name of column or expression - >>> df = sqlContext.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) + >>> df = spark.createDataFrame([([2, 1, 3],),([1],),([],)], ['data']) >>> df.select(sort_array(df.data).alias('r')).collect() [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])] >>> df.select(sort_array(df.data, asc=False).alias('r')).collect() @@ -1756,6 +1786,12 @@ def __call__(self, *cols): @since(1.3) def udf(f, returnType=StringType()): """Creates a :class:`Column` expression representing a user defined function (UDF). + Note that the user-defined functions must be deterministic. Due to optimization, + duplicate invocations may be eliminated or the function may even be invoked more times than + it is present in the query. + + :param f: python function + :param returnType: a :class:`pyspark.sql.types.DataType` object >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) @@ -1772,18 +1808,21 @@ def udf(f, returnType=StringType()): def _test(): import doctest - from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext + from pyspark.sql import Row, SparkSession import pyspark.sql.functions globs = pyspark.sql.functions.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.functions tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) + globs['spark'] = spark globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.functions, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index ee734cb439287..f2092f9c63054 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -27,7 +27,7 @@ def dfapi(f): def _api(self): name = f.__name__ - jdf = getattr(self._jdf, name)() + jdf = getattr(self._jgd, name)() return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -35,9 +35,9 @@ def _api(self): def df_varargs_api(f): - def _api(self, *args): + def _api(self, *cols): name = f.__name__ - jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols)) return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -54,8 +54,8 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jdf, sql_ctx): - self._jdf = jdf + def __init__(self, jgd, sql_ctx): + self._jgd = jgd self.sql_ctx = sql_ctx @ignore_unicode_prefix @@ -83,11 +83,11 @@ def agg(self, *exprs): """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - jdf = self._jdf.agg(exprs[0]) + jdf = self._jgd.agg(exprs[0]) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jdf.agg(exprs[0]._jc, + jdf = self._jgd.agg(exprs[0]._jc, _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.sql_ctx) @@ -178,30 +178,34 @@ def pivot(self, pivot_col, values=None): :param pivot_col: Name of the column to pivot. :param values: List of values that will be translated to columns in the output DataFrame. - // Compute the sum of earnings for each year by course with each course as a separate column + # Compute the sum of earnings for each year by course with each course as a separate column + >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] - // Or without specifying column values (less efficient) + # Or without specifying column values (less efficient) + >>> df4.groupBy("year").pivot("course").sum("earnings").collect() [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] """ if values is None: - jgd = self._jdf.pivot(pivot_col) + jgd = self._jgd.pivot(pivot_col) else: - jgd = self._jdf.pivot(pivot_col, values) + jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self.sql_ctx) def _test(): import doctest - from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext + from pyspark.sql import Row, SparkSession import pyspark.sql.group globs = pyspark.sql.group.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') + spark = SparkSession.builder\ + .master("local[4]")\ + .appName("sql.group tests")\ + .getOrCreate() + sc = spark.sparkContext globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) @@ -216,7 +220,7 @@ def _test(): (failure_count, test_count) = doctest.testmod( pyspark.sql.group, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) - globs['sc'].stop() + spark.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index ed9e716ab78e3..e62f4832fa184 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -44,24 +44,35 @@ def to_str(value): return str(value) -class DataFrameReader(object): +class OptionUtils(object): + + def _set_opts(self, schema=None, **options): + """ + Set named options (filter out those the value is None) + """ + if schema is not None: + self.schema(schema) + for k, v in options.items(): + if v is not None: + self.option(k, v) + + +class DataFrameReader(OptionUtils): """ Interface used to load a :class:`DataFrame` from external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read` + (e.g. file systems, key-value stores, etc). Use :func:`spark.read` to access this. - ::Note: Experimental - .. versionadded:: 1.4 """ - def __init__(self, sqlContext): - self._jreader = sqlContext._ssql_ctx.read() - self._sqlContext = sqlContext + def __init__(self, spark): + self._jreader = spark._ssql_ctx.read() + self._spark = spark def _df(self, jdf): from pyspark.sql.dataframe import DataFrame - return DataFrame(jdf, self._sqlContext) + return DataFrame(jdf, self._spark) @since(1.4) def format(self, source): @@ -69,7 +80,7 @@ def format(self, source): :param source: string, name of the data source, e.g. 'json', 'parquet'. - >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json') + >>> df = spark.read.format('json').load('python/test_support/sql/people.json') >>> df.dtypes [('age', 'bigint'), ('name', 'string')] @@ -85,11 +96,11 @@ def schema(self, schema): By specifying the schema here, the underlying data source can skip the schema inference step, and thus speed up data loading. - :param schema: a StructType object + :param schema: a :class:`pyspark.sql.types.StructType` object """ if not isinstance(schema, StructType): raise TypeError("schema should be StructType") - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + jschema = self._spark._ssql_ctx.parseDataType(schema.json()) self._jreader = self._jreader.schema(jschema) return self @@ -114,15 +125,15 @@ def load(self, path=None, format=None, schema=None, **options): :param path: optional string or a list of string for file-system backed data sources. :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`StructType` for the input schema. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. :param options: all other string options - >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True, + >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] - >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', + >>> df = spark.read.format('json').load(['python/test_support/sql/people.json', ... 'python/test_support/sql/people1.json']) >>> df.dtypes [('age', 'bigint'), ('aka', 'string'), ('name', 'string')] @@ -132,41 +143,20 @@ def load(self, path=None, format=None, schema=None, **options): if schema is not None: self.schema(schema) self.options(**options) - if path is not None: + if isinstance(path, basestring): + return self._df(self._jreader.load(path)) + elif path is not None: if type(path) != list: path = [path] - return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.load(self._spark._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load()) - @since(2.0) - def stream(self, path=None, format=None, schema=None, **options): - """Loads a data stream from a data source and returns it as a :class`DataFrame`. - - :param path: optional string for file-system backed data sources. - :param format: optional string for format of the data source. Default to 'parquet'. - :param schema: optional :class:`StructType` for the input schema. - :param options: all other string options - - >>> df = sqlContext.read.format('text').stream('python/test_support/sql/streaming') - >>> df.isStreaming - True - """ - if format is not None: - self.format(format) - if schema is not None: - self.schema(schema) - self.options(**options) - if path is not None: - if type(path) != str or len(path.strip()) == 0: - raise ValueError("If the path is provided for stream, it needs to be a " + - "non-empty string. List of paths are not supported.") - return self._df(self._jreader.stream(path)) - else: - return self._df(self._jreader.stream()) - @since(1.4) - def json(self, path, schema=None): + def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, + allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, + allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, + mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None): """ Loads a JSON file (one object per line) or an RDD of Strings storing JSON objects (one object per record) and returns the result as a :class`DataFrame`. @@ -176,48 +166,67 @@ def json(self, path, schema=None): :param path: string represents path to the JSON dataset, or RDD of Strings storing JSON objects. - :param schema: an optional :class:`StructType` for the input schema. - - You can set the following JSON-specific options to deal with non-standard JSON files: - * ``primitivesAsString`` (default ``false``): infers all primitive values as a string \ - type - * `prefersDecimal` (default `false`): infers all floating-point values as a decimal \ - type. If the values do not fit in decimal, then it infers them as doubles. - * ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records - * ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names - * ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \ - quotes - * ``allowNumericLeadingZeros`` (default ``false``): allows leading zeros in numbers \ - (e.g. 00012) - * ``allowBackslashEscapingAnyCharacter`` (default ``false``): allows accepting quoting \ - of all character using backslash quoting mechanism - * ``mode`` (default ``PERMISSIVE``): allows a mode for dealing with corrupt records \ - during parsing. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param primitivesAsString: infers all primitive values as a string type. If None is set, + it uses the default value, ``false``. + :param prefersDecimal: infers all floating-point values as a decimal type. If the values + do not fit in decimal, then it infers them as doubles. If None is + set, it uses the default value, ``false``. + :param allowComments: ignores Java/C++ style comment in JSON records. If None is set, + it uses the default value, ``false``. + :param allowUnquotedFieldNames: allows unquoted JSON field names. If None is set, + it uses the default value, ``false``. + :param allowSingleQuotes: allows single quotes in addition to double quotes. If None is + set, it uses the default value, ``true``. + :param allowNumericLeadingZero: allows leading zeros in numbers (e.g. 00012). If None is + set, it uses the default value, ``false``. + :param allowBackslashEscapingAnyCharacter: allows accepting quoting of all character + using backslash quoting mechanism. If None is + set, it uses the default value, ``false``. + :param mode: allows a mode for dealing with corrupt records during parsing. If None is + set, it uses the default value, ``PERMISSIVE``. + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ record and puts the malformed string into a new field configured by \ ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ ``null`` for extra fields. * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. - * ``columnNameOfCorruptRecord`` (default ``_corrupt_record``): allows renaming the \ - new field having malformed string created by ``PERMISSIVE`` mode. \ - This overrides ``spark.sql.columnNameOfCorruptRecord``. - >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') + :param columnNameOfCorruptRecord: allows renaming the new field having malformed string + created by ``PERMISSIVE`` mode. This overrides + ``spark.sql.columnNameOfCorruptRecord``. If None is set, + it uses the value specified in + ``spark.sql.columnNameOfCorruptRecord``. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + + >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes [('age', 'bigint'), ('name', 'string')] >>> rdd = sc.textFile('python/test_support/sql/people.json') - >>> df2 = sqlContext.read.json(rdd) + >>> df2 = spark.read.json(rdd) >>> df2.dtypes [('age', 'bigint'), ('name', 'string')] """ - if schema is not None: - self.schema(schema) + self._set_opts( + schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal, + allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames, + allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, + allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, + mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, + timestampFormat=timestampFormat) if isinstance(path, basestring): - return self._df(self._jreader.json(path)) - elif type(path) == list: - return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + path = [path] + if type(path) == list: + return self._df(self._jreader.json(self._spark._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): def func(iterator): for x in iterator: @@ -228,7 +237,7 @@ def func(iterator): yield x keyed = path.mapPartitions(func) keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString()) + jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) return self._df(self._jreader.json(jrdd)) else: raise TypeError("path can be only string or RDD") @@ -239,9 +248,9 @@ def table(self, tableName): :param tableName: string, name of the table. - >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') - >>> df.registerTempTable('tmpTable') - >>> sqlContext.read.table('tmpTable').dtypes + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.createOrReplaceTempView('tmpTable') + >>> spark.read.table('tmpTable').dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ return self._df(self._jreader.table(tableName)) @@ -250,54 +259,130 @@ def table(self, tableName): def parquet(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + You can set the following Parquet-specific option(s) for reading Parquet files: + * ``mergeSchema``: sets whether we should merge schemas collected from all \ + Parquet part-files. This will override ``spark.sql.parquet.mergeSchema``. \ + The default value is specified in ``spark.sql.parquet.mergeSchema``. + + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ - return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths))) + return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths))) @ignore_unicode_prefix @since(1.6) def text(self, paths): - """Loads a text file and returns a [[DataFrame]] with a single string column named "value". + """ + Loads text files and returns a :class:`DataFrame` whose schema starts with a + string column named "value", and followed by partitioned columns if there + are any. Each line in the text file is a new row in the resulting DataFrame. :param paths: string, or list of strings, for input path(s). - >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') + >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] """ if isinstance(paths, basestring): paths = [paths] - return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(paths))) @since(2.0) - def csv(self, paths): - """Loads a CSV file and returns the result as a [[DataFrame]]. - - This function goes through the input once to determine the input schema. To avoid going - through the entire data once, specify the schema explicitly using [[schema]]. - - :param paths: string, or list of strings, for input path(s). - - >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv') + def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, + comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, + ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, + negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None): + """Loads a CSV file and returns the result as a :class:`DataFrame`. + + This function will go through the input once to determine the input schema if + ``inferSchema`` is enabled. To avoid going through the entire data once, disable + ``inferSchema`` option or specify the schema explicitly using ``schema``. + + :param path: string, or list of strings, for input path(s). + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param sep: sets the single character as a separator for each field and value. + If None is set, it uses the default value, ``,``. + :param encoding: decodes the CSV files by the given encoding type. If None is set, + it uses the default value, ``UTF-8``. + :param quote: sets the single character used for escaping quoted values where the + separator can be part of the value. If None is set, it uses the default + value, ``"``. If you would like to turn off quotations, you need to set an + empty string. + :param escape: sets the single character used for escaping quotes inside an already + quoted value. If None is set, it uses the default value, ``\``. + :param comment: sets the single character used for skipping lines beginning with this + character. By default (None), it is disabled. + :param header: uses the first line as names of columns. If None is set, it uses the + default value, ``false``. + :param inferSchema: infers the input schema automatically from data. It requires one extra + pass over the data. If None is set, it uses the default value, ``false``. + :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values + being read should be skipped. If None is set, it uses + the default value, ``false``. + :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values + being read should be skipped. If None is set, it uses + the default value, ``false``. + :param nullValue: sets the string representation of a null value. If None is set, it uses + the default value, empty string. Since 2.0.1, this ``nullValue`` param + applies to all supported types including the string type. + :param nanValue: sets the string representation of a non-number value. If None is set, it + uses the default value, ``NaN``. + :param positiveInf: sets the string representation of a positive infinity value. If None + is set, it uses the default value, ``Inf``. + :param negativeInf: sets the string representation of a negative infinity value. If None + is set, it uses the default value, ``Inf``. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + :param maxColumns: defines a hard limit of how many columns a record can have. If None is + set, it uses the default value, ``20480``. + :param maxCharsPerColumn: defines the maximum number of characters allowed for any given + value being read. If None is set, it uses the default value, + ``1000000``. + :param maxMalformedLogPerPartition: sets the maximum number of malformed rows Spark will + log for each partition. Malformed records beyond this + number will be ignored. If None is set, it + uses the default value, ``10``. + :param mode: allows a mode for dealing with corrupt records during parsing. If None is + set, it uses the default value, ``PERMISSIVE``. + + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. + When a schema is set by user, it sets ``null`` for extra fields. + * ``DROPMALFORMED`` : ignores the whole corrupted records. + * ``FAILFAST`` : throws an exception when it meets corrupted records. + + >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes - [('C0', 'string'), ('C1', 'string')] - """ - if isinstance(paths, basestring): - paths = [paths] - return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(paths))) + [('_c0', 'string'), ('_c1', 'string')] + """ + self._set_opts( + schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment, + header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, + nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, + dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, + maxCharsPerColumn=maxCharsPerColumn, + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode) + if isinstance(path, basestring): + path = [path] + return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) @since(1.5) def orc(self, path): """Loads an ORC file, returning the result as a :class:`DataFrame`. - ::Note: Currently ORC support is only available together with - :class:`HiveContext`. + .. note:: Currently ORC support is only available together with Hive support. - >>> df = hiveContext.read.orc('python/test_support/sql/orc_partitioned') + >>> df = spark.read.orc('python/test_support/sql/orc_partitioned') >>> df.dtypes [('a', 'bigint'), ('b', 'int'), ('c', 'int')] """ @@ -307,65 +392,67 @@ def orc(self, path): def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, predicates=None, properties=None): """ - Construct a :class:`DataFrame` representing the database table accessible - via JDBC URL `url` named `table` and connection `properties`. + Construct a :class:`DataFrame` representing the database table named ``table`` + accessible via JDBC URL ``url`` and connection ``properties``. - The `column` parameter could be used to partition the table, then it will - be retrieved in parallel based on the parameters passed to this function. + Partitions of the table will be retrieved in parallel if either ``column`` or + ``predicates`` is specified. - The `predicates` parameter gives a list expressions suitable for inclusion - in WHERE clauses; each one defines one partition of the :class:`DataFrame`. + If both ``column`` and ``predicates`` are specified, ``column`` will be used. - ::Note: Don't create too many partitions in parallel on a large cluster; + .. note:: Don't create too many partitions in parallel on a large cluster; \ otherwise Spark might crash your external database systems. - :param url: a JDBC URL - :param table: name of table - :param column: the column used to partition - :param lowerBound: the lower bound of partition column - :param upperBound: the upper bound of the partition column + :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` + :param table: the name of the table + :param column: the name of an integer column that will be used for partitioning; + if this parameter is specified, then ``numPartitions``, ``lowerBound`` + (inclusive), and ``upperBound`` (exclusive) will form partition strides + for generated WHERE clause expressions used to split the column + ``column`` evenly + :param lowerBound: the minimum value of ``column`` used to decide partition stride + :param upperBound: the maximum value of ``column`` used to decide partition stride :param numPartitions: the number of partitions - :param predicates: a list of expressions - :param properties: JDBC database connection arguments, a list of arbitrary string - tag/value. Normally at least a "user" and "password" property - should be included. + :param predicates: a list of expressions suitable for inclusion in WHERE clauses; + each one defines one partition of the :class:`DataFrame` + :param properties: a dictionary of JDBC database connection arguments. Normally at + least properties "user" and "password" with their corresponding values. + For example { 'user' : 'SYSTEM', 'password' : 'mypassword' } :return: a DataFrame """ if properties is None: properties = dict() - jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) if column is not None: if numPartitions is None: - numPartitions = self._sqlContext._sc.defaultParallelism + numPartitions = self._spark._sc.defaultParallelism return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - gateway = self._sqlContext._sc._gateway + gateway = self._spark._sc._gateway jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) -class DataFrameWriter(object): +class DataFrameWriter(OptionUtils): """ - Interface used to write a [[DataFrame]] to external storage systems + Interface used to write a :class:`DataFrame` to external storage systems (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write` to access this. - ::Note: Experimental - .. versionadded:: 1.4 """ def __init__(self, df): self._df = df - self._sqlContext = df.sql_ctx + self._spark = df.sql_ctx self._jwrite = df._jdf.write() - def _cq(self, jcq): - from pyspark.sql.streaming import ContinuousQuery - return ContinuousQuery(jcq) + def _sq(self, jsq): + from pyspark.sql.streaming import StreamingQuery + return StreamingQuery(jsq) @since(1.4) def mode(self, saveMode): @@ -425,45 +512,7 @@ def partitionBy(self, *cols): """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] - self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) - return self - - @since(2.0) - def queryName(self, queryName): - """Specifies the name of the :class:`ContinuousQuery` that can be started with - :func:`startStream`. This name must be unique among all the currently active queries - in the associated SQLContext. - - :param queryName: unique name for the query - - >>> writer = sdf.write.queryName('streaming_query') - """ - if not queryName or type(queryName) != str or len(queryName.strip()) == 0: - raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName) - self._jwrite = self._jwrite.queryName(queryName) - return self - - @keyword_only - @since(2.0) - def trigger(self, processingTime=None): - """Set the trigger for the stream query. If this is not set it will run the query as fast - as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. - - :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. - - >>> # trigger the query for execution every 5 seconds - >>> writer = sdf.write.trigger(processingTime='5 seconds') - """ - from pyspark.sql.streaming import ProcessingTime - trigger = None - if processingTime is not None: - if type(processingTime) != str or len(processingTime.strip()) == 0: - raise ValueError('The processing time must be a non empty string. Got: %s' % - processingTime) - trigger = ProcessingTime(processingTime) - if trigger is None: - raise ValueError('A trigger was not provided. Supported triggers: processingTime.') - self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext)) + self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) return self @since(1.4) @@ -497,55 +546,6 @@ def save(self, path=None, format=None, mode=None, partitionBy=None, **options): else: self._jwrite.save(path) - @ignore_unicode_prefix - @since(2.0) - def startStream(self, path=None, format=None, partitionBy=None, queryName=None, **options): - """Streams the contents of the :class:`DataFrame` to a data source. - - The data source is specified by the ``format`` and a set of ``options``. - If ``format`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - :param path: the path in a Hadoop supported file system - :param format: the format used to save - - * ``append``: Append contents of this :class:`DataFrame` to existing data. - * ``overwrite``: Overwrite existing data. - * ``ignore``: Silently ignore this operation if data already exists. - * ``error`` (default case): Throw an exception if data already exists. - :param partitionBy: names of partitioning columns - :param queryName: unique name for the query - :param options: All other string options. You may want to provide a `checkpointLocation` - for most streams, however it is not required for a `memory` stream. - - >>> cq = sdf.write.format('memory').queryName('this_query').startStream() - >>> cq.isActive - True - >>> cq.name - u'this_query' - >>> cq.stop() - >>> cq.isActive - False - >>> cq = sdf.write.trigger(processingTime='5 seconds').startStream( - ... queryName='that_query', format='memory') - >>> cq.name - u'that_query' - >>> cq.isActive - True - >>> cq.stop() - """ - self.options(**options) - if partitionBy is not None: - self.partitionBy(partitionBy) - if format is not None: - self.format(format) - if queryName is not None: - self.queryName(queryName) - if path is None: - return self._cq(self._jwrite.startStream()) - else: - return self._cq(self._jwrite.startStream(path)) - @since(1.4) def insertInto(self, tableName, overwrite=False): """Inserts the content of the :class:`DataFrame` to the specified table. @@ -563,7 +563,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) In the case the table already exists, behavior of this function depends on the save mode, specified by the `mode` function (default to throwing an exception). - When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be + When `mode` is `Overwrite`, the schema of the :class:`DataFrame` does not need to be the same as that of the existing table. * `append`: Append contents of this :class:`DataFrame` to existing data. @@ -585,7 +585,7 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode=None, compression=None): + def json(self, path, mode=None, compression=None, dateFormat=None, timestampFormat=None): """Saves the content of the :class:`DataFrame` in JSON format at the specified path. :param path: the path in any Hadoop supported file system @@ -598,12 +598,20 @@ def json(self, path, mode=None, compression=None): :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) - if compression is not None: - self.option("compression", compression) + self._set_opts( + compression=compression, dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.json(path) @since(1.4) @@ -620,15 +628,16 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None): :param partitionBy: names of partitioning columns :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, snappy, gzip, and lzo). - This will overwrite ``spark.sql.parquet.compression.codec``. + This will override ``spark.sql.parquet.compression.codec``. If None + is set, it uses the value specified in + ``spark.sql.parquet.compression.codec``. >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) if partitionBy is not None: self.partitionBy(partitionBy) - if compression is not None: - self.option("compression", compression) + self._set_opts(compression=compression) self._jwrite.parquet(path) @since(1.6) @@ -643,13 +652,14 @@ def text(self, path, compression=None): The DataFrame must have only one column that is of string type. Each row becomes a new line in the output file. """ - if compression is not None: - self.option("compression", compression) + self._set_opts(compression=compression) self._jwrite.text(path) @since(2.0) - def csv(self, path, mode=None, compression=None): - """Saves the content of the [[DataFrame]] in CSV format at the specified path. + def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None, + header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None, + timestampFormat=None): + """Saves the content of the :class:`DataFrame` in CSV format at the specified path. :param path: the path in any Hadoop supported file system :param mode: specifies the behavior of the save operation when data already exists. @@ -662,20 +672,46 @@ def csv(self, path, mode=None, compression=None): :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, bzip2, gzip, lz4, snappy and deflate). + :param sep: sets the single character as a separator for each field and value. If None is + set, it uses the default value, ``,``. + :param quote: sets the single character used for escaping quoted values where the + separator can be part of the value. If None is set, it uses the default + value, ``"``. If you would like to turn off quotations, you need to set an + empty string. + :param escape: sets the single character used for escaping quotes inside an already + quoted value. If None is set, it uses the default value, ``\`` + :param escapeQuotes: A flag indicating whether values containing quotes should always + be enclosed in quotes. If None is set, it uses the default value + ``true``, escaping all values containing a quote character. + :param quoteAll: A flag indicating whether all values should always be enclosed in + quotes. If None is set, it uses the default value ``false``, + only escaping values containing a quote character. + :param header: writes the names of columns as the first line. If None is set, it uses + the default value, ``false``. + :param nullValue: sets the string representation of a null value. If None is set, it uses + the default value, empty string. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. >>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) - if compression is not None: - self.option("compression", compression) + self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header, + nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll, + dateFormat=dateFormat, timestampFormat=timestampFormat) self._jwrite.csv(path) @since(1.5) def orc(self, path, mode=None, partitionBy=None, compression=None): """Saves the content of the :class:`DataFrame` in ORC format at the specified path. - ::Note: Currently ORC support is only available together with - :class:`HiveContext`. + .. note:: Currently ORC support is only available together with Hive support. :param path: the path in any Hadoop supported file system :param mode: specifies the behavior of the save operation when data already exists. @@ -687,23 +723,23 @@ def orc(self, path, mode=None, partitionBy=None, compression=None): :param partitionBy: names of partitioning columns :param compression: compression codec to use when saving to file. This can be one of the known case-insensitive shorten names (none, snappy, zlib, and lzo). - This will overwrite ``orc.compress``. + This will override ``orc.compress``. If None is set, it uses the + default value, ``snappy``. - >>> orc_df = hiveContext.read.orc('python/test_support/sql/orc_partitioned') + >>> orc_df = spark.read.orc('python/test_support/sql/orc_partitioned') >>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data')) """ self.mode(mode) if partitionBy is not None: self.partitionBy(partitionBy) - if compression is not None: - self.option("compression", compression) + self._set_opts(compression=compression) self._jwrite.orc(path) @since(1.4) def jdbc(self, url, table, mode=None, properties=None): - """Saves the content of the :class:`DataFrame` to a external database table via JDBC. + """Saves the content of the :class:`DataFrame` to an external database table via JDBC. - .. note:: Don't create too many partitions in parallel on a large cluster;\ + .. note:: Don't create too many partitions in parallel on a large cluster; \ otherwise Spark might crash your external database systems. :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` @@ -714,13 +750,13 @@ def jdbc(self, url, table, mode=None, properties=None): * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. - :param properties: JDBC database connection arguments, a list of - arbitrary string tag/value. Normally at least a - "user" and "password" property should be included. + :param properties: a dictionary of JDBC database connection arguments. Normally at + least properties "user" and "password" with their corresponding values. + For example { 'user' : 'SYSTEM', 'password' : 'mypassword' } """ if properties is None: properties = dict() - jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) self._jwrite.mode(mode).jdbc(url, table, jprop) @@ -730,28 +766,29 @@ def _test(): import doctest import os import tempfile + import py4j from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext, HiveContext + from pyspark.sql import SparkSession, Row import pyspark.sql.readwriter os.chdir(os.environ["SPARK_HOME"]) globs = pyspark.sql.readwriter.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') + try: + spark = SparkSession.builder.enableHiveSupport().getOrCreate() + except py4j.protocol.Py4JError: + spark = SparkSession(sc) globs['tempfile'] = tempfile globs['os'] = os globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) - globs['hiveContext'] = HiveContext._createForTesting(sc) - globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') - globs['sdf'] =\ - globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming') - + globs['spark'] = spark + globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned') (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) - globs['sc'].stop() + sc.stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 35c36b4935090..d25823dfcacd3 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -19,6 +19,7 @@ import sys import warnings from functools import reduce +from threading import RLock if sys.version >= '3': basestring = unicode = str @@ -31,6 +32,7 @@ from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.streaming import DataStreamReader from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string from pyspark.sql.utils import install_exception_handler @@ -45,7 +47,7 @@ def toDF(self, schema=None, sampleRatio=None): This is a shorthand for ``spark.createDataFrame(rdd, schema, sampleRatio)`` - :param schema: a StructType or list of names of columns + :param schema: a :class:`pyspark.sql.types.StructType` or list of names of columns :param samplingRatio: the sample ratio of rows used for inferring :return: a DataFrame @@ -58,16 +60,129 @@ def toDF(self, schema=None, sampleRatio=None): class SparkSession(object): - """Main entry point for Spark SQL functionality. + """The entry point to programming Spark with the Dataset and DataFrame API. A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. + To create a SparkSession, use the following builder pattern: - :param sparkContext: The :class:`SparkContext` backing this SparkSession. - :param jsparkSession: An optional JVM Scala SparkSession. If set, we do not instantiate a new - SparkSession in the JVM, instead we make all calls to this object. + >>> spark = SparkSession.builder \\ + ... .master("local") \\ + ... .appName("Word Count") \\ + ... .config("spark.some.config.option", "some-value") \\ + ... .getOrCreate() """ + class Builder(object): + """Builder for :class:`SparkSession`. + """ + + _lock = RLock() + _options = {} + + @since(2.0) + def config(self, key=None, value=None, conf=None): + """Sets a config option. Options set using this method are automatically propagated to + both :class:`SparkConf` and :class:`SparkSession`'s own configuration. + + For an existing SparkConf, use `conf` parameter. + + >>> from pyspark.conf import SparkConf + >>> SparkSession.builder.config(conf=SparkConf()) + >> SparkSession.builder.config("spark.some.config.option", "some-value") + >> s1 = SparkSession.builder.config("k1", "v1").getOrCreate() + >>> s1.conf.get("k1") == s1.sparkContext.getConf().get("k1") == "v1" + True + + In case an existing SparkSession is returned, the config options specified + in this builder will be applied to the existing SparkSession. + + >>> s2 = SparkSession.builder.config("k2", "v2").getOrCreate() + >>> s1.conf.get("k1") == s2.conf.get("k1") + True + >>> s1.conf.get("k2") == s2.conf.get("k2") + True + """ + with self._lock: + from pyspark.context import SparkContext + from pyspark.conf import SparkConf + session = SparkSession._instantiatedContext + if session is None: + sparkConf = SparkConf() + for key, value in self._options.items(): + sparkConf.set(key, value) + sc = SparkContext.getOrCreate(sparkConf) + # This SparkContext may be an existing one. + for key, value in self._options.items(): + # we need to propagate the confs + # before we create the SparkSession. Otherwise, confs like + # warehouse path and metastore url will not be set correctly ( + # these confs cannot be changed once the SparkSession is created). + sc._conf.set(key, value) + session = SparkSession(sc) + for key, value in self._options.items(): + session.conf.set(key, value) + for key, value in self._options.items(): + session.sparkContext._conf.set(key, value) + return session + + builder = Builder() + _instantiatedContext = None @ignore_unicode_prefix @@ -80,7 +195,7 @@ def __init__(self, sparkContext, jsparkSession=None): ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> df = allTypes.toDF() - >>> df.registerTempTable("allTypes") + >>> df.createOrReplaceTempView("allTypes") >>> spark.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \ @@ -95,32 +210,34 @@ def __init__(self, sparkContext, jsparkSession=None): if jsparkSession is None: jsparkSession = self._jvm.SparkSession(self._jsc.sc()) self._jsparkSession = jsparkSession - self._jwrapped = self._jsparkSession.wrapped() + self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) _monkey_patch_RDD(self) install_exception_handler() if SparkSession._instantiatedContext is None: SparkSession._instantiatedContext = self - @classmethod - @since(2.0) - def withHiveSupport(cls, sparkContext): - """Returns a new SparkSession with a catalog backed by Hive. - - :param sparkContext: The underlying :class:`SparkContext`. - """ - jsparkSession = sparkContext._jvm.SparkSession.withHiveSupport(sparkContext._jsc.sc()) - return cls(sparkContext, jsparkSession) - @since(2.0) def newSession(self): """ Returns a new SparkSession as new session, that has separate SQLConf, - registered temporary tables and UDFs, but shared SparkContext and + registered temporary views and UDFs, but shared SparkContext and table cache. """ return self.__class__(self._sc, self._jsparkSession.newSession()) + @property + @since(2.0) + def sparkContext(self): + """Returns the underlying :class:`SparkContext`.""" + return self._sc + + @property + @since(2.0) + def version(self): + """The version of Spark on which this application is running.""" + return self._jsparkSession.version() + @property @since(2.0) def conf(self): @@ -157,9 +274,9 @@ def udf(self): @since(2.0) def range(self, start, end=None, step=1, numPartitions=None): """ - Create a :class:`DataFrame` with single LongType column named `id`, - containing elements in a range from `start` to `end` (exclusive) with - step value `step`. + Create a :class:`DataFrame` with single :class:`pyspark.sql.types.LongType` column named + ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with + step value ``step``. :param start: the start value :param end: the end value (exclusive) @@ -190,7 +307,7 @@ def _inferSchemaFromList(self, data): Infer schema from list of Row or tuple. :param data: list of Row or tuple - :return: StructType + :return: :class:`pyspark.sql.types.StructType` """ if not data: raise ValueError("can not infer schema from empty dataset") @@ -209,7 +326,7 @@ def _inferSchema(self, rdd, samplingRatio=None): :param rdd: an RDD of Row or tuple :param samplingRatio: sampling ratio, or no sampling (default) - :return: StructType + :return: :class:`pyspark.sql.types.StructType` """ first = rdd.first() if not first: @@ -258,7 +375,7 @@ def _createFromRDD(self, rdd, schema, samplingRatio): def _createFromLocal(self, data, schema): """ - Create an RDD for DataFrame from an list or pandas.DataFrame, returns + Create an RDD for DataFrame from a list or pandas.DataFrame, returns the RDD and schema. """ # make sure data could consumed multiple times @@ -267,17 +384,15 @@ def _createFromLocal(self, data, schema): if schema is None or isinstance(schema, (list, tuple)): struct = self._inferSchemaFromList(data) + converter = _create_converter(struct) + data = map(converter, data) if isinstance(schema, (list, tuple)): for i, name in enumerate(schema): struct.fields[i].name = name struct.names[i] = name schema = struct - elif isinstance(schema, StructType): - for row in data: - _verify_type(row, schema) - - else: + elif not isinstance(schema, StructType): raise TypeError("schema should be StructType or list or None, but got: %s" % schema) # convert python objects to sql data @@ -286,7 +401,7 @@ def _createFromLocal(self, data, schema): @since(2.0) @ignore_unicode_prefix - def createDataFrame(self, data, schema=None, samplingRatio=None): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): """ Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. @@ -297,28 +412,31 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): from ``data``, which should be an RDD of :class:`Row`, or :class:`namedtuple`, or :class:`dict`. - When ``schema`` is :class:`DataType` or datatype string, it must match the real data, or - exception will be thrown at runtime. If the given schema is not StructType, it will be - wrapped into a StructType as its only field, and the field name will be "value", each record - will also be wrapped into a tuple, which can be converted to row later. + When ``schema`` is :class:`pyspark.sql.types.DataType` or + :class:`pyspark.sql.types.StringType`, it must match the + real data, or an exception will be thrown at runtime. If the given schema is not + :class:`pyspark.sql.types.StructType`, it will be wrapped into a + :class:`pyspark.sql.types.StructType` as its only field, and the field name will be "value", + each record will also be wrapped into a tuple, which can be converted to row later. If schema inference is needed, ``samplingRatio`` is used to determined the ratio of rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``. :param data: an RDD of any kind of SQL data representation(e.g. row, tuple, int, boolean, etc.), or :class:`list`, or :class:`pandas.DataFrame`. - :param schema: a :class:`DataType` or a datatype string or a list of column names, default - is None. The data type string format equals to `DataType.simpleString`, except that - top level struct type can omit the `struct<>` and atomic types use `typeName()` as - their format, e.g. use `byte` instead of `tinyint` for ByteType. We can also use `int` - as a short name for IntegerType. + :param schema: a :class:`pyspark.sql.types.DataType` or a + :class:`pyspark.sql.types.StringType` or a list of + column names, default is ``None``. The data type string format equals to + :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can + omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use + ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use + ``int`` as a short name for ``IntegerType``. :param samplingRatio: the sample ratio of rows used for inferring + :param verifySchema: verify data types of every row against schema. :return: :class:`DataFrame` - .. versionchanged:: 2.0 - The schema parameter can be a DataType or a datatype string after 2.0. If it's not a - StructType, it will be wrapped into a StructType and each record will also be wrapped - into a tuple. + .. versionchanged:: 2.0.1 + Added verifySchema. >>> l = [('Alice', 1)] >>> spark.createDataFrame(l).collect() @@ -383,18 +501,21 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] + verify_func = _verify_type if verifySchema else lambda _, t: True if isinstance(schema, StructType): def prepare(obj): - _verify_type(obj, schema) + verify_func(obj, schema) return obj elif isinstance(schema, DataType): - datatype = schema + dataType = schema + schema = StructType().add("value", schema) def prepare(obj): - _verify_type(obj, datatype) - return (obj, ) - schema = StructType().add("value", datatype) + verify_func(obj, dataType) + return obj, else: + if isinstance(schema, list): + schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema] prepare = lambda obj: obj if isinstance(data, RDD): @@ -414,7 +535,7 @@ def sql(self, sqlQuery): :return: :class:`DataFrame` - >>> spark.catalog.registerTable(df, "table1") + >>> df.createOrReplaceTempView("table1") >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] @@ -427,7 +548,7 @@ def table(self, tableName): :return: :class:`DataFrame` - >>> spark.catalog.registerTable(df, "table1") + >>> df.createOrReplaceTempView("table1") >>> df2 = spark.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) True @@ -445,6 +566,55 @@ def read(self): """ return DataFrameReader(self._wrapped) + @property + @since(2.0) + def readStream(self): + """ + Returns a :class:`DataStreamReader` that can be used to read data streams + as a streaming :class:`DataFrame`. + + .. note:: Experimental. + + :return: :class:`DataStreamReader` + """ + return DataStreamReader(self._wrapped) + + @property + @since(2.0) + def streams(self): + """Returns a :class:`StreamingQueryManager` that allows managing all the + :class:`StreamingQuery` StreamingQueries active on `this` context. + + .. note:: Experimental. + + :return: :class:`StreamingQueryManager` + """ + from pyspark.sql.streaming import StreamingQueryManager + return StreamingQueryManager(self._jsparkSession.streams()) + + @since(2.0) + def stop(self): + """Stop the underlying :class:`SparkContext`. + """ + self._sc.stop() + SparkSession._instantiatedContext = None + + @since(2.0) + def __enter__(self): + """ + Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax. + """ + return self + + @since(2.0) + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax. + + Specifically stop the SparkSession on exit of the with block. + """ + self.stop() + def _test(): import os diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index bf03fdca91394..cfe917bd12fc9 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -15,15 +15,24 @@ # limitations under the License. # +import sys +if sys.version >= '3': + intlike = int + basestring = unicode = str +else: + intlike = (int, long) + from abc import ABCMeta, abstractmethod -from pyspark import since +from pyspark import since, keyword_only from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql.readwriter import OptionUtils, to_str +from pyspark.sql.types import * -__all__ = ["ContinuousQuery"] +__all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"] -class ContinuousQuery(object): +class StreamingQuery(object): """ A handle to a query that is executing continuously in the background as new data arrives. All these methods are thread-safe. @@ -33,22 +42,30 @@ class ContinuousQuery(object): .. versionadded:: 2.0 """ - def __init__(self, jcq): - self._jcq = jcq + def __init__(self, jsq): + self._jsq = jsq + + @property + @since(2.0) + def id(self): + """The id of the streaming query. This id is unique across all queries that have been + started in the current process. + """ + return self._jsq.id() @property @since(2.0) def name(self): - """The name of the continuous query. + """The name of the streaming query. This name is unique across all active queries. """ - return self._jcq.name() + return self._jsq.name() @property @since(2.0) def isActive(self): - """Whether this continuous query is currently active or not. + """Whether this streaming query is currently active or not. """ - return self._jcq.isActive() + return self._jsq.isActive() @since(2.0) def awaitTermination(self, timeout=None): @@ -61,43 +78,42 @@ def awaitTermination(self, timeout=None): immediately (if the query was terminated by :func:`stop()`), or throw the exception immediately (if the query has terminated with exception). - throws :class:`ContinuousQueryException`, if `this` query has terminated with an exception + throws :class:`StreamingQueryException`, if `this` query has terminated with an exception """ if timeout is not None: if not isinstance(timeout, (int, float)) or timeout < 0: raise ValueError("timeout must be a positive integer or float. Got %s" % timeout) - return self._jcq.awaitTermination(int(timeout * 1000)) + return self._jsq.awaitTermination(int(timeout * 1000)) else: - return self._jcq.awaitTermination() + return self._jsq.awaitTermination() @since(2.0) def processAllAvailable(self): - """Blocks until all available data in the source has been processed an committed to the + """Blocks until all available data in the source has been processed and committed to the sink. This method is intended for testing. Note that in the case of continually arriving data, this method may block forever. Additionally, this method is only guaranteed to block until data that has been synchronously appended data to a stream source prior to invocation. (i.e. `getOffset` must immediately reflect the addition). """ - return self._jcq.processAllAvailable() + return self._jsq.processAllAvailable() @since(2.0) def stop(self): - """Stop this continuous query. + """Stop this streaming query. """ - self._jcq.stop() + self._jsq.stop() -class ContinuousQueryManager(object): - """A class to manage all the :class:`ContinuousQuery` ContinuousQueries active - on a :class:`SQLContext`. +class StreamingQueryManager(object): + """A class to manage all the :class:`StreamingQuery` StreamingQueries active. .. note:: Experimental .. versionadded:: 2.0 """ - def __init__(self, jcqm): - self._jcqm = jcqm + def __init__(self, jsqm): + self._jsqm = jsqm @property @ignore_unicode_prefix @@ -105,29 +121,35 @@ def __init__(self, jcqm): def active(self): """Returns a list of active queries associated with this SQLContext - >>> cq = df.write.format('memory').queryName('this_query').startStream() - >>> cqm = sqlContext.streams - >>> # get the list of active continuous queries - >>> [q.name for q in cqm.active] + >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() + >>> sqm = spark.streams + >>> # get the list of active streaming queries + >>> [q.name for q in sqm.active] [u'this_query'] - >>> cq.stop() + >>> sq.stop() """ - return [ContinuousQuery(jcq) for jcq in self._jcqm.active()] + return [StreamingQuery(jsq) for jsq in self._jsqm.active()] + @ignore_unicode_prefix @since(2.0) - def get(self, name): + def get(self, id): """Returns an active query from this SQLContext or throws exception if an active query with this name doesn't exist. - >>> df.write.format('memory').queryName('this_query').startStream() - >>> cq = sqlContext.streams.get('this_query') - >>> cq.isActive + >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() + >>> sq.name + u'this_query' + >>> sq = spark.streams.get(sq.id) + >>> sq.isActive + True + >>> sq = sqlContext.streams.get(sq.id) + >>> sq.isActive True - >>> cq.stop() + >>> sq.stop() """ - if type(name) != str or len(name.strip()) == 0: - raise ValueError("The name for the query must be a non-empty string. Got: %s" % name) - return ContinuousQuery(self._jcqm.get(name)) + if not isinstance(id, intlike): + raise ValueError("The id for the query must be an integer. Got: %s" % id) + return StreamingQuery(self._jsqm.get(id)) @since(2.0) def awaitAnyTermination(self, timeout=None): @@ -148,27 +170,324 @@ def awaitAnyTermination(self, timeout=None): queries, users need to stop all of them after any of them terminates with exception, and then check the `query.exception()` for each query. - throws :class:`ContinuousQueryException`, if `this` query has terminated with an exception + throws :class:`StreamingQueryException`, if `this` query has terminated with an exception """ if timeout is not None: if not isinstance(timeout, (int, float)) or timeout < 0: raise ValueError("timeout must be a positive integer or float. Got %s" % timeout) - return self._jcqm.awaitAnyTermination(int(timeout * 1000)) + return self._jsqm.awaitAnyTermination(int(timeout * 1000)) else: - return self._jcqm.awaitAnyTermination() + return self._jsqm.awaitAnyTermination() @since(2.0) def resetTerminated(self): """Forget about past terminated queries so that :func:`awaitAnyTermination()` can be used again to wait for new terminations. - >>> sqlContext.streams.resetTerminated() + >>> spark.streams.resetTerminated() + """ + self._jsqm.resetTerminated() + + +class StreamingQueryStatus(object): + """A class used to report information about the progress of a StreamingQuery. + + .. note:: Experimental + + .. versionadded:: 2.1 + """ + + def __init__(self, jsqs): + self._jsqs = jsqs + + def __str__(self): + """ + Pretty string of this query status. + + >>> print(sqs) + Status of query 'query' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + triggerId: 5 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: #0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [#1, -] """ - self._jcqm.resetTerminated() + return self._jsqs.toString() + + @property + @ignore_unicode_prefix + @since(2.1) + def name(self): + """ + Name of the query. This name is unique across all active queries. + + >>> sqs.name + u'query' + """ + return self._jsqs.name() + + @property + @since(2.1) + def id(self): + """ + Id of the query. This id is unique across all queries that have been started in + the current process. + + >>> int(sqs.id) + 1 + """ + return self._jsqs.id() + + @property + @since(2.1) + def timestamp(self): + """ + Timestamp (ms) of when this query was generated. + + >>> int(sqs.timestamp) + 123 + """ + return self._jsqs.timestamp() + + @property + @since(2.1) + def inputRate(self): + """ + Current total rate (rows/sec) at which data is being generated by all the sources. + + >>> sqs.inputRate + 15.5 + """ + return self._jsqs.inputRate() + + @property + @since(2.1) + def processingRate(self): + """ + Current rate (rows/sec) at which the query is processing data from all the sources. + + >>> sqs.processingRate + 23.5 + """ + return self._jsqs.processingRate() + + @property + @since(2.1) + def latency(self): + """ + Current average latency between the data being available in source and the sink + writing the corresponding output. + + >>> sqs.latency + 345.0 + """ + if (self._jsqs.latency().nonEmpty()): + return self._jsqs.latency().get() + else: + return None + + @property + @ignore_unicode_prefix + @since(2.1) + def sourceStatuses(self): + """ + Current statuses of the sources as a list. + + >>> len(sqs.sourceStatuses) + 1 + >>> sqs.sourceStatuses[0].description + u'MySource1' + """ + return [SourceStatus(ss) for ss in self._jsqs.sourceStatuses()] + + @property + @ignore_unicode_prefix + @since(2.1) + def sinkStatus(self): + """ + Current status of the sink. + + >>> sqs.sinkStatus.description + u'MySink' + """ + return SinkStatus(self._jsqs.sinkStatus()) + + @property + @ignore_unicode_prefix + @since(2.1) + def triggerDetails(self): + """ + Low-level details of the currently active trigger (e.g. number of rows processed + in trigger, latency of intermediate steps, etc.). + + If no trigger is currently active, then it will have details of the last completed trigger. + + >>> sqs.triggerDetails + {u'triggerId': u'5', u'latency.getBatch.total': u'20', u'numRows.input.total': u'100', + u'isTriggerActive': u'true', u'latency.getOffset.total': u'10', + u'isDataPresentInTrigger': u'true'} + """ + return self._jsqs.triggerDetails() + + +class SourceStatus(object): + """ + Status and metrics of a streaming Source. + + .. note:: Experimental + + .. versionadded:: 2.1 + """ + + def __init__(self, jss): + self._jss = jss + + def __str__(self): + """ + Pretty string of this source status. + + >>> print(sqs.sourceStatuses[0]) + Status of source MySource1 + Available offset: #0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + """ + return self._jss.toString() + + @property + @ignore_unicode_prefix + @since(2.1) + def description(self): + """ + Description of the source corresponding to this status. + + >>> sqs.sourceStatuses[0].description + u'MySource1' + """ + return self._jss.description() + + @property + @ignore_unicode_prefix + @since(2.1) + def offsetDesc(self): + """ + Description of the current offset if known. + + >>> sqs.sourceStatuses[0].offsetDesc + u'#0' + """ + return self._jss.offsetDesc() + + @property + @since(2.1) + def inputRate(self): + """ + Current rate (rows/sec) at which data is being generated by the source. + + >>> sqs.sourceStatuses[0].inputRate + 15.5 + """ + return self._jss.inputRate() + + @property + @since(2.1) + def processingRate(self): + """ + Current rate (rows/sec) at which the query is processing data from the source. + + >>> sqs.sourceStatuses[0].processingRate + 23.5 + """ + return self._jss.processingRate() + + @property + @ignore_unicode_prefix + @since(2.1) + def triggerDetails(self): + """ + Low-level details of the currently active trigger (e.g. number of rows processed + in trigger, latency of intermediate steps, etc.). + + If no trigger is currently active, then it will have details of the last completed trigger. + + >>> sqs.sourceStatuses[0].triggerDetails + {u'numRows.input.source': u'100', u'latency.getOffset.source': u'10', + u'latency.getBatch.source': u'20'} + """ + return self._jss.triggerDetails() + + +class SinkStatus(object): + """ + Status and metrics of a streaming Sink. + + .. note:: Experimental + + .. versionadded:: 2.1 + """ + + def __init__(self, jss): + self._jss = jss + + def __str__(self): + """ + Pretty string of this source status. + + >>> print(sqs.sinkStatus) + Status of sink MySink + Committed offsets: [#1, -] + """ + return self._jss.toString() + + @property + @ignore_unicode_prefix + @since(2.1) + def description(self): + """ + Description of the source corresponding to this status. + + >>> sqs.sinkStatus.description + u'MySink' + """ + return self._jss.description() + + @property + @ignore_unicode_prefix + @since(2.1) + def offsetDesc(self): + """ + Description of the current offsets up to which data has been written by the sink. + + >>> sqs.sinkStatus.offsetDesc + u'[#1, -]' + """ + return self._jss.offsetDesc() class Trigger(object): - """Used to indicate how often results should be produced by a :class:`ContinuousQuery`. + """Used to indicate how often results should be produced by a :class:`StreamingQuery`. .. note:: Experimental @@ -201,34 +520,542 @@ def __init__(self, interval): self.interval = interval def _to_java_trigger(self, sqlContext): - return sqlContext._sc._jvm.org.apache.spark.sql.ProcessingTime.create(self.interval) + return sqlContext._sc._jvm.org.apache.spark.sql.streaming.ProcessingTime.create( + self.interval) + + +class DataStreamReader(OptionUtils): + """ + Interface used to load a streaming :class:`DataFrame` from external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream` + to access this. + + .. note:: Experimental. + + .. versionadded:: 2.0 + """ + + def __init__(self, spark): + self._jreader = spark._ssql_ctx.readStream() + self._spark = spark + + def _df(self, jdf): + from pyspark.sql.dataframe import DataFrame + return DataFrame(jdf, self._spark) + + @since(2.0) + def format(self, source): + """Specifies the input data source format. + + .. note:: Experimental. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> s = spark.readStream.format("text") + """ + self._jreader = self._jreader.format(source) + return self + + @since(2.0) + def schema(self, schema): + """Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + .. note:: Experimental. + + :param schema: a :class:`pyspark.sql.types.StructType` object + + >>> s = spark.readStream.schema(sdf_schema) + """ + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + jschema = self._spark._ssql_ctx.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + return self + + @since(2.0) + def option(self, key, value): + """Adds an input option for the underlying data source. + + .. note:: Experimental. + + >>> s = spark.readStream.option("x", 1) + """ + self._jreader = self._jreader.option(key, to_str(value)) + return self + + @since(2.0) + def options(self, **options): + """Adds input options for the underlying data source. + + .. note:: Experimental. + + >>> s = spark.readStream.options(x="1", y=2) + """ + for k in options: + self._jreader = self._jreader.option(k, to_str(options[k])) + return self + + @since(2.0) + def load(self, path=None, format=None, schema=None, **options): + """Loads a data stream from a data source and returns it as a :class`DataFrame`. + + .. note:: Experimental. + + :param path: optional string for file-system backed data sources. + :param format: optional string for format of the data source. Default to 'parquet'. + :param schema: optional :class:`pyspark.sql.types.StructType` for the input schema. + :param options: all other string options + + >>> json_sdf = spark.readStream.format("json") \\ + ... .schema(sdf_schema) \\ + ... .load(tempfile.mkdtemp()) + >>> json_sdf.isStreaming + True + >>> json_sdf.schema == sdf_schema + True + """ + if format is not None: + self.format(format) + if schema is not None: + self.schema(schema) + self.options(**options) + if path is not None: + if type(path) != str or len(path.strip()) == 0: + raise ValueError("If the path is provided for stream, it needs to be a " + + "non-empty string. List of paths are not supported.") + return self._df(self._jreader.load(path)) + else: + return self._df(self._jreader.load()) + + @since(2.0) + def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, + allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, + allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, + mode=None, columnNameOfCorruptRecord=None, dateFormat=None, + timestampFormat=None): + """ + Loads a JSON file stream (one object per line) and returns a :class`DataFrame`. + + If the ``schema`` parameter is not specified, this function goes + through the input once to determine the input schema. + + .. note:: Experimental. + + :param path: string represents path to the JSON dataset, + or RDD of Strings storing JSON objects. + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param primitivesAsString: infers all primitive values as a string type. If None is set, + it uses the default value, ``false``. + :param prefersDecimal: infers all floating-point values as a decimal type. If the values + do not fit in decimal, then it infers them as doubles. If None is + set, it uses the default value, ``false``. + :param allowComments: ignores Java/C++ style comment in JSON records. If None is set, + it uses the default value, ``false``. + :param allowUnquotedFieldNames: allows unquoted JSON field names. If None is set, + it uses the default value, ``false``. + :param allowSingleQuotes: allows single quotes in addition to double quotes. If None is + set, it uses the default value, ``true``. + :param allowNumericLeadingZero: allows leading zeros in numbers (e.g. 00012). If None is + set, it uses the default value, ``false``. + :param allowBackslashEscapingAnyCharacter: allows accepting quoting of all character + using backslash quoting mechanism. If None is + set, it uses the default value, ``false``. + :param mode: allows a mode for dealing with corrupt records during parsing. If None is + set, it uses the default value, ``PERMISSIVE``. + + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \ + record and puts the malformed string into a new field configured by \ + ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \ + ``null`` for extra fields. + * ``DROPMALFORMED`` : ignores the whole corrupted records. + * ``FAILFAST`` : throws an exception when it meets corrupted records. + + :param columnNameOfCorruptRecord: allows renaming the new field having malformed string + created by ``PERMISSIVE`` mode. This overrides + ``spark.sql.columnNameOfCorruptRecord``. If None is set, + it uses the value specified in + ``spark.sql.columnNameOfCorruptRecord``. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + + >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) + >>> json_sdf.isStreaming + True + >>> json_sdf.schema == sdf_schema + True + """ + self._set_opts( + schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal, + allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames, + allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, + allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, + mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, + timestampFormat=timestampFormat) + if isinstance(path, basestring): + return self._df(self._jreader.json(path)) + else: + raise TypeError("path can be only a single string") + + @since(2.0) + def parquet(self, path): + """Loads a Parquet file stream, returning the result as a :class:`DataFrame`. + + You can set the following Parquet-specific option(s) for reading Parquet files: + * ``mergeSchema``: sets whether we should merge schemas collected from all \ + Parquet part-files. This will override ``spark.sql.parquet.mergeSchema``. \ + The default value is specified in ``spark.sql.parquet.mergeSchema``. + + .. note:: Experimental. + + >>> parquet_sdf = spark.readStream.schema(sdf_schema).parquet(tempfile.mkdtemp()) + >>> parquet_sdf.isStreaming + True + >>> parquet_sdf.schema == sdf_schema + True + """ + if isinstance(path, basestring): + return self._df(self._jreader.parquet(path)) + else: + raise TypeError("path can be only a single string") + + @ignore_unicode_prefix + @since(2.0) + def text(self, path): + """ + Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a + string column named "value", and followed by partitioned columns if there + are any. + + Each line in the text file is a new row in the resulting DataFrame. + + .. note:: Experimental. + + :param paths: string, or list of strings, for input path(s). + + >>> text_sdf = spark.readStream.text(tempfile.mkdtemp()) + >>> text_sdf.isStreaming + True + >>> "value" in str(text_sdf.schema) + True + """ + if isinstance(path, basestring): + return self._df(self._jreader.text(path)) + else: + raise TypeError("path can be only a single string") + + @since(2.0) + def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, + comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None, + ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None, + negativeInf=None, dateFormat=None, timestampFormat=None, maxColumns=None, + maxCharsPerColumn=None, maxMalformedLogPerPartition=None, mode=None): + """Loads a CSV file stream and returns the result as a :class:`DataFrame`. + + This function will go through the input once to determine the input schema if + ``inferSchema`` is enabled. To avoid going through the entire data once, disable + ``inferSchema`` option or specify the schema explicitly using ``schema``. + + .. note:: Experimental. + + :param path: string, or list of strings, for input path(s). + :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema. + :param sep: sets the single character as a separator for each field and value. + If None is set, it uses the default value, ``,``. + :param encoding: decodes the CSV files by the given encoding type. If None is set, + it uses the default value, ``UTF-8``. + :param quote: sets the single character used for escaping quoted values where the + separator can be part of the value. If None is set, it uses the default + value, ``"``. If you would like to turn off quotations, you need to set an + empty string. + :param escape: sets the single character used for escaping quotes inside an already + quoted value. If None is set, it uses the default value, ``\``. + :param comment: sets the single character used for skipping lines beginning with this + character. By default (None), it is disabled. + :param header: uses the first line as names of columns. If None is set, it uses the + default value, ``false``. + :param inferSchema: infers the input schema automatically from data. It requires one extra + pass over the data. If None is set, it uses the default value, ``false``. + :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values + being read should be skipped. If None is set, it uses + the default value, ``false``. + :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values + being read should be skipped. If None is set, it uses + the default value, ``false``. + :param nullValue: sets the string representation of a null value. If None is set, it uses + the default value, empty string. Since 2.0.1, this ``nullValue`` param + applies to all supported types including the string type. + :param nanValue: sets the string representation of a non-number value. If None is set, it + uses the default value, ``NaN``. + :param positiveInf: sets the string representation of a positive infinity value. If None + is set, it uses the default value, ``Inf``. + :param negativeInf: sets the string representation of a negative infinity value. If None + is set, it uses the default value, ``Inf``. + :param dateFormat: sets the string that indicates a date format. Custom date formats + follow the formats at ``java.text.SimpleDateFormat``. This + applies to date type. If None is set, it uses the + default value value, ``yyyy-MM-dd``. + :param timestampFormat: sets the string that indicates a timestamp format. Custom date + formats follow the formats at ``java.text.SimpleDateFormat``. + This applies to timestamp type. If None is set, it uses the + default value value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. + :param maxColumns: defines a hard limit of how many columns a record can have. If None is + set, it uses the default value, ``20480``. + :param maxCharsPerColumn: defines the maximum number of characters allowed for any given + value being read. If None is set, it uses the default value, + ``1000000``. + :param mode: allows a mode for dealing with corrupt records during parsing. If None is + set, it uses the default value, ``PERMISSIVE``. + + * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record. + When a schema is set by user, it sets ``null`` for extra fields. + * ``DROPMALFORMED`` : ignores the whole corrupted records. + * ``FAILFAST`` : throws an exception when it meets corrupted records. + + >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema) + >>> csv_sdf.isStreaming + True + >>> csv_sdf.schema == sdf_schema + True + """ + self._set_opts( + schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment, + header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace, + ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue, + nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf, + dateFormat=dateFormat, timestampFormat=timestampFormat, maxColumns=maxColumns, + maxCharsPerColumn=maxCharsPerColumn, + maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode) + if isinstance(path, basestring): + return self._df(self._jreader.csv(path)) + else: + raise TypeError("path can be only a single string") + + +class DataStreamWriter(object): + """ + Interface used to write a streaming :class:`DataFrame` to external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.writeStream` + to access this. + + .. note:: Experimental. + + .. versionadded:: 2.0 + """ + + def __init__(self, df): + self._df = df + self._spark = df.sql_ctx + self._jwrite = df._jdf.writeStream() + + def _sq(self, jsq): + from pyspark.sql.streaming import StreamingQuery + return StreamingQuery(jsq) + + @since(2.0) + def outputMode(self, outputMode): + """Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + + Options include: + + * `append`:Only the new rows in the streaming DataFrame/Dataset will be written to + the sink + * `complete`:All the rows in the streaming DataFrame/Dataset will be written to the sink + every time these is some updates + + .. note:: Experimental. + + >>> writer = sdf.writeStream.outputMode('append') + """ + if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0: + raise ValueError('The output mode must be a non-empty string. Got: %s' % outputMode) + self._jwrite = self._jwrite.outputMode(outputMode) + return self + + @since(2.0) + def format(self, source): + """Specifies the underlying output data source. + + .. note:: Experimental. + + :param source: string, name of the data source, which for now can be 'parquet'. + + >>> writer = sdf.writeStream.format('json') + """ + self._jwrite = self._jwrite.format(source) + return self + + @since(2.0) + def option(self, key, value): + """Adds an output option for the underlying data source. + + .. note:: Experimental. + """ + self._jwrite = self._jwrite.option(key, to_str(value)) + return self + + @since(2.0) + def options(self, **options): + """Adds output options for the underlying data source. + + .. note:: Experimental. + """ + for k in options: + self._jwrite = self._jwrite.option(k, to_str(options[k])) + return self + + @since(2.0) + def partitionBy(self, *cols): + """Partitions the output by the given columns on the file system. + + If specified, the output is laid out on the file system similar + to Hive's partitioning scheme. + + .. note:: Experimental. + + :param cols: name of columns + + """ + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] + self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) + return self + + @since(2.0) + def queryName(self, queryName): + """Specifies the name of the :class:`StreamingQuery` that can be started with + :func:`start`. This name must be unique among all the currently active queries + in the associated SparkSession. + + .. note:: Experimental. + + :param queryName: unique name for the query + + >>> writer = sdf.writeStream.queryName('streaming_query') + """ + if not queryName or type(queryName) != str or len(queryName.strip()) == 0: + raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName) + self._jwrite = self._jwrite.queryName(queryName) + return self + + @keyword_only + @since(2.0) + def trigger(self, processingTime=None): + """Set the trigger for the stream query. If this is not set it will run the query as fast + as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``. + + .. note:: Experimental. + + :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + + >>> # trigger the query for execution every 5 seconds + >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') + """ + from pyspark.sql.streaming import ProcessingTime + trigger = None + if processingTime is not None: + if type(processingTime) != str or len(processingTime.strip()) == 0: + raise ValueError('The processing time must be a non empty string. Got: %s' % + processingTime) + trigger = ProcessingTime(processingTime) + if trigger is None: + raise ValueError('A trigger was not provided. Supported triggers: processingTime.') + self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark)) + return self + + @ignore_unicode_prefix + @since(2.0) + def start(self, path=None, format=None, partitionBy=None, queryName=None, **options): + """Streams the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``format`` and a set of ``options``. + If ``format`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + .. note:: Experimental. + + :param path: the path in a Hadoop supported file system + :param format: the format used to save + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + :param queryName: unique name for the query + :param options: All other string options. You may want to provide a `checkpointLocation` + for most streams, however it is not required for a `memory` stream. + + >>> sq = sdf.writeStream.format('memory').queryName('this_query').start() + >>> sq.isActive + True + >>> sq.name + u'this_query' + >>> sq.stop() + >>> sq.isActive + False + >>> sq = sdf.writeStream.trigger(processingTime='5 seconds').start( + ... queryName='that_query', format='memory') + >>> sq.name + u'that_query' + >>> sq.isActive + True + >>> sq.stop() + """ + self.options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) + if format is not None: + self.format(format) + if queryName is not None: + self.queryName(queryName) + if path is None: + return self._sq(self._jwrite.start()) + else: + return self._sq(self._jwrite.start(path)) def _test(): import doctest import os import tempfile - from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext, HiveContext - import pyspark.sql.readwriter + from pyspark.sql import Row, SparkSession, SQLContext + import pyspark.sql.streaming os.chdir(os.environ["SPARK_HOME"]) - globs = pyspark.sql.readwriter.__dict__.copy() - sc = SparkContext('local[4]', 'PythonTest') + globs = pyspark.sql.streaming.__dict__.copy() + try: + spark = SparkSession.builder.getOrCreate() + except py4j.protocol.Py4JError: + spark = SparkSession(sc) globs['tempfile'] = tempfile globs['os'] = os - globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) - globs['hiveContext'] = HiveContext(sc) + globs['spark'] = spark + globs['sqlContext'] = SQLContext.getOrCreate(spark.sparkContext) + globs['sdf'] = \ + spark.readStream.format('text').load('python/test_support/sql/streaming') + globs['sdf_schema'] = StructType([StructField("data", StringType(), False)]) globs['df'] = \ - globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming') + globs['spark'].readStream.format('text').load('python/test_support/sql/streaming') + globs['sqs'] = StreamingQueryStatus( + spark.sparkContext._jvm.org.apache.spark.sql.streaming.StreamingQueryStatus.testStatus()) (failure_count, test_count) = doctest.testmod( - pyspark.sql.readwriter, globs=globs, + pyspark.sql.streaming, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) - globs['sc'].stop() + globs['spark'].stop() + if failure_count: exit(-1) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4995b263e1930..a71457a076bf8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -22,6 +22,7 @@ """ import os import sys +import subprocess import pydoc import shutil import tempfile @@ -45,10 +46,10 @@ else: import unittest -from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row +from pyspark.sql import SparkSession, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type -from pyspark.tests import ReusedPySparkTestCase +from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests from pyspark.sql.functions import UserDefinedFunction, sha2 from pyspark.sql.window import Window from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException @@ -177,19 +178,9 @@ def test_datetype_equal_zero(self): dt = DateType() self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) - -class SQLContextTests(ReusedPySparkTestCase): - def test_get_or_create(self): - sqlCtx = SQLContext.getOrCreate(self.sc) - self.assertTrue(SQLContext.getOrCreate(self.sc) is sqlCtx) - - def test_new_session(self): - sqlCtx = SQLContext.getOrCreate(self.sc) - sqlCtx.setConf("test_key", "a") - sqlCtx2 = sqlCtx.newSession() - sqlCtx2.setConf("test_key", "b") - self.assertEqual(sqlCtx.getConf("test_key", ""), "a") - self.assertEqual(sqlCtx2.getConf("test_key", ""), "b") + def test_empty_row(self): + row = Row() + self.assertEqual(len(row), 0) class SQLTests(ReusedPySparkTestCase): @@ -199,15 +190,14 @@ def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) - cls.sparkSession = SparkSession(cls.sc) - cls.sqlCtx = cls.sparkSession._wrapped + cls.spark = SparkSession(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = cls.sc.parallelize(cls.testData, 2) - cls.df = rdd.toDF() + cls.df = cls.spark.createDataFrame(cls.testData) @classmethod def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() shutil.rmtree(cls.tempdir.name, ignore_errors=True) def test_row_should_be_read_only(self): @@ -218,7 +208,7 @@ def foo(): row.a = 3 self.assertRaises(Exception, foo) - row2 = self.sqlCtx.range(10).first() + row2 = self.spark.range(10).first() self.assertEqual(0, row2.id) def foo2(): @@ -226,14 +216,14 @@ def foo2(): self.assertRaises(Exception, foo2) def test_range(self): - self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) - self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) - self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) - self.assertEqual(self.sqlCtx.range(-2).count(), 0) - self.assertEqual(self.sqlCtx.range(3).count(), 3) + self.assertEqual(self.spark.range(1, 1).count(), 0) + self.assertEqual(self.spark.range(1, 0, -1).count(), 1) + self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2) + self.assertEqual(self.spark.range(-2).count(), 0) + self.assertEqual(self.spark.range(3).count(), 3) def test_duplicated_column_names(self): - df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"]) + df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) row = df.select('*').first() self.assertEqual(1, row[0]) self.assertEqual(2, row[1]) @@ -243,11 +233,18 @@ def test_duplicated_column_names(self): self.assertRaises(AnalysisException, lambda: df.select(df.c).first()) self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first()) + def test_column_name_encoding(self): + """Ensure that created columns has `str` type consistently.""" + columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns + self.assertEqual(columns, ['name', 'age']) + self.assertTrue(isinstance(columns[0], str)) + self.assertTrue(isinstance(columns[1], str)) + def test_explode(self): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] rdd = self.sc.parallelize(d) - data = self.sqlCtx.createDataFrame(rdd) + data = self.spark.createDataFrame(rdd) result = data.select(explode(data.intlist).alias("a")).select("a").collect() self.assertEqual(result[0][0], 1) @@ -269,7 +266,7 @@ def test_and_in_expression(self): def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) - data = self.sqlCtx.createDataFrame(rdd) + data = self.spark.createDataFrame(rdd) class PlusFour: def __call__(self, col): @@ -284,7 +281,7 @@ def __call__(self, col): def test_udf_with_partial_function(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) - data = self.sqlCtx.createDataFrame(rdd) + data = self.spark.createDataFrame(rdd) def some_func(col, param): if col is not None: @@ -296,66 +293,103 @@ def some_func(col, param): self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\ + .createOrReplaceTempView("test") + [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) def test_chained_udf(self): - self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) - [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.spark.sql("SELECT double(1)").collect() self.assertEqual(row[0], 2) - [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + [row] = self.spark.sql("SELECT double(double(1))").collect() self.assertEqual(row[0], 4) - [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() self.assertEqual(row[0], 6) def test_multiple_udfs(self): - self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType()) - [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect() + self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.spark.sql("SELECT double(1), double(2)").collect() self.assertEqual(tuple(row), (2, 4)) - [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect() self.assertEqual(tuple(row), (4, 12)) - self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() self.assertEqual(tuple(row), (6, 5)) + def test_udf_in_filter_on_top_of_outer_join(self): + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(a=1)]) + df = left.join(right, on='a', how='left_outer') + df = df.withColumn('b', udf(lambda x: 'x')(df.a)) + self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) + + def test_udf_without_arguments(self): + self.spark.catalog.registerFunction("foo", lambda: "bar") + [row] = self.spark.sql("SELECT foo()").collect() + self.assertEqual(row[0], "bar") + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) - self.sqlCtx.createDataFrame(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.spark.createDataFrame(rdd).createOrReplaceTempView("test") + self.spark.catalog.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect() self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.spark.sql("SELECT MYUDF('c')").collect() self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + [res] = self.spark.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) def test_udf_with_aggregate_function(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) - from pyspark.sql.functions import udf, col + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col, sum from pyspark.sql.types import BooleanType my_filter = udf(lambda a: a == 1, BooleanType()) sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) self.assertEqual(sel.collect(), [Row(key=1)]) + my_copy = udf(lambda x: x, IntegerType()) + my_add = udf(lambda a, b: int(a + b), IntegerType()) + my_strlen = udf(lambda x: len(x), IntegerType()) + sel = df.groupBy(my_copy(col("key")).alias("k"))\ + .agg(sum(my_strlen(col("value"))).alias("s"))\ + .select(my_add(col("k"), col("s")).alias("t")) + self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) + + def test_udf_in_generate(self): + from pyspark.sql.functions import udf, explode + df = self.spark.range(5) + f = udf(lambda x: list(range(x)), ArrayType(LongType())) + row = df.select(explode(f(*df))).groupBy().sum().first() + self.assertEqual(row[0], 10) + + def test_udf_with_order_by_and_limit(self): + from pyspark.sql.functions import udf + my_copy = udf(lambda x: x, IntegerType()) + df = self.spark.range(10).orderBy("id") + res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1) + res.explain(True) + self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.read.json(rdd) + df = self.spark.read.json(rdd) df.count() df.collect() df.schema @@ -368,42 +402,58 @@ def test_basic_functions(self): self.assertTrue(df.is_cached) self.assertEqual(2, df.count()) - df.registerTempTable("temp") - df = self.sqlCtx.sql("select foo from temp") + df.createOrReplaceTempView("temp") + df = self.spark.sql("select foo from temp") df.count() df.collect() def test_apply_schema_to_row(self): - df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.createDataFrame(df.rdd.map(lambda x: x), df.schema) + df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + df3 = self.spark.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) def test_infer_schema_to_local(self): input = [{"a": 1}, {"b": "coffee"}] rdd = self.sc.parallelize(input) - df = self.sqlCtx.createDataFrame(input) - df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + df = self.spark.createDataFrame(input) + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) - df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + df3 = self.spark.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_apply_schema_to_dict_and_rows(self): + schema = StructType().add("b", StringType()).add("a", IntegerType()) + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + for verify in [False, True]: + df = self.spark.createDataFrame(input, schema, verifySchema=verify) + df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) + df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify) + self.assertEqual(10, df3.count()) + input = [Row(a=x, b=str(x)) for x in range(10)] + df4 = self.spark.createDataFrame(input, schema, verifySchema=verify) + self.assertEqual(10, df4.count()) + def test_create_dataframe_schema_mismatch(self): input = [Row(a=1)] rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) - df = self.sqlCtx.createDataFrame(rdd, schema) + df = self.spark.createDataFrame(rdd, schema) self.assertRaises(Exception, lambda: df.show()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) - df = self.sqlCtx.createDataFrame(rdd) + df = self.spark.createDataFrame(rdd) row = df.head() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) @@ -425,31 +475,31 @@ def test_infer_schema(self): d = [Row(l=[], d={}, s=None), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) - df = self.sqlCtx.createDataFrame(rdd) + df = self.spark.createDataFrame(rdd) self.assertEqual([], df.rdd.map(lambda r: r.l).first()) self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) - df.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + df.createOrReplaceTempView("test") + result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) - df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) self.assertEqual({}, df2.rdd.map(lambda r: r.d).first()) self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect()) - df2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + df2.createOrReplaceTempView("test2") + result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), NestedRow([2, 3], {"row2": 2.0})]) - df = self.sqlCtx.createDataFrame(nestedRdd1) + df = self.spark.createDataFrame(nestedRdd1) self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), NestedRow([[2, 3], [3, 4]], [2, 3])]) - df = self.sqlCtx.createDataFrame(nestedRdd2) + df = self.spark.createDataFrame(nestedRdd2) self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) from collections import namedtuple @@ -457,17 +507,17 @@ def test_infer_nested_schema(self): rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) - df = self.sqlCtx.createDataFrame(rdd) + df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] - df = self.sqlCtx.createDataFrame(data) + df = self.spark.createDataFrame(data) self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) self.assertEqual(df.first(), Row(key=1, value="1")) def test_select_null_literal(self): - df = self.sqlCtx.sql("select null as col") + df = self.spark.sql("select null as col") self.assertEqual(Row(col=None), df.first()) def test_apply_schema(self): @@ -488,17 +538,17 @@ def test_apply_schema(self): StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), StructField("list1", ArrayType(ByteType(), False), False), StructField("null1", DoubleType(), True)]) - df = self.sqlCtx.createDataFrame(rdd, schema) + df = self.spark.createDataFrame(rdd, schema) results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) self.assertEqual(r, results.first()) - df.registerTempTable("table2") - r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + - "float1 + 1.5 as float1 FROM table2").first() + df.createOrReplaceTempView("table2") + r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) @@ -508,7 +558,7 @@ def test_apply_schema(self): abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" schema = _parse_schema_abstract(abstract) typedSchema = _infer_schema_type(rdd.first(), schema) - df = self.sqlCtx.createDataFrame(rdd, typedSchema) + df = self.spark.createDataFrame(rdd, typedSchema) r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) self.assertEqual(r, tuple(df.first())) @@ -523,8 +573,8 @@ def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) df = self.sc.parallelize([row]).toDF() - df.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").head() + df.createOrReplaceTempView("test") + row = self.spark.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) @@ -535,7 +585,7 @@ def test_udt(self): def check_datatype(datatype): pickled = pickle.loads(pickle.dumps(datatype)) assert datatype == pickled - scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json()) + scala_datatype = self.spark._wrapped._ssql_ctx.parseDataType(datatype.json()) python_datatype = _parse_datatype_json_string(scala_datatype.json()) assert datatype == python_datatype @@ -557,24 +607,70 @@ def check_datatype(datatype): _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_simple_udt_in_df(self): + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=schema) + df.show() + + def test_nested_udt_in_df(self): + schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], + schema=schema) + df.collect() + + schema = StructType().add("key", LongType()).add("val", + MapType(LongType(), PythonOnlyUDT())) + df = self.spark.createDataFrame( + [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)], + schema=schema) + df.collect() + + def test_complex_nested_udt_in_df(self): + from pyspark.sql.functions import udf + + schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) + df = self.spark.createDataFrame( + [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], + schema=schema) + df.collect() + + gd = df.groupby("key").agg({"val": "collect_list"}) + gd.collect() + udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema)) + gd.select(udf(*gd)).collect() + + def test_udt_with_none(self): + df = self.spark.range(0, 10, 1, 1) + + def myudf(x): + if x > 0: + return PythonOnlyPoint(float(x), float(x)) + + self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT()) + rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] + self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) - df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + df.createOrReplaceTempView("labeled_point") + point = self.spark.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), PythonOnlyUDT) - df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + df.createOrReplaceTempView("labeled_point") + point = self.spark.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_apply_schema_with_udt(self): @@ -582,21 +678,21 @@ def test_apply_schema_with_udt(self): row = (1.0, ExamplePoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.createDataFrame([row], schema) + df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) - df = self.sqlCtx.createDataFrame([row], schema) + df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) @@ -604,7 +700,7 @@ def test_udf_with_udt(self): self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) @@ -614,17 +710,17 @@ def test_udf_with_udt(self): def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.sqlCtx.createDataFrame([row]) + df0 = self.spark.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.write.parquet(output_dir) - df1 = self.sqlCtx.read.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df0 = self.sqlCtx.createDataFrame([row]) + df0 = self.spark.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') - df1 = self.sqlCtx.read.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) @@ -634,8 +730,8 @@ def test_union_with_udt(self): row2 = (2.0, ExamplePoint(3.0, 4.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df1 = self.sqlCtx.createDataFrame([row1], schema) - df2 = self.sqlCtx.createDataFrame([row2], schema) + df1 = self.spark.createDataFrame([row1], schema) + df2 = self.spark.createDataFrame([row2], schema) result = df1.union(df2).orderBy("label").collect() self.assertEqual( @@ -688,7 +784,7 @@ def test_aggregator(self): def test_first_last_ignorenulls(self): from pyspark.sql import functions - df = self.sqlCtx.range(0, 100) + df = self.spark.range(0, 100) df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) df3 = df2.select(functions.first(df2.id, False).alias('a'), functions.first(df2.id, True).alias('b'), @@ -829,36 +925,36 @@ def test_metadata_null(self): schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) - self.sqlCtx.createDataFrame(rdd, schema) + self.spark.createDataFrame(rdd, schema) def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.json(tmpPath) - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.read.json(tmpPath, schema) + actual = self.spark.read.json(tmpPath, schema) self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) df.write.json(tmpPath, "overwrite") - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) df.write.save(format="json", mode="overwrite", path=tmpPath, noUse="this options will not be used in save.") - actual = self.sqlCtx.read.load(format="json", path=tmpPath, - noUse="this options will not be used in load.") + actual = self.spark.read.load(format="json", path=tmpPath, + noUse="this options will not be used in load.") self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", "org.apache.spark.sql.parquet") - self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.read.load(path=tmpPath) + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) csvpath = os.path.join(tempfile.mkdtemp(), 'data') df.write.option('quote', None).format('csv').save(csvpath) @@ -870,38 +966,38 @@ def test_save_and_load_builder(self): tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.json(tmpPath) - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.read.json(tmpPath, schema) + actual = self.spark.read.json(tmpPath, schema) self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) df.write.mode("overwrite").json(tmpPath) - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ .option("noUse", "this option will not be used in save.")\ .format("json").save(path=tmpPath) actual =\ - self.sqlCtx.read.format("json")\ - .load(path=tmpPath, noUse="this options will not be used in load.") + self.spark.read.format("json")\ + .load(path=tmpPath, noUse="this options will not be used in load.") self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", "org.apache.spark.sql.parquet") - self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.read.load(path=tmpPath) + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) def test_stream_trigger_takes_keyword_args(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') try: - df.write.trigger('5 seconds') + df.writeStream.trigger('5 seconds') self.fail("Should have thrown an exception") except TypeError: # should throw error @@ -909,48 +1005,51 @@ def test_stream_trigger_takes_keyword_args(self): def test_stream_read_options(self): schema = StructType([StructField("data", StringType(), False)]) - df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\ - .schema(schema).stream() + df = self.spark.readStream\ + .format('text')\ + .option('path', 'python/test_support/sql/streaming')\ + .schema(schema)\ + .load() self.assertTrue(df.isStreaming) self.assertEqual(df.schema.simpleString(), "struct") def test_stream_read_options_overwrite(self): bad_schema = StructType([StructField("test", IntegerType(), False)]) schema = StructType([StructField("data", StringType(), False)]) - df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \ - .schema(bad_schema).stream(path='python/test_support/sql/streaming', - schema=schema, format='text') + df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \ + .schema(bad_schema)\ + .load(path='python/test_support/sql/streaming', schema=schema, format='text') self.assertTrue(df.isStreaming) self.assertEqual(df.schema.simpleString(), "struct") def test_stream_save_options(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: - cq.stop() + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - cq = df.write.option('checkpointLocation', chk).queryName('this_query') \ - .format('parquet').option('path', out).startStream() + q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \ + .format('parquet').outputMode('append').option('path', out).start() try: - self.assertEqual(cq.name, 'this_query') - self.assertTrue(cq.isActive) - cq.processAllAvailable() + self.assertEqual(q.name, 'this_query') + self.assertTrue(q.isActive) + q.processAllAvailable() output_files = [] for _, _, files in os.walk(out): - output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + output_files.extend([f for f in files if not f.startswith('.')]) self.assertTrue(len(output_files) > 0) self.assertTrue(len(os.listdir(chk)) > 0) finally: - cq.stop() + q.stop() shutil.rmtree(tmpPath) def test_stream_save_options_overwrite(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: - cq.stop() + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) @@ -958,84 +1057,86 @@ def test_stream_save_options_overwrite(self): chk = os.path.join(tmpPath, 'chk') fake1 = os.path.join(tmpPath, 'fake1') fake2 = os.path.join(tmpPath, 'fake2') - cq = df.write.option('checkpointLocation', fake1).format('memory').option('path', fake2) \ - .queryName('fake_query').startStream(path=out, format='parquet', queryName='this_query', - checkpointLocation=chk) + q = df.writeStream.option('checkpointLocation', fake1)\ + .format('memory').option('path', fake2) \ + .queryName('fake_query').outputMode('append') \ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) + try: - self.assertEqual(cq.name, 'this_query') - self.assertTrue(cq.isActive) - cq.processAllAvailable() + self.assertEqual(q.name, 'this_query') + self.assertTrue(q.isActive) + q.processAllAvailable() output_files = [] for _, _, files in os.walk(out): - output_files.extend([f for f in files if 'parquet' in f and not f.startswith('.')]) + output_files.extend([f for f in files if not f.startswith('.')]) self.assertTrue(len(output_files) > 0) self.assertTrue(len(os.listdir(chk)) > 0) self.assertFalse(os.path.isdir(fake1)) # should not have been created self.assertFalse(os.path.isdir(fake2)) # should not have been created finally: - cq.stop() + q.stop() shutil.rmtree(tmpPath) def test_stream_await_termination(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: - cq.stop() + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - cq = df.write.startStream(path=out, format='parquet', queryName='this_query', - checkpointLocation=chk) + q = df.writeStream\ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) try: - self.assertTrue(cq.isActive) + self.assertTrue(q.isActive) try: - cq.awaitTermination("hello") + q.awaitTermination("hello") self.fail("Expected a value exception") except ValueError: pass now = time.time() # test should take at least 2 seconds - res = cq.awaitTermination(2.6) + res = q.awaitTermination(2.6) duration = time.time() - now self.assertTrue(duration >= 2) self.assertFalse(res) finally: - cq.stop() + q.stop() shutil.rmtree(tmpPath) def test_query_manager_await_termination(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: - cq.stop() + df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') + for q in self.spark._wrapped.streams.active: + q.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) self.assertTrue(df.isStreaming) out = os.path.join(tmpPath, 'out') chk = os.path.join(tmpPath, 'chk') - cq = df.write.startStream(path=out, format='parquet', queryName='this_query', - checkpointLocation=chk) + q = df.writeStream\ + .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk) try: - self.assertTrue(cq.isActive) + self.assertTrue(q.isActive) try: - self.sqlCtx.streams.awaitAnyTermination("hello") + self.spark._wrapped.streams.awaitAnyTermination("hello") self.fail("Expected a value exception") except ValueError: pass now = time.time() # test should take at least 2 seconds - res = self.sqlCtx.streams.awaitAnyTermination(2.6) + res = self.spark._wrapped.streams.awaitAnyTermination(2.6) duration = time.time() - now self.assertTrue(duration >= 2) self.assertFalse(res) finally: - cq.stop() + q.stop() shutil.rmtree(tmpPath) def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.read.json(rdd) + df = self.spark.read.json(rdd) # render_doc() reproduces the help() exception without printing output pydoc.render_doc(df) pydoc.render_doc(df.foo) @@ -1051,8 +1152,15 @@ def test_access_column(self): self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): - df = self.sqlCtx.createDataFrame([(1,)], ["数量"]) - self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema) + if sys.version >= '3': + columnName = "数量" + self.assertTrue(isinstance(columnName, str)) + else: + columnName = unicode("数量", "utf-8") + self.assertTrue(isinstance(columnName, unicode)) + schema = StructType([StructField(columnName, LongType(), True)]) + df = self.spark.createDataFrame([(1,)], schema) + self.assertEqual(schema, df.schema) self.assertEqual("DataFrame[数量: bigint]", str(df)) self.assertEqual([("数量", 'bigint')], df.dtypes) self.assertEqual(1, df.select("数量").first()[0]) @@ -1084,7 +1192,7 @@ def test_infer_long_type(self): # this saving as Parquet caused issues as well. output_dir = os.path.join(self.tempdir.name, "infer_long_type") df.write.parquet(output_dir) - df1 = self.sqlCtx.read.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) self.assertEqual('a', df1.first().f1) self.assertEqual(100000000000000, df1.first().f2) @@ -1100,7 +1208,7 @@ def test_filter_with_datetime(self): time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) date = time.date() row = Row(date=date, time=time) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(1, df.filter(df.date == date).count()) self.assertEqual(1, df.filter(df.time == time).count()) self.assertEqual(0, df.filter(df.date > date).count()) @@ -1110,7 +1218,7 @@ def test_filter_with_datetime_timezone(self): dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) row = Row(date=dt1) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(0, df.filter(df.date == dt2).count()) self.assertEqual(1, df.filter(df.date > dt2).count()) self.assertEqual(0, df.filter(df.date < dt2).count()) @@ -1125,7 +1233,7 @@ def test_time_with_timezone(self): utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds # add microseconds to utcnow (keeping year,month,day,hour,minute,second) utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) - df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) + df = self.spark.createDataFrame([(day, now, utcnow)]) day1, now1, utcnow1 = df.first() self.assertEqual(day1, day) self.assertEqual(now, now1) @@ -1134,13 +1242,13 @@ def test_time_with_timezone(self): def test_decimal(self): from decimal import Decimal schema = StructType([StructField("decimal", DecimalType(10, 5))]) - df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema) + df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema) row = df.select(df.decimal + 1).first() self.assertEqual(row[0], Decimal("4.14159")) tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.parquet(tmpPath) - df2 = self.sqlCtx.read.parquet(tmpPath) + df2 = self.spark.read.parquet(tmpPath) row = df2.first() self.assertEqual(row[0], Decimal("3.14159")) @@ -1151,52 +1259,52 @@ def test_dropna(self): StructField("height", DoubleType(), True)]) # shouldn't drop a non-null row - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, 80.1)], schema).dropna().count(), 1) # dropping rows with a single null value - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna().count(), 0) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), 0) # if how = 'all', only drop rows if all values are null - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(None, None, None)], schema).dropna(how='all').count(), 0) # how and subset - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), 0) # threshold - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, None)], schema).dropna(thresh=2).count(), 0) # threshold and subset - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), 0) # thresh should take precedence over how - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna( how='any', thresh=2, subset=['name', 'age']).count(), 1) @@ -1208,33 +1316,33 @@ def test_fillna(self): StructField("height", DoubleType(), True)]) # fillna shouldn't change non-null values - row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() self.assertEqual(row.age, 10) # fillna with int - row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.0) # fillna with double - row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.1) # fillna with string - row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first() + row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first() self.assertEqual(row.name, u"hello") self.assertEqual(row.age, None) # fillna with subset specified for numeric cols - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() self.assertEqual(row.name, None) self.assertEqual(row.age, 50) self.assertEqual(row.height, None) # fillna with subset specified for numeric cols - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() self.assertEqual(row.name, "haha") self.assertEqual(row.age, None) @@ -1243,7 +1351,7 @@ def test_fillna(self): def test_bitwise_operations(self): from pyspark.sql import functions row = Row(a=170, b=75) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict() self.assertEqual(170 & 75, result['(a & b)']) result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict() @@ -1256,7 +1364,7 @@ def test_bitwise_operations(self): def test_expr(self): from pyspark.sql import functions row = Row(a="length string", b=75) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) result = df.select(functions.expr("length(a)")).collect()[0].asDict() self.assertEqual(13, result["length(a)"]) @@ -1267,58 +1375,58 @@ def test_replace(self): StructField("height", DoubleType(), True)]) # replace with int - row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() + row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() self.assertEqual(row.age, 20) self.assertEqual(row.height, 20.0) # replace with double - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first() self.assertEqual(row.age, 82) self.assertEqual(row.height, 82.1) # replace with string - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first() self.assertEqual(row.name, u"Ann") self.assertEqual(row.age, 10) # replace with subset specified by a string of a column name w/ actual change - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first() self.assertEqual(row.age, 20) # replace with subset specified by a string of a column name w/o actual change - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first() self.assertEqual(row.age, 10) # replace with subset specified with one column replaced, another column not in subset # stays unchanged. - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first() self.assertEqual(row.name, u'Alice') self.assertEqual(row.age, 20) self.assertEqual(row.height, 10.0) # replace with subset specified but no column will be replaced - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first() self.assertEqual(row.name, u'Alice') self.assertEqual(row.age, 10) self.assertEqual(row.height, None) def test_capture_analysis_exception(self): - self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) def test_capture_parse_exception(self): - self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc")) + self.assertRaises(ParseException, lambda: self.spark.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", - lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1")) - df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"]) + lambda: self.spark.sql("SET mapred.reduce.tasks=-1")) + df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", lambda: df.select(sha2(df.a, 1024)).collect()) try: @@ -1345,8 +1453,8 @@ def foo(): def test_functions_broadcast(self): from pyspark.sql.functions import broadcast - df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) - df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) # equijoin - should be converted into broadcast join plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() @@ -1396,9 +1504,9 @@ def test_toDF_with_schema_string(self): self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) def test_conf(self): - spark = self.sparkSession + spark = self.spark spark.conf.set("bogo", "sipeo") - self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo") + self.assertEqual(spark.conf.get("bogo"), "sipeo") spark.conf.set("bogo", "ta") self.assertEqual(spark.conf.get("bogo"), "ta") self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") @@ -1408,7 +1516,7 @@ def test_conf(self): self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") def test_current_database(self): - spark = self.sparkSession + spark = self.spark spark.catalog._reset() self.assertEquals(spark.catalog.currentDatabase(), "default") spark.sql("CREATE DATABASE some_db") @@ -1420,7 +1528,7 @@ def test_current_database(self): lambda: spark.catalog.setCurrentDatabase("does_not_exist")) def test_list_databases(self): - spark = self.sparkSession + spark = self.spark spark.catalog._reset() databases = [db.name for db in spark.catalog.listDatabases()] self.assertEquals(databases, ["default"]) @@ -1430,12 +1538,12 @@ def test_list_databases(self): def test_list_tables(self): from pyspark.sql.catalog import Table - spark = self.sparkSession + spark = self.spark spark.catalog._reset() spark.sql("CREATE DATABASE some_db") self.assertEquals(spark.catalog.listTables(), []) self.assertEquals(spark.catalog.listTables("some_db"), []) - spark.createDataFrame([(1, 1)]).registerTempTable("temp_tab") + spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab") spark.sql("CREATE TABLE tab1 (name STRING, age INT)") spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT)") tables = sorted(spark.catalog.listTables(), key=lambda t: t.name) @@ -1475,7 +1583,7 @@ def test_list_tables(self): def test_list_functions(self): from pyspark.sql.catalog import Function - spark = self.sparkSession + spark = self.spark spark.catalog._reset() spark.sql("CREATE DATABASE some_db") functions = dict((f.name, f) for f in spark.catalog.listFunctions()) @@ -1512,7 +1620,7 @@ def test_list_functions(self): def test_list_columns(self): from pyspark.sql.catalog import Column - spark = self.sparkSession + spark = self.spark spark.catalog._reset() spark.sql("CREATE DATABASE some_db") spark.sql("CREATE TABLE tab1 (name STRING, age INT)") @@ -1561,9 +1669,9 @@ def test_list_columns(self): lambda: spark.catalog.listColumns("does_not_exist")) def test_cache(self): - spark = self.sparkSession - spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1") - spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2") + spark = self.spark + spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab1") + spark.createDataFrame([(2, 2), (3, 3)]).createOrReplaceTempView("tab2") self.assertFalse(spark.catalog.isCached("tab1")) self.assertFalse(spark.catalog.isCached("tab2")) spark.catalog.cacheTable("tab1") @@ -1589,6 +1697,60 @@ def test_cache(self): "does_not_exist", lambda: spark.catalog.uncacheTable("does_not_exist")) + def test_read_text_file_list(self): + df = self.spark.read.text(['python/test_support/sql/text-test.txt', + 'python/test_support/sql/text-test.txt']) + count = df.count() + self.assertEquals(count, 4) + + def test_BinaryType_serialization(self): + # Pyrolite version <= 4.9 could not serialize BinaryType with Python3 SPARK-17808 + schema = StructType([StructField('mybytes', BinaryType())]) + data = [[bytearray(b'here is my data')], + [bytearray(b'and here is some more')]] + df = self.spark.createDataFrame(data, schema=schema) + df.collect() + + +class HiveSparkSubmitTests(SparkSubmitTests): + + def test_hivecontext(self): + # This test checks that HiveContext is using Hive metastore (SPARK-16224). + # It sets a metastore url and checks if there is a derby dir created by + # Hive metastore. If this derby dir exists, HiveContext is using + # Hive metastore. + metastore_path = os.path.join(tempfile.mkdtemp(), "spark16224_metastore_db") + metastore_URL = "jdbc:derby:;databaseName=" + metastore_path + ";create=true" + hive_site_dir = os.path.join(self.programDir, "conf") + hive_site_file = self.createTempFile("hive-site.xml", (""" + | + | + | javax.jdo.option.ConnectionURL + | %s + | + | + """ % metastore_URL).lstrip(), "conf") + script = self.createTempFile("test.py", """ + |import os + | + |from pyspark.conf import SparkConf + |from pyspark.context import SparkContext + |from pyspark.sql import HiveContext + | + |conf = SparkConf() + |sc = SparkContext(conf=conf) + |hive_context = HiveContext(sc) + |print(hive_context.sql("show databases").collect()) + """) + proc = subprocess.Popen( + [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", + "--driver-class-path", hive_site_dir, script], + stdout=subprocess.PIPE) + out, err = proc.communicate() + self.assertEqual(0, proc.returncode) + self.assertIn("default", out.decode('utf-8')) + self.assertTrue(os.path.exists(metastore_path)) + class HiveContextSQLTests(ReusedPySparkTestCase): @@ -1605,7 +1767,7 @@ def setUpClass(cls): cls.tearDownClass() raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) - cls.sqlCtx = HiveContext._createForTesting(cls.sc) + cls.spark = HiveContext._createForTesting(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] cls.df = cls.sc.parallelize(cls.testData).toDF() @@ -1619,45 +1781,45 @@ def test_save_and_load_table(self): tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) - actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json") + actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json") self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json", - schema=schema, path=tmpPath, - noUse="this options will not be used") + actual = self.spark.createExternalTable("externalJsonTable", source="json", + schema=schema, path=tmpPath, + noUse="this options will not be used") self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) self.assertEqual(sorted(df.select("value").collect()), - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - self.sqlCtx.sql("DROP TABLE savedJsonTable") - self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") - defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + defaultDataSourceName = self.spark.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") - actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath) self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("DROP TABLE savedJsonTable") - self.sqlCtx.sql("DROP TABLE externalJsonTable") - self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) def test_window_functions(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) w = Window.partitionBy("value").orderBy("key") from pyspark.sql import functions as F sel = df.select(df.value, df.key, @@ -1679,7 +1841,7 @@ def test_window_functions(self): self.assertEqual(tuple(r), ex[:len(r)]) def test_window_functions_without_partitionBy(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) w = Window.orderBy("key", df.value) from pyspark.sql import functions as F sel = df.select(df.value, df.key, @@ -1701,7 +1863,7 @@ def test_window_functions_without_partitionBy(self): self.assertEqual(tuple(r), ex[:len(r)]) def test_collect_functions(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions self.assertEqual( @@ -1717,6 +1879,24 @@ def test_collect_functions(self): sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), ["1", "2", "2", "2"]) + def test_limit_and_take(self): + df = self.spark.range(1, 1000, numPartitions=10) + + def assert_runs_only_one_job_stage_and_task(job_group_name, f): + tracker = self.sc.statusTracker() + self.sc.setJobGroup(job_group_name, description="") + f() + jobs = tracker.getJobIdsForGroup(job_group_name) + self.assertEqual(1, len(jobs)) + stages = tracker.getJobInfo(jobs[0]).stageIds + self.assertEqual(1, len(stages)) + self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks) + + # Regression test for SPARK-10731: take should delegate to Scala implementation + assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1)) + # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n) + assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect()) + if __name__ == "__main__": from pyspark.sql.tests import * diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f7cd4b80ca91d..b765472d6edbc 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -27,7 +27,7 @@ if sys.version >= "3": long = int - unicode = str + basestring = unicode = str from py4j.protocol import register_input_converter from py4j.java_gateway import JavaClass @@ -401,6 +401,7 @@ def __init__(self, name, dataType, nullable=True, metadata=None): False """ assert isinstance(dataType, DataType), "dataType should be DataType" + assert isinstance(name, basestring), "field name should be string" if not isinstance(name, str): name = name.encode('utf-8') self.name = name @@ -485,8 +486,8 @@ def add(self, field, data_type=None, nullable=True, metadata=None): DataType object. >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - >>> struct2 = StructType([StructField("f1", StringType(), True),\ - StructField("f2", StringType(), True, None)]) + >>> struct2 = StructType([StructField("f1", StringType(), True), \\ + ... StructField("f2", StringType(), True, None)]) >>> struct1 == struct2 True >>> struct1 = StructType().add(StructField("f1", StringType(), True)) @@ -581,6 +582,8 @@ def toInternal(self, obj): else: if isinstance(obj, dict): return tuple(obj.get(n) for n in self.names) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + return tuple(obj[n] for n in self.names) elif isinstance(obj, (list, tuple)): return tuple(obj) elif hasattr(obj, "__dict__"): @@ -647,10 +650,13 @@ def _cachedSqlType(cls): return cls._cached_sql_type def toInternal(self, obj): - return self._cachedSqlType().toInternal(self.serialize(obj)) + if obj is not None: + return self._cachedSqlType().toInternal(self.serialize(obj)) def fromInternal(self, obj): - return self.deserialize(self._cachedSqlType().fromInternal(obj)) + v = self._cachedSqlType().fromInternal(obj) + if v is not None: + return self.deserialize(v) def serialize(self, obj): """ @@ -782,9 +788,10 @@ def _parse_struct_fields_string(s): def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals - to `DataType.simpleString`, except that top level struct type can omit the `struct<>` and - atomic types use `typeName()` as their format, e.g. use `byte` instead of `tinyint` for - ByteType. We can also use `int` as a short name for IntegerType. + to :class:`DataType.simpleString`, except that top level struct type can omit + the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead + of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name + for :class:`IntegerType`. >>> _parse_datatype_string("int ") IntegerType @@ -1045,7 +1052,7 @@ def _need_converter(dataType): def _create_converter(dataType): - """Create an converter to drop the names of fields in obj """ + """Create a converter to drop the names of fields in obj """ if not _need_converter(dataType): return lambda x: x @@ -1238,7 +1245,7 @@ def _infer_schema_type(obj, dataType): TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), - StructType: (tuple, list), + StructType: (tuple, list, dict), } @@ -1309,10 +1316,10 @@ def _verify_type(obj, dataType, nullable=True): assert _type in _acceptable_types, "unknown datatype: %s for object %r" % (dataType, obj) if _type is StructType: - if not isinstance(obj, (tuple, list)): - raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) + # check the type and fields later + pass else: - # subclass of them can not be fromInternald in JVM + # subclass of them can not be fromInternal in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) @@ -1338,11 +1345,25 @@ def _verify_type(obj, dataType, nullable=True): _verify_type(v, dataType.valueType, dataType.valueContainsNull) elif isinstance(dataType, StructType): - if len(obj) != len(dataType.fields): - raise ValueError("Length of object (%d) does not match with " - "length of fields (%d)" % (len(obj), len(dataType.fields))) - for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType, f.nullable) + if isinstance(obj, dict): + for f in dataType.fields: + _verify_type(obj.get(f.name), f.dataType, f.nullable) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + # the order in obj could be different than dataType.fields + for f in dataType.fields: + _verify_type(obj[f.name], f.dataType, f.nullable) + elif isinstance(obj, (tuple, list)): + if len(obj) != len(dataType.fields): + raise ValueError("Length of object (%d) does not match with " + "length of fields (%d)" % (len(obj), len(dataType.fields))) + for v, f in zip(obj, dataType.fields): + _verify_type(v, f.dataType, f.nullable) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + for f in dataType.fields: + _verify_type(d.get(f.name), f.dataType, f.nullable) + else: + raise TypeError("StructType can not accept object %r in type %s" % (obj, type(obj))) # This is used to unpickle a Row from JVM @@ -1359,7 +1380,13 @@ def _create_row(fields, values): class Row(tuple): """ - A row in L{DataFrame}. The fields in it can be accessed like attributes. + A row in L{DataFrame}. + The fields in it can be accessed: + + * like attributes (``row.key``) + * like dictionary values (``row[key]``) + + ``key in row`` will search through row keys. Row can be used to create a row object by using named arguments, the fields will be sorted by names. @@ -1371,6 +1398,10 @@ class Row(tuple): ('Alice', 11) >>> row.name, row.age ('Alice', 11) + >>> 'name' in row + True + >>> 'wrong_key' in row + False Row also can be used to create another Row like class, then it could be used to create Row objects, such as @@ -1378,6 +1409,10 @@ class Row(tuple): >>> Person = Row("name", "age") >>> Person + >>> 'name' in Person + True + >>> 'wrong_key' in Person + False >>> Person("Alice", 11) Row(name='Alice', age=11) """ @@ -1386,19 +1421,17 @@ def __new__(self, *args, **kwargs): if args and kwargs: raise ValueError("Can not use both args " "and kwargs to create Row") - if args: - # create row class or objects - return tuple.__new__(self, args) - - elif kwargs: + if kwargs: # create row objects names = sorted(kwargs.keys()) row = tuple.__new__(self, [kwargs[n] for n in names]) row.__fields__ = names + row.__from_dict__ = True return row else: - raise ValueError("No args or kwargs") + # create row class or objects + return tuple.__new__(self, args) def asDict(self, recursive=False): """ @@ -1431,6 +1464,12 @@ def conv(obj): else: return dict(zip(self.__fields__, self)) + def __contains__(self, item): + if hasattr(self, "__fields__"): + return item in self.__fields__ + else: + return super(Row, self).__contains__(item) + # let object acts like class def __call__(self, *args): """create new Row object""" @@ -1463,7 +1502,7 @@ def __getattr__(self, item): raise AttributeError(item) def __setattr__(self, key, value): - if key != '__fields__': + if key != '__fields__' and key != "__from_dict__": raise Exception("Row is read-only") self.__dict__[key] = value diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index cb172d21f3b1d..2a85ec01bc92a 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -45,9 +45,15 @@ class IllegalArgumentException(CapturedException): """ -class ContinuousQueryException(CapturedException): +class StreamingQueryException(CapturedException): """ - Exception that stopped a :class:`ContinuousQuery`. + Exception that stopped a :class:`StreamingQuery`. + """ + + +class QueryExecutionException(CapturedException): + """ + Failed to execute a query. """ @@ -61,10 +67,14 @@ def deco(*a, **kw): e.java_exception.getStackTrace())) if s.startswith('org.apache.spark.sql.AnalysisException: '): raise AnalysisException(s.split(': ', 1)[1], stackTrace) + if s.startswith('org.apache.spark.sql.catalyst.analysis'): + raise AnalysisException(s.split(': ', 1)[1], stackTrace) if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): raise ParseException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.ContinuousQueryException: '): - raise ContinuousQueryException(s.split(': ', 1)[1], stackTrace) + if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '): + raise StreamingQueryException(s.split(': ', 1)[1], stackTrace) + if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '): + raise QueryExecutionException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) raise diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 2056663872198..59977dcb435a8 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -454,7 +454,9 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio This is more efficient than `invReduceFunc` is None. @param reduceFunc: associative and commutative reduce function - @param invReduceFunc: inverse reduce function of `reduceFunc` + @param invReduceFunc: inverse reduce function of `reduceFunc`; such that for all y, + and invertible x: + `invReduceFunc(reduceFunc(x, y), x) = y` @param windowDuration: width of the window; must be a multiple of this DStream's batching interval @param slideDuration: sliding interval of the window (i.e., the interval after which @@ -606,8 +608,8 @@ def reduceFunc(t, a, b): class TransformedDStream(DStream): """ - TransformedDStream is an DStream generated by an Python function - transforming each RDD of an DStream to another RDDs. + TransformedDStream is a DStream generated by an Python function + transforming each RDD of a DStream to another RDDs. Multiple continuous transformations of DStream can be combined into one transformation. @@ -621,7 +623,7 @@ def __init__(self, prev, func): self._jdstream_val = None # Using type() to avoid folding the functions and compacting the DStreams which is not - # not strictly a object of TransformedDStream. + # not strictly an object of TransformedDStream. # Changed here is to avoid bug in KafkaTransformedDStream when calling offsetRanges(). if (type(prev) is TransformedDStream and not prev.is_cached and not prev.is_checkpointed): diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 02a88699a2886..bf27d8047a753 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -208,13 +208,13 @@ def _printErrorMsg(sc): 1. Include the Kafka library and its dependencies with in the spark-submit command as - $ bin/spark-submit --packages org.apache.spark:spark-streaming-kafka:%s ... + $ bin/spark-submit --packages org.apache.spark:spark-streaming-kafka-0-8:%s ... 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, - Group Id = org.apache.spark, Artifact Id = spark-streaming-kafka-assembly, Version = %s. + Group Id = org.apache.spark, Artifact Id = spark-streaming-kafka-0-8-assembly, Version = %s. Then, include the jar in the spark-submit command as - $ bin/spark-submit --jars ... + $ bin/spark-submit --jars ... ________________________________________________________________________________________________ @@ -228,7 +228,7 @@ class OffsetRange(object): def __init__(self, topic, partition, fromOffset, untilOffset): """ - Create a OffsetRange to represent range of offsets + Create an OffsetRange to represent range of offsets :param topic: Kafka topic name. :param partition: Kafka partition id. :param fromOffset: Inclusive starting offset. @@ -287,6 +287,9 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + def __hash__(self): + return (self._topic, self._partition).__hash__() + class Broker(object): """ diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 148bf7e8ff5ce..5ac007cd598b9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -41,6 +41,9 @@ else: import unittest +if sys.version >= "3": + long = int + from pyspark.context import SparkConf, SparkContext, RDD from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext @@ -1058,7 +1061,6 @@ def test_kafka_direct_stream(self): stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) self._validateStreamResult(sendData, stream) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_from_offset(self): """Test the Python direct Kafka stream API with start offset specified.""" topic = self._randomTopic() @@ -1072,7 +1074,6 @@ def test_kafka_direct_stream_from_offset(self): stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) self._validateStreamResult(sendData, stream) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd(self): """Test the Python direct Kafka RDD API.""" topic = self._randomTopic() @@ -1085,7 +1086,6 @@ def test_kafka_rdd(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) self._validateRddResult(sendData, rdd) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_with_leaders(self): """Test the Python direct Kafka RDD API with leaders.""" topic = self._randomTopic() @@ -1100,7 +1100,6 @@ def test_kafka_rdd_with_leaders(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_get_offsetRanges(self): """Test Python direct Kafka RDD get OffsetRanges.""" topic = self._randomTopic() @@ -1113,7 +1112,6 @@ def test_kafka_rdd_get_offsetRanges(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) self.assertEqual(offsetRanges, rdd.offsetRanges()) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_foreach_get_offsetRanges(self): """Test the Python direct Kafka stream foreachRDD get offsetRanges.""" topic = self._randomTopic() @@ -1138,7 +1136,6 @@ def getOffsetRanges(_, rdd): self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_transform_get_offsetRanges(self): """Test the Python direct Kafka stream transform get offsetRanges.""" topic = self._randomTopic() @@ -1176,7 +1173,6 @@ def test_topic_and_partition_equality(self): self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_transform_with_checkpoint(self): """Test the Python direct Kafka stream transform with checkpoint correctly recovered.""" topic = self._randomTopic() @@ -1225,7 +1221,6 @@ def setup(): finally: shutil.rmtree(tmpdir) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_rdd_message_handler(self): """Test Python direct Kafka RDD MessageHandler.""" topic = self._randomTopic() @@ -1242,7 +1237,6 @@ def getKeyAndDoubleMessage(m): messageHandler=getKeyAndDoubleMessage) self._validateRddResult({"aa": 1, "bb": 1, "cc": 2}, rdd) - @unittest.skipIf(sys.version >= "3", "long type not support") def test_kafka_direct_stream_message_handler(self): """Test the Python direct Kafka stream MessageHandler.""" topic = self._randomTopic() @@ -1357,7 +1351,7 @@ def get_output(_, rdd): dstream.foreachRDD(get_output) ssc.start() - self._utils.sendDatAndEnsureAllDataHasBeenReceived() + self._utils.sendDataAndEnsureAllDataHasBeenReceived() self.wait_for(outputBuffer, self._utils.getTotalEvents()) outputHeaders = [event[0] for event in outputBuffer] @@ -1476,13 +1470,13 @@ def search_jar(dir, name_prefix): def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] - kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") - jars = search_jar(kafka_assembly_dir, "spark-streaming-kafka-assembly") + kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-0-8-assembly") + jars = search_jar(kafka_assembly_dir, "spark-streaming-kafka-0-8-assembly") if not jars: raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/package streaming-kafka-assembly/assembly' or " + "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 97ea39dde05fa..b1e92e13a1fa3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1921,6 +1921,14 @@ def test_parallelize_eager_cleanup(self): post_parallalize_temp_files = os.listdir(sc._temp_dir) self.assertEqual(temp_files, post_parallalize_temp_files) + def test_set_conf(self): + # This is for an internal use case. When there is an existing SparkContext, + # SparkSession's builder needs to set configs into SparkContext's conf. + sc = SparkContext() + sc._conf.set("spark.test.SPARK16224", "SPARK16224") + self.assertEqual(sc._jsc.sc().conf().get("spark.test.SPARK16224"), "SPARK16224") + sc.stop() + def test_stop(self): sc = SparkContext() self.assertNotEqual(SparkContext._active_spark_context, None) diff --git a/repl/pom.xml b/repl/pom.xml index 0f396c9b809bd..00fadf2b73dd6 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml @@ -72,6 +72,10 @@ ${scala.version} + ${jline.groupid} + jline + + org.slf4j jul-to-slf4j @@ -87,7 +91,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} org.apache.xbean @@ -161,13 +165,6 @@ scala-2.10 - - - ${jline.groupid} - jline - ${jline.version} - - diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 6a811adcf9b7a..e017aa42a4c18 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -201,10 +201,10 @@ class SparkILoop( if (Utils.isWindows) { // Strip any URI scheme prefix so we can add the correct path to the classpath // e.g. file:/C:/my/path.jar -> C:/my/path.jar - SparkILoop.getAddedJars.map { jar => new URI(jar).getPath.stripPrefix("/") } + getAddedJars().map { jar => new URI(jar).getPath.stripPrefix("/") } } else { // We need new URI(jar).getPath here for the case that `jar` includes encoded white space (%20). - SparkILoop.getAddedJars.map { jar => new URI(jar).getPath } + getAddedJars().map { jar => new URI(jar).getPath } } // work around for Scala bug val totalClassPath = addedJars.foldLeft( @@ -943,8 +943,6 @@ class SparkILoop( }) private def process(settings: Settings): Boolean = savingContextLoader { - if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - this.settings = settings createInterpreter() @@ -1003,9 +1001,9 @@ class SparkILoop( // NOTE: Must be public for visibility @DeveloperApi - def createSparkContext(): SparkContext = { + def createSparkSession(): SparkSession = { val execUri = System.getenv("SPARK_EXECUTOR_URI") - val jars = SparkILoop.getAddedJars + val jars = getAddedJars() val conf = new SparkConf() .setMaster(getMaster()) .setJars(jars) @@ -1019,22 +1017,18 @@ class SparkILoop( if (execUri != null) { conf.set("spark.executor.uri", execUri) } - sparkContext = new SparkContext(conf) - logInfo("Created spark context..") - Signaling.cancelOnInterrupt(sparkContext) - sparkContext - } - @DeveloperApi - // TODO: don't duplicate this code - def createSparkSession(): SparkSession = { - if (SparkSession.hiveClassesArePresent) { + val builder = SparkSession.builder.config(conf) + val sparkSession = if (SparkSession.hiveClassesArePresent) { logInfo("Creating Spark session with Hive support") - SparkSession.withHiveSupport(sparkContext) + builder.enableHiveSupport().getOrCreate() } else { logInfo("Creating Spark session") - new SparkSession(sparkContext) + builder.getOrCreate() } + sparkContext = sparkSession.sparkContext + Signaling.cancelOnInterrupt(sparkContext) + sparkSession } private def getMaster(): String = { @@ -1064,22 +1058,31 @@ class SparkILoop( @deprecated("Use `process` instead", "2.9.0") private def main(settings: Settings): Unit = process(settings) -} - -object SparkILoop extends Logging { - implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp - private def echo(msg: String) = Console println msg - def getAddedJars: Array[String] = { + @DeveloperApi + def getAddedJars(): Array[String] = { + val conf = new SparkConf().setMaster(getMaster()) val envJars = sys.env.get("ADD_JARS") if (envJars.isDefined) { logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead") } - val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } - val jars = propJars.orElse(envJars).getOrElse("") + val jars = { + val userJars = Utils.getUserJars(conf, isShell = true) + if (userJars.isEmpty) { + envJars.getOrElse("") + } else { + userJars.mkString(",") + } + } Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) } +} + +object SparkILoop extends Logging { + implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp + private def echo(msg: String) = Console println msg + // Designed primarily for use by test code: take a String with a // bunch of code, and prints out a transcript of what it would look // like if you'd just typed it into the repl. @@ -1113,7 +1116,7 @@ object SparkILoop extends Logging { if (settings.classpath.isDefault) settings.classpath.value = sys.props("java.class.path") - getAddedJars.map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_)) + repl.getAddedJars().map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_)) repl process settings } diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index f1febb9497c78..29f63de8a0fa1 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -123,19 +123,14 @@ private[repl] trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" + @transient val spark = org.apache.spark.repl.Main.interp.createSparkSession() @transient val sc = { - val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + val _sc = spark.sparkContext _sc.uiWebUrl.foreach(webUrl => println(s"Spark context Web UI available at ${webUrl}")) println("Spark context available as 'sc' " + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - _sc - } - """) - command(""" - @transient val spark = { - val _session = org.apache.spark.repl.Main.interp.createSparkSession() println("Spark session available as 'spark'.") - _session + _sc } """) command("import org.apache.spark.SparkContext._") diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 547da8f713ac7..26b8600c32c11 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -107,13 +107,13 @@ class ReplSuite extends SparkFunSuite { test("simple foreach with accumulator") { val output = runInterpreter("local", """ - |val accum = sc.accumulator(0) - |sc.parallelize(1 to 10).foreach(x => accum += x) + |val accum = sc.longAccumulator + |sc.parallelize(1 to 10).foreach(x => accum.add(x)) |accum.value """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res1: Int = 55", output) + assertContains("res1: Long = 55", output) } test("external vars") { @@ -233,7 +233,7 @@ class ReplSuite extends SparkFunSuite { } test("SPARK-1199 two instances of same class don't type check.") { - val output = runInterpreter("local-cluster[1,1,1024]", + val output = runInterpreter("local", """ |case class Sum(exp: String, exp2: String) |val a = Sum("A", "B") @@ -305,7 +305,7 @@ class ReplSuite extends SparkFunSuite { } test("SPARK-2632 importing a method from non serializable class and not using it.") { - val output = runInterpreter("local", + val output = runInterpreter("local-cluster[1,1,1024]", """ |class TestClass() { def testMethod = 3 } |val t = new TestClass diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index 8e381ff6ae5a1..5dfe18ad49822 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -22,6 +22,7 @@ import java.io.File import scala.tools.nsc.GenericRunnerSettings import org.apache.spark._ +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.util.Utils @@ -53,9 +54,7 @@ object Main extends Logging { // Visible for testing private[repl] def doMain(args: Array[String], _interp: SparkILoop): Unit = { interp = _interp - val jars = conf.getOption("spark.jars") - .map(_.replace(",", File.pathSeparator)) - .getOrElse("") + val jars = Utils.getUserJars(conf, isShell = true).mkString(File.pathSeparator) val interpArguments = List( "-Yrepl-class-based", "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", @@ -71,35 +70,45 @@ object Main extends Logging { } } - def createSparkContext(): SparkContext = { + def createSparkSession(): SparkSession = { val execUri = System.getenv("SPARK_EXECUTOR_URI") conf.setIfMissing("spark.app.name", "Spark shell") - // SparkContext will detect this configuration and register it with the RpcEnv's - // file server, setting spark.repl.class.uri to the actual URI for executors to - // use. This is sort of ugly but since executors are started as part of SparkContext - // initialization in certain cases, there's an initialization order issue that prevents - // this from being set after SparkContext is instantiated. - .set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) + // SparkContext will detect this configuration and register it with the RpcEnv's + // file server, setting spark.repl.class.uri to the actual URI for executors to + // use. This is sort of ugly but since executors are started as part of SparkContext + // initialization in certain cases, there's an initialization order issue that prevents + // this from being set after SparkContext is instantiated. + conf.set("spark.repl.class.outputDir", outputDir.getAbsolutePath()) if (execUri != null) { conf.set("spark.executor.uri", execUri) } if (System.getenv("SPARK_HOME") != null) { conf.setSparkHome(System.getenv("SPARK_HOME")) } - sparkContext = new SparkContext(conf) - logInfo("Created spark context..") - Signaling.cancelOnInterrupt(sparkContext) - sparkContext - } - def createSparkSession(): SparkSession = { - if (SparkSession.hiveClassesArePresent) { - sparkSession = SparkSession.withHiveSupport(sparkContext) - logInfo("Created Spark session with Hive support") + val builder = SparkSession.builder.config(conf) + if (conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { + if (SparkSession.hiveClassesArePresent) { + // In the case that the property is not set at all, builder's config + // does not have this value set to 'hive' yet. The original default + // behavior is that when there are hive classes, we use hive catalog. + sparkSession = builder.enableHiveSupport().getOrCreate() + logInfo("Created Spark session with Hive support") + } else { + // Need to change it back to 'in-memory' if no hive classes are found + // in the case that the property is set to hive in spark-defaults.conf + builder.config(CATALOG_IMPLEMENTATION.key, "in-memory") + sparkSession = builder.getOrCreate() + logInfo("Created Spark session") + } } else { - sparkSession = new SparkSession(sparkContext) + // In the case that the property is set but not to 'hive', the internal + // default is 'in-memory'. So the sparkSession will use in-memory catalog. + sparkSession = builder.getOrCreate() logInfo("Created Spark session") } + sparkContext = sparkSession.sparkContext + Signaling.cancelOnInterrupt(sparkContext) sparkSession } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index d74b7965316fc..2707b0847aefc 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -36,25 +36,25 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) def initializeSpark() { intp.beQuietDuring { processLine(""" + @transient val spark = if (org.apache.spark.repl.Main.sparkSession != null) { + org.apache.spark.repl.Main.sparkSession + } else { + org.apache.spark.repl.Main.createSparkSession() + } @transient val sc = { - val _sc = org.apache.spark.repl.Main.createSparkContext() + val _sc = spark.sparkContext _sc.uiWebUrl.foreach(webUrl => println(s"Spark context Web UI available at ${webUrl}")) println("Spark context available as 'sc' " + s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") - _sc - } - """) - processLine(""" - @transient val spark = { - val _session = org.apache.spark.repl.Main.createSparkSession() println("Spark session available as 'spark'.") - _session + _sc } """) processLine("import org.apache.spark.SparkContext._") processLine("import spark.implicits._") processLine("import spark.sql") processLine("import org.apache.spark.sql.functions._") + replayCommandStack = Nil // remove above commands from session history. } } @@ -75,7 +75,8 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) echo("Type :help for more information.") } - private val blockedCommands = Set("implicits", "javap", "power", "type", "kind", "reset") + /** Add repl commands that needs to be blocked. e.g. reset */ + private val blockedCommands = Set[String]() /** Standard commands */ lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] = @@ -93,6 +94,12 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) initializeSpark() super.loadFiles(settings) } + + override def resetCommand(line: String): Unit = { + super.resetCommand(line) + initializeSpark() + echo("Note that after :reset, state of SparkSession and SparkContext is unchanged.") + } } object SparkILoop { diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index af82e7a111fa8..f7d7a4f041315 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -21,9 +21,11 @@ import java.io._ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer - import org.apache.commons.lang3.StringEscapeUtils +import org.apache.log4j.{Level, LogManager} import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.sql.SparkSession import org.apache.spark.util.Utils class ReplSuite extends SparkFunSuite { @@ -47,7 +49,8 @@ class ReplSuite extends SparkFunSuite { val oldExecutorClasspath = System.getProperty(CONF_EXECUTOR_CLASSPATH) System.setProperty(CONF_EXECUTOR_CLASSPATH, classpath) - + Main.sparkContext = null + Main.sparkSession = null // causes recreation of SparkContext for each test. Main.conf.set("spark.master", master) Main.doMain(Array("-classpath", classpath), new SparkILoop(in, new PrintWriter(out))) @@ -99,16 +102,62 @@ class ReplSuite extends SparkFunSuite { System.clearProperty("spark.driver.port") } + test("SPARK-15236: use Hive catalog") { + // turn on the INFO log so that it is possible the code will dump INFO + // entry for using "HiveMetastore" + val rootLogger = LogManager.getRootLogger() + val logLevel = rootLogger.getLevel + rootLogger.setLevel(Level.INFO) + try { + Main.conf.set(CATALOG_IMPLEMENTATION.key, "hive") + val output = runInterpreter("local", + """ + |spark.sql("drop table if exists t_15236") + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + // only when the config is set to hive and + // hive classes are built, we will use hive catalog. + // Then log INFO entry will show things using HiveMetastore + if (SparkSession.hiveClassesArePresent) { + assertContains("HiveMetaStore", output) + } else { + // If hive classes are not built, in-memory catalog will be used + assertDoesNotContain("HiveMetaStore", output) + } + } finally { + rootLogger.setLevel(logLevel) + } + } + + test("SPARK-15236: use in-memory catalog") { + val rootLogger = LogManager.getRootLogger() + val logLevel = rootLogger.getLevel + rootLogger.setLevel(Level.INFO) + try { + Main.conf.set(CATALOG_IMPLEMENTATION.key, "in-memory") + val output = runInterpreter("local", + """ + |spark.sql("drop table if exists t_16236") + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertDoesNotContain("HiveMetaStore", output) + } finally { + rootLogger.setLevel(logLevel) + } + } + test("simple foreach with accumulator") { val output = runInterpreter("local", """ - |val accum = sc.accumulator(0) - |sc.parallelize(1 to 10).foreach(x => accum += x) + |val accum = sc.longAccumulator + |sc.parallelize(1 to 10).foreach(x => accum.add(x)) |accum.value """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) - assertContains("res1: Int = 55", output) + assertContains("res1: Long = 55", output) } test("external vars") { @@ -228,7 +277,7 @@ class ReplSuite extends SparkFunSuite { } test("SPARK-1199 two instances of same class don't type check.") { - val output = runInterpreter("local-cluster[1,1,1024]", + val output = runInterpreter("local", """ |case class Sum(exp: String, exp2: String) |val a = Sum("A", "B") @@ -288,7 +337,7 @@ class ReplSuite extends SparkFunSuite { } test("SPARK-2632 importing a method from non serializable class and not using it.") { - val output = runInterpreter("local", + val output = runInterpreter("local-cluster[1,1,1024]", """ |class TestClass() { def testMethod = 3 } |val t = new TestClass @@ -347,6 +396,29 @@ class ReplSuite extends SparkFunSuite { assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output) } + test("replicating blocks of object with class defined in repl") { + val output = runInterpreter("local-cluster[2,1,1024]", + """ + |val timeout = 60000 // 60 seconds + |val start = System.currentTimeMillis + |while(sc.getExecutorStorageStatus.size != 3 && + | (System.currentTimeMillis - start) < timeout) { + | Thread.sleep(10) + |} + |if (System.currentTimeMillis - start >= timeout) { + | throw new java.util.concurrent.TimeoutException("Executors were not up in 60 seconds") + |} + |import org.apache.spark.storage.StorageLevel._ + |case class Foo(i: Int) + |val ret = sc.parallelize((1 to 100).map(Foo), 10).persist(MEMORY_AND_DISK_2) + |ret.count() + |sc.getExecutorStorageStatus.map(s => s.rddBlocksById(ret.id).size).sum + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains(": Int = 20", output) + } + test("line wrapper only initialized once when used as encoder outer scope") { val output = runInterpreter("local", """ diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 4a15d52b570a4..2f07395edf8d1 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -79,13 +79,7 @@ class ExecutorClassLoader( case e: ClassNotFoundException => val classOption = findClassLocally(name) classOption match { - case None => - // If this class has a cause, it will break the internal assumption of Janino - // (the compiler used for Spark SQL code-gen). - // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see - // its behavior will be changed if there is a cause and the compilation - // of generated class will fail. - throw new ClassNotFoundException(name) + case None => throw new ClassNotFoundException(name, e) case Some(a) => a } } diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index 97df433a0b675..b7284487c511d 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -26,5 +26,8 @@ fi export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: -export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.9.2-src.zip:${PYTHONPATH}" +if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then + export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.3-src.zip:${PYTHONPATH}" + export PYSPARK_PYTHONPATH_SET=1 +fi diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 6ab57df409529..59823571124f1 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -162,6 +162,16 @@ run_command() { esac echo "$newpid" > "$pid" + + #Poll for up to 5 seconds for the java process to start + for i in {1..10} + do + if [[ $(ps -p "$newpid" -o comm=) =~ "java" ]]; then + break + fi + sleep 0.5 + done + sleep 2 # Check if the process has died; in that case we'll tail the log so the user can see if [[ ! $(ps -p "$newpid" -o comm=) =~ "java" ]]; then diff --git a/sbin/start-master.sh b/sbin/start-master.sh index ce7f17795997e..d970fcc45e2c1 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -47,8 +47,8 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then SPARK_MASTER_PORT=7077 fi -if [ "$SPARK_MASTER_IP" = "" ]; then - SPARK_MASTER_IP=`hostname` +if [ "$SPARK_MASTER_HOST" = "" ]; then + SPARK_MASTER_HOST=`hostname -f` fi if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then @@ -56,5 +56,5 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then fi "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \ - --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ + --host $SPARK_MASTER_HOST --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ $ORIGINAL_ARGS diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh index 06a966d1c20b4..ef65fb9539146 100755 --- a/sbin/start-mesos-dispatcher.sh +++ b/sbin/start-mesos-dispatcher.sh @@ -34,7 +34,7 @@ if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then fi if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then - SPARK_MESOS_DISPATCHER_HOST=`hostname` + SPARK_MESOS_DISPATCHER_HOST=`hostname -f` fi if [ "$SPARK_MESOS_DISPATCHER_NUM" = "" ]; then diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 5bf2b83b42ce4..7d8871251f81b 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -31,9 +31,9 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then SPARK_MASTER_PORT=7077 fi -if [ "$SPARK_MASTER_IP" = "" ]; then - SPARK_MASTER_IP="`hostname`" +if [ "$SPARK_MASTER_HOST" = "" ]; then + SPARK_MASTER_HOST="`hostname -f`" fi # Launch the slaves -"${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/start-slave.sh" "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" +"${SPARK_HOME}/sbin/slaves.sh" cd "${SPARK_HOME}" \; "${SPARK_HOME}/sbin/start-slave.sh" "spark://$SPARK_MASTER_HOST:$SPARK_MASTER_PORT" diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 270104f85b838..81d57d723a720 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -200,6 +200,7 @@ This file is divided into 3 sections: // scalastyle:off awaitresult Await.result(...) // scalastyle:on awaitresult + If your codes use ThreadLocal and may run in threads created by the user, use ThreadUtils.awaitResultInForkJoinSafely instead. ]]> @@ -210,6 +211,12 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods + + org\.apache\.commons\.lang\. + Use Commons Lang 3 classes (package org.apache.commons.lang3.*) instead + of Commons Lang 2 (package org.apache.commons.lang.*) + + java,scala,3rdParty,spark @@ -244,6 +251,14 @@ This file is divided into 3 sections: Omit braces in case clauses. + + + ^Override$ + override modifier should be used instead of @java.lang.Override. + + + + diff --git a/sql/README.md b/sql/README.md index b0903980a59f3..58e9097ed4db1 100644 --- a/sql/README.md +++ b/sql/README.md @@ -1,83 +1,10 @@ Spark SQL ========= -This module provides support for executing relational queries expressed in either SQL or a LINQ-like Scala DSL. +This module provides support for executing relational queries expressed in either SQL or the DataFrame/Dataset API. Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst's logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. - - -Other dependencies for developers ---------------------------------- -In order to create new hive test cases (i.e. a test suite based on `HiveComparisonTest`), -you will need to setup your development environment based on the following instructions. - -If you are working with Hive 0.12.0, you will need to set several environmental variables as follows. - -``` -export HIVE_HOME="/hive/build/dist" -export HIVE_DEV_HOME="/hive/" -export HADOOP_HOME="/hadoop" -``` - -If you are working with Hive 0.13.1, the following steps are needed: - -1. Download Hive's [0.13.1](https://archive.apache.org/dist/hive/hive-0.13.1) and set `HIVE_HOME` with `export HIVE_HOME=""`. Please do not set `HIVE_DEV_HOME` (See [SPARK-4119](https://issues.apache.org/jira/browse/SPARK-4119)). -2. Set `HADOOP_HOME` with `export HADOOP_HOME=""` -3. Download all Hive 0.13.1a jars (Hive jars actually used by Spark) from [here](http://mvnrepository.com/artifact/org.spark-project.hive) and replace corresponding original 0.13.1 jars in `$HIVE_HOME/lib`. -4. Download [Kryo 2.21 jar](http://mvnrepository.com/artifact/com.esotericsoftware.kryo/kryo/2.21) (Note: 2.22 jar does not work) and [Javolution 5.5.1 jar](http://mvnrepository.com/artifact/javolution/javolution/5.5.1) to `$HIVE_HOME/lib`. -5. This step is optional. But, when generating golden answer files, if a Hive query fails and you find that Hive tries to talk to HDFS or you find weird runtime NPEs, set the following in your test suite... - -``` -val testTempDir = Utils.createTempDir() -// We have to use kryo to let Hive correctly serialize some plans. -sql("set hive.plan.serialization.format=kryo") -// Explicitly set fs to local fs. -sql(s"set fs.default.name=file://$testTempDir/") -// Ask Hive to run jobs in-process as a single map and reduce task. -sql("set mapred.job.tracker=local") -``` - -Using the console -================= -An interactive scala console can be invoked by running `build/sbt hive/console`. -From here you can execute queries with HiveQl and manipulate DataFrame by using DSL. - -```scala -$ build/sbt hive/console - -[info] Starting scala interpreter... -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.dsl._ -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.types._ -Type in expressions to have them evaluated. -Type :help for more information. - -scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.DataFrame = [key: int, value: string] -``` - -Query results are `DataFrames` and can be operated as such. -``` -scala> query.collect() -res0: Array[org.apache.spark.sql.Row] = Array([238,val_238], [86,val_86], [311,val_311], [27,val_27]... -``` - -You can also build further queries on top of these `DataFrames` using the query DSL. -``` -scala> query.where(query("key") > 30).select(avg(query("key"))).collect() -res1: Array[org.apache.spark.sql.Row] = Array([274.79025423728814]) -``` diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 1748fa2778d6a..4b039122e9151 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -55,7 +55,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} org.apache.spark diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 4d5d125ecdd7e..e405d478ca334 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -16,6 +16,30 @@ grammar SqlBase; +@members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is folllowed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } +} + tokens { DELIMITER } @@ -45,7 +69,9 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS tablePropertyList)? #createTableUsing + (OPTIONS tablePropertyList)? + (PARTITIONED BY partitionColumnNames=identifierList)? + bucketSpec? #createTableUsing | createTableHeader tableProvider (OPTIONS tablePropertyList)? (PARTITIONED BY partitionColumnNames=identifierList)? @@ -81,28 +107,22 @@ statement DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions | ALTER VIEW tableIdentifier DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions - | ALTER TABLE tableIdentifier partitionSpec? - SET FILEFORMAT fileFormat #setTableFileFormat | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation - | ALTER TABLE tableIdentifier partitionSpec? - CHANGE COLUMN? oldName=identifier colType - (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn - | ALTER TABLE tableIdentifier partitionSpec? - ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns - | ALTER TABLE tableIdentifier partitionSpec? - REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns - | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? - (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | ALTER TABLE tableIdentifier RECOVER PARTITIONS #recoverPartitions + | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? #dropTable | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable - | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier + | CREATE (OR REPLACE)? TEMPORARY? VIEW (IF NOT EXISTS)? tableIdentifier identifierCommentList? (COMMENT STRING)? (PARTITIONED ON identifierList)? (TBLPROPERTIES tablePropertyList)? AS query #createView + | CREATE (OR REPLACE)? TEMPORARY VIEW + tableIdentifier ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTempViewUsing | ALTER VIEW tableIdentifier AS? query #alterViewQuery | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING (USING resource (',' resource)*)? #createFunction | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction - | EXPLAIN explainOption* statement #explain + | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN)? statement #explain | SHOW TABLES ((FROM | IN) db=identifier)? (LIKE? pattern=STRING)? #showTables | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases @@ -111,27 +131,27 @@ statement | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM | IN) db=identifier)? #showColumns | SHOW PARTITIONS tableIdentifier partitionSpec? #showPartitions - | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions + | SHOW identifier? FUNCTIONS + (LIKE? (qualifiedName | pattern=STRING))? #showFunctions + | SHOW CREATE TABLE tableIdentifier #showCreateTable | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction + | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? tableIdentifier partitionSpec? describeColName? #describeTable - | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase | REFRESH TABLE tableIdentifier #refreshTable - | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable - | UNCACHE TABLE identifier #uncacheTable + | REFRESH .*? #refreshResource + | CACHE LAZY? TABLE tableIdentifier (AS? query)? #cacheTable + | UNCACHE TABLE tableIdentifier #uncacheTable | CLEAR CACHE #clearCache | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE - tableIdentifier partitionSpec? #loadData - | ADD identifier .*? #addResource + tableIdentifier partitionSpec? #loadData + | TRUNCATE TABLE tableIdentifier partitionSpec? #truncateTable + | MSCK REPAIR TABLE tableIdentifier #repairTable + | op=(ADD | LIST) identifier .*? #manageResource | SET ROLE .*? #failNativeCommand | SET .*? #setConfiguration - | kws=unsupportedHiveNativeCommands .*? #failNativeCommand - | hiveNativeCommands #executeNativeCommand - ; - -hiveNativeCommands - : TRUNCATE TABLE tableIdentifier partitionSpec? - (COLUMNS identifierList)? + | RESET #resetConfiguration + | unsupportedHiveNativeCommands .*? #failNativeCommand ; unsupportedHiveNativeCommands @@ -160,7 +180,6 @@ unsupportedHiveNativeCommands | kw1=UNLOCK kw2=DATABASE | kw1=CREATE kw2=TEMPORARY kw3=MACRO | kw1=DROP kw2=TEMPORARY kw3=MACRO - | kw1=MSCK kw2=REPAIR kw3=TABLE | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED @@ -174,6 +193,10 @@ unsupportedHiveNativeCommands | kw1=ALTER kw2=TABLE tableIdentifier kw3=TOUCH | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=ADD kw4=COLUMNS + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CHANGE kw4=COLUMNS? + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS | kw1=START kw2=TRANSACTION | kw1=COMMIT | kw1=ROLLBACK @@ -206,7 +229,7 @@ query ; insertInto - : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)? + : INSERT OVERWRITE TABLE tableIdentifier (partitionSpec (IF NOT EXISTS)?)? | INSERT INTO TABLE? tableIdentifier partitionSpec? ; @@ -255,7 +278,7 @@ tableProperty ; tablePropertyKey - : looseIdentifier ('.' looseIdentifier)* + : identifier ('.' identifier)* | STRING ; @@ -267,22 +290,14 @@ nestedConstantList : '(' constantList (',' constantList)* ')' ; -skewedLocation - : (constant | constantList) EQ STRING - ; - -skewedLocationList - : '(' skewedLocation (',' skewedLocation)* ')' - ; - createFileFormat : STORED AS fileFormat | STORED BY storageHandler ; fileFormat - : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? #tableFileFormat - | identifier #genericFileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING #tableFileFormat + | identifier #genericFileFormat ; storageHandler @@ -376,11 +391,12 @@ setQuantifier ; relation - : left=relation - ((CROSS | joinType) JOIN right=relation joinCriteria? - | NATURAL joinType JOIN right=relation - ) #joinRelation - | relationPrimary #relationDefault + : relationPrimary joinRelation* + ; + +joinRelation + : (CROSS | joinType) JOIN right=relationPrimary joinCriteria? + | NATURAL joinType JOIN right=relationPrimary ; joinType @@ -401,7 +417,8 @@ sample : TABLESAMPLE '(' ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) | (expression sampleType=ROWS) - | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?)) + | sampleType=BYTELENGTH_LITERAL + | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON (identifier | qualifiedName '(' ')'))?)) ')' ; @@ -430,10 +447,11 @@ identifierComment ; relationPrimary - : tableIdentifier sample? (AS? identifier)? #tableName - | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery - | '(' relation ')' sample? (AS? identifier)? #aliasedRelation + : tableIdentifier sample? (AS? strictIdentifier)? #tableName + | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 + | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction ; inlineTable @@ -467,8 +485,8 @@ expression ; booleanExpression - : predicated #booleanDefault - | NOT booleanExpression #logicalNot + : NOT booleanExpression #logicalNot + | predicated #booleanDefault | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary | EXISTS '(' query ')' #exists @@ -501,15 +519,16 @@ valueExpression ; primaryExpression - : constant #constantDefault + : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star | '(' expression (',' expression)+ ')' #rowConstructor - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall | '(' query ')' #subqueryExpression - | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase - | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase - | CAST '(' expression AS dataType ')' #cast + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference @@ -604,30 +623,17 @@ frameBound | expression boundType=(PRECEDING | FOLLOWING) ; - -explainOption - : LOGICAL | FORMATTED | EXTENDED | CODEGEN - ; - -transactionMode - : ISOLATION LEVEL SNAPSHOT #isolationLevel - | READ accessMode=(ONLY | WRITE) #transactionAccessMode - ; - qualifiedName : identifier ('.' identifier)* ; -// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility). -looseIdentifier - : identifier - | FROM - | TO - | TABLE - | WITH +identifier + : strictIdentifier + | ANTI | FULL | INNER | LEFT | SEMI | RIGHT | NATURAL | JOIN | CROSS | ON + | UNION | INTERSECT | EXCEPT ; -identifier +strictIdentifier : IDENTIFIER #unquotedIdentifier | quotedIdentifier #quotedIdentifierAlternative | nonReserved #unquotedIdentifier @@ -638,13 +644,14 @@ quotedIdentifier ; number - : DECIMAL_VALUE #decimalLiteral - | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral - | INTEGER_VALUE #integerLiteral - | BIGINT_LITERAL #bigIntLiteral - | SMALLINT_LITERAL #smallIntLiteral - | TINYINT_LITERAL #tinyIntLiteral - | DOUBLE_LITERAL #doubleLiteral + : MINUS? DECIMAL_VALUE #decimalLiteral + | MINUS? SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; nonReserved @@ -657,23 +664,25 @@ nonReserved | GROUPING | CUBE | ROLLUP | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF - | SET + | SET | RESET | VIEW | REPLACE | IF | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL - | SNAPSHOT | READ | WRITE | ONLY + | START | TRANSACTION | COMMIT | ROLLBACK | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION - | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST - | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT - | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE + | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE + | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT + | DBPROPERTIES | DFS | TRUNCATE | COMPUTE | LIST | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH | ASC | DESC | LIMIT | RENAME | SETS | AT | NULLS | OVERWRITE | ALL | ALTER | AS | BETWEEN | BY | CREATE | DELETE | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE + | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN + | UNBOUNDED | WHEN + | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP ; SELECT: 'SELECT'; @@ -773,18 +782,12 @@ MAP: 'MAP'; STRUCT: 'STRUCT'; COMMENT: 'COMMENT'; SET: 'SET'; +RESET: 'RESET'; DATA: 'DATA'; START: 'START'; TRANSACTION: 'TRANSACTION'; COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; -WORK: 'WORK'; -ISOLATION: 'ISOLATION'; -LEVEL: 'LEVEL'; -SNAPSHOT: 'SNAPSHOT'; -READ: 'READ'; -WRITE: 'WRITE'; -ONLY: 'ONLY'; MACRO: 'MACRO'; IF: 'IF'; @@ -860,8 +863,6 @@ TOUCH: 'TOUCH'; COMPACT: 'COMPACT'; CONCATENATE: 'CONCATENATE'; CHANGE: 'CHANGE'; -FIRST: 'FIRST'; -AFTER: 'AFTER'; CASCADE: 'CASCADE'; RESTRICT: 'RESTRICT'; CLUSTERED: 'CLUSTERED'; @@ -873,10 +874,9 @@ DATABASE: 'DATABASE' | 'SCHEMA'; DATABASES: 'DATABASES' | 'SCHEMAS'; DFS: 'DFS'; TRUNCATE: 'TRUNCATE'; -METADATA: 'METADATA'; -REPLICATION: 'REPLICATION'; ANALYZE: 'ANALYZE'; COMPUTE: 'COMPUTE'; +LIST: 'LIST'; STATISTICS: 'STATISTICS'; PARTITIONED: 'PARTITIONED'; EXTERNAL: 'EXTERNAL'; @@ -887,6 +887,7 @@ LOCK: 'LOCK'; UNLOCK: 'UNLOCK'; MSCK: 'MSCK'; REPAIR: 'REPAIR'; +RECOVER: 'RECOVER'; EXPORT: 'EXPORT'; IMPORT: 'IMPORT'; LOAD: 'LOAD'; @@ -902,6 +903,8 @@ OPTION: 'OPTION'; ANTI: 'ANTI'; LOCAL: 'LOCAL'; INPATH: 'INPATH'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' @@ -920,23 +923,31 @@ TINYINT_LITERAL : DIGIT+ 'Y' ; +BYTELENGTH_LITERAL + : DIGIT+ ('B' | 'K' | 'M' | 'G') + ; + INTEGER_VALUE : DIGIT+ ; DECIMAL_VALUE - : DIGIT+ '.' DIGIT* - | '.' DIGIT+ + : DECIMAL_DIGITS {isValidDecimal()}? ; SCIENTIFIC_DECIMAL_VALUE - : DIGIT+ ('.' DIGIT*)? EXPONENT - | '.' DIGIT+ EXPONENT + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? ; DOUBLE_LITERAL - : - (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? ; IDENTIFIER @@ -947,6 +958,11 @@ BACKQUOTED_IDENTIFIER : '`' ( ~'`' | '``' )* '`' ; +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + fragment EXPONENT : 'E' [+-]? DIGIT+ ; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 02a863b2bb498..6302660548ec1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -298,6 +298,10 @@ public UnsafeMapData getMap(int ordinal) { return map; } + // This `hashCode` computation could consume much processor time for large data. + // If the computation becomes a bottleneck, we can use a light-weight logic; the first fixed bytes + // are used to compute `hashCode` (See `Vector.hashCode`). + // The same issue exists in `UnsafeRow.hashCode`. @Override public int hashCode() { return Murmur3_x86_32.hashUnsafeBytes(baseObject, baseOffset, sizeInBytes, 42); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index dd2f39eb816f2..9027652d57f14 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -31,6 +31,7 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -577,8 +578,12 @@ public boolean equals(Object other) { return (sizeInBytes == o.sizeInBytes) && ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, sizeInBytes); + } else if (!(other instanceof InternalRow)) { + return false; + } else { + throw new IllegalArgumentException( + "Cannot compare UnsafeRow to " + other.getClass().getName()); } - return false; } /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index af61e2011f400..0e4264fe8dfb5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -45,7 +45,13 @@ public BufferHolder(UnsafeRow row) { } public BufferHolder(UnsafeRow row, int initialSize) { - this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields(); + int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); + if (row.numFields() > (Integer.MAX_VALUE - initialSize - bitsetWidthInBytes) / 8) { + throw new UnsupportedOperationException( + "Cannot create BufferHolder for input UnsafeRow because there are " + + "too many fields (number of fields: " + row.numFields() + ")"); + } + this.fixedSize = bitsetWidthInBytes + 8 * row.numFields(); this.buffer = new byte[fixedSize + initialSize]; this.row = row; this.row.pointTo(buffer, buffer.length); @@ -55,10 +61,16 @@ public BufferHolder(UnsafeRow row, int initialSize) { * Grows the buffer by at least neededSize and points the row to the buffer. */ public void grow(int neededSize) { + if (neededSize > Integer.MAX_VALUE - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + Integer.MAX_VALUE); + } final int length = totalSize() + neededSize; if (buffer.length < length) { // This will not happen frequently, because the buffer is re-used. - final byte[] tmp = new byte[length * 2]; + int newLength = length < Integer.MAX_VALUE / 2 ? length * 2 : Integer.MAX_VALUE; + final byte[] tmp = new byte[newLength]; Platform.copyMemory( buffer, Platform.BYTE_ARRAY_OFFSET, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java new file mode 100644 index 0000000000000..d224332d8a6c9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.xml; + +import java.io.IOException; +import java.io.Reader; + +import javax.xml.namespace.QName; +import javax.xml.xpath.XPath; +import javax.xml.xpath.XPathConstants; +import javax.xml.xpath.XPathExpression; +import javax.xml.xpath.XPathExpressionException; +import javax.xml.xpath.XPathFactory; + +import org.w3c.dom.Node; +import org.w3c.dom.NodeList; +import org.xml.sax.InputSource; + +/** + * Utility class for all XPath UDFs. Each UDF instance should keep an instance of this class. + * + * This is based on Hive's UDFXPathUtil implementation. + */ +public class UDFXPathUtil { + private XPath xpath = XPathFactory.newInstance().newXPath(); + private ReusableStringReader reader = new ReusableStringReader(); + private InputSource inputSource = new InputSource(reader); + private XPathExpression expression = null; + private String oldPath = null; + + public Object eval(String xml, String path, QName qname) throws XPathExpressionException { + if (xml == null || path == null || qname == null) { + return null; + } + + if (xml.length() == 0 || path.length() == 0) { + return null; + } + + if (!path.equals(oldPath)) { + try { + expression = xpath.compile(path); + } catch (XPathExpressionException e) { + throw new RuntimeException("Invalid XPath '" + path + "'" + e.getMessage(), e); + } + oldPath = path; + } + + if (expression == null) { + return null; + } + + reader.set(xml); + try { + return expression.evaluate(inputSource, qname); + } catch (XPathExpressionException e) { + throw new RuntimeException("Invalid XML document: " + e.getMessage() + "\n" + xml, e); + } + } + + public Boolean evalBoolean(String xml, String path) throws XPathExpressionException { + return (Boolean) eval(xml, path, XPathConstants.BOOLEAN); + } + + public String evalString(String xml, String path) throws XPathExpressionException { + return (String) eval(xml, path, XPathConstants.STRING); + } + + public Double evalNumber(String xml, String path) throws XPathExpressionException { + return (Double) eval(xml, path, XPathConstants.NUMBER); + } + + public Node evalNode(String xml, String path) throws XPathExpressionException { + return (Node) eval(xml, path, XPathConstants.NODE); + } + + public NodeList evalNodeList(String xml, String path) throws XPathExpressionException { + return (NodeList) eval(xml, path, XPathConstants.NODESET); + } + + /** + * Reusable, non-threadsafe version of {@link java.io.StringReader}. + */ + public static class ReusableStringReader extends Reader { + + private String str = null; + private int length = -1; + private int next = 0; + private int mark = 0; + + public ReusableStringReader() { + } + + public void set(String s) { + this.str = s; + this.length = s.length(); + this.mark = 0; + this.next = 0; + } + + /** Check to make sure that the stream has not been closed */ + private void ensureOpen() throws IOException { + if (str == null) { + throw new IOException("Stream closed"); + } + } + + @Override + public int read() throws IOException { + ensureOpen(); + if (next >= length) { + return -1; + } + return str.charAt(next++); + } + + @Override + public int read(char[] cbuf, int off, int len) throws IOException { + ensureOpen(); + if ((off < 0) || (off > cbuf.length) || (len < 0) + || ((off + len) > cbuf.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } else if (len == 0) { + return 0; + } + if (next >= length) { + return -1; + } + int n = Math.min(length - next, len); + str.getChars(next, next + n, cbuf, off); + next += n; + return n; + } + + @Override + public long skip(long ns) throws IOException { + ensureOpen(); + if (next >= length) { + return 0; + } + // Bound skip by beginning and end of the source + long n = Math.min(length - next, ns); + n = Math.max(-next, n); + next += n; + return n; + } + + @Override + public boolean ready() throws IOException { + ensureOpen(); + return true; + } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public void mark(int readAheadLimit) throws IOException { + if (readAheadLimit < 0) { + throw new IllegalArgumentException("Read-ahead limit < 0"); + } + ensureOpen(); + mark = next; + } + + @Override + public void reset() throws IOException { + ensureOpen(); + next = mark; + } + + @Override + public void close() { + str = null; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 8d9906da7e461..c9a1f2293a6c6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -74,6 +74,8 @@ public UnsafeExternalRowSorter( prefixComparator, /* initialSize */ 4096, pageSizeBytes, + SparkEnv.get().conf().getLong("spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), canUseRadixSort ); } @@ -108,6 +110,13 @@ public long getPeakMemoryUsage() { return sorter.getPeakMemoryUsedBytes(); } + /** + * @return the total amount of time spent sorting data (in-memory only). + */ + public long getSortTimeNanos() { + return sorter.getSortTimeNanos(); + } + private void cleanupResources() { sorter.cleanupResources(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java new file mode 100644 index 0000000000000..41e2582921198 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/OutputMode.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.InternalOutputModes; + +/** + * :: Experimental :: + * + * OutputMode is used to what data will be written to a streaming sink when there is + * new data available in a streaming DataFrame/Dataset. + * + * @since 2.0.0 + */ +@Experimental +public class OutputMode { + + /** + * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be + * written to the sink. This output mode can be only be used in queries that do not + * contain any aggregation. + * + * @since 2.0.0 + */ + public static OutputMode Append() { + return InternalOutputModes.Append$.MODULE$; + } + + /** + * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates. This output mode can only be used in queries + * that contain aggregations. + * + * @since 2.0.0 + */ + public static OutputMode Complete() { + return InternalOutputModes.Complete$.MODULE$; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala index d2003fd6892e1..d3ee8f7a77445 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -32,8 +32,9 @@ class AnalysisException protected[sql] ( val message: String, val line: Option[Int] = None, val startPosition: Option[Int] = None, - val plan: Option[LogicalPlan] = None) - extends Exception with Serializable { + val plan: Option[LogicalPlan] = None, + val cause: Option[Throwable] = None) + extends Exception(message, cause.orNull) with Serializable { def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { val newException = new AnalysisException(message, line, startPosition) @@ -42,6 +43,13 @@ class AnalysisException protected[sql] ( } override def getMessage: String = { + val planAnnotation = plan.map(p => s";\n$p").getOrElse("") + getSimpleMessage + planAnnotation + } + + // Outputs an exception without the logical plan. + // For testing only + def getSimpleMessage: String = { val lineAnnotation = line.map(l => s" line $l").getOrElse("") val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("") s"$message;$lineAnnotation$positionAnnotation" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index ffa694fcdc07a..501c1304dbedb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -29,13 +29,13 @@ import org.apache.spark.sql.types._ * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * * == Scala == - * Encoders are generally created automatically through implicits from a `SQLContext`, or can be + * Encoders are generally created automatically through implicits from a `SparkSession`, or can be * explicitly created by calling static methods on [[Encoders]]. * * {{{ - * import sqlContext.implicits._ + * import spark.implicits._ * - * val ds = Seq(1, 2, 3).toDS() // implicitly provided (sqlContext.implicits.newIntEncoder) + * val ds = Seq(1, 2, 3).toDS() // implicitly provided (spark.implicits.newIntEncoder) * }}} * * == Java == @@ -69,7 +69,7 @@ import org.apache.spark.sql.types._ @Experimental @implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + "(Int, String, etc) and Product types (case classes) are supported by importing " + - "sqlContext.implicits._ Support for serializing other types will be added in future " + + "spark.implicits._ Support for serializing other types will be added in future " + "releases.") trait Encoder[T] extends Serializable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index 3f4df704db755..e72f67c48a296 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -23,8 +23,10 @@ import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Cast} +import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** @@ -207,7 +209,9 @@ object Encoders { BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), deserializer = DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + Cast(GetColumnByOrdinal(0, BinaryType), BinaryType), + classTag[T], + kryo = useKryo), clsTag = classTag[T] ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala new file mode 100644 index 0000000000000..153f9f57faf42 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/InternalOutputModes.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.streaming.OutputMode + +/** + * Internal helper class to generate objects representing various [[OutputMode]]s, + */ +private[sql] object InternalOutputModes { + + /** + * OutputMode in which only the new rows in the streaming DataFrame/Dataset will be + * written to the sink. This output mode can be only be used in queries that do not + * contain any aggregation. + */ + case object Append extends OutputMode + + /** + * OutputMode in which all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates. This output mode can only be used in queries + * that contain aggregations. + */ + case object Complete extends OutputMode + + /** + * OutputMode in which only the rows in the streaming DataFrame/Dataset that were updated will be + * written to the sink every time these is some updates. This output mode can only be used in + * queries that contain aggregations. + */ + case object Update extends OutputMode +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 726291b96c29d..9b0a3c7e4ce5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -151,7 +151,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row (or Product) + * StructType -> org.apache.spark.sql.Row * }}} */ def apply(i: Int): Any = get(i) @@ -176,7 +176,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row (or Product) + * StructType -> org.apache.spark.sql.Row * }}} */ def get(i: Int): Any @@ -300,19 +300,11 @@ trait Row extends Serializable { getMap[K, V](i).asJava /** - * Returns the value at position i of struct type as an [[Row]] object. + * Returns the value at position i of struct type as a [[Row]] object. * * @throws ClassCastException when data type does not match. */ - def getStruct(i: Int): Row = { - // Product and Row both are recognized as StructType in a Row - val t = get(i) - if (t.isInstanceOf[Product]) { - Row.fromTuple(t.asInstanceOf[Product]) - } else { - t.asInstanceOf[Row] - } - } + def getStruct(i: Int): Row = getAs[Row](i) /** * Returns the value at position i. @@ -464,13 +456,13 @@ trait Row extends Serializable { def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) /** - * Returns the value of a given fieldName. + * Returns the value at position i. * * @throws UnsupportedOperationException when schema is not defined. * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ private def getAnyValAs[T <: AnyVal](i: Int): T = - if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null") + if (isNullAt(i)) throw new NullPointerException(s"Value at index $i is null") else getAs[T](i) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9bfc381639140..f542f5cf40506 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} +import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -326,6 +327,7 @@ object CatalystTypeConverters { val decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) + case d: JavaBigInteger => Decimal(d) case d: Decimal => d } if (decimal.changePrecision(dataType.precision, dataType.scale)) { @@ -380,7 +382,7 @@ object CatalystTypeConverters { * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + def createToCatalystConverter(dataType: DataType): Any => Any = { if (isPrimitive(dataType)) { // Although the `else` branch here is capable of handling inbound conversion of primitives, // we add some special-case handling for those types here. The motivation for this relates to @@ -407,7 +409,7 @@ object CatalystTypeConverters { * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + def createToScalaConverter(dataType: DataType): Any => Any = { if (isPrimitive(dataType)) { identity } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 6f9fbbbead474..b3a233ae390ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -25,8 +25,9 @@ import scala.language.existentials import com.google.common.reflect.TypeToken -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -63,6 +64,11 @@ object JavaTypeInference { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + case c: Class[_] if UDTRegistration.exists(c.getName) => + val udt = UDTRegistration.getUDTFor(c.getName).get.newInstance() + .asInstanceOf[UserDefinedType[_ >: Null]] + (udt, true) + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true) @@ -83,6 +89,7 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) + case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) @@ -170,8 +177,8 @@ object JavaTypeInference { .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) - /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + /** Returns the current path or `GetColumnByOrdinal`. */ + def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1)) typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index be0d75a8304e9..dd36468583b1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** @@ -30,7 +30,7 @@ import org.apache.spark.unsafe.types.UTF8String * for classes whose fields are entirely defined by constructor params but should not be * case classes. */ -private[sql] trait DefinedByConstructorParams +trait DefinedByConstructorParams /** @@ -72,6 +72,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.ByteTpe => ByteType case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT case _ => val className = getClassNameFromType(tpe) @@ -113,8 +114,8 @@ object ScalaReflection extends ScalaReflection { * Returns true if the value of this data type is same between internal and external. */ def isNativeType(dt: DataType): Boolean = dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => true + case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType | CalendarIntervalType => true case _ => false } @@ -155,17 +156,17 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => GetStructField(p, ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + .getOrElse(GetColumnByOrdinal(ordinal, dataType)) upCastToExpectedType(newPath, dataType, walkedTypePath) } - /** Returns the current path or `BoundReference`. */ + /** Returns the current path or `GetColumnByOrdinal`. */ def getPath: Expression = { val dataType = schemaFor(tpe).dataType if (path.isDefined) { path.get } else { - upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) } } @@ -189,7 +190,6 @@ object ScalaReflection extends ScalaReflection { case _ => UpCast(expr, expected, walkedTypePath) } - val className = getClassNameFromType(tpe) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -239,16 +239,14 @@ object ScalaReflection extends ScalaReflection { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil, - propagateNull = true) + getPath :: Nil) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil, - propagateNull = true) + getPath :: Nil) case t if t <:< localTypeOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String])) @@ -259,6 +257,12 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[BigDecimal] => Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + case t if t <:< localTypeOf[java.math.BigInteger] => + Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger])) + + case t if t <:< localTypeOf[scala.math.BigInt] => + Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt])) + case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -342,6 +346,23 @@ object ScalaReflection extends ScalaReflection { "toScalaMap", keyData :: valueData :: Nil) + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass)) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) + case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -382,23 +403,6 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } - - case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) - - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } } @@ -417,7 +421,7 @@ object ScalaReflection extends ScalaReflection { def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + val walkedTypePath = s"""- root class: "$clsName"""" :: Nil serializerFor(inputObject, tpe, walkedTypePath) match { case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) @@ -431,202 +435,170 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = silentSchemaFor(elementType) - if (isNativeType(externalDataType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - val clsName = getClassNameFromType(elementType) - val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(serializerFor(_, elementType, newPath), input, externalDataType) + dataTypeFor(elementType) match { + case dt: ObjectType => + val clsName = getClassNameFromType(elementType) + val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + MapObjects(serializerFor(_, elementType, newPath), input, dt) + + case dt => + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dt, schemaFor(elementType).nullable)) } } - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - tpe match { - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - optType match { - // For primitive types we must manually unbox the value of the object. - case t if t <:< definitions.IntTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), - "intValue", - IntegerType) - case t if t <:< definitions.LongTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), - "longValue", - LongType) - case t if t <:< definitions.DoubleTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), - "doubleValue", - DoubleType) - case t if t <:< definitions.FloatTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), - "floatValue", - FloatType) - case t if t <:< definitions.ShortTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), - "shortValue", - ShortType) - case t if t <:< definitions.ByteTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), - "byteValue", - ByteType) - case t if t <:< definitions.BooleanTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), - "booleanValue", - BooleanType) - - // For non-primitives, we can just extract the object from the Option and then recurse. - case other => - val className = getClassNameFromType(optType) - val newPath = s"""- option value class: "$className"""" +: walkedTypePath - - val optionObjectType: DataType = other match { - // Special handling is required for arrays, as getClassFromType() will fail - // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to - // the Java type "[I". - case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t) - case cls => ObjectType(getClassFromType(cls)) - } - val unwrapped = UnwrapOption(optionObjectType, inputObject) - - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - serializerFor(unwrapped, optType, newPath)) + tpe match { + case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject + + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + val className = getClassNameFromType(optType) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) + serializerFor(unwrapped, optType, newPath) + + // Since List[_] also belongs to localTypeOf[Product], we put this case before + // "case t if definedByConstructorParams(t)" to make sure it will match to the + // case "localTypeOf[Seq[_]]" + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keys = + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigInteger] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[scala.math.BigInt] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt, inputObject :: Nil) + + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass)) + Invoke(obj, "serialize", udt, inputObject :: Nil) + + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + if (javaKeywords.contains(fieldName)) { + throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + + "cannot be used as field name\n" + walkedTypePath.mkString("\n")) } - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - val clsName = getClassNameFromType(fieldType) - val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil - }) - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) - - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "fromJavaDate", - inputObject :: Nil) - - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - - case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) - - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) - - case other => - throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) - } + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + + case other => + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } + } /** @@ -656,6 +628,15 @@ object ScalaReflection extends ScalaReflection { constructParams(t).map(_.name.toString) } + /** + * Returns the parameter values for the primary constructor of this class. + */ + def getConstructorParameterValues(obj: DefinedByConstructorParams): Seq[AnyRef] = { + getConstructorParameterNames(obj.getClass).map { name => + obj.getClass.getMethod(name).invoke(obj) + } + } + /* * Retrieves the runtime class corresponding to the provided type. */ @@ -672,18 +653,6 @@ object ScalaReflection extends ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) - /** - * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. - * - * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return - * `NullType` silently instead. - */ - def silentSchemaFor(tpe: `Type`): Schema = try { - schemaFor(tpe) - } catch { - case _: UnsupportedOperationException => Schema(NullType, nullable = true) - } - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { tpe match { @@ -711,19 +680,16 @@ object ScalaReflection extends ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - Schema(StructType( - params.map { case (fieldName, fieldType) => - val Schema(dataType, nullable) = schemaFor(fieldType) - StructField(fieldName, dataType, nullable) - }), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if t <:< localTypeOf[java.math.BigInteger] => + Schema(DecimalType.BigIntDecimal, nullable = true) + case t if t <:< localTypeOf[scala.math.BigInt] => + Schema(DecimalType.BigIntDecimal, nullable = true) case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) @@ -739,6 +705,13 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + Schema(StructType( + params.map { case (fieldName, fieldType) => + val Schema(dataType, nullable) = schemaFor(fieldType) + StructField(fieldName, dataType, nullable) + }), nullable = true) case other => throw new UnsupportedOperationException(s"Schema for type $other is not supported") } @@ -747,10 +720,16 @@ object ScalaReflection extends ScalaReflection { /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ - private[sql] def definedByConstructorParams(tpe: Type): Boolean = { + def definedByConstructorParams(tpe: Type): Boolean = { tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams] } + private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch", + "char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false", + "final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int", + "interface", "long", "native", "new", "null", "package", "private", "protected", "public", + "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", + "throws", "transient", "true", "try", "void", "volatile", "while") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala new file mode 100644 index 0000000000000..ec56fe7729c2a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec + +/** + * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception + * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. + */ +class DatabaseAlreadyExistsException(db: String) + extends AnalysisException(s"Database '$db' already exists") + +class TableAlreadyExistsException(db: String, table: String) + extends AnalysisException(s"Table or view '$table' already exists in database '$db'") + +class TempTableAlreadyExistsException(table: String) + extends AnalysisException(s"Temporary table '$table' already exists") + +class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec) + extends AnalysisException( + s"Partition already exists in table '$table' database '$db':\n" + spec.mkString("\n")) + +class PartitionsAlreadyExistException(db: String, table: String, specs: Seq[TablePartitionSpec]) + extends AnalysisException( + s"The following partitions already exists in table '$table' database '$db':\n" + + specs.mkString("\n===\n")) + +class FunctionAlreadyExistsException(db: String, func: String) + extends AnalysisException(s"Function '$func' already exists in database '$db'") + +class TempFunctionAlreadyExistsException(func: String) + extends AnalysisException(s"Temporary function '$func' already exists") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e98036a970d44..617f3e062465c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,20 +22,22 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} -import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogRelation, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.objects.NewInstance +import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.planning.IntegerIndex import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ /** - * A trivial [[Analyzer]] with an dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]]. + * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]]. * Used for testing when all relations are already filled in and the analyzer needs only * to resolve attribute references. */ @@ -43,7 +45,9 @@ object SimpleAnalyzer extends Analyzer( new SessionCatalog( new InMemoryCatalog, EmptyFunctionRegistry, - new SimpleCatalystConf(caseSensitiveAnalysis = true)), + new SimpleCatalystConf(caseSensitiveAnalysis = true)) { + override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean) {} + }, new SimpleCatalystConf(caseSensitiveAnalysis = true)) /** @@ -82,6 +86,7 @@ class Analyzer( WindowsSubstitution, EliminateUnions), Batch("Resolution", fixedPoint, + ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: ResolveDeserializer :: @@ -91,6 +96,7 @@ class Analyzer( ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: ResolveMissingReferences :: + ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: ResolveAliases :: @@ -102,12 +108,15 @@ class Analyzer( GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: - HiveTypeCoercion.typeCoercionRules ++ + ResolveInlineTables :: + TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, HandleNullInputsForUDF), + Batch("FixNullability", Once, + FixNullability), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -172,14 +181,16 @@ class Analyzer( private def assignAliases(exprs: Seq[NamedExpression]) = { exprs.zipWithIndex.map { case (expr, i) => - expr.transformUp { case u @ UnresolvedAlias(child, optionalAliasName) => + expr.transformUp { case u @ UnresolvedAlias(child, optGenAliasFunc) => child match { case ne: NamedExpression => ne case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() - case e: ExtractValue => Alias(e, usePrettyExpression(e).sql)() - case e => Alias(e, optionalAliasName.getOrElse(usePrettyExpression(e).sql))() + case e: ExtractValue => Alias(e, toPrettySQL(e))() + case e if optGenAliasFunc.isDefined => + Alias(child, optGenAliasFunc.get.apply(e))() + case e => Alias(e, toPrettySQL(e))() } } }.asInstanceOf[Seq[NamedExpression]] @@ -237,7 +248,7 @@ class Analyzer( }.isDefined } - private def hasGroupingFunction(e: Expression): Boolean = { + private[analysis] def hasGroupingFunction(e: Expression): Boolean = { e.collectFirst { case g: Grouping => g case g: GroupingID => g @@ -363,43 +374,68 @@ class Analyzer( object ResolvePivot extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p + case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) + | !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => val singleAgg = aggregates.size == 1 - val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => - def ifExpr(expr: Expression) = { - If(EqualTo(pivotColumn, value), expr, Literal(null)) + def outputName(value: Literal, aggregate: Expression): String = { + if (singleAgg) value.toString else value + "_" + aggregate.sql + } + if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) { + // Since evaluating |pivotValues| if statements for each input row can get slow this is an + // alternate plan that instead uses two steps of aggregation. + val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) + val namedPivotCol = pivotColumn match { + case n: NamedExpression => n + case _ => Alias(pivotColumn, "__pivot_col")() + } + val bigGroup = groupByExprs :+ namedPivotCol + val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) + val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow)) + val pivotAggs = namedAggExps.map { a => + Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues) + .toAggregateExpression() + , "__pivot_" + a.sql)() + } + val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg) + val pivotAggAttribute = pivotAggs.map(_.toAttribute) + val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => + aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => + Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() + } } - aggregates.map { aggregate => - val filteredAggregate = aggregate.transformDown { - // Assumption is the aggregate function ignores nulls. This is true for all current - // AggregateFunction's with the exception of First and Last in their default mode - // (which we handle) and possibly some Hive UDAF's. - case First(expr, _) => - First(ifExpr(expr), Literal(true)) - case Last(expr, _) => - Last(ifExpr(expr), Literal(true)) - case a: AggregateFunction => - a.withNewChildren(a.children.map(ifExpr)) - }.transform { - // We are duplicating aggregates that are now computing a different value for each - // pivot value. - // TODO: Don't construct the physical container until after analysis. - case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) + Project(groupByExprs ++ pivotOutputs, secondAgg) + } else { + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => + def ifExpr(expr: Expression) = { + If(EqualTo(pivotColumn, value), expr, Literal(null)) } - if (filteredAggregate.fastEquals(aggregate)) { - throw new AnalysisException( - s"Aggregate expression required for pivot, found '$aggregate'") + aggregates.map { aggregate => + val filteredAggregate = aggregate.transformDown { + // Assumption is the aggregate function ignores nulls. This is true for all current + // AggregateFunction's with the exception of First and Last in their default mode + // (which we handle) and possibly some Hive UDAF's. + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) + case a: AggregateFunction => + a.withNewChildren(a.children.map(ifExpr)) + }.transform { + // We are duplicating aggregates that are now computing a different value for each + // pivot value. + // TODO: Don't construct the physical container until after analysis. + case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) + } + if (filteredAggregate.fastEquals(aggregate)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$aggregate'") + } + Alias(filteredAggregate, outputName(value, aggregate))() } - val name = if (singleAgg) value.toString else value + "_" + aggregate.sql - Alias(filteredAggregate, name)() } + Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child) } - val newGroupByExprs = groupByExprs.map { - case UnresolvedAlias(e, _) => e - case e => e - } - Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) } } @@ -417,7 +453,7 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => + case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) case u: UnresolvedRelation => val table = u.tableIdentifier @@ -457,6 +493,10 @@ class Analyzer( val newVersion = oldVersion.newInstance() (oldVersion, newVersion) + case oldVersion: SerializeFromObject + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))) + // Handle projects that create conflicting aliases. case oldVersion @ Project(projectList, _) if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => @@ -509,8 +549,7 @@ class Analyzer( case a: Aggregate if containsStar(a.aggregateExpressions) => if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) { failAnalysis( - "Group by position: star is not allowed to use in the select list " + - "when using ordinals in group by") + "Star (*) is not allowed in select list when GROUP BY ordinal position is used") } else { a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) } @@ -649,6 +688,7 @@ class Analyzer( // Else, throw exception. try { expr transformUp { + case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) case u @ UnresolvedAttribute(nameParts) => withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } case UnresolvedExtractValue(child, fieldName) if child.resolved => @@ -684,9 +724,9 @@ class Analyzer( if (index > 0 && index <= child.output.size) { SortOrder(child.output(index - 1), direction) } else { - throw new UnresolvedException(s, - s"Order/sort By position: $index does not exist " + - s"The Select List is indexed from 1 to ${child.output.size}") + s.failAnalysis( + s"ORDER BY position $index is not in select list " + + s"(valid range is [1, ${child.output.size}])") } case o => o } @@ -698,17 +738,18 @@ class Analyzer( if conf.groupByOrdinal && aggs.forall(_.resolved) && groups.exists(IntegerIndex.unapply(_).nonEmpty) => val newGroups = groups.map { - case IntegerIndex(index) if index > 0 && index <= aggs.size => + case ordinal @ IntegerIndex(index) if index > 0 && index <= aggs.size => aggs(index - 1) match { case e if ResolveAggregateFunctions.containsAggregate(e) => - throw new UnresolvedException(a, - s"Group by position: the '$index'th column in the select contains an " + - s"aggregate function: ${e.sql}. Aggregate functions are not allowed in GROUP BY") + ordinal.failAnalysis( + s"GROUP BY position $index is an aggregate function, and " + + "aggregate functions are not allowed in GROUP BY") case o => o } - case IntegerIndex(index) => - throw new UnresolvedException(a, - s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.") + case ordinal @ IntegerIndex(index) => + ordinal.failAnalysis( + s"GROUP BY position $index is not in select list " + + s"(valid range is [1, ${aggs.size}])") case o => o } Aggregate(newGroups, aggs, child) @@ -797,6 +838,8 @@ class Analyzer( // attributes that its child might have or could have. val missing = missingAttrs -- g.child.outputSet g.copy(join = true, child = addMissingAttr(g.child, missing)) + case d: Distinct => + throw new AnalysisException(s"Can't add $missingAttrs to $d") case u: UnaryNode => u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil) case other => @@ -932,7 +975,8 @@ class Analyzer( localPredicateReferences -- p.outputSet } - val transformed = sub transformUp { + // Simplify the predicates before pulling them out. + val transformed = BooleanSimplification(sub) transformUp { case f @ Filter(cond, child) => // Find all predicates with an outer reference. val (correlated, local) = splitConjunctivePredicates(cond).partition(containsOuter) @@ -1056,10 +1100,10 @@ class Analyzer( // Step 2: Pull out the predicates if the plan is resolved. if (current.resolved) { // Make sure the resolved query has the required number of output columns. This is only - // needed for IN expressions. + // needed for Scalar and IN subqueries. if (requiredColumns > 0 && requiredColumns != current.output.size) { - failAnalysis(s"The number of fields in the value ($requiredColumns) does not " + - s"match with the number of columns in the subquery (${current.output.size})") + failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + + s"does not match the required number of columns ($requiredColumns)") } // Pullout predicates and construct a new plan. f.tupled(rewriteSubQuery(current, plans)) @@ -1074,8 +1118,11 @@ class Analyzer( */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { + case s @ ScalarSubquery(sub, conditions, exprId) + if sub.resolved && conditions.isEmpty && sub.output.size != 1 => + failAnalysis(s"Scalar subquery must return only one column, but got ${sub.output.size}") case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => - resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) + resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, exprId) => resolveSubQuery(e, plans)(PredicateSubquery(_, _, nullAware = false, exprId)) case In(e, Seq(l @ ListQuery(_, exprId))) if e.resolved => @@ -1164,6 +1211,19 @@ class Analyzer( val alias = Alias(ae, ae.toString)() aggregateExpressions += alias alias.toAttribute + // Grouping functions are handled in the rule [[ResolveGroupingAnalytics]]. + case e: Expression if grouping.exists(_.semanticEquals(e)) && + !ResolveGroupingAnalytics.hasGroupingFunction(e) && + !aggregate.output.exists(_.semanticEquals(e)) => + e match { + case ne: NamedExpression => + aggregateExpressions += ne + ne.toAttribute + case _ => + val alias = Alias(e, e.toString)() + aggregateExpressions += alias + alias.toAttribute + } } // Push the aggregate expressions into the aggregate (if any). @@ -1250,20 +1310,54 @@ class Analyzer( } /** - * Rewrites table generating expressions that either need one or more of the following in order - * to be resolved: - * - concrete attribute references for their output. - * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). + * Extracts [[Generator]] from the projectList of a [[Project]] operator and create [[Generate]] + * operator under [[Project]]. * - * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions - * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an - * [[AnalysisException]] is throw. + * This rule will throw [[AnalysisException]] for following cases: + * 1. [[Generator]] is nested in expressions, e.g. `SELECT explode(list) + 1 FROM tbl` + * 2. more than one [[Generator]] is found in projectList, + * e.g. `SELECT explode(list), explode(list) FROM tbl` + * 3. [[Generator]] is found in other operators that are not [[Project]] or [[Generate]], + * e.g. `SELECT * FROM tbl SORT BY explode(list)` */ - object ResolveGenerate extends Rule[LogicalPlan] { + object ExtractGenerator extends Rule[LogicalPlan] { + private def hasGenerator(expr: Expression): Boolean = { + expr.find(_.isInstanceOf[Generator]).isDefined + } + + private def hasNestedGenerator(expr: NamedExpression): Boolean = expr match { + case UnresolvedAlias(_: Generator, _) => false + case Alias(_: Generator, _) => false + case MultiAlias(_: Generator, _) => false + case other => hasGenerator(other) + } + + private def trimAlias(expr: NamedExpression): Expression = expr match { + case UnresolvedAlias(child, _) => child + case Alias(child, _) => child + case MultiAlias(child, _) => child + case _ => expr + } + + /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ + private object AliasedGenerator { + def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { + case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) + case _ => None + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p: Generate if !p.child.resolved || !p.generator.resolved => p - case g: Generate if !g.resolved => - g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) + case Project(projectList, _) if projectList.exists(hasNestedGenerator) => + val nestedGenerator = projectList.find(hasNestedGenerator).get + throw new AnalysisException("Generators are not supported when it's nested in " + + "expressions, but got: " + toPrettySQL(trimAlias(nestedGenerator))) + + case Project(projectList, _) if projectList.count(hasGenerator) > 1 => + val generators = projectList.filter(hasGenerator).map(trimAlias) + throw new AnalysisException("Only one generator allowed per select clause but found " + + generators.size + ": " + generators.map(toPrettySQL).mkString(", ")) case p @ Project(projectList, child) => // Holds the resolved generator, if one exists in the project list. @@ -1271,11 +1365,9 @@ class Analyzer( val newProjectList = projectList.flatMap { case AliasedGenerator(generator, names) if generator.childrenResolved => - if (resolvedGenerator != null) { - failAnalysis( - s"Only one generator allowed per select but ${resolvedGenerator.nodeName} and " + - s"and ${generator.nodeName} found.") - } + // It's a sanity check, this should not happen as the previous case will throw + // exception earlier. + assert(resolvedGenerator == null, "More than one generator found in SELECT.") resolvedGenerator = Generate( @@ -1283,7 +1375,7 @@ class Analyzer( join = projectList.size > 1, // Only join if there are other expressions in SELECT. outer = false, qualifier = None, - generatorOutput = makeGeneratorOutput(generator, names), + generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), child) resolvedGenerator.generatorOutput @@ -1295,49 +1387,89 @@ class Analyzer( } else { p } + + case g: Generate => g + + case p if p.expressions.exists(hasGenerator) => + throw new AnalysisException("Generators are not supported outside the SELECT clause, but " + + "got: " + p.simpleString) } + } - /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ - private object AliasedGenerator { - def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 => - // If not given the default names, and the TGF with multiple output columns - failAnalysis( - s"""Expect multiple names given for ${g.getClass.getName}, - |but only single name '${name}' specified""".stripMargin) - case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) - case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) - case _ => None - } + /** + * Rewrites table generating expressions that either need one or more of the following in order + * to be resolved: + * - concrete attribute references for their output. + * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). + * + * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions + * that wrap the [[Generator]]. + */ + object ResolveGenerate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case g: Generate if !g.child.resolved || !g.generator.resolved => g + case g: Generate if !g.resolved => + g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) } /** * Construct the output attributes for a [[Generator]], given a list of names. If the list of * names is empty names are assigned from field names in generator. */ - private def makeGeneratorOutput( + private[analysis] def makeGeneratorOutput( generator: Generator, names: Seq[String]): Seq[Attribute] = { - val elementTypes = generator.elementTypes + val elementAttrs = generator.elementSchema.toAttributes - if (names.length == elementTypes.length) { - names.zip(elementTypes).map { - case (name, (t, nullable, _)) => - AttributeReference(name, t, nullable)() + if (names.length == elementAttrs.length) { + names.zip(elementAttrs).map { + case (name, attr) => attr.withName(name) } } else if (names.isEmpty) { - elementTypes.map { - case (t, nullable, name) => AttributeReference(name, t, nullable)() - } + elementAttrs } else { failAnalysis( "The number of aliases supplied in the AS clause does not match the number of columns " + - s"output by the UDTF expected ${elementTypes.size} aliases but got " + + s"output by the UDTF expected ${elementAttrs.size} aliases but got " + s"${names.mkString(",")} ") } } } + /** + * Fixes nullability of Attributes in a resolved LogicalPlan by using the nullability of + * corresponding Attributes of its children output Attributes. This step is needed because + * users can use a resolved AttributeReference in the Dataset API and outer joins + * can change the nullability of an AttribtueReference. Without the fix, a nullable column's + * nullable field can be actually set as non-nullable, which cause illegal optimization + * (e.g., NULL propagation) and wrong answers. + * See SPARK-13484 and SPARK-13801 for the concrete queries of this case. + */ + object FixNullability extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p if !p.resolved => p // Skip unresolved nodes. + case p: LogicalPlan if p.resolved => + val childrenOutput = p.children.flatMap(c => c.output).groupBy(_.exprId).flatMap { + case (exprId, attributes) => + // If there are multiple Attributes having the same ExprId, we need to resolve + // the conflict of nullable field. We do not really expect this happen. + val nullable = attributes.exists(_.nullable) + attributes.map(attr => attr.withNullability(nullable)) + }.toSeq + // At here, we create an AttributeMap that only compare the exprId for the lookup + // operation. So, we can find the corresponding input attribute's nullability. + val attributeMap = AttributeMap[Attribute](childrenOutput.map(attr => attr -> attr)) + // For an Attribute used by the current LogicalPlan, if it is from its children, + // we fix the nullable field by using the nullability setting of the corresponding + // output Attribute from the children. + p.transformExpressions { + case attr: Attribute if attributeMap.contains(attr) => + attr.withNullability(attributeMap(attr).nullable) + } + } + } + /** * Extracts [[WindowExpression]]s from the projectList of a [[Project]] operator and * aggregateExpressions of an [[Aggregate]] operator and creates individual [[Window]] @@ -1346,7 +1478,7 @@ class Analyzer( * This rule handles three cases: * - A [[Project]] having [[WindowExpression]]s in its projectList; * - An [[Aggregate]] having [[WindowExpression]]s in its aggregateExpressions. - * - An [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING + * - A [[Filter]]->[[Aggregate]] pattern representing GROUP BY with a HAVING * clause and the [[Aggregate]] has [[WindowExpression]]s in its aggregateExpressions. * Note: If there is a GROUP BY clause in the query, aggregations and corresponding * filters (expressions in the HAVING clause) should be evaluated before any @@ -1506,7 +1638,7 @@ class Analyzer( // We do a final check and see if we only have a single Window Spec defined in an // expressions. - if (distinctWindowSpec.length == 0 ) { + if (distinctWindowSpec.isEmpty) { failAnalysis(s"$expr does not have any WindowExpression.") } else if (distinctWindowSpec.length > 1) { // newExpressionsWithWindowFunctions only have expressions with a single @@ -1519,27 +1651,17 @@ class Analyzer( } }.toSeq - // Third, for every Window Spec, we add a Window operator and set currentChild as the - // child of it. - var currentChild = child - var i = 0 - while (i < groupedWindowExpressions.size) { - val ((partitionSpec, orderSpec), windowExpressions) = groupedWindowExpressions(i) - // Set currentChild to the newly created Window operator. - currentChild = - Window( - windowExpressions, - partitionSpec, - orderSpec, - currentChild) - - // Move to next Window Spec. - i += 1 - } + // Third, we aggregate them by adding each Window operator for each Window Spec and then + // setting this to the child of the next Window operator. + val windowOps = + groupedWindowExpressions.foldLeft(child) { + case (last, ((partitionSpec, orderSpec), windowExpressions)) => + Window(windowExpressions, partitionSpec, orderSpec, last) + } - // Finally, we create a Project to output currentChild's output + // Finally, we create a Project to output windowOps's output // newExpressionsWithWindowFunctions. - Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) + Project(windowOps.output ++ newExpressionsWithWindowFunctions, windowOps) } // end of addWindow // We have to use transformDown at here to make sure the rule of @@ -1672,7 +1794,8 @@ class Analyzer( s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) if wf.frame != UnspecifiedFrame => WindowExpression(wf, s.copy(frameSpecification = wf.frame)) - case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + if e.resolved => val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) we.copy(windowSpec = s.copy(frameSpecification = frame)) } @@ -1686,7 +1809,9 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case logical: LogicalPlan => logical transformExpressions { case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => - failAnalysis(s"WindowFunction $wf requires window to be ordered") + failAnalysis(s"Window function $wf requires window to be ordered, please add ORDER BY " + + s"clause. For example SELECT $wf(value_expr) OVER (PARTITION BY window_partition " + + s"ORDER BY window_ordering) from table") case WindowExpression(rank: RankLike, spec) if spec.resolved => val order = spec.orderSpec.map(_.child) WindowExpression(rank.withOrder(order), spec) @@ -1719,13 +1844,25 @@ class Analyzer( } private def commonNaturalJoinProcessing( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - joinNames: Seq[String], - condition: Option[Expression]) = { - val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) - val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + joinNames: Seq[String], + condition: Option[Expression]) = { + val leftKeys = joinNames.map { keyName => + val joinColumn = left.output.find(attr => resolver(attr.name, keyName)) + assert( + joinColumn.isDefined, + s"$keyName should exist in ${left.output.map(_.name).mkString(",")}") + joinColumn.get + } + val rightKeys = joinNames.map { keyName => + val joinColumn = right.output.find(attr => resolver(attr.name, keyName)) + assert( + joinColumn.isDefined, + s"$keyName should exist in ${right.output.map(_.name).mkString(",")}") + joinColumn.get + } val joinPairs = leftKeys.zip(rightKeys) val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And) @@ -1773,10 +1910,59 @@ class Analyzer( } else { inputAttributes } - val unbound = deserializer transform { - case b: BoundReference => inputs(b.ordinal) - } - resolveExpression(unbound, LocalRelation(inputs), throws = true) + + validateTopLevelTupleFields(deserializer, inputs) + val resolved = resolveExpression( + deserializer, LocalRelation(inputs), throws = true) + validateNestedTupleFields(resolved) + resolved + } + } + + private def fail(schema: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.") + } + + /** + * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column + * by position. However, the actual number of columns may be different from the number of Tuple + * fields. This method is used to check the number of columns and fields, and throw an + * exception if they do not match. + */ + private def validateTopLevelTupleFields( + deserializer: Expression, inputs: Seq[Attribute]): Unit = { + val ordinals = deserializer.collect { + case GetColumnByOrdinal(ordinal, _) => ordinal + }.distinct.sorted + + if (ordinals.nonEmpty && ordinals != inputs.indices) { + fail(inputs.toStructType, ordinals.last) + } + } + + /** + * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field + * by position. However, the actual number of struct fields may be different from the number + * of nested Tuple fields. This method is used to check the number of struct fields and nested + * Tuple fields, and throw an exception if they do not match. + */ + private def validateNestedTupleFields(deserializer: Expression): Unit = { + val structChildToOrdinals = deserializer + // There are 2 kinds of `GetStructField`: + // 1. resolved from `UnresolvedExtractValue`, and it will have a `name` property. + // 2. created when we build deserializer expression for nested tuple, no `name` property. + // Here we want to validate the ordinals of nested tuple, so we should only catch + // `GetStructField` without the name property. + .collect { case g: GetStructField if g.name.isEmpty => g } + .groupBy(_.child) + .mapValues(_.map(_.ordinal).distinct.sorted) + + structChildToOrdinals.foreach { case (expr, ordinals) => + val schema = expr.dataType.asInstanceOf[StructType] + if (ordinals != schema.indices) { + fail(schema, ordinals.last) + } } } } @@ -1817,8 +2003,8 @@ class Analyzer( } private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { - val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) - val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) toPrecedence > 0 && fromPrecedence > toPrecedence } @@ -2018,4 +2204,3 @@ object TimeWindowing extends Rule[LogicalPlan] { } } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 61a7d9ea24f4f..790566c7659c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog.SimpleCatalogRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin} +import org.apache.spark.sql.catalyst.plans.UsingJoin import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -45,6 +46,21 @@ trait CheckAnalysis extends PredicateHelper { }).length > 1 } + private def checkLimitClause(limitExpr: Expression): Unit = { + limitExpr match { + case e if !e.foldable => failAnalysis( + "The limit expression must evaluate to a constant value, but got " + + limitExpr.sql) + case e if e.dataType != IntegerType => failAnalysis( + s"The limit expression must be integer type, but got " + + e.dataType.simpleString) + case e if e.eval().asInstanceOf[Int] < 0 => failAnalysis( + "The limit expression must be equal to or greater than 0, but got " + + e.eval().asInstanceOf[Int]) + case e => // OK + } + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. @@ -60,9 +76,6 @@ trait CheckAnalysis extends PredicateHelper { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.sql}' given input columns: [$from]") - case ScalarSubquery(_, conditions, _) if conditions.nonEmpty => - failAnalysis("Correlated scalar subqueries are not supported.") - case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(message) => @@ -75,9 +88,9 @@ trait CheckAnalysis extends PredicateHelper { s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") case g: Grouping => - failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup") + failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup") case g: GroupingID => - failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") + failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup") case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") @@ -104,6 +117,41 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis(s"Window specification $s is not valid because $m") case None => w } + + case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => + // Make sure we are using equi-joins. + conditions.foreach { + case _: EqualTo | _: EqualNullSafe => // ok + case e => failAnalysis( + s"The correlated scalar subquery can only contain equality predicates: $e") + } + + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates which contain exactly one aggregate expressions. + // The analyzer has already checked that subquery contained only one output column, and + // added all the grouping expressions to the aggregate. + def checkAggregate(a: Aggregate): Unit = { + val aggregates = a.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + failAnalysis("The output of a correlated scalar subquery must be aggregated") + } + } + + // Skip projects and subquery aliases added by the Analyzer and the SQLBuilder. + def cleanQuery(p: LogicalPlan): LogicalPlan = p match { + case SubqueryAlias(_, child) => cleanQuery(child) + case Project(_, child) => cleanQuery(child) + case child => child + } + + cleanQuery(query) match { + case a: Aggregate => checkAggregate(a) + case Filter(_, a: Aggregate) => checkAggregate(a) + case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") + } + s } operator match { @@ -115,8 +163,9 @@ trait CheckAnalysis extends PredicateHelper { case f @ Filter(condition, child) => splitConjunctivePredicates(condition).foreach { case _: PredicateSubquery | Not(_: PredicateSubquery) => - case e if PredicateSubquery.hasPredicateSubquery(e) => - failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e") + case e if PredicateSubquery.hasNullAwarePredicateWithinNot(e) => + failAnalysis(s"Null-aware predicate sub-queries cannot be used in nested" + + s" conditions: $e") case e => } @@ -172,7 +221,6 @@ trait CheckAnalysis extends PredicateHelper { "Add to group by or wrap in first() (or first_value) if you don't care " + "which value you get.") case e if groupingExprs.exists(_.semanticEquals(e)) => // OK - case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } @@ -181,7 +229,7 @@ trait CheckAnalysis extends PredicateHelper { if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( s"expression ${expr.sql} cannot be used as a grouping expression " + - s"because its data type ${expr.dataType.simpleString} is not a orderable " + + s"because its data type ${expr.dataType.simpleString} is not an orderable " + s"data type.") } @@ -208,16 +256,26 @@ trait CheckAnalysis extends PredicateHelper { case s @ SetOperation(left, right) if left.output.length != right.output.length => failAnalysis( s"${s.nodeName} can only be performed on tables with the same number of columns, " + - s"but the left table has ${left.output.length} columns and the right has " + - s"${right.output.length}") + s"but the left table has ${left.output.length} columns and the right has " + + s"${right.output.length}") case s: Union if s.children.exists(_.output.length != s.children.head.output.length) => val firstError = s.children.find(_.output.length != s.children.head.output.length).get failAnalysis( - s""" - |Unions can only be performed on tables with the same number of columns, - | but one table has '${firstError.output.length}' columns and another table has - | '${s.children.head.output.length}' columns""".stripMargin) + s"Unions can only be performed on tables with the same number of columns, " + + s"but one table has '${firstError.output.length}' columns and another table has " + + s"'${s.children.head.output.length}' columns") + + case GlobalLimit(limitExpr, _) => checkLimitClause(limitExpr) + + case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr) + + case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => + p match { + case _: Filter | _: Aggregate | _: Project => // Ok + case other => failAnalysis( + s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") + } case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") @@ -266,6 +324,40 @@ trait CheckAnalysis extends PredicateHelper { |Conflicting attributes: ${conflictingAttributes.mkString(",")} """.stripMargin) + case s: SimpleCatalogRelation => + failAnalysis( + s""" + |Hive support is required to select over the following tables: + |${s.catalogTable.identifier} + """.stripMargin) + + // TODO: We need to consolidate this kind of checks for InsertIntoTable + // with the rule of PreWriteCheck defined in extendedCheckRules. + case InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) => + failAnalysis( + s""" + |Hive support is required to insert into the following tables: + |${s.catalogTable.identifier} + """.stripMargin) + + case InsertIntoTable(t, _, _, _, _) + if !t.isInstanceOf[LeafNode] || + t.isInstanceOf[Range] || + t == OneRowRelation || + t.isInstanceOf[LocalRelation] => + failAnalysis(s"Inserting into an RDD-based table is not allowed.") + + case i @ InsertIntoTable(table, partitions, query, _, _) => + val numStaticPartitions = partitions.values.count(_.isDefined) + if (table.output.size != (query.output.size + numStaticPartitions)) { + failAnalysis( + s"$table requires that the data to be inserted have the same number of " + + s"columns as the target table: target table has ${table.output.size} " + + s"column(s) but the inserted data has " + + s"${query.output.size + numStaticPartitions} column(s), including " + + s"$numStaticPartitions partition column(s) having constant value(s).") + } + case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1bada2ce67ea2..35fd800df4a4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -25,10 +25,16 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap +import org.apache.spark.sql.types._ -/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ +/** + * A catalog for looking up user defined functions, used by an [[Analyzer]]. + * + * Note: The implementation should be thread-safe to allow concurrent access. + */ trait FunctionRegistry { final def registerFunction(name: String, builder: FunctionBuilder): Unit = { @@ -62,7 +68,7 @@ trait FunctionRegistry { class SimpleFunctionRegistry extends FunctionRegistry { - private[sql] val functionBuilders = + protected val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) override def registerFunction( @@ -97,7 +103,7 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.remove(name).isDefined } - override def clear(): Unit = { + override def clear(): Unit = synchronized { functionBuilders.clear() } @@ -155,22 +161,24 @@ object FunctionRegistry { val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions expression[Abs]("abs"), - expression[CreateArray]("array"), expression[Coalesce]("coalesce"), expression[Explode]("explode"), expression[Greatest]("greatest"), expression[If]("if"), + expression[Inline]("inline"), expression[IsNaN]("isnan"), + expression[IfNull]("ifnull"), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), - expression[CreateMap]("map"), - expression[CreateNamedStruct]("named_struct"), expression[NaNvl]("nanvl"), - expression[Coalesce]("nvl"), + expression[NullIf]("nullif"), + expression[Nvl]("nvl"), + expression[Nvl2]("nvl2"), + expression[PosExplode]("posexplode"), expression[Rand]("rand"), expression[Randn]("randn"), - expression[CreateStruct]("struct"), + expression[Stack]("stack"), expression[CaseWhen]("when"), // math functions @@ -215,6 +223,7 @@ object FunctionRegistry { expression[Signum]("signum"), expression[Sin]("sin"), expression[Sinh]("sinh"), + expression[StringToMap]("str_to_map"), expression[Sqrt]("sqrt"), expression[Tan]("tan"), expression[Tanh]("tanh"), @@ -241,6 +250,7 @@ object FunctionRegistry { expression[Average]("mean"), expression[Min]("min"), expression[Skewness]("skewness"), + expression[StddevSamp]("std"), expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), @@ -248,6 +258,8 @@ object FunctionRegistry { expression[VarianceSamp]("variance"), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), + expression[CollectList]("collect_list"), + expression[CollectSet]("collect_set"), // string functions expression[Ascii]("ascii"), @@ -255,6 +267,7 @@ object FunctionRegistry { expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), + expression[Elt]("elt"), expression[Encode]("encode"), expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), @@ -271,6 +284,7 @@ object FunctionRegistry { expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), expression[JsonTuple]("json_tuple"), + expression[ParseUrl]("parse_url"), expression[FormatString]("printf"), expression[RegExpExtract]("regexp_extract"), expression[RegExpReplace]("regexp_replace"), @@ -279,6 +293,7 @@ object FunctionRegistry { expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), + expression[Sentences]("sentences"), expression[SoundEx]("soundex"), expression[StringSpace]("space"), expression[StringSplit]("split"), @@ -291,6 +306,15 @@ object FunctionRegistry { expression[UnBase64]("unbase64"), expression[Unhex]("unhex"), expression[Upper]("upper"), + expression[XPathList]("xpath"), + expression[XPathBoolean]("xpath_boolean"), + expression[XPathDouble]("xpath_double"), + expression[XPathDouble]("xpath_number"), + expression[XPathFloat]("xpath_float"), + expression[XPathInt]("xpath_int"), + expression[XPathLong]("xpath_long"), + expression[XPathShort]("xpath_short"), + expression[XPathString]("xpath_string"), // datetime functions expression[AddMonths]("add_months"), @@ -324,9 +348,15 @@ object FunctionRegistry { expression[TimeWindow]("window"), // collection functions + expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[CreateMap]("map"), + expression[CreateNamedStruct]("named_struct"), + expression[MapKeys]("map_keys"), + expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[CreateStruct]("struct"), // misc functions expression[AssertTrue]("assert_true"), @@ -340,6 +370,8 @@ object FunctionRegistry { expression[InputFileName]("input_file_name"), expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), expression[CurrentDatabase]("current_database"), + expression[CallMethodViaReflection]("reflect"), + expression[CallMethodViaReflection]("java_method"), // grouping sets expression[Cube]("cube"), @@ -377,8 +409,21 @@ object FunctionRegistry { expression[BitwiseAnd]("&"), expression[BitwiseNot]("~"), expression[BitwiseOr]("|"), - expression[BitwiseXor]("^") - + expression[BitwiseXor]("^"), + + // Cast aliases (SPARK-16730) + castAlias("boolean", BooleanType), + castAlias("tinyint", ByteType), + castAlias("smallint", ShortType), + castAlias("int", IntegerType), + castAlias("bigint", LongType), + castAlias("float", FloatType), + castAlias("double", DoubleType), + castAlias("decimal", DecimalType.USER_DEFAULT), + castAlias("date", DateType), + castAlias("timestamp", TimestampType), + castAlias("binary", BinaryType), + castAlias("string", StringType) ) val builtin: SimpleFunctionRegistry = { @@ -387,6 +432,8 @@ object FunctionRegistry { fr } + val functionSet: Set[String] = builtin.listFunction().toSet + /** See usage above. */ private def expression[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { @@ -401,7 +448,7 @@ object FunctionRegistry { case Failure(e) => throw new AnalysisException(e.getMessage) } } else { - // Otherwise, find an ctor method that matches the number of arguments, and use that. + // Otherwise, find a constructor method that matches the number of arguments, and use that. val params = Seq.fill(expressions.size)(classOf[Expression]) val f = Try(tag.runtimeClass.getDeclaredConstructor(params : _*)) match { case Success(e) => @@ -419,14 +466,37 @@ object FunctionRegistry { } } - val clazz = tag.runtimeClass + (name, (expressionInfo[T](name), builder)) + } + + /** + * Creates a function registry lookup entry for cast aliases (SPARK-16730). + * For example, if name is "int", and dataType is IntegerType, this means int(x) would become + * an alias for cast(x as IntegerType). + * See usage above. + */ + private def castAlias( + name: String, + dataType: DataType): (String, (ExpressionInfo, FunctionBuilder)) = { + val builder = (args: Seq[Expression]) => { + if (args.size != 1) { + throw new AnalysisException(s"Function $name accepts only one argument") + } + Cast(args.head, dataType) + } + (name, (expressionInfo[Cast](name), builder)) + } + + /** + * Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name. + */ + private def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = { + val clazz = scala.reflect.classTag[T].runtimeClass val df = clazz.getAnnotation(classOf[ExpressionDescription]) if (df != null) { - (name, - (new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended()), - builder)) + new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended()) } else { - (name, (new ExpressionInfo(clazz.getCanonicalName, name), builder)) + new ExpressionInfo(clazz.getCanonicalName, name) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 394be47a588b7..95a3837ae1ab7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** - * A trait that should be mixed into query operators where an single instance might appear multiple + * A trait that should be mixed into query operators where a single instance might appear multiple * times in a logical query plan. It is invalid to have multiple copies of the same attribute * produced by distinct operators in a query tree as this breaks the guarantee that expression * ids, which are used to differentiate attributes, are unique. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index 11ef9e1160a9a..8febdcaee829b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -25,13 +25,30 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ -class NoSuchDatabaseException(db: String) extends AnalysisException(s"Database $db not found") +class NoSuchDatabaseException(db: String) extends AnalysisException(s"Database '$db' not found") class NoSuchTableException(db: String, table: String) - extends AnalysisException(s"Table or View $table not found in database $db") + extends AnalysisException(s"Table or view '$table' not found in database '$db'") -class NoSuchPartitionException(db: String, table: String, spec: TablePartitionSpec) extends - AnalysisException(s"Partition not found in table $table database $db:\n" + spec.mkString("\n")) +class NoSuchPartitionException( + db: String, + table: String, + spec: TablePartitionSpec) + extends AnalysisException( + s"Partition not found in table '$table' database '$db':\n" + spec.mkString("\n")) + +class NoSuchPermanentFunctionException(db: String, func: String) + extends AnalysisException(s"Function '$func' not found in database '$db'") class NoSuchFunctionException(db: String, func: String) - extends AnalysisException(s"Function $func not found in database $db") + extends AnalysisException( + s"Undefined function: '$func'. This function is neither a registered temporary function nor " + + s"a permanent function registered in the database '$db'.") + +class NoSuchPartitionsException(db: String, table: String, specs: Seq[TablePartitionSpec]) + extends AnalysisException( + s"The following partitions not found in table '$table' database '$db':\n" + + specs.mkString("\n===\n")) + +class NoSuchTempFunctionException(func: String) + extends AnalysisException(s"Temporary function '$func' not found") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala new file mode 100644 index 0000000000000..7323197b10f6e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.util.control.NonFatal + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. + */ +object ResolveInlineTables extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case table: UnresolvedInlineTable if table.expressionsResolved => + validateInputDimension(table) + validateInputEvaluable(table) + convert(table) + } + + /** + * Validates the input data dimension: + * 1. All rows have the same cardinality. + * 2. The number of column aliases defined is consistent with the number of columns in data. + * + * This is package visible for unit testing. + */ + private[analysis] def validateInputDimension(table: UnresolvedInlineTable): Unit = { + if (table.rows.nonEmpty) { + val numCols = table.names.size + table.rows.zipWithIndex.foreach { case (row, ri) => + if (row.size != numCols) { + table.failAnalysis(s"expected $numCols columns but found ${row.size} columns in row $ri") + } + } + } + } + + /** + * Validates that all inline table data are valid expressions that can be evaluated + * (in this they must be foldable). + * + * This is package visible for unit testing. + */ + private[analysis] def validateInputEvaluable(table: UnresolvedInlineTable): Unit = { + table.rows.foreach { row => + row.foreach { e => + // Note that nondeterministic expressions are not supported since they are not foldable. + if (!e.resolved || !e.foldable) { + e.failAnalysis(s"cannot evaluate expression ${e.sql} in inline table definition") + } + } + } + } + + /** + * Convert a valid (with right shape and foldable inputs) [[UnresolvedInlineTable]] + * into a [[LocalRelation]]. + * + * This function attempts to coerce inputs into consistent types. + * + * This is package visible for unit testing. + */ + private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = { + // For each column, traverse all the values and find a common data type and nullability. + val fields = table.rows.transpose.zip(table.names).map { case (column, name) => + val inputTypes = column.map(_.dataType) + val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { + table.failAnalysis(s"incompatible types found in column $name for inline table") + } + StructField(name, tpe, nullable = column.exists(_.nullable)) + } + val attributes = StructType(fields).toAttributes + assert(fields.size == table.names.size) + + val newRows: Seq[InternalRow] = table.rows.map { row => + InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => + val targetType = fields(ci).dataType + try { + if (e.dataType.sameType(targetType)) { + e.eval() + } else { + Cast(e, targetType).eval() + } + } catch { + case NonFatal(ex) => + table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") + } + }) + } + + LocalRelation(attributes, newRows) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala new file mode 100644 index 0000000000000..6b3bb68538dd1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{DataType, IntegerType, LongType} + +/** + * Rule that resolves table-valued function references. + */ +object ResolveTableValuedFunctions extends Rule[LogicalPlan] { + /** + * List of argument names and their types, used to declare a function. + */ + private case class ArgumentList(args: (String, DataType)*) { + /** + * Try to cast the expressions to satisfy the expected types of this argument list. If there + * are any types that cannot be casted, then None is returned. + */ + def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = { + if (args.length == values.length) { + val casted = values.zip(args).map { case (value, (_, expectedType)) => + TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType) + } + if (casted.forall(_.isDefined)) { + return Some(casted.map(_.get)) + } + } + None + } + + override def toString: String = { + args.map { a => + s"${a._1}: ${a._2.typeName}" + }.mkString(", ") + } + } + + /** + * A TVF maps argument lists to resolver functions that accept those arguments. Using a map + * here allows for function overloading. + */ + private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] + + /** + * TVF builder. + */ + private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) + : (ArgumentList, Seq[Any] => LogicalPlan) = { + (ArgumentList(args: _*), + pf orElse { + case args => + throw new IllegalArgumentException( + "Invalid arguments for resolved function: " + args.mkString(", ")) + }) + } + + /** + * Internal registry of table-valued functions. + */ + private val builtinFunctions: Map[String, TVF] = Map( + "range" -> Map( + /* range(end) */ + tvf("end" -> LongType) { case Seq(end: Long) => + Range(0, end, 1, None) + }, + + /* range(start, end) */ + tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) => + Range(start, end, 1, None) + }, + + /* range(start, end, step) */ + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { + case Seq(start: Long, end: Long, step: Long) => + Range(start, end, step, None) + }, + + /* range(start, end, step, numPartitions) */ + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, + "numPartitions" -> IntegerType) { + case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => + Range(start, end, step, Some(numPartitions)) + }) + ) + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => + builtinFunctions.get(u.functionName) match { + case Some(tvf) => + val resolved = tvf.flatMap { case (argList, resolver) => + argList.implicitCast(u.functionArgs) match { + case Some(casted) => + Some(resolver(casted.map(_.eval()))) + case _ => + None + } + } + resolved.headOption.getOrElse { + val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") + u.failAnalysis( + s"""error: table-valued function ${u.functionName} with alternatives: + |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |cannot be applied to: (${argTypes})""".stripMargin) + } + case _ => + u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala index 79c3528a522d3..d4350598f478c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -37,7 +37,7 @@ object TypeCheckResult { /** * Represents the failing result of `Expression.checkInputDataTypes`, - * with a error message to show the reason of failure. + * with an error message to show the reason of failure. */ case class TypeCheckFailure(message: String) extends TypeCheckResult { def isSuccess: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala similarity index 92% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 5323b79c57c4b..193c3ec4e585a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -24,7 +24,6 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -34,9 +33,6 @@ import org.apache.spark.sql.types._ * A collection of [[Rule]] that can be used to coerce differing types that participate in * operations into compatible ones. * - * Most of these rules are based on Hive semantics, but they do not introduce any dependencies on - * the hive codebase. - * * Notes about type widening / tightest common types: Broadly, there are two cases when we need * to widen data types (e.g. union, binary comparison). In case 1, we are looking for a common * data type for two or more data types, and in this case no loss of precision is allowed. Examples @@ -46,7 +42,7 @@ import org.apache.spark.sql.types._ * double's range is larger than decimal, and yet decimal is more precise than double, but in * union we would cast the decimal into double). */ -object HiveTypeCoercion { +object TypeCoercion { val typeCoercionRules = PropagateTypes :: @@ -67,7 +63,7 @@ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - private[sql] val numericPrecedence = + val numericPrecedence = IndexedSeq( ByteType, ShortType, @@ -77,7 +73,7 @@ object HiveTypeCoercion { DoubleType) /** - * Case 1 type widening (see the classdoc comment above for HiveTypeCoercion). + * Case 1 type widening (see the classdoc comment above for TypeCoercion). * * Find the tightest common type of two types that might be used in a binary expression. * This handles all numeric types except fixed-precision decimals interacting with each other or @@ -95,7 +91,8 @@ object HiveTypeCoercion { Some(t1) // Promote numeric types to the highest of the two - case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => + case (t1: NumericType, t2: NumericType) + if !t1.isInstanceOf[DecimalType] && !t2.isInstanceOf[DecimalType] => val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) Some(numericPrecedence(index)) @@ -103,7 +100,7 @@ object HiveTypeCoercion { } /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */ - private def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = { + def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = { findTightestCommonTypeOfTwo(left, right).orElse((left, right) match { case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) @@ -111,18 +108,6 @@ object HiveTypeCoercion { }) } - /** - * Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use - * [[findTightestCommonTypeToString]] to find the TightestCommonType. - */ - private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): Option[DataType] = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case None => None - case Some(d) => - findTightestCommonTypeToString(d, c) - }) - } - /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -135,7 +120,7 @@ object HiveTypeCoercion { } /** - * Case 2 type widening (see the classdoc comment above for HiveTypeCoercion). + * Case 2 type widening (see the classdoc comment above for TypeCoercion). * * i.e. the main difference with [[findTightestCommonTypeOfTwo]] is that here we allow some * loss of precision when widening decimal and double. @@ -160,6 +145,28 @@ object HiveTypeCoercion { }) } + /** + * Similar to [[findWiderCommonType]], but can't promote to string. This is also similar to + * [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds + * system limitation, this rule will truncate the decimal type before return it. + */ + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => + Some(DoubleType) + case _ => None + }) + case None => None + }) + } + private def haveSameType(exprs: Seq[Expression]): Boolean = exprs.map(_.dataType).distinct.length == 1 @@ -181,8 +188,8 @@ object HiveTypeCoercion { q transformExpressions { case a: AttributeReference => inputMap.get(a.exprId) match { - // This can happen when a Attribute reference is born in a non-leaf node, for example - // due to a call to an external script like in the Transform operator. + // This can happen when an Attribute reference is born in a non-leaf node, for + // example due to a call to an external script like in the Transform operator. // TODO: Perhaps those should actually be aliases? case None => a // Leave the same if the dataTypes match. @@ -293,11 +300,6 @@ object HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) => - a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right)) - case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) => - a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT))) - case a @ BinaryArithmetic(left @ StringType(), right) => a.makeCopy(Array(Cast(left, DoubleType), right)) case a @ BinaryArithmetic(left, right @ StringType()) => @@ -448,7 +450,7 @@ object HiveTypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -459,7 +461,7 @@ object HiveTypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -469,7 +471,7 @@ object HiveTypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -502,16 +504,19 @@ object HiveTypeCoercion { case None => c } + // When finding wider type for `Greatest` and `Least`, we should handle decimal types even if + // we need to truncate, but we should not promote one side to string if the other side is + // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonType(types) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonType(types) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -520,6 +525,8 @@ object HiveTypeCoercion { NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => NaNvl(Cast(l, DoubleType), r) + + case e: RuntimeReplaceable => e.replaceForTypeCoercion() } } @@ -531,13 +538,18 @@ object HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. - case e if !e.resolved => e + case e if !e.childrenResolved => e // Decimal and Double remain the same case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d + case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) => + Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + } - case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + private def isNumericOrNull(ex: Expression): Boolean = { + // We need to handle null types in case a query contains null literals. + ex.dataType.isInstanceOf[NumericType] || ex.dataType == NullType } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index aadc1d31bd4b2..f6e32e29ebca8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, InternalOutputModes} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.OutputMode /** * Analyzes the presence of unsupported operations in a logical plan. @@ -29,8 +30,7 @@ object UnsupportedOperationChecker { def checkForBatch(plan: LogicalPlan): Unit = { plan.foreachUp { case p if p.isStreaming => - throwError( - "Queries with streaming sources must be executed with write.startStream()")(p) + throwError("Queries with streaming sources must be executed with writeStream.start()")(p) case _ => } @@ -40,13 +40,48 @@ object UnsupportedOperationChecker { if (!plan.isStreaming) { throwError( - "Queries without streaming sources cannot be executed with write.startStream()")(plan) + "Queries without streaming sources cannot be executed with writeStream.start()")(plan) + } + + // Disallow multiple streaming aggregations + val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + + if (aggregates.size > 1) { + throwError( + "Multiple streaming aggregations are not supported with " + + "streaming DataFrames/Datasets")(plan) } - plan.foreachUp { implicit plan => + // Disallow some output mode + outputMode match { + case InternalOutputModes.Append if aggregates.nonEmpty => + throwError( + s"$outputMode output mode not supported when there are streaming aggregations on " + + s"streaming DataFrames/DataSets")(plan) + + case InternalOutputModes.Complete | InternalOutputModes.Update if aggregates.isEmpty => + throwError( + s"$outputMode output mode not supported when there are no streaming aggregations on " + + s"streaming DataFrames/Datasets")(plan) + + case _ => + } + + /** + * Whether the subplan will contain complete data or incremental data in every incremental + * execution. Some operations may be allowed only when the child logical plan gives complete + * data. + */ + def containsCompleteData(subplan: LogicalPlan): Boolean = { + val aggs = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + // Either the subplan has no streaming source, or it has aggregation with Complete mode + !subplan.isStreaming || (aggs.nonEmpty && outputMode == InternalOutputModes.Complete) + } + + plan.foreachUp { implicit subPlan => // Operations that cannot exists anywhere in a streaming plan - plan match { + subPlan match { case _: Command => throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + @@ -55,11 +90,6 @@ object UnsupportedOperationChecker { case _: InsertIntoTable => throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets") - case Aggregate(_, _, child) if child.isStreaming && outputMode == Append => - throwError( - "Aggregations are not supported on streaming DataFrames/Datasets in " + - "Append output mode. Consider changing output mode to Update.") - case Join(left, right, joinType, _) => joinType match { @@ -109,11 +139,12 @@ object UnsupportedOperationChecker { case GroupingSets(_, _, child, _) if child.isStreaming => throwError("GroupingSets is not supported on streaming DataFrames/Datasets") - case GlobalLimit(_, _) | LocalLimit(_, _) if plan.children.forall(_.isStreaming) => + case GlobalLimit(_, _) | LocalLimit(_, _) if subPlan.children.forall(_.isStreaming) => throwError("Limits are not supported on streaming DataFrames/Datasets") - case Sort(_, _, _) | SortPartitions(_, _) if plan.children.forall(_.isStreaming) => - throwError("Sorting is not supported on streaming DataFrames/Datasets") + case Sort(_, _, _) | SortPartitions(_, _) if !containsCompleteData(subPlan) => + throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on" + + "aggregated DataFrame/Dataset in Complete mode") case Sample(_, _, _, _, child) if child.isStreaming => throwError("Sampling is not supported on streaming DataFrames/Datasets") @@ -123,7 +154,7 @@ object UnsupportedOperationChecker { case ReturnAnswer(child) if child.isStreaming => throwError("Cannot return immediate result on streaming DataFrames/Dataset. Queries " + - "with streaming DataFrames/Datasets must be executed with write.startStream().") + "with streaming DataFrames/Datasets must be executed with writeStream.start().") case _ => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index f82b63ad96764..15239b99c946c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -49,6 +49,37 @@ case class UnresolvedRelation( override lazy val resolved = false } +/** + * An inline table that has not been resolved yet. Once resolved, it is turned by the analyzer into + * a [[org.apache.spark.sql.catalyst.plans.logical.LocalRelation]]. + * + * @param names list of column names + * @param rows expressions for the data + */ +case class UnresolvedInlineTable( + names: Seq[String], + rows: Seq[Seq[Expression]]) + extends LeafNode { + + lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved)) + override lazy val resolved = false + override def output: Seq[Attribute] = Nil +} + +/** + * A table-valued function, e.g. + * {{{ + * select * from range(10); + * }}} + */ +case class UnresolvedTableValuedFunction(functionName: String, functionArgs: Seq[Expression]) + extends LeafNode { + + override def output: Seq[Attribute] = Nil + + override lazy val resolved = false +} + /** * Holds the name of an attribute that has yet to be resolved. */ @@ -142,8 +173,7 @@ object UnresolvedAttribute { case class UnresolvedGenerator(name: FunctionIdentifier, children: Seq[Expression]) extends Generator { - override def elementTypes: Seq[(DataType, Boolean, String)] = - throw new UnresolvedException(this, "elementTypes") + override def elementSchema: StructType = throw new UnresolvedException(this, "elementTypes") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") @@ -216,23 +246,20 @@ abstract class Star extends LeafExpression with NamedExpression { case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable { override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = { + // If there is no table specified, use all input attributes. + if (target.isEmpty) return input.output - // First try to expand assuming it is table.*. - val expandedAttributes: Seq[Attribute] = target match { - // If there is no table specified, use all input attributes. - case None => input.output - // If there is a table, pick out attributes that are part of this table. - case Some(t) => if (t.size == 1) { - input.output.filter(_.qualifier.exists(resolver(_, t.head))) + val expandedAttributes = + if (target.get.size == 1) { + // If there is a table, pick out attributes that are part of this table. + input.output.filter(_.qualifier.exists(resolver(_, target.get.head))) } else { List() } - } if (expandedAttributes.nonEmpty) return expandedAttributes // Try to resolve it as a struct expansion. If there is a conflict and both are possible, // (i.e. [name].* is both a table and a struct), the struct path can always be qualified. - require(target.isDefined) val attribute = input.resolve(target.get, resolver) if (attribute.isDefined) { // This target resolved to an attribute in child. It must be a struct. Expand it. @@ -326,10 +353,13 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) * Holds the expression that has yet to be aliased. * * @param child The computation that is needs to be resolved during analysis. - * @param aliasName The name if specified to be associated with the result of computing [[child]] + * @param aliasFunc The function if specified to be called to generate an alias to associate + * with the result of computing [[child]] * */ -case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) +case class UnresolvedAlias( + child: Expression, + aliasFunc: Option[Expression => String] = None) extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") @@ -364,3 +394,10 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression + with Unevaluable with NonSQLExpression { + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 178ae6d7c21c6..4371ff3d58a3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} /** @@ -27,14 +27,26 @@ import org.apache.spark.sql.AnalysisException * can be accessed in multiple threads. This is an external catalog because it is expected to * interact with external systems. * - * Implementations should throw [[AnalysisException]] when table or database don't exist. + * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ abstract class ExternalCatalog { import CatalogTypes.TablePartitionSpec protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { - throw new AnalysisException(s"Database '$db' does not exist") + throw new NoSuchDatabaseException(db) + } + } + + protected def requireFunctionExists(db: String, funcName: String): Unit = { + if (!functionExists(db, funcName)) { + throw new NoSuchFunctionException(db = db, func = funcName) + } + } + + protected def requireFunctionNotExists(db: String, funcName: String): Unit = { + if (functionExists(db, funcName)) { + throw new FunctionAlreadyExistsException(db = db, func = funcName) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 60eb7329f71dc..ef5a19687ce86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -17,10 +17,17 @@ package org.apache.spark.sql.catalyst.catalog +import java.io.IOException + import scala.collection.mutable +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.StringUtils /** @@ -32,7 +39,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils * * All public methods should be synchronized for thread-safety. */ -class InMemoryCatalog extends ExternalCatalog { +class InMemoryCatalog(hadoopConfig: Configuration = new Configuration) extends ExternalCatalog { import CatalogTypes.TablePartitionSpec private class TableDesc(var table: CatalogTable) { @@ -52,24 +59,37 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).tables(table).partitions.contains(spec) } - private def requireFunctionExists(db: String, funcName: String): Unit = { - if (!functionExists(db, funcName)) { - throw new AnalysisException( - s"Function not found: '$funcName' does not exist in database '$db'") + private def requireTableExists(db: String, table: String): Unit = { + if (!tableExists(db, table)) { + throw new NoSuchTableException(db = db, table = table) } } - private def requireTableExists(db: String, table: String): Unit = { - if (!tableExists(db, table)) { - throw new AnalysisException( - s"Table or view not found: '$table' does not exist in database '$db'") + private def requireTableNotExists(db: String, table: String): Unit = { + if (tableExists(db, table)) { + throw new TableAlreadyExistsException(db = db, table = table) } } - private def requirePartitionExists(db: String, table: String, spec: TablePartitionSpec): Unit = { - if (!partitionExists(db, table, spec)) { - throw new AnalysisException( - s"Partition not found: database '$db' table '$table' does not contain: '$spec'") + private def requirePartitionsExist( + db: String, + table: String, + specs: Seq[TablePartitionSpec]): Unit = { + specs.foreach { s => + if (!partitionExists(db, table, s)) { + throw new NoSuchPartitionException(db = db, table = table, spec = s) + } + } + } + + private def requirePartitionsNotExist( + db: String, + table: String, + specs: Seq[TablePartitionSpec]): Unit = { + specs.foreach { s => + if (partitionExists(db, table, s)) { + throw new PartitionAlreadyExistsException(db = db, table = table, spec = s) + } } } @@ -82,9 +102,18 @@ class InMemoryCatalog extends ExternalCatalog { ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { if (!ignoreIfExists) { - throw new AnalysisException(s"Database '${dbDefinition.name}' already exists.") + throw new DatabaseAlreadyExistsException(dbDefinition.name) } } else { + try { + val location = new Path(dbDefinition.locationUri) + val fs = location.getFileSystem(hadoopConfig) + fs.mkdirs(location) + } catch { + case e: IOException => + throw new SparkException(s"Unable to create database ${dbDefinition.name} as failed " + + s"to create its directory ${dbDefinition.locationUri}", e) + } catalog.put(dbDefinition.name, new DatabaseDesc(dbDefinition)) } } @@ -104,10 +133,20 @@ class InMemoryCatalog extends ExternalCatalog { } } // Remove the database. + val dbDefinition = catalog(db).db + try { + val location = new Path(dbDefinition.locationUri) + val fs = location.getFileSystem(hadoopConfig) + fs.delete(location, true) + } catch { + case e: IOException => + throw new SparkException(s"Unable to drop database ${dbDefinition.name} as failed " + + s"to delete its directory ${dbDefinition.locationUri}", e) + } catalog.remove(db) } else { if (!ignoreIfNotExists) { - throw new AnalysisException(s"Database '$db' does not exist") + throw new NoSuchDatabaseException(db) } } } @@ -127,7 +166,7 @@ class InMemoryCatalog extends ExternalCatalog { } override def listDatabases(): Seq[String] = synchronized { - catalog.keySet.toSeq + catalog.keySet.toSeq.sorted } override def listDatabases(pattern: String): Seq[String] = synchronized { @@ -148,9 +187,20 @@ class InMemoryCatalog extends ExternalCatalog { val table = tableDefinition.identifier.table if (tableExists(db, table)) { if (!ignoreIfExists) { - throw new AnalysisException(s"Table '$table' already exists in database '$db'") + throw new TableAlreadyExistsException(db = db, table = table) } } else { + if (tableDefinition.tableType == CatalogTableType.MANAGED) { + val dir = new Path(catalog(db).db.locationUri, table) + try { + val fs = dir.getFileSystem(hadoopConfig) + fs.mkdirs(dir) + } catch { + case e: IOException => + throw new SparkException(s"Unable to create table $table as failed " + + s"to create its directory $dir", e) + } + } catalog(db).tables.put(table, new TableDesc(tableDefinition)) } } @@ -161,18 +211,44 @@ class InMemoryCatalog extends ExternalCatalog { ignoreIfNotExists: Boolean): Unit = synchronized { requireDbExists(db) if (tableExists(db, table)) { + if (getTable(db, table).tableType == CatalogTableType.MANAGED) { + val dir = new Path(catalog(db).db.locationUri, table) + try { + val fs = dir.getFileSystem(hadoopConfig) + fs.delete(dir, true) + } catch { + case e: IOException => + throw new SparkException(s"Unable to drop table $table as failed " + + s"to delete its directory $dir", e) + } + } catalog(db).tables.remove(table) } else { if (!ignoreIfNotExists) { - throw new AnalysisException(s"Table or View '$table' does not exist in database '$db'") + throw new NoSuchTableException(db = db, table = table) } } } override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { requireTableExists(db, oldName) + requireTableNotExists(db, newName) val oldDesc = catalog(db).tables(oldName) oldDesc.table = oldDesc.table.copy(identifier = TableIdentifier(newName, Some(db))) + + if (oldDesc.table.tableType == CatalogTableType.MANAGED) { + val oldDir = new Path(catalog(db).db.locationUri, oldName) + val newDir = new Path(catalog(db).db.locationUri, newName) + try { + val fs = oldDir.getFileSystem(hadoopConfig) + fs.rename(oldDir, newDir) + } catch { + case e: IOException => + throw new SparkException(s"Unable to rename table $oldName to $newName as failed " + + s"to rename its directory $oldDir", e) + } + } + catalog(db).tables.put(newName, oldDesc) catalog(db).tables.remove(oldName) } @@ -198,7 +274,7 @@ class InMemoryCatalog extends ExternalCatalog { override def listTables(db: String): Seq[String] = synchronized { requireDbExists(db) - catalog(db).tables.keySet.toSeq + catalog(db).tables.keySet.toSeq.sorted } override def listTables(db: String, pattern: String): Seq[String] = synchronized { @@ -211,7 +287,7 @@ class InMemoryCatalog extends ExternalCatalog { loadPath: String, isOverwrite: Boolean, holdDDLTime: Boolean): Unit = { - throw new AnalysisException("loadTable is not implemented for InMemoryCatalog.") + throw new UnsupportedOperationException("loadTable is not implemented") } override def loadPartition( @@ -223,7 +299,7 @@ class InMemoryCatalog extends ExternalCatalog { holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean): Unit = { - throw new AnalysisException("loadPartition is not implemented for InMemoryCatalog.") + throw new UnsupportedOperationException("loadPartition is not implemented.") } // -------------------------------------------------------------------------- @@ -240,12 +316,30 @@ class InMemoryCatalog extends ExternalCatalog { if (!ignoreIfExists) { val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec } if (dupSpecs.nonEmpty) { - val dupSpecsStr = dupSpecs.mkString("\n===\n") - throw new AnalysisException("The following partitions already exist in database " + - s"'$db' table '$table':\n$dupSpecsStr") + throw new PartitionsAlreadyExistException(db = db, table = table, specs = dupSpecs) } } - parts.foreach { p => existingParts.put(p.spec, p) } + + val tableDir = new Path(catalog(db).db.locationUri, table) + val partitionColumnNames = getTable(db, table).partitionColumnNames + // TODO: we should follow hive to roll back if one partition path failed to create. + parts.foreach { p => + // If location is set, the partition is using an external partition location and we don't + // need to handle its directory. + if (p.storage.locationUri.isEmpty) { + val partitionPath = partitionColumnNames.flatMap { col => + p.spec.get(col).map(col + "=" + _) + }.mkString("/") + try { + val fs = tableDir.getFileSystem(hadoopConfig) + fs.mkdirs(new Path(tableDir, partitionPath)) + } catch { + case e: IOException => + throw new SparkException(s"Unable to create partition path $partitionPath", e) + } + } + existingParts.put(p.spec, p) + } } override def dropPartitions( @@ -258,12 +352,30 @@ class InMemoryCatalog extends ExternalCatalog { if (!ignoreIfNotExists) { val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s } if (missingSpecs.nonEmpty) { - val missingSpecsStr = missingSpecs.mkString("\n===\n") - throw new AnalysisException("The following partitions do not exist in database " + - s"'$db' table '$table':\n$missingSpecsStr") + throw new NoSuchPartitionsException(db = db, table = table, specs = missingSpecs) } } - partSpecs.foreach(existingParts.remove) + + val tableDir = new Path(catalog(db).db.locationUri, table) + val partitionColumnNames = getTable(db, table).partitionColumnNames + // TODO: we should follow hive to roll back if one partition path failed to delete. + partSpecs.foreach { p => + // If location is set, the partition is using an external partition location and we don't + // need to handle its directory. + if (existingParts.contains(p) && existingParts(p).storage.locationUri.isEmpty) { + val partitionPath = partitionColumnNames.flatMap { col => + p.get(col).map(col + "=" + _) + }.mkString("/") + try { + val fs = tableDir.getFileSystem(hadoopConfig) + fs.delete(new Path(tableDir, partitionPath), true) + } catch { + case e: IOException => + throw new SparkException(s"Unable to delete partition path $partitionPath", e) + } + } + existingParts.remove(p) + } } override def renamePartitions( @@ -272,9 +384,34 @@ class InMemoryCatalog extends ExternalCatalog { specs: Seq[TablePartitionSpec], newSpecs: Seq[TablePartitionSpec]): Unit = synchronized { require(specs.size == newSpecs.size, "number of old and new partition specs differ") + requirePartitionsExist(db, table, specs) + requirePartitionsNotExist(db, table, newSpecs) + + val tableDir = new Path(catalog(db).db.locationUri, table) + val partitionColumnNames = getTable(db, table).partitionColumnNames + // TODO: we should follow hive to roll back if one partition path failed to rename. specs.zip(newSpecs).foreach { case (oldSpec, newSpec) => val newPart = getPartition(db, table, oldSpec).copy(spec = newSpec) val existingParts = catalog(db).tables(table).partitions + + // If location is set, the partition is using an external partition location and we don't + // need to handle its directory. + if (newPart.storage.locationUri.isEmpty) { + val oldPath = partitionColumnNames.flatMap { col => + oldSpec.get(col).map(col + "=" + _) + }.mkString("/") + val newPath = partitionColumnNames.flatMap { col => + newSpec.get(col).map(col + "=" + _) + }.mkString("/") + try { + val fs = tableDir.getFileSystem(hadoopConfig) + fs.rename(new Path(tableDir, oldPath), new Path(tableDir, newPath)) + } catch { + case e: IOException => + throw new SparkException(s"Unable to rename partition path $oldPath", e) + } + } + existingParts.remove(oldSpec) existingParts.put(newSpec, newPart) } @@ -284,8 +421,8 @@ class InMemoryCatalog extends ExternalCatalog { db: String, table: String, parts: Seq[CatalogTablePartition]): Unit = synchronized { + requirePartitionsExist(db, table, parts.map(p => p.spec)) parts.foreach { p => - requirePartitionExists(db, table, p.spec) catalog(db).tables(table).partitions.put(p.spec, p) } } @@ -294,7 +431,7 @@ class InMemoryCatalog extends ExternalCatalog { db: String, table: String, spec: TablePartitionSpec): CatalogTablePartition = synchronized { - requirePartitionExists(db, table, spec) + requirePartitionsExist(db, table, Seq(spec)) catalog(db).tables(table).partitions(spec) } @@ -304,8 +441,8 @@ class InMemoryCatalog extends ExternalCatalog { partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = synchronized { requireTableExists(db, table) if (partialSpec.nonEmpty) { - throw new AnalysisException("listPartition does not support partition spec in " + - "InMemoryCatalog.") + throw new UnsupportedOperationException( + "listPartition with partial partition spec is not implemented") } catalog(db).tables(table).partitions.values.toSeq } @@ -316,11 +453,8 @@ class InMemoryCatalog extends ExternalCatalog { override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) - if (functionExists(db, func.identifier.funcName)) { - throw new AnalysisException(s"Function '$func' already exists in '$db' database") - } else { - catalog(db).functions.put(func.identifier.funcName, func) - } + requireFunctionNotExists(db, func.identifier.funcName) + catalog(db).functions.put(func.identifier.funcName, func) } override def dropFunction(db: String, funcName: String): Unit = synchronized { @@ -330,6 +464,7 @@ class InMemoryCatalog extends ExternalCatalog { override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized { requireFunctionExists(db, oldName) + requireFunctionNotExists(db, newName) val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db))) catalog(db).functions.remove(oldName) catalog(db).functions.put(newName, newFunc) @@ -340,7 +475,7 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).functions(funcName) } - override def functionExists(db: String, funcName: String): Boolean = { + override def functionExists(db: String, funcName: String): Boolean = synchronized { requireDbExists(db) catalog(db).functions.contains(funcName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a445a253ff9b2..f455cc90963f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.catalog -import java.io.File +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -28,18 +28,22 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException, SimpleFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.util.StringUtils +object SessionCatalog { + val DEFAULT_DATABASE = "default" +} + /** * An internal catalog that is used by a Spark Session. This internal catalog serves as a * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary * tables and functions of the Spark Session that it belongs to. * - * This class is not thread-safe. + * This class must be thread-safe. */ class SessionCatalog( externalCatalog: ExternalCatalog, @@ -47,6 +51,7 @@ class SessionCatalog( functionRegistry: FunctionRegistry, conf: CatalystConf, hadoopConf: Configuration) extends Logging { + import SessionCatalog._ import CatalogTypes.TablePartitionSpec // For testing only. @@ -68,19 +73,21 @@ class SessionCatalog( } /** List of temporary tables, mapping from table name to their logical plan. */ + @GuardedBy("this") protected val tempTables = new mutable.HashMap[String, LogicalPlan] // Note: we track current database here because certain operations do not explicitly // specify the database (e.g. DROP TABLE my_table). In these cases we must first // check whether the temporary table or function exists, then, if not, operate on // the corresponding item in the current database. + @GuardedBy("this") protected var currentDb = { - val defaultName = "default" + val defaultName = DEFAULT_DATABASE val defaultDbDefinition = CatalogDatabase(defaultName, "default database", conf.warehousePath, Map()) // Initialize default database if it doesn't already exist createDatabase(defaultDbDefinition, ignoreIfExists = true) - defaultName + formatDatabaseName(defaultName) } /** @@ -90,6 +97,13 @@ class SessionCatalog( if (conf.caseSensitiveAnalysis) name else name.toLowerCase } + /** + * Format database name, taking into account case sensitivity. + */ + protected[this] def formatDatabaseName(name: String): String = { + if (conf.caseSensitiveAnalysis) name else name.toLowerCase + } + /** * This method is used to make the given path qualified before we * store this path in the underlying external catalog. So, when a path @@ -102,6 +116,25 @@ class SessionCatalog( fs.makeQualified(hadoopPath) } + private def requireDbExists(db: String): Unit = { + if (!databaseExists(db)) { + throw new NoSuchDatabaseException(db) + } + } + + private def requireTableExists(name: TableIdentifier): Unit = { + if (!tableExists(name)) { + val db = name.database.getOrElse(currentDb) + throw new NoSuchTableException(db = db, table = name.table) + } + } + + private def requireTableNotExists(name: TableIdentifier): Unit = { + if (tableExists(name)) { + val db = name.database.getOrElse(currentDb) + throw new TableAlreadyExistsException(db = db, table = name.table) + } + } // ---------------------------------------------------------------------------- // Databases // ---------------------------------------------------------------------------- @@ -110,25 +143,37 @@ class SessionCatalog( def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString + val dbName = formatDatabaseName(dbDefinition.name) externalCatalog.createDatabase( - dbDefinition.copy(locationUri = qualifiedPath), + dbDefinition.copy(name = dbName, locationUri = qualifiedPath), ignoreIfExists) } def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { - externalCatalog.dropDatabase(db, ignoreIfNotExists, cascade) + val dbName = formatDatabaseName(db) + if (dbName == DEFAULT_DATABASE) { + throw new AnalysisException(s"Can not drop default database") + } else if (dbName == getCurrentDatabase) { + throw new AnalysisException(s"Can not drop current database `${dbName}`") + } + externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade) } def alterDatabase(dbDefinition: CatalogDatabase): Unit = { - externalCatalog.alterDatabase(dbDefinition) + val dbName = formatDatabaseName(dbDefinition.name) + requireDbExists(dbName) + externalCatalog.alterDatabase(dbDefinition.copy(name = dbName)) } def getDatabaseMetadata(db: String): CatalogDatabase = { - externalCatalog.getDatabase(db) + val dbName = formatDatabaseName(db) + requireDbExists(dbName) + externalCatalog.getDatabase(dbName) } def databaseExists(db: String): Boolean = { - externalCatalog.databaseExists(db) + val dbName = formatDatabaseName(db) + externalCatalog.databaseExists(dbName) } def listDatabases(): Seq[String] = { @@ -139,17 +184,20 @@ class SessionCatalog( externalCatalog.listDatabases(pattern) } - def getCurrentDatabase: String = currentDb + def getCurrentDatabase: String = synchronized { currentDb } def setCurrentDatabase(db: String): Unit = { - if (!databaseExists(db)) { - throw new AnalysisException(s"Database '$db' does not exist.") - } - currentDb = db + val dbName = formatDatabaseName(db) + requireDbExists(dbName) + synchronized { currentDb = dbName } } + /** + * Get the path for creating a non-default database when database location is not provided + * by users. + */ def getDefaultDBPath(db: String): String = { - val database = if (conf.caseSensitiveAnalysis) db else db.toLowerCase + val database = formatDatabaseName(db) new Path(new Path(conf.warehousePath), database + ".db").toString } @@ -171,9 +219,10 @@ class SessionCatalog( * If no such database is specified, create it in the current database. */ def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.identifier.database.getOrElse(currentDb) + val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + requireDbExists(db) externalCatalog.createTable(db, newTableDefinition, ignoreIfExists) } @@ -187,20 +236,35 @@ class SessionCatalog( * this becomes a no-op. */ def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.identifier.database.getOrElse(currentDb) + val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableDefinition.identifier.table) - val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) + val tableIdentifier = TableIdentifier(table, Some(db)) + val newTableDefinition = tableDefinition.copy(identifier = tableIdentifier) + requireDbExists(db) + requireTableExists(tableIdentifier) externalCatalog.alterTable(db, newTableDefinition) } /** - * Retrieve the metadata of an existing metastore table. - * If no database is specified, assume the table is in the current database. - * If the specified table is not found in the database then an [[AnalysisException]] is thrown. + * Return whether a table/view with the specified name exists. If no database is specified, check + * with current database. + */ + def tableExists(name: TableIdentifier): Boolean = synchronized { + val db = formatDatabaseName(name.database.getOrElse(currentDb)) + val table = formatTableName(name.table) + externalCatalog.tableExists(db, table) + } + + /** + * Retrieve the metadata of an existing permanent table/view. If no database is specified, + * assume the table/view is in the current database. If the specified table/view is not found + * in the database then a [[NoSuchTableException]] is thrown. */ def getTableMetadata(name: TableIdentifier): CatalogTable = { - val db = name.database.getOrElse(currentDb) + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Some(db))) externalCatalog.getTable(db, table) } @@ -210,30 +274,33 @@ class SessionCatalog( * If the specified table is not found in the database then return None if it doesn't exist. */ def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = { - val db = name.database.getOrElse(currentDb) + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) + requireDbExists(db) externalCatalog.getTableOption(db, table) } /** * Load files stored in given path into an existing metastore table. * If no database is specified, assume the table is in the current database. - * If the specified table is not found in the database then an [[AnalysisException]] is thrown. + * If the specified table is not found in the database then a [[NoSuchTableException]] is thrown. */ def loadTable( name: TableIdentifier, loadPath: String, isOverwrite: Boolean, holdDDLTime: Boolean): Unit = { - val db = name.database.getOrElse(currentDb) + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Some(db))) externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime) } /** * Load files stored in given path into the partition of an existing metastore table. * If no database is specified, assume the table is in the current database. - * If the specified table is not found in the database then an [[AnalysisException]] is thrown. + * If the specified table is not found in the database then a [[NoSuchTableException]] is thrown. */ def loadPartition( name: TableIdentifier, @@ -243,37 +310,89 @@ class SessionCatalog( holdDDLTime: Boolean, inheritTableSpecs: Boolean, isSkewedStoreAsSubdir: Boolean): Unit = { - val db = name.database.getOrElse(currentDb) + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Some(db))) externalCatalog.loadPartition(db, table, loadPath, partition, isOverwrite, holdDDLTime, inheritTableSpecs, isSkewedStoreAsSubdir) } def defaultTablePath(tableIdent: TableIdentifier): String = { - val dbName = tableIdent.database.getOrElse(currentDb) + val dbName = formatDatabaseName(tableIdent.database.getOrElse(getCurrentDatabase)) val dbLocation = getDatabaseMetadata(dbName).locationUri new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString } - // ------------------------------------------------------------- - // | Methods that interact with temporary and metastore tables | - // ------------------------------------------------------------- + // ---------------------------------------------- + // | Methods that interact with temp views only | + // ---------------------------------------------- /** * Create a temporary table. */ - def createTempTable( + def createTempView( name: String, tableDefinition: LogicalPlan, - overrideIfExists: Boolean): Unit = { + overrideIfExists: Boolean): Unit = synchronized { val table = formatTableName(name) if (tempTables.contains(table) && !overrideIfExists) { - throw new AnalysisException(s"Temporary table '$name' already exists.") + throw new TempTableAlreadyExistsException(name) } tempTables.put(table, tableDefinition) } + /** + * Return a temporary view exactly as it was stored. + */ + def getTempView(name: String): Option[LogicalPlan] = synchronized { + tempTables.get(formatTableName(name)) + } + + /** + * Drop a temporary view. + */ + def dropTempView(name: String): Unit = synchronized { + tempTables.remove(formatTableName(name)) + } + + // ------------------------------------------------------------- + // | Methods that interact with temporary and metastore tables | + // ------------------------------------------------------------- + + /** + * Retrieve the metadata of an existing temporary view or permanent table/view. + * + * If a database is specified in `name`, this will return the metadata of table/view in that + * database. + * If no database is specified, this will first attempt to get the metadata of a temporary view + * with the same name, then, if that does not exist, return the metadata of table/view in the + * current database. + */ + def getTempViewOrPermanentTableMetadata(name: TableIdentifier): CatalogTable = synchronized { + val table = formatTableName(name.table) + if (name.database.isDefined) { + getTableMetadata(name) + } else { + getTempView(table).map { plan => + CatalogTable( + identifier = TableIdentifier(table), + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = plan.output.map { c => + CatalogColumn( + name = c.name, + dataType = c.dataType.catalogString, + nullable = c.nullable + ) + }, + properties = Map(), + viewText = None) + }.getOrElse(getTableMetadata(name)) + } + } + /** * Rename a table. * @@ -283,16 +402,30 @@ class SessionCatalog( * * This assumes the database specified in `oldName` matches the one specified in `newName`. */ - def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = { - if (oldName.database != newName.database) { - throw new AnalysisException("rename does not support moving tables across databases") + def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = synchronized { + val db = formatDatabaseName(oldName.database.getOrElse(currentDb)) + requireDbExists(db) + val newDb = formatDatabaseName(newName.database.getOrElse(currentDb)) + if (db != newDb) { + throw new AnalysisException( + s"RENAME TABLE source and destination databases do not match: '$db' != '$newDb'") } - val db = oldName.database.getOrElse(currentDb) val oldTableName = formatTableName(oldName.table) val newTableName = formatTableName(newName.table) if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { + requireTableExists(TableIdentifier(oldTableName, Some(db))) + requireTableNotExists(TableIdentifier(newTableName, Some(db))) externalCatalog.renameTable(db, oldTableName, newTableName) } else { + if (newName.database.isDefined) { + throw new AnalysisException( + s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': cannot specify database " + + s"name '${newName.database.get}' in the destination table") + } + if (tempTables.contains(newTableName)) { + throw new AnalysisException( + s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': destination table already exists") + } val table = tempTables(oldTableName) tempTables.remove(oldTableName) tempTables.put(newTableName, table) @@ -306,16 +439,17 @@ class SessionCatalog( * If no database is specified, this will first attempt to drop a temporary table with * the same name, then, if that does not exist, drop the table from the current database. */ - def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = { - val db = name.database.getOrElse(currentDb) + def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = synchronized { + val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) if (name.database.isDefined || !tempTables.contains(table)) { + requireDbExists(db) // When ignoreIfNotExists is false, no exception is issued when the table does not exist. // Instead, log it as an error message. - if (externalCatalog.tableExists(db, table)) { + if (tableExists(TableIdentifier(table, Option(db)))) { externalCatalog.dropTable(db, table, ignoreIfNotExists = true) } else if (!ignoreIfNotExists) { - logError(s"Table or View '${name.quotedString}' does not exist") + throw new NoSuchTableException(db = db, table = table) } } else { tempTables.remove(table) @@ -330,36 +464,20 @@ class SessionCatalog( * the same name, then, if that does not exist, return the table from the current database. */ def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = { - val db = name.database.getOrElse(currentDb) - val table = formatTableName(name.table) - val relation = - if (name.database.isDefined || !tempTables.contains(table)) { - val metadata = externalCatalog.getTable(db, table) - SimpleCatalogRelation(db, metadata, alias) - } else { - tempTables(table) - } - val qualifiedTable = SubqueryAlias(table, relation) - // If an alias was specified by the lookup, wrap the plan in a subquery so that - // attributes are properly qualified with this alias. - alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable) - } - - /** - * Return whether a table with the specified name exists. - * - * Note: If a database is explicitly specified, then this will return whether the table - * exists in that particular database instead. In that case, even if there is a temporary - * table with the same name, we will return false if the specified database does not - * contain the table. - */ - def tableExists(name: TableIdentifier): Boolean = { - val db = name.database.getOrElse(currentDb) - val table = formatTableName(name.table) - if (name.database.isDefined || !tempTables.contains(table)) { - externalCatalog.tableExists(db, table) - } else { - true // it's a temporary table + synchronized { + val db = formatDatabaseName(name.database.getOrElse(currentDb)) + val table = formatTableName(name.table) + val relation = + if (name.database.isDefined || !tempTables.contains(table)) { + val metadata = externalCatalog.getTable(db, table) + SimpleCatalogRelation(db, metadata, alias) + } else { + tempTables(table) + } + val qualifiedTable = SubqueryAlias(table, relation) + // If an alias was specified by the lookup, wrap the plan in a subquery so that + // attributes are properly qualified with this alias. + alias.map(a => SubqueryAlias(a, qualifiedTable)).getOrElse(qualifiedTable) } } @@ -369,7 +487,7 @@ class SessionCatalog( * Note: The temporary table cache is checked only when database is not * explicitly specified. */ - def isTemporaryTable(name: TableIdentifier): Boolean = { + def isTemporaryTable(name: TableIdentifier): Boolean = synchronized { name.database.isEmpty && tempTables.contains(formatTableName(name.table)) } @@ -382,41 +500,37 @@ class SessionCatalog( * List all matching tables in the specified database, including temporary tables. */ def listTables(db: String, pattern: String): Seq[TableIdentifier] = { + val dbName = formatDatabaseName(db) + requireDbExists(dbName) val dbTables = - externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) } - val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) - .map { t => TableIdentifier(t) } - dbTables ++ _tempTables + externalCatalog.listTables(dbName, pattern).map { t => TableIdentifier(t, Some(dbName)) } + synchronized { + val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) + .map { t => TableIdentifier(t) } + dbTables ++ _tempTables + } } - // TODO: It's strange that we have both refresh and invalidate here. - /** * Refresh the cache entry for a metastore table, if any. */ - def refreshTable(name: TableIdentifier): Unit = { /* no-op */ } - - /** - * Invalidate the cache entry for a metastore table, if any. - */ - def invalidateTable(name: TableIdentifier): Unit = { /* no-op */ } + def refreshTable(name: TableIdentifier): Unit = { + // Go through temporary tables and invalidate them. + // If the database is defined, this is definitely not a temp table. + // If the database is not defined, there is a good chance this is a temp table. + if (name.database.isEmpty) { + tempTables.get(formatTableName(name.table)).foreach(_.refresh()) + } + } /** * Drop all existing temporary tables. * For testing only. */ - def clearTempTables(): Unit = { + def clearTempTables(): Unit = synchronized { tempTables.clear() } - /** - * Return a temporary table exactly as it was stored. - * For testing only. - */ - private[catalog] def getTempTable(name: String): Option[LogicalPlan] = { - tempTables.get(name) - } - // ---------------------------------------------------------------------------- // Partitions // ---------------------------------------------------------------------------- @@ -437,8 +551,11 @@ class SessionCatalog( tableName: TableIdentifier, parts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = { - val db = tableName.database.getOrElse(currentDb) + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) externalCatalog.createPartitions(db, table, parts, ignoreIfExists) } @@ -448,11 +565,14 @@ class SessionCatalog( */ def dropPartitions( tableName: TableIdentifier, - parts: Seq[TablePartitionSpec], + specs: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean): Unit = { - val db = tableName.database.getOrElse(currentDb) + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) - externalCatalog.dropPartitions(db, table, parts, ignoreIfNotExists) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requirePartialMatchedPartitionSpec(specs, getTableMetadata(tableName)) + externalCatalog.dropPartitions(db, table, specs, ignoreIfNotExists) } /** @@ -465,8 +585,13 @@ class SessionCatalog( tableName: TableIdentifier, specs: Seq[TablePartitionSpec], newSpecs: Seq[TablePartitionSpec]): Unit = { - val db = tableName.database.getOrElse(currentDb) + val tableMetadata = getTableMetadata(tableName) + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(specs, tableMetadata) + requireExactMatchedPartitionSpec(newSpecs, tableMetadata) externalCatalog.renamePartitions(db, table, specs, newSpecs) } @@ -480,8 +605,11 @@ class SessionCatalog( * this becomes a no-op. */ def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = { - val db = tableName.database.getOrElse(currentDb) + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(parts.map(_.spec), getTableMetadata(tableName)) externalCatalog.alterPartitions(db, table, parts) } @@ -490,8 +618,11 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. */ def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = { - val db = tableName.database.getOrElse(currentDb) + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) + requireExactMatchedPartitionSpec(Seq(spec), getTableMetadata(tableName)) externalCatalog.getPartition(db, table, spec) } @@ -505,11 +636,49 @@ class SessionCatalog( def listPartitions( tableName: TableIdentifier, partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = { - val db = tableName.database.getOrElse(currentDb) + val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase)) val table = formatTableName(tableName.table) + requireDbExists(db) + requireTableExists(TableIdentifier(table, Option(db))) externalCatalog.listPartitions(db, table, partialSpec) } + /** + * Verify if the input partition spec exactly matches the existing defined partition spec + * The columns must be the same but the orders could be different. + */ + private def requireExactMatchedPartitionSpec( + specs: Seq[TablePartitionSpec], + table: CatalogTable): Unit = { + val defined = table.partitionColumnNames.sorted + specs.foreach { s => + if (s.keys.toSeq.sorted != defined) { + throw new AnalysisException( + s"Partition spec is invalid. The spec (${s.keys.mkString(", ")}) must match " + + s"the partition spec (${table.partitionColumnNames.mkString(", ")}) defined in " + + s"table '${table.identifier}'") + } + } + } + + /** + * Verify if the input partition spec partially matches the existing defined partition spec + * That is, the columns of partition spec should be part of the defined partition spec. + */ + private def requirePartialMatchedPartitionSpec( + specs: Seq[TablePartitionSpec], + table: CatalogTable): Unit = { + val defined = table.partitionColumnNames + specs.foreach { s => + if (!s.keys.forall(defined.contains)) { + throw new AnalysisException( + s"Partition spec is invalid. The spec (${s.keys.mkString(", ")}) must be contained " + + s"within the partition spec (${table.partitionColumnNames.mkString(", ")}) defined " + + s"in table '${table.identifier}'") + } + } + } + // ---------------------------------------------------------------------------- // Functions // ---------------------------------------------------------------------------- @@ -528,13 +697,14 @@ class SessionCatalog( * If no such database is specified, create it in the current database. */ def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { - val db = funcDefinition.identifier.database.getOrElse(currentDb) + val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)) val newFuncDefinition = funcDefinition.copy(identifier = identifier) if (!functionExists(identifier)) { externalCatalog.createFunction(db, newFuncDefinition) } else if (!ignoreIfExists) { - throw new AnalysisException(s"function '$identifier' already exists in database '$db'") + throw new FunctionAlreadyExistsException(db = db, func = identifier.toString) } } @@ -543,7 +713,8 @@ class SessionCatalog( * If no database is specified, assume the function is in the current database. */ def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = { - val db = name.database.getOrElse(currentDb) + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) val identifier = name.copy(database = Some(db)) if (functionExists(identifier)) { // TODO: registry should just take in FunctionIdentifier for type safety @@ -556,7 +727,7 @@ class SessionCatalog( } externalCatalog.dropFunction(db, name.funcName) } else if (!ignoreIfNotExists) { - throw new AnalysisException(s"function '$identifier' does not exist in database '$db'") + throw new NoSuchFunctionException(db = db, func = identifier.toString) } } @@ -567,7 +738,8 @@ class SessionCatalog( * If no database is specified, this will return the function in the current database. */ def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = { - val db = name.database.getOrElse(currentDb) + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) externalCatalog.getFunction(db, name.funcName) } @@ -575,7 +747,8 @@ class SessionCatalog( * Check if the specified function exists. */ def functionExists(name: FunctionIdentifier): Boolean = { - val db = name.database.getOrElse(currentDb) + val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) + requireDbExists(db) functionRegistry.functionExists(name.unquotedString) || externalCatalog.functionExists(db, name.funcName) } @@ -589,7 +762,7 @@ class SessionCatalog( * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ - private[sql] def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { // TODO: at least support UDAFs here throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") } @@ -598,12 +771,8 @@ class SessionCatalog( * Loads resources such as JARs and Files for a function. Every resource is represented * by a tuple (resource type, resource uri). */ - def loadFunctionResources(resources: Seq[(String, String)]): Unit = { - resources.foreach { case (resourceType, uri) => - val functionResource = - FunctionResource(FunctionResourceType.fromString(resourceType.toLowerCase), uri) - functionResourceLoader.loadResource(functionResource) - } + def loadFunctionResources(resources: Seq[FunctionResource]): Unit = { + resources.foreach(functionResourceLoader.loadResource) } /** @@ -616,7 +785,7 @@ class SessionCatalog( funcDefinition: FunctionBuilder, ignoreIfExists: Boolean): Unit = { if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { - throw new AnalysisException(s"Temporary function '$name' already exists.") + throw new TempFunctionAlreadyExistsException(name) } functionRegistry.registerFunction(name, info, funcDefinition) } @@ -626,27 +795,26 @@ class SessionCatalog( */ def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { - throw new AnalysisException( - s"Temporary function '$name' cannot be dropped because it does not exist!") + throw new NoSuchTempFunctionException(name) } } protected def failFunctionLookup(name: String): Nothing = { - throw new AnalysisException(s"Undefined function: $name. This function is " + - s"neither a registered temporary function nor " + - s"a permanent function registered in the database $currentDb.") + throw new NoSuchFunctionException(db = currentDb, func = name) } /** * Look up the [[ExpressionInfo]] associated with the specified function, assuming it exists. */ - private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = { + def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized { // TODO: just make function registry take in FunctionIdentifier instead of duplicating this - val qualifiedName = name.copy(database = name.database.orElse(Some(currentDb))) + val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName) + val qualifiedName = name.copy(database = database) functionRegistry.lookupFunction(name.funcName) .orElse(functionRegistry.lookupFunction(qualifiedName.unquotedString)) .getOrElse { val db = qualifiedName.database.get + requireDbExists(db) if (externalCatalog.functionExists(db, name.funcName)) { val metadata = externalCatalog.getFunction(db, name.funcName) new ExpressionInfo(metadata.className, qualifiedName.unquotedString) @@ -669,7 +837,9 @@ class SessionCatalog( * based on the function class and put the builder into the FunctionRegistry. * The name of this function in the FunctionRegistry will be `databaseName.functionName`. */ - def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { + def lookupFunction( + name: FunctionIdentifier, + children: Seq[Expression]): Expression = synchronized { // Note: the implementation of this function is a little bit convoluted. // We probably shouldn't use a single FunctionRegistry to register all three kinds of functions // (built-in, temp, and external). @@ -679,7 +849,8 @@ class SessionCatalog( } // If the name itself is not qualified, add the current database to it. - val qualifiedName = if (name.database.isEmpty) name.copy(database = Some(currentDb)) else name + val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName) + val qualifiedName = name.copy(database = database) if (functionRegistry.functionExists(qualifiedName.unquotedString)) { // This function has been already loaded into the function registry. @@ -695,7 +866,7 @@ class SessionCatalog( externalCatalog.getFunction(currentDb, name.funcName) } catch { case e: AnalysisException => failFunctionLookup(name.funcName) - case e: NoSuchFunctionException => failFunctionLookup(name.funcName) + case e: NoSuchPermanentFunctionException => failFunctionLookup(name.funcName) } loadFunctionResources(catalogFunction.resources) // Please note that qualifiedName is provided by the user. However, @@ -711,19 +882,29 @@ class SessionCatalog( } /** - * List all functions in the specified database, including temporary functions. + * List all functions in the specified database, including temporary functions. This + * returns the function identifier and the scope in which it was defined (system or user + * defined). */ - def listFunctions(db: String): Seq[FunctionIdentifier] = listFunctions(db, "*") + def listFunctions(db: String): Seq[(FunctionIdentifier, String)] = listFunctions(db, "*") /** - * List all matching functions in the specified database, including temporary functions. + * List all matching functions in the specified database, including temporary functions. This + * returns the function identifier and the scope in which it was defined (system or user + * defined). */ - def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = { - val dbFunctions = - externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } + def listFunctions(db: String, pattern: String): Seq[(FunctionIdentifier, String)] = { + val dbName = formatDatabaseName(db) + requireDbExists(dbName) + val dbFunctions = externalCatalog.listFunctions(dbName, pattern) + .map { f => FunctionIdentifier(f, Some(dbName)) } val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) .map { f => FunctionIdentifier(f) } - dbFunctions ++ loadedFunctions + val functions = dbFunctions ++ loadedFunctions + functions.map { + case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM") + case f => (f, "USER") + } } @@ -737,15 +918,15 @@ class SessionCatalog( * * This is mainly used for tests. */ - private[sql] def reset(): Unit = { - val default = "default" - listDatabases().filter(_ != default).foreach { db => + def reset(): Unit = synchronized { + setCurrentDatabase(DEFAULT_DATABASE) + listDatabases().filter(_ != DEFAULT_DATABASE).foreach { db => dropDatabase(db, ignoreIfNotExists = false, cascade = true) } - listTables(default).foreach { table => + listTables(DEFAULT_DATABASE).foreach { table => dropTable(table, ignoreIfNotExists = false) } - listFunctions(default).foreach { func => + listFunctions(DEFAULT_DATABASE).map(_._1).foreach { func => if (func.database.isDefined) { dropFunction(func, ignoreIfNotExists = false) } else { @@ -762,7 +943,6 @@ class SessionCatalog( require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder") functionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get) } - setCurrentDatabase(default) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala index 5adcc892cf682..8e46b962ff432 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.AnalysisException -/** An trait that represents the type of a resourced needed by a function. */ -sealed trait FunctionResourceType +/** A trait that represents the type of a resourced needed by a function. */ +abstract class FunctionResourceType(val resourceType: String) -object JarResource extends FunctionResourceType +object JarResource extends FunctionResourceType("jar") -object FileResource extends FunctionResourceType +object FileResource extends FunctionResourceType("file") -// We do not allow users to specify a archive because it is YARN specific. +// We do not allow users to specify an archive because it is YARN specific. // When loading resources, we will throw an exception and ask users to // use --archive with spark submit. -object ArchiveResource extends FunctionResourceType +object ArchiveResource extends FunctionResourceType("archive") object FunctionResourceType { def fromString(resourceType: String): FunctionResourceType = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 3851e4c706742..83428924b1600 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.catalog +import java.util.Date import javax.annotation.Nullable import org.apache.spark.sql.AnalysisException @@ -33,11 +34,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc" * @param resources resource types and Uris used by the function */ -// TODO: Use FunctionResource instead of (String, String) as the element type of resources. case class CatalogFunction( identifier: FunctionIdentifier, className: String, - resources: Seq[(String, String)]) + resources: Seq[FunctionResource]) /** @@ -48,8 +48,33 @@ case class CatalogStorageFormat( inputFormat: Option[String], outputFormat: Option[String], serde: Option[String], - serdeProperties: Map[String, String]) + compressed: Boolean, + serdeProperties: Map[String, String]) { + + override def toString: String = { + val serdePropsToString = + if (serdeProperties.nonEmpty) { + s"Properties: " + serdeProperties.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") + } else { + "" + } + val output = + Seq(locationUri.map("Location: " + _).getOrElse(""), + inputFormat.map("InputFormat: " + _).getOrElse(""), + outputFormat.map("OutputFormat: " + _).getOrElse(""), + if (compressed) "Compressed" else "", + serde.map("Serde: " + _).getOrElse(""), + serdePropsToString) + output.filter(_.nonEmpty).mkString("Storage(", ", ", ")") + } + +} +object CatalogStorageFormat { + /** Empty storage format for default values and copies. */ + val empty = CatalogStorageFormat(locationUri = None, inputFormat = None, + outputFormat = None, serde = None, compressed = false, serdeProperties = Map.empty) +} /** * A column in a table. @@ -60,18 +85,40 @@ case class CatalogColumn( // as a string due to issues in converting Hive varchars to and from SparkSQL strings. @Nullable dataType: String, nullable: Boolean = true, - comment: Option[String] = None) + comment: Option[String] = None) { + + override def toString: String = { + val output = + Seq(s"`$name`", + dataType, + if (!nullable) "NOT NULL" else "", + comment.map("(" + _ + ")").getOrElse("")) + output.filter(_.nonEmpty).mkString(" ") + } +} /** * A partition (Hive style) defined in the catalog. * * @param spec partition spec values indexed by column name * @param storage storage format of the partition + * @param parameters some parameters for the partition, for example, stats. */ case class CatalogTablePartition( spec: CatalogTypes.TablePartitionSpec, - storage: CatalogStorageFormat) + storage: CatalogStorageFormat, + parameters: Map[String, String] = Map.empty) { + + override def toString: String = { + val output = + Seq( + s"Partition Values: [${spec.values.mkString(", ")}]", + s"$storage", + s"Partition Parameters:{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}") + output.filter(_.nonEmpty).mkString("CatalogPartition(\n\t", "\n\t", ")") + } +} /** @@ -79,6 +126,9 @@ case class CatalogTablePartition( * * Note that Hive's metastore also tracks skewed columns. We should consider adding that in the * future once we have a better understanding of how we want to handle skewed columns. + * + * @param unsupportedFeatures is a list of string descriptions of features that are used by the + * underlying table but not supported by Spark SQL yet. */ case class CatalogTable( identifier: TableIdentifier, @@ -89,12 +139,14 @@ case class CatalogTable( sortColumnNames: Seq[String] = Seq.empty, bucketColumnNames: Seq[String] = Seq.empty, numBuckets: Int = -1, + owner: String = "", createTime: Long = System.currentTimeMillis, lastAccessTime: Long = -1, properties: Map[String, String] = Map.empty, viewOriginalText: Option[String] = None, viewText: Option[String] = None, - comment: Option[String] = None) { + comment: Option[String] = None, + unsupportedFeatures: Seq[String] = Seq.empty) { // Verify that the provided columns are part of the schema private val colNames = schema.map(_.name).toSet @@ -123,10 +175,37 @@ case class CatalogTable( locationUri: Option[String] = storage.locationUri, inputFormat: Option[String] = storage.inputFormat, outputFormat: Option[String] = storage.outputFormat, + compressed: Boolean = false, serde: Option[String] = storage.serde, serdeProperties: Map[String, String] = storage.serdeProperties): CatalogTable = { copy(storage = CatalogStorageFormat( - locationUri, inputFormat, outputFormat, serde, serdeProperties)) + locationUri, inputFormat, outputFormat, serde, compressed, serdeProperties)) + } + + override def toString: String = { + val tableProperties = properties.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") + val partitionColumns = partitionColumnNames.map("`" + _ + "`").mkString("[", ", ", "]") + val sortColumns = sortColumnNames.map("`" + _ + "`").mkString("[", ", ", "]") + val bucketColumns = bucketColumnNames.map("`" + _ + "`").mkString("[", ", ", "]") + + val output = + Seq(s"Table: ${identifier.quotedString}", + if (owner.nonEmpty) s"Owner: $owner" else "", + s"Created: ${new Date(createTime).toString}", + s"Last Access: ${new Date(lastAccessTime).toString}", + s"Type: ${tableType.name}", + if (schema.nonEmpty) s"Schema: ${schema.mkString("[", ", ", "]")}" else "", + if (partitionColumnNames.nonEmpty) s"Partition Columns: $partitionColumns" else "", + if (numBuckets != -1) s"Num Buckets: $numBuckets" else "", + if (bucketColumnNames.nonEmpty) s"Bucket Columns: $bucketColumns" else "", + if (sortColumnNames.nonEmpty) s"Sort Columns: $sortColumns" else "", + viewOriginalText.map("Original View: " + _).getOrElse(""), + viewText.map("View: " + _).getOrElse(""), + comment.map("Comment: " + _).getOrElse(""), + if (properties.nonEmpty) s"Properties: $tableProperties" else "", + s"$storage") + + output.filter(_.nonEmpty).mkString("CatalogTable(\n\t", "\n\t", ")") } } @@ -136,7 +215,6 @@ case class CatalogTableType private(name: String) object CatalogTableType { val EXTERNAL = new CatalogTableType("EXTERNAL") val MANAGED = new CatalogTableType("MANAGED") - val INDEX = new CatalogTableType("INDEX") val VIEW = new CatalogTableType("VIEW") } @@ -183,6 +261,8 @@ case class SimpleCatalogRelation( override def catalogTable: CatalogTable = metadata + override lazy val resolved: Boolean = false + override val output: Seq[Attribute] = { val cols = catalogTable.schema .filter { c => !catalogTable.partitionColumnNames.contains(c.name) } @@ -196,6 +276,7 @@ case class SimpleCatalogRelation( } } - require(metadata.identifier.database == Some(databaseName), + require( + metadata.identifier.database == Some(databaseName), "provided database does not match the one specified in the table definition") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index b5d10e4a584f3..2ca990d19a2cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 5d294485afd79..1fac26c4388a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,19 +17,18 @@ package org.apache.spark.sql.catalyst.encoders -import java.util.concurrent.ConcurrentMap - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} -import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} +import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, StructType} import org.apache.spark.util.Utils /** @@ -51,8 +50,15 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) - val serializer = ScalaReflection.serializerFor[T](inputObject) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val nullSafeInput = if (flat) { + inputObject + } else { + // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow top-level row to be null, only its columns can be null. + AssertNotNull(inputObject, Seq("top level non-flat input object")) + } + val serializer = ScalaReflection.serializerFor[T](nullSafeInput) val deserializer = ScalaReflection.deserializerFor[T] val schema = ScalaReflection.schemaFor[T] match { @@ -104,32 +110,51 @@ object ExpressionEncoder { val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val serializer = encoders.map { - case e if e.flat => e.serializer.head - case other => CreateStruct(other.serializer) - }.zipWithIndex.map { case (expr, index) => - expr.transformUp { - case BoundReference(0, t, _) => - Invoke( - BoundReference(0, ObjectType(cls), nullable = true), - s"_${index + 1}", - t) + val serializer = encoders.zipWithIndex.map { case (enc, index) => + val originalInputObject = enc.serializer.head.collect { case b: BoundReference => b }.head + val newInputObject = Invoke( + BoundReference(0, ObjectType(cls), nullable = true), + s"_${index + 1}", + originalInputObject.dataType) + + val newSerializer = enc.serializer.map(_.transformUp { + case b: BoundReference if b == originalInputObject => newInputObject + }) + + if (enc.flat) { + newSerializer.head + } else { + // For non-flat encoder, the input object is not top level anymore after being combined to + // a tuple encoder, thus it can be null and we should wrap the `CreateStruct` with `If` and + // null check to handle null case correctly. + // e.g. for Encoder[(Int, String)], the serializer expressions will create 2 columns, and is + // not able to handle the case when the input tuple is null. This is not a problem as there + // is a check to make sure the input object won't be null. However, if this encoder is used + // to create a bigger tuple encoder, the original input object becomes a filed of the new + // input tuple and can be null. So instead of creating a struct directly here, we should add + // a null/None check and return a null struct if the null/None check fails. + val struct = CreateStruct(newSerializer) + val nullCheck = Or( + IsNull(newInputObject), + Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) + If(nullCheck, Literal.create(null, struct.dataType), struct) } } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => if (enc.flat) { enc.deserializer.transform { - case b: BoundReference => b.copy(ordinal = index) + case g: GetColumnByOrdinal => g.copy(ordinal = index) } } else { - val input = BoundReference(index, enc.schema, nullable = true) - enc.deserializer.transformUp { + val input = GetColumnByOrdinal(index, enc.schema) + val deserialized = enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) + case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal) } + If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized) } } @@ -190,6 +215,39 @@ case class ExpressionEncoder[T]( if (flat) require(serializer.size == 1) + // serializer expressions are used to encode an object to a row, while the object is usually an + // intermediate value produced inside an operator, not from the output of the child operator. This + // is quite different from normal expressions, and `AttributeReference` doesn't work here + // (intermediate value is not an attribute). We assume that all serializer expressions use a same + // `BoundReference` to refer to the object, and throw exception if they don't. + assert(serializer.forall(_.references.isEmpty), "serializer cannot reference to any attributes.") + assert(serializer.flatMap { ser => + val boundRefs = ser.collect { case b: BoundReference => b } + assert(boundRefs.nonEmpty, + "each serializer expression should contains at least one `BoundReference`") + boundRefs + }.distinct.length <= 1, "all serializer expressions must use the same BoundReference.") + + /** + * Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the + * given schema. + * + * Note that, ideally encoder is used as a container of serde expressions, the resolution and + * binding stuff should happen inside query framework. However, in some cases we need to + * use encoder as a function to do serialization directly(e.g. Dataset.collect), then we can use + * this method to do resolution and binding outside of query framework. + */ + def resolveAndBind( + attrs: Seq[Attribute] = schema.toAttributes, + analyzer: Analyzer = SimpleAnalyzer): ExpressionEncoder[T] = { + val dummyPlan = CatalystSerde.deserialize(LocalRelation(attrs))(this) + val analyzedPlan = analyzer.execute(dummyPlan) + analyzer.checkAnalysis(analyzedPlan) + val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer + val bound = BindReferences.bindReference(resolved, attrs) + copy(deserializer = bound) + } + @transient private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @@ -199,16 +257,6 @@ case class ExpressionEncoder[T]( @transient private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) - /** - * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns - * is performed). - */ - def defaultBinding: ExpressionEncoder[T] = { - val attrs = schema.toAttributes - resolve(attrs, OuterScopes.outerScopes).bind(attrs) - } - - /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form * of this object. @@ -234,7 +282,7 @@ case class ExpressionEncoder[T]( /** * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must `resolve` and `bind` an encoder to a specific schema before you can call this + * you must `resolveAndBind` an encoder to a specific schema before you can call this * function. */ def fromRow(row: InternalRow): T = try { @@ -257,94 +305,6 @@ case class ExpressionEncoder[T]( }) } - /** - * Validates `deserializer` to make sure it can be resolved by given schema, and produce - * friendly error messages to explain why it fails to resolve if there is something wrong. - */ - def validate(schema: Seq[Attribute]): Unit = { - def fail(st: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + - " - Target schema: " + this.schema.simpleString) - } - - // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all - // `BoundReference`, make sure their ordinals are all valid. - var maxOrdinal = -1 - deserializer.foreach { - case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal - case _ => - } - if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { - fail(StructType.fromAttributes(schema), maxOrdinal) - } - - // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of - // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid. - // Note that, `BoundReference` contains the expected type, but here we need the actual type, so - // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after - // we resolve the `fromRowExpression`. - val resolved = SimpleAnalyzer.resolveExpression( - deserializer, - LocalRelation(schema), - throws = true) - - val unbound = resolved transform { - case b: BoundReference => schema(b.ordinal) - } - - val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] - unbound.foreach { - case g: GetStructField => - val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) - if (maxOrdinal < g.ordinal) { - exprToMaxOrdinal.update(g.child, g.ordinal) - } - case _ => - } - exprToMaxOrdinal.foreach { - case (expr, maxOrdinal) => - val schema = expr.dataType.asInstanceOf[StructType] - if (maxOrdinal != schema.length - 1) { - fail(schema, maxOrdinal) - } - } - } - - /** - * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema. - */ - def resolve( - schema: Seq[Attribute], - outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check - // analysis, go through optimizer, etc. - val plan = Project( - Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil, - LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - SimpleAnalyzer.checkAnalysis(analyzedPlan) - copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) - } - - /** - * Returns a copy of this encoder where the `deserializer` has been bound to the - * ordinals of the given schema. Note that you need to first call resolve before bind. - */ - def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(deserializer = BindReferences.bindReference(deserializer, schema)) - } - - /** - * Returns a new encoder with input columns shifted by `delta` ordinals - */ - def shift(delta: Int): ExpressionEncoder[T] = { - copy(deserializer = deserializer transform { - case r: BoundReference => r.copy(ordinal = r.ordinal + delta) - }) - } - protected val attrs = serializer.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 44e135cbf8354..2a6fcd03a26b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -25,24 +25,45 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** * A factory for constructing encoders that convert external row to/from the Spark SQL * internal binary representation. + * + * The following is a mapping between Spark SQL types and its allowed external types: + * {{{ + * BooleanType -> java.lang.Boolean + * ByteType -> java.lang.Byte + * ShortType -> java.lang.Short + * IntegerType -> java.lang.Integer + * FloatType -> java.lang.Float + * DoubleType -> java.lang.Double + * StringType -> String + * DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal + * + * DateType -> java.sql.Date + * TimestampType -> java.sql.Timestamp + * + * BinaryType -> byte array + * ArrayType -> scala.collection.Seq or Array + * MapType -> scala.collection.Map + * StructType -> org.apache.spark.sql.Row + * }}} */ object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - // We use an If expression to wrap extractorsFor result of StructType - val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue + val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema) val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, flat = false, - serializer.asInstanceOf[CreateStruct].children, + serializer.asInstanceOf[CreateNamedStruct].flatten, deserializer, ClassTag(cls)) } @@ -50,8 +71,7 @@ object RowEncoder { private def serializerFor( inputObject: Expression, inputType: DataType): Expression = inputType match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject + case dt if ScalaReflection.isNativeType(dt) => inputObject case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) @@ -69,7 +89,7 @@ object RowEncoder { udtClass, Nil, dataType = ObjectType(udtClass), false) - Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + Invoke(obj, "serialize", udt, inputObject :: Nil) case TimestampType => StaticInvoke( @@ -85,10 +105,10 @@ object RowEncoder { "fromJavaDate", inputObject :: Nil) - case _: DecimalType => + case d: DecimalType => StaticInvoke( Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, + d, "fromDecimal", inputObject :: Nil) @@ -101,11 +121,15 @@ object RowEncoder { case t @ ArrayType(et, _) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + // TODO: validate input type for primitive array. NewInstance( classOf[GenericArrayData], inputObject :: Nil, dataType = t) - case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et)) + case _ => MapObjects( + element => serializerFor(ValidateExternalType(element, et), et), + inputObject, + ObjectType(classOf[Object])) } case t @ MapType(kt, vt, valueNullable) => @@ -129,22 +153,31 @@ object RowEncoder { dataType = t) case StructType(fields) => - val convertedFields = fields.zipWithIndex.map { case (f, i) => - val method = if (f.dataType.isInstanceOf[StructType]) { - "getStruct" + val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => + val fieldValue = serializerFor( + ValidateExternalType( + GetExternalRowField(inputObject, index, field.name), + field.dataType), + field.dataType) + val convertedField = if (field.nullable) { + If( + Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), + Literal.create(null, field.dataType), + fieldValue + ) } else { - "get" + fieldValue } - If( - Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), - Literal.create(null, f.dataType), - serializerFor( - Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), - f.dataType)) + Literal(field.name) :: convertedField :: Nil + }) + + if (inputObject.nullable) { + If(IsNull(inputObject), + Literal.create(null, inputType), + nonNullOutput) + } else { + nonNullOutput } - If(IsNull(inputObject), - Literal.create(null, inputType), - CreateStruct(convertedFields)) } /** @@ -155,16 +188,17 @@ object RowEncoder { * can be `scala.math.BigDecimal`, `java.math.BigDecimal`, or * `org.apache.spark.sql.types.Decimal`. */ - private def externalDataTypeForInput(dt: DataType): DataType = dt match { - // In order to support both Decimal and java BigDecimal in external row, we make this + def externalDataTypeForInput(dt: DataType): DataType = dt match { + // In order to support both Decimal and java/scala BigDecimal in external row, we make this // as java.lang.Object. case _: DecimalType => ObjectType(classOf[java.lang.Object]) + // In order to support both Array and Seq in external row, we make this as java.lang.Object. + case _: ArrayType => ObjectType(classOf[java.lang.Object]) case _ => externalDataTypeFor(dt) } - private def externalDataTypeFor(dt: DataType): DataType = dt match { + def externalDataTypeFor(dt: DataType): DataType = dt match { case _ if ScalaReflection.isNativeType(dt) => dt - case CalendarIntervalType => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -172,8 +206,8 @@ object RowEncoder { case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) + case p: PythonUserDefinedType => externalDataTypeFor(p.sqlType) case udt: UserDefinedType[_] => ObjectType(udt.userClass) - case _: NullType => ObjectType(classOf[java.lang.Object]) } private def deserializerFor(schema: StructType): Expression = { @@ -182,19 +216,19 @@ object RowEncoder { case p: PythonUserDefinedType => p.sqlType case other => other } - val field = BoundReference(i, dt, f.nullable) - If( - IsNull(field), - Literal.create(null, externalDataTypeFor(dt)), - deserializerFor(field) - ) + deserializerFor(GetColumnByOrdinal(i, dt)) } CreateExternalRow(fields, schema) } - private def deserializerFor(input: Expression): Expression = input.dataType match { - case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType | CalendarIntervalType => input + private def deserializerFor(input: Expression): Expression = { + deserializerFor(input, input.dataType) + } + + private def deserializerFor(input: Expression, dataType: DataType): Expression = dataType match { + case dt if ScalaReflection.isNativeType(dt) => input + + case p: PythonUserDefinedType => deserializerFor(input, p.sqlType) case udt: UserDefinedType[_] => val annotation = udt.userClass.getAnnotation(classOf[SQLUserDefinedType]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index 03708fb7afd44..59f7969e56144 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -26,7 +26,7 @@ package object encoders { * references from a specific schema.) This requirement allows us to preserve whether a given * object type is being bound by name or by ordinal when doing resolution. */ - private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { + def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { case e: ExpressionEncoder[A] => e.assertUnresolved() e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0420b4b5387c0..0d45f371fa0cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.SparkException /** * Functions for attaching and retrieving trees that are associated with errors. @@ -47,7 +50,10 @@ package object errors { */ def attachTree[TreeType <: TreeNode[_], A](tree: TreeType, msg: String = "")(f: => A): A = { try f catch { - case e: Exception => throw new TreeNodeException(tree, msg, e) + // SPARK-16748: We do not want SparkExceptions from job failures in the planning phase + // to create TreeNodeException. Hence, wrap exception only if it is not SparkException. + case NonFatal(e) if !e.isInstanceOf[SparkException] => + throw new TreeNodeException(tree, msg, e) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index ef3cc554b79c0..96a11e352ec50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -26,13 +26,6 @@ object AttributeMap { def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } - - /** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to ordinal */ - def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = apply(schema.zipWithIndex) - - /** Given a schema, constructs a map from ordinal to Attribute. */ - def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] = - schema.zipWithIndex.map { case (a, i) => i -> a }.toMap } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 99f156a935b50..7d16118c9d59f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression { - override def toString: String = s"input[$ordinal, ${dataType.simpleString}]" + override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { @@ -82,16 +82,16 @@ object BindReferences extends Logging { def bindReference[A <: Expression]( expression: A, - input: Seq[Attribute], + input: AttributeSeq, allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { - val ordinal = input.indexWhere(_.exprId == a.exprId) + val ordinal = input.indexOf(a.exprId) if (ordinal == -1) { if (allowFailures) { a } else { - sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}") } } else { BoundReference(ordinal, a.dataType, input(ordinal).nullable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala new file mode 100644 index 0000000000000..fe24c0489fc98 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.lang.reflect.{Method, Modifier} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * An expression that invokes a method on a class via reflection. + * + * For now, only types defined in `Reflect.typeMapping` are supported (basically primitives + * and string) as input types, and the output is turned automatically to a string. + * + * Note that unlike Hive's reflect function, this expression calls only static methods + * (i.e. does not support calling non-static methods). + * + * We should also look into how to consolidate this expression with + * [[org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke]] in the future. + * + * @param children the first element should be a literal string for the class name, + * and the second element should be a literal string for the method name, + * and the remaining are input arguments to the Java method. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(class,method[,arg1[,arg2..]]) calls method with reflection", + extended = "> SELECT _FUNC_('java.util.UUID', 'randomUUID');\n c33fb387-8500-4bfa-81d2-6e0e3e930df2") +// scalastyle:on line.size.limit +case class CallMethodViaReflection(children: Seq[Expression]) + extends Expression with CodegenFallback { + + override def prettyName: String = "reflect" + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size < 2) { + TypeCheckFailure("requires at least two arguments") + } else if (!children.take(2).forall(e => e.dataType == StringType && e.foldable)) { + // The first two arguments must be string type. + TypeCheckFailure("first two arguments should be string literals") + } else if (!classExists) { + TypeCheckFailure(s"class $className not found") + } else if (method == null) { + TypeCheckFailure(s"cannot find a static method that matches the argument types in $className") + } else { + TypeCheckSuccess + } + } + + override def deterministic: Boolean = false + override def nullable: Boolean = true + override val dataType: DataType = StringType + + override def eval(input: InternalRow): Any = { + var i = 0 + while (i < argExprs.length) { + buffer(i) = argExprs(i).eval(input).asInstanceOf[Object] + // Convert if necessary. Based on the types defined in typeMapping, string is the only + // type that needs conversion. If we support timestamps, dates, decimals, arrays, or maps + // in the future, proper conversion needs to happen here too. + if (buffer(i).isInstanceOf[UTF8String]) { + buffer(i) = buffer(i).toString + } + i += 1 + } + val ret = method.invoke(null, buffer : _*) + UTF8String.fromString(String.valueOf(ret)) + } + + @transient private lazy val argExprs: Array[Expression] = children.drop(2).toArray + + /** Name of the class -- this has to be called after we verify children has at least two exprs. */ + @transient private lazy val className = children(0).eval().asInstanceOf[UTF8String].toString + + /** True if the class exists and can be loaded. */ + @transient private lazy val classExists = CallMethodViaReflection.classExists(className) + + /** The reflection method. */ + @transient lazy val method: Method = { + val methodName = children(1).eval(null).asInstanceOf[UTF8String].toString + CallMethodViaReflection.findMethod(className, methodName, argExprs.map(_.dataType)).orNull + } + + /** A temporary buffer used to hold intermediate results returned by children. */ + @transient private lazy val buffer = new Array[Object](argExprs.length) +} + +object CallMethodViaReflection { + /** Mapping from Spark's type to acceptable JVM types. */ + val typeMapping = Map[DataType, Seq[Class[_]]]( + BooleanType -> Seq(classOf[java.lang.Boolean], classOf[Boolean]), + ByteType -> Seq(classOf[java.lang.Byte], classOf[Byte]), + ShortType -> Seq(classOf[java.lang.Short], classOf[Short]), + IntegerType -> Seq(classOf[java.lang.Integer], classOf[Int]), + LongType -> Seq(classOf[java.lang.Long], classOf[Long]), + FloatType -> Seq(classOf[java.lang.Float], classOf[Float]), + DoubleType -> Seq(classOf[java.lang.Double], classOf[Double]), + StringType -> Seq(classOf[String]) + ) + + /** + * Returns true if the class can be found and loaded. + */ + private def classExists(className: String): Boolean = { + try { + Utils.classForName(className) + true + } catch { + case e: ClassNotFoundException => false + } + } + + /** + * Finds a Java static method using reflection that matches the given argument types, + * and whose return type is string. + * + * The types sequence must be the valid types defined in [[typeMapping]]. + * + * This is made public for unit testing. + */ + def findMethod(className: String, methodName: String, argTypes: Seq[DataType]): Option[Method] = { + val clazz: Class[_] = Utils.classForName(className) + clazz.getMethods.find { method => + val candidateTypes = method.getParameterTypes + if (method.getName != methodName) { + // Name must match + false + } else if (!Modifier.isStatic(method.getModifiers)) { + // Method must be static + false + } else if (candidateTypes.length != argTypes.length) { + // Argument length must match + false + } else { + // Argument type must match. That is, either the method's argument type matches one of the + // acceptable types defined in typeMapping, or it is a super type of the acceptable types. + candidateTypes.zip(argTypes).forall { case (candidateType, argType) => + typeMapping(argType).exists(candidateType.isAssignableFrom) + } + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b1e89b5de833f..a53468cfd3011 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -52,7 +52,8 @@ object Cast { case (DateType, TimestampType) => true case (_: NumericType, TimestampType) => true - case (_, DateType) => true + case (StringType, DateType) => true + case (TimestampType, DateType) => true case (StringType, CalendarIntervalType) => true @@ -112,6 +113,9 @@ object Cast { } /** Cast the child expression to the target data type. */ +@ExpressionDescription( + usage = " - Cast value v to the target data type.", + extended = "> SELECT _FUNC_('10' as int);\n 10") case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant { override def toString: String = s"cast($child as ${dataType.simpleString})" @@ -228,18 +232,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 1000L)) - // Hive throws this exception as a Semantic Exception - // It is never possible to compare result when hive return with exception, - // so we can return null - // NULL is more reasonable here, since the query itself obeys the grammar. - case _ => _ => null } // IntervalConverter private[this] def castToInterval(from: DataType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => CalendarInterval.fromString(s.toString)) - case _ => _ => null } // LongConverter @@ -418,7 +416,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } private[this] def cast(from: DataType, to: DataType): Any => Any = to match { - case dt if dt == child.dataType => identity[Any] + case dt if dt == from => identity[Any] case StringType => castToString(from) case BinaryType => castToBinary(from) case DateType => castToDate(from) @@ -659,7 +657,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"$evPrim = CalendarInterval.fromString($c.toString());" + s"""$evPrim = CalendarInterval.fromString($c.toString()); + if(${evPrim} == null) { + ${evNull} = true; + } + """.stripMargin + } private[this] def decimalToTimestampCode(d: String): String = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index d0ad7a05a0c37..b8e2b67b2fe9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -68,7 +68,10 @@ class EquivalentExpressions { * is found. That is, if `expr` has already been added, its children are not added. * If ignoreLeaf is true, leaf nodes are ignored. */ - def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { + def addExprTree( + root: Expression, + ignoreLeaf: Boolean = true, + skipReferenceToExpressions: Boolean = true): Unit = { val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf // There are some special expressions that we should not recurse into children. // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) @@ -77,7 +80,7 @@ class EquivalentExpressions { // TODO: some expressions implements `CodegenFallback` but can still do codegen, // e.g. `CaseWhen`, we should support them. case _: CodegenFallback => false - case _: ReferenceToExpressions => false + case _: ReferenceToExpressions if skipReferenceToExpressions => false case _ => true } if (!skip && !addExpr(root) && shouldRecurse) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index b3dfac806f7fe..98f25a9ad7597 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types.AbstractDataType /** - * An trait that gets mixin to define the expected input types of an expression. + * A trait that gets mixin to define the expected input types of an expression. * * This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define * expected input types without any implicit casting. @@ -57,7 +57,8 @@ trait ExpectsInputTypes extends Expression { /** - * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]]. + * A mixin for the analyzer to perform implicit type casting using + * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]]. */ trait ImplicitCastInputTypes extends ExpectsInputTypes { // No other methods diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c26faee2f4876..0f6a896a72be5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,8 +21,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the basic expression abstract classes in Catalyst. @@ -97,15 +97,14 @@ abstract class Expression extends TreeNode[Expression] { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated which means that the code to evaluate it has already been added // as a function before. In that case, we just re-use it. - val code = s"/* ${toCommentSafeString(this.toString)} */" - ExprCode(code, subExprState.isNull, subExprState.value) + ExprCode(ctx.registerComment(this.toString), subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val ve = doGenCode(ctx, ExprCode("", isNull, value)) if (ve.code.nonEmpty) { // Add `this` in the comment. - ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim) + ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim) } else { ve } @@ -187,14 +186,19 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyName: String = nodeName.toLowerCase - private def flatArguments = productIterator.flatMap { + protected def flatArguments = productIterator.flatMap { case t: Traversable[_] => t case single => single :: Nil } + // Marks this as final, Expression.verboseString should never be called, and thus shouldn't be + // overridden by concrete classes. + final override def verboseString: String = simpleString + override def simpleString: String = toString - override def toString: String = prettyName + flatArguments.mkString("(", ", ", ")") + override def toString: String = prettyName + Utils.truncatedString( + flatArguments.toSeq, "(", ", ", ")") /** * Returns SQL representation of this expression. For expressions extending [[NonSQLExpression]], @@ -221,6 +225,33 @@ trait Unevaluable extends Expression { } +/** + * An expression that gets replaced at runtime (currently by the optimizer) into a different + * expression for evaluation. This is mainly used to provide compatibility with other databases. + * For example, we use this to support "nvl" by replacing it with "coalesce". + */ +trait RuntimeReplaceable extends Unevaluable { + /** + * Method for concrete implementations to override that specifies how to construct the expression + * that should replace the current one. + */ + def replaceForEvaluation(): Expression + + /** + * Method for concrete implementations to override that specifies how to coerce the input types. + */ + def replaceForTypeCoercion(): Expression + + /** The expression that should be used during evaluation. */ + lazy val replaced: Expression = replaceForEvaluation() + + override def nullable: Boolean = replaced.nullable + override def foldable: Boolean = replaced.foldable + override def dataType: DataType = replaced.dataType + override def checkInputDataTypes(): TypeCheckResult = replaced.checkInputDataTypes() +} + + /** * Expressions that don't have SQL representation should extend this trait. Examples are * `ScalaUDF`, `ScalaUDAF`, and object expressions like `MapObjects` and `Invoke`. @@ -346,6 +377,7 @@ abstract class UnaryExpression extends Expression { """) } else { ev.copy(code = s""" + boolean ${ev.isNull} = false; ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode""", isNull = "false") @@ -444,6 +476,7 @@ abstract class BinaryExpression extends Expression { """) } else { ev.copy(code = s""" + boolean ${ev.isNull} = false; ${leftGen.code} ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -478,7 +511,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { override def checkInputDataTypes(): TypeCheckResult = { // First check whether left and right have the same type, then check if the type is acceptable. - if (left.dataType != right.dataType) { + if (!left.dataType.sameType(right.dataType)) { TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " + s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") } else if (!inputType.acceptsType(left.dataType)) { @@ -493,7 +526,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { } -private[sql] object BinaryOperator { +object BinaryOperator { def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } @@ -586,6 +619,7 @@ abstract class TernaryExpression extends Expression { $nullSafeEval""") } else { ev.copy(code = s""" + boolean ${ev.isNull} = false; ${leftGen.code} ${midGen.code} ${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 75c6bb2d84dfb..5b4922e0cf2b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.{DataType, LongType} represent the record number within each partition. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records.""", extended = "> SELECT _FUNC_();\n 0") -private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { +case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { /** * Record ID within each partition. By being transient, count's value is reset to 0 every time diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 27ad8e4cf22ce..c8d18667f7c4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -158,7 +158,7 @@ object UnsafeProjection { object FromUnsafeProjection { /** - * Returns an Projection for given StructType. + * Returns a Projection for given StructType. */ def apply(schema: StructType): Projection = { apply(schema.fields.map(_.dataType)) @@ -172,7 +172,7 @@ object FromUnsafeProjection { } /** - * Returns an Projection for given sequence of Expressions (bounded). + * Returns a Projection for given sequence of Expressions (bounded). */ private def create(exprs: Seq[Expression]): Projection = { GenerateSafeProjection.generate(exprs) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index c4cc6c39b0477..502d791c6e85c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable import org.apache.spark.sql.types.DataType /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 0038cf65e2993..6cfdea9fdf9c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.DataType /** * User-defined function. + * Note that the user-defined functions must be deterministic. * @param function The user defined scala function to run. * Note that if you use primitive parameters, you are not able to check if it is * null or not, and the UDF will return null for you if the primitive input is @@ -993,20 +995,15 @@ case class ScalaUDF( ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.references += this - - val scalaUDFClassName = classOf[ScalaUDF].getName + val scalaUDF = ctx.addReferenceObj("scalaUDF", this) val converterClassName = classOf[Any => Any].getName val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - val expressionClassName = classOf[Expression].getName // Generate codes used to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") - val catalystConverterTermIdx = ctx.references.size - 1 ctx.addMutableState(converterClassName, catalystConverterTerm, s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToCatalystConverter((($scalaUDFClassName)references" + - s"[$catalystConverterTermIdx]).dataType());") + s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1018,10 +1015,8 @@ case class ScalaUDF( val funcClassName = s"scala.Function${children.size}" val funcTerm = ctx.freshName("udf") - val funcExpressionIdx = ctx.references.size - 1 ctx.addMutableState(funcClassName, funcTerm, - s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" + - s"[$funcExpressionIdx]).userDefinedFunc());") + s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1038,9 +1033,16 @@ case class ScalaUDF( (convert, argTerm) }.unzip - val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + - s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + - s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" + val callFunc = + s""" + ${ctx.boxedType(dataType)} $resultTerm = null; + try { + $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + } catch (Exception e) { + throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); + } + """ ev.copy(code = s""" $evalCode @@ -1056,5 +1058,20 @@ case class ScalaUDF( private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) - override def eval(input: InternalRow): Any = converter(f(input)) + lazy val udfErrorMessage = { + val funcCls = function.getClass.getSimpleName + val inputTypes = children.map(_.dataType.simpleString).mkString(", ") + s"Failed to execute user defined function($funcCls: ($inputTypes) => ${dataType.simpleString})" + } + + override def eval(input: InternalRow): Any = { + val result = try { + f(input) + } catch { + case e: Exception => + throw new SparkException(udfErrorMessage, e) + } + + converter(result) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 71af59a7a8529..1f675d5b07270 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType} @ExpressionDescription( usage = "_FUNC_() - Returns the current partition id of the Spark task", extended = "> SELECT _FUNC_();\n 0") -private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic { +case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 83fa447cf8c85..7ff61ee479452 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.commons.lang.StringUtils +import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -45,12 +45,12 @@ case class TimeWindow( slideDuration: Expression, startTime: Expression) = { this(timeColumn, TimeWindow.parseExpression(windowDuration), - TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(startTime)) + TimeWindow.parseExpression(slideDuration), TimeWindow.parseExpression(startTime)) } def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = { this(timeColumn, TimeWindow.parseExpression(windowDuration), - TimeWindow.parseExpression(windowDuration), 0) + TimeWindow.parseExpression(slideDuration), 0) } def this(timeColumn: Expression, windowDuration: Expression) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index b8ab0364dd8f3..d702c08cfd342 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ /** * Returns the first value of `child` for a group of rows. If the first value of `child` - * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already + * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on an already * sorted column, if we do partial aggregation and final aggregation (when mergeExpression * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). @@ -43,7 +43,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara throw new AnalysisException("The second argument of First should be a boolean literal.") } - override def children: Seq[Expression] = child :: Nil + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true @@ -54,7 +54,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara override def dataType: DataType = child.dataType // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType) private lazy val first = AttributeReference("first", child.dataType)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index b05d74b49b591..8579f7292d3ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ /** * Returns the last value of `child` for a group of rows. If the last value of `child` - * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already + * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on an already * sorted column, if we do partial aggregation and final aggregation (when mergeExpression * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). @@ -40,7 +40,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat throw new AnalysisException("The second argument of First should be a boolean literal.") } - override def children: Seq[Expression] = child :: Nil + override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil override def nullable: Boolean = true @@ -51,38 +51,39 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat override def dataType: DataType = child.dataType // Expected input data type. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType) private lazy val last = AttributeReference("last", child.dataType)() - override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: valueSet :: Nil override lazy val initialValues: Seq[Literal] = Seq( - /* last = */ Literal.create(null, child.dataType) + /* last = */ Literal.create(null, child.dataType), + /* valueSet = */ Literal.create(false, BooleanType) ) override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( - /* last = */ If(IsNull(child), last, child) + /* last = */ If(IsNull(child), last, child), + /* valueSet = */ Or(valueSet, IsNotNull(child)) ) } else { Seq( - /* last = */ child + /* last = */ child, + /* valueSet = */ Literal.create(true, BooleanType) ) } } override lazy val mergeExpressions: Seq[Expression] = { - if (ignoreNulls) { - Seq( - /* last = */ If(IsNull(last.right), last.left, last.right) - ) - } else { - Seq( - /* last = */ last.right - ) - } + // Prefer the right hand expression if it has been set. + Seq( + /* last = */ If(valueSet.right, last.right, last.left), + /* valueSet = */ Or(valueSet.right, valueSet.left) + ) } override lazy val evaluateExpression: AttributeReference = last diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala new file mode 100644 index 0000000000000..16c03c500ad08 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import scala.collection.immutable.HashMap + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ + +object PivotFirst { + + def supportsDataType(dataType: DataType): Boolean = updateFunction.isDefinedAt(dataType) + + // Currently UnsafeRow does not support the generic update method (throws + // UnsupportedOperationException), so we need to explicitly support each DataType. + private val updateFunction: PartialFunction[DataType, (MutableRow, Int, Any) => Unit] = { + case DoubleType => + (row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double]) + case IntegerType => + (row, offset, value) => row.setInt(offset, value.asInstanceOf[Int]) + case LongType => + (row, offset, value) => row.setLong(offset, value.asInstanceOf[Long]) + case FloatType => + (row, offset, value) => row.setFloat(offset, value.asInstanceOf[Float]) + case BooleanType => + (row, offset, value) => row.setBoolean(offset, value.asInstanceOf[Boolean]) + case ShortType => + (row, offset, value) => row.setShort(offset, value.asInstanceOf[Short]) + case ByteType => + (row, offset, value) => row.setByte(offset, value.asInstanceOf[Byte]) + case d: DecimalType => + (row, offset, value) => row.setDecimal(offset, value.asInstanceOf[Decimal], d.precision) + } +} + +/** + * PivotFirst is an aggregate function used in the second phase of a two phase pivot to do the + * required rearrangement of values into pivoted form. + * + * For example on an input of + * A | B + * --+-- + * x | 1 + * y | 2 + * z | 3 + * + * with pivotColumn=A, valueColumn=B, and pivotColumnValues=[z,y] the output is [3,2]. + * + * @param pivotColumn column that determines which output position to put valueColumn in. + * @param valueColumn the column that is being rearranged. + * @param pivotColumnValues the list of pivotColumn values in the order of desired output. Values + * not listed here will be ignored. + */ +case class PivotFirst( + pivotColumn: Expression, + valueColumn: Expression, + pivotColumnValues: Seq[Any], + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends ImperativeAggregate { + + override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil + + override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType) + + override val nullable: Boolean = false + + val valueDataType = valueColumn.dataType + + override val dataType: DataType = ArrayType(valueDataType) + + val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*) + + val indexSize = pivotIndex.size + + private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType) + + override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = { + val pivotColValue = pivotColumn.eval(inputRow) + if (pivotColValue != null) { + // We ignore rows whose pivot column value is not in the list of pivot column values. + val index = pivotIndex.getOrElse(pivotColValue, -1) + if (index >= 0) { + val value = valueColumn.eval(inputRow) + if (value != null) { + updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) + } + } + } + } + + override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = { + for (i <- 0 until indexSize) { + if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) { + val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType) + updateRow(mutableAggBuffer, mutableAggBufferOffset + i, value) + } + } + } + + override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match { + case d: DecimalType => + // Per doc of setDecimal we need to do this instead of setNullAt for DecimalType. + for (i <- 0 until indexSize) { + mutableAggBuffer.setDecimal(mutableAggBufferOffset + i, null, d.precision) + } + case _ => + for (i <- 0 until indexSize) { + mutableAggBuffer.setNullAt(mutableAggBufferOffset + i) + } + } + + override def eval(input: InternalRow): Any = { + val result = new Array[Any](indexSize) + for (i <- 0 until indexSize) { + result(i) = input.get(mutableAggBufferOffset + i, valueDataType) + } + new GenericArrayData(result) + } + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + + override val aggBufferAttributes: Seq[AttributeReference] = + pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)()) + + override val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala new file mode 100644 index 0000000000000..78a388d20630b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import scala.collection.generic.Growable +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** + * The Collect aggregate function collects all seen expression values into a list of values. + * + * The operator is bound to the slower sort based aggregation path because the number of + * elements (and their memory usage) can not be determined in advance. This also means that the + * collected elements are stored on heap, and that too many elements can cause GC pauses and + * eventually Out of Memory Errors. + */ +abstract class Collect extends ImperativeAggregate { + + val child: Expression + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = ArrayType(child.dataType) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + override def supportsPartial: Boolean = false + + override def aggBufferAttributes: Seq[AttributeReference] = Nil + + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + override def inputAggBufferAttributes: Seq[AttributeReference] = Nil + + // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the + // actual order of input rows. + override def deterministic: Boolean = false + + protected[this] val buffer: Growable[Any] with Iterable[Any] + + override def initialize(b: MutableRow): Unit = { + buffer.clear() + } + + override def update(b: MutableRow, input: InternalRow): Unit = { + // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. + // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator + val value = child.eval(input) + if (value != null) { + buffer += value + } + } + + override def merge(buffer: MutableRow, input: InternalRow): Unit = { + sys.error("Collect cannot be used in partial aggregations.") + } + + override def eval(input: InternalRow): Any = { + new GenericArrayData(buffer.toArray) + } +} + +/** + * Collect a list of elements. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements.") +case class CollectList( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends Collect { + + def this(child: Expression) = this(child, 0, 0) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "collect_list" + + override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty +} + +/** + * Collect a list of unique elements. + */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Collects and returns a set of unique elements.") +case class CollectSet( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends Collect { + + def this(child: Expression) = this(child, 0, 0) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("collect_set() cannot have map type data") + } + } + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def prettyName: String = "collect_set" + + override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d31ccf9985360..7a39e568fa289 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -24,14 +24,14 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ -private[sql] sealed trait AggregateMode +sealed trait AggregateMode /** * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ -private[sql] case object Partial extends AggregateMode +case object Partial extends AggregateMode /** * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers @@ -39,7 +39,7 @@ private[sql] case object Partial extends AggregateMode * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. */ -private[sql] case object PartialMerge extends AggregateMode +case object PartialMerge extends AggregateMode /** * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers @@ -47,7 +47,7 @@ private[sql] case object PartialMerge extends AggregateMode * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. */ -private[sql] case object Final extends AggregateMode +case object Final extends AggregateMode /** * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly @@ -55,13 +55,13 @@ private[sql] case object Final extends AggregateMode * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. */ -private[sql] case object Complete extends AggregateMode +case object Complete extends AggregateMode /** * A place holder expressions used in code-gen, it does not change the corresponding value * in the row. */ -private[sql] case object NoOp extends Expression with Unevaluable { +case object NoOp extends Expression with Unevaluable { override def nullable: Boolean = true override def dataType: DataType = NullType override def children: Seq[Expression] = Nil @@ -84,7 +84,7 @@ object AggregateExpression { * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. */ -private[sql] case class AggregateExpression( +case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, isDistinct: Boolean, @@ -126,7 +126,14 @@ private[sql] case class AggregateExpression( AttributeSet(childReferences) } - override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" + override def toString: String = { + val prefix = mode match { + case Partial => "partial_" + case PartialMerge => "merge_" + case Final | Complete => "" + } + prefix + aggregateFunction.toAggString(isDistinct) + } override def sql: String = aggregateFunction.sql(isDistinct) } @@ -203,6 +210,12 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu val distinct = if (isDistinct) "DISTINCT " else "" s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" } + + /** String representation used in explain plans. */ + def toAggString(isDistinct: Boolean): String = { + val start = if (isDistinct) "(distinct " else "(" + prettyName + flatArguments.mkString(start, ", ", ")") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b2df79a58884b..01c5d8217074f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression } } - override def sql: String = s"(-${child.sql})" + override def sql: String = s"(- ${child.sql})" } @ExpressionDescription( @@ -75,7 +75,7 @@ case class UnaryPositive(child: Expression) protected override def nullSafeEval(input: Any): Any = input - override def sql: String = s"(+${child.sql})" + override def sql: String = s"(+ ${child.sql})" } /** @@ -125,7 +125,7 @@ abstract class BinaryArithmetic extends BinaryOperator { } } -private[sql] object BinaryArithmetic { +object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } @@ -213,7 +213,7 @@ case class Multiply(left: Expression, right: Expression) case class Divide(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { - override def inputType: AbstractDataType = NumericType + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) override def symbol: String = "/" override def decimalMethod: String = "$div" @@ -221,7 +221,6 @@ case class Divide(left: Expression, right: Expression) private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot } override def eval(input: InternalRow): Any = { @@ -310,7 +309,11 @@ case class Remainder(left: Expression, right: Expression) if (input1 == null) { null } else { - integral.rem(input1, input2) + input1 match { + case d: Double => d % input2.asInstanceOf[java.lang.Double] + case f: Float => f % input2.asInstanceOf[java.lang.Float] + case _ => integral.rem(input1, input2) + } } } } @@ -499,34 +502,35 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic wi override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val remainder = ctx.freshName("remainder") dataType match { case dt: DecimalType => val decimalAdd = "$plus" s""" - ${ctx.javaType(dataType)} r = $eval1.remainder($eval2); - if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { - ${ev.value} = (r.$decimalAdd($eval2)).remainder($eval2); + ${ctx.javaType(dataType)} $remainder = $eval1.remainder($eval2); + if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { + ${ev.value} = ($remainder.$decimalAdd($eval2)).remainder($eval2); } else { - ${ev.value} = r; + ${ev.value} = $remainder; } """ // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => s""" - ${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2); - if (r < 0) { - ${ev.value} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2); + ${ctx.javaType(dataType)} $remainder = (${ctx.javaType(dataType)})($eval1 % $eval2); + if ($remainder < 0) { + ${ev.value} = (${ctx.javaType(dataType)})(($remainder + $eval2) % $eval2); } else { - ${ev.value} = r; + ${ev.value} = $remainder; } """ case _ => s""" - ${ctx.javaType(dataType)} r = $eval1 % $eval2; - if (r < 0) { - ${ev.value} = (r + $eval2) % $eval2; + ${ctx.javaType(dataType)} $remainder = $eval1 % $eval2; + if ($remainder < 0) { + ${ev.value} = ($remainder + $eval2) % $eval2; } else { - ${ev.value} = r; + ${ev.value} = $remainder; } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index ab4831f7abdd0..05b7c96e44c02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import java.util.regex.Matcher + /** * An utility class that indents a block of code based on the curly braces and parentheses. * This is used to prettify generated code when in debug mode (or exceptions). @@ -24,13 +26,25 @@ package org.apache.spark.sql.catalyst.expressions.codegen * Written by Matei Zaharia. */ object CodeFormatter { - def format(code: String): String = new CodeFormatter().addLines(code).result() + val commentHolder = """\/\*(.+?)\*\/""".r + + def format(code: CodeAndComment): String = { + val formatter = new CodeFormatter + code.body.split("\n").foreach { line => + val commentReplaced = commentHolder.replaceAllIn( + line.trim, + m => code.comment.get(m.group(1)).map(Matcher.quoteReplacement).getOrElse(m.group(0))) + formatter.addLine(commentReplaced) + } + formatter.result() + } + def stripExtraNewLines(input: String): String = { val code = new StringBuilder var lastLine: String = "dummy" input.split('\n').foreach { l => val line = l.trim() - val skip = line == "" && (lastLine == "" || lastLine.endsWith("{")) + val skip = line == "" && (lastLine == "" || lastLine.endsWith("{") || lastLine.endsWith("*/")) if (!skip) { code.append(line) code.append("\n") @@ -39,6 +53,36 @@ object CodeFormatter { } code.result() } + + def stripOverlappingComments(codeAndComment: CodeAndComment): CodeAndComment = { + val code = new StringBuilder + val map = codeAndComment.comment + + def getComment(line: String): Option[String] = { + if (line.startsWith("/*") && line.endsWith("*/")) { + map.get(line.substring(2, line.length - 2)) + } else { + None + } + } + + var lastLine: String = "dummy" + codeAndComment.body.split('\n').foreach { l => + val line = l.trim() + + val skip = getComment(lastLine).zip(getComment(line)).exists { + case (lastComment, currentComment) => + lastComment.substring(3).contains(currentComment.substring(3)) + } + + if (!skip) { + code.append(line).append("\n") + } + + lastLine = line + } + new CodeAndComment(code.result().trim(), map) + } } private class CodeFormatter { @@ -89,19 +133,18 @@ private class CodeFormatter { } else { indentString } - code.append(f"/* ${currentLine}%03d */ ") - code.append(thisLineIndent) - code.append(line) + code.append(f"/* ${currentLine}%03d */") + if (line.trim().length > 0) { + code.append(" ") // add a space after the line number comment. + code.append(thisLineIndent) + if (inCommentBlock && line.startsWith("*") || line.startsWith("*/")) code.append(" ") + code.append(line) + } code.append("\n") indentLevel = newIndentLevel indentString = " " * (indentSize * newIndentLevel) currentLine += 1 } - private def addLines(code: String): CodeFormatter = { - code.split('\n').foreach(s => addLine(s.trim())) - this - } - private def result(): String = code.result() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e4fa429b37546..929f2da07531e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -17,21 +17,29 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import java.io.ByteArrayInputStream +import java.util.{Map => JavaMap} + +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.language.existentials +import scala.util.control.NonFatal import com.google.common.cache.{CacheBuilder, CacheLoader} -import org.codehaus.janino.ClassBodyEvaluator +import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler} +import org.codehaus.janino.util.ClassFile +import scala.language.existentials +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ParentClassLoader, Utils} /** * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input. @@ -46,6 +54,25 @@ import org.apache.spark.util.Utils */ case class ExprCode(var code: String, var isNull: String, var value: String) +/** + * State used for subexpression elimination. + * + * @param isNull A term that holds a boolean value representing whether the expression evaluated + * to null. + * @param value A term for a value of a common sub-expression. Not valid if `isNull` + * is set to `true`. + */ +case class SubExprEliminationState(isNull: String, value: String) + +/** + * Codes and common subexpressions mapping used for subexpression elimination. + * + * @param codes Strings representing the codes that evaluate common subexpressions. + * @param states Foreach expression that is participating in subexpression elimination, + * the state to use. + */ +case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState]) + /** * A context for codegen, tracking a list of objects that could be passed into generated Java * function. @@ -109,6 +136,22 @@ class CodegenContext { mutableStates += ((javaType, variableName, initCode)) } + /** + * Add buffer variable which stores data coming from an [[InternalRow]]. This methods guarantees + * that the variable is safely stored, which is important for (potentially) byte array backed + * data types like: UTF8String, ArrayData, MapData & InternalRow. + */ + def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { + val value = freshName(variableName) + addMutableState(javaType(dataType), value, "") + val code = dataType match { + case StringType => s"$value = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" + case _ => s"$value = $initCode;" + } + ExprCode(code, "false", value) + } + def declareMutableStates(): String = { // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. @@ -148,9 +191,6 @@ class CodegenContext { */ val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions - // State used for subexpression elimination. - case class SubExprEliminationState(isNull: String, value: String) - // Foreach expression that is participating in subexpression elimination, the state to use. val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] @@ -183,6 +223,11 @@ class CodegenContext { */ var freshNamePrefix = "" + /** + * The map from a place holder to a corresponding comment + */ + private val placeHolderToComments = new mutable.HashMap[String, String] + /** * Returns a term name that is unique within this instance of a `CodegenContext`. */ @@ -468,6 +513,7 @@ class CodegenContext { addNewFunction(compareFunc, funcCode) s"this.$compareFunc($c1, $c2)" case schema: StructType => + INPUT_ROW = "i" val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") val funcCode: String = @@ -539,11 +585,18 @@ class CodegenContext { * @param expressions the codes to evaluate expressions. */ def splitExpressions(row: String, expressions: Seq[String]): String = { + if (row == null || currentVars != null) { + // Cannot split these expressions because they are not created from a row object. + return expressions.mkString("\n") + } val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() for (code <- expressions) { - // We can't know how many byte code will be generated, so use the number of bytes as limit - if (blockBuilder.length > 64 * 1000) { + // We can't know how many bytecode will be generated, so use the length of source code + // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should + // also not be too small, or it will have many function calls (for wide table), see the + // results in BenchmarkWideTable. + if (blockBuilder.length > 1024) { blocks.append(blockBuilder.toString()) blockBuilder.clear() } @@ -571,6 +624,58 @@ class CodegenContext { } } + /** + * Perform a function which generates a sequence of ExprCodes with a given mapping between + * expressions and common expressions, instead of using the mapping in current context. + */ + def withSubExprEliminationExprs( + newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])( + f: => Seq[ExprCode]): Seq[ExprCode] = { + val oldsubExprEliminationExprs = subExprEliminationExprs + subExprEliminationExprs.clear + newSubExprEliminationExprs.foreach(subExprEliminationExprs += _) + + val genCodes = f + + // Restore previous subExprEliminationExprs + subExprEliminationExprs.clear + oldsubExprEliminationExprs.foreach(subExprEliminationExprs += _) + genCodes + } + + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpressions, generates the code snippets that evaluate those expressions and + * populates the mapping of common subexpressions to the generated code snippets. The generated + * code snippets will be returned and should be inserted into generated codes before these + * common subexpressions actually are used first time. + */ + def subexpressionEliminationForWholeStageCodegen(expressions: Seq[Expression]): SubExprCodes = { + // Create a clear EquivalentExpressions and SubExprEliminationState mapping + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] + + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_, true, false)) + + // Get all the expressions that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + val codes = commonExprs.map { e => + val expr = e.head + val fnName = freshName("evalExpr") + val isNull = s"${fnName}IsNull" + val value = s"${fnName}Value" + + // Generate the code for this expression tree. + val code = expr.genCode(this) + val state = SubExprEliminationState(code.isNull, code.value) + e.foreach(subExprEliminationExprs.put(_, state)) + code.code.trim + } + SubExprCodes(codes, subExprEliminationExprs.toMap) + } + /** * Checks and sets up the state and codegen for subexpression elimination. This finds the * common subexpressions, generates the functions that evaluate those expressions and populates @@ -638,6 +743,33 @@ class CodegenContext { if (doSubexpressionElimination) subexpressionElimination(expressions) expressions.map(e => e.genCode(this)) } + + /** + * get a map of the pair of a place holder and a corresponding comment + */ + def getPlaceHolderToComments(): collection.Map[String, String] = placeHolderToComments + + /** + * Register a comment and return the corresponding place holder + */ + def registerComment(text: => String): String = { + // By default, disable comments in generated code because computing the comments themselves can + // be extremely expensive in certain cases, such as deeply-nested expressions which operate over + // inputs with wide schemas. For more details on the performance issues that motivated this + // flat, see SPARK-15680. + if (SparkEnv.get != null && SparkEnv.get.conf.getBoolean("spark.sql.codegen.comments", false)) { + val name = freshName("c") + val comment = if (text.contains("\n") || text.contains("\r")) { + text.split("(\r\n)|\r|\n").mkString("/**\n * ", "\n * ", "\n */") + } else { + s"// $text" + } + placeHolderToComments += (name -> comment) + s"/*$name*/" + } else { + "" + } + } } /** @@ -648,6 +780,19 @@ abstract class GeneratedClass { def generate(references: Array[Any]): Any } +/** + * A wrapper for the source code to be compiled by [[CodeGenerator]]. + */ +class CodeAndComment(val body: String, val comment: collection.Map[String, String]) + extends Serializable { + override def equals(that: Any): Boolean = that match { + case t: CodeAndComment if t.body == body => true + case _ => false + } + + override def hashCode(): Int = body.hashCode +} + /** * A base class for generators of byte code to perform expression evaluation. Includes a set of * helpers for referring to Catalyst types and building trees that perform evaluation of individual @@ -692,16 +837,26 @@ object CodeGenerator extends Logging { /** * Compile the Java source code into a Java class, using Janino. */ - def compile(code: String): GeneratedClass = { + def compile(code: CodeAndComment): GeneratedClass = { cache.get(code) } /** * Compile the Java source code into a Java class, using Janino. */ - private[this] def doCompile(code: String): GeneratedClass = { + private[this] def doCompile(code: CodeAndComment): GeneratedClass = { val evaluator = new ClassBodyEvaluator() - evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) + + // A special classloader used to wrap the actual parent classloader of + // [[org.codehaus.janino.ClassBodyEvaluator]] (see CodeGenerator.doCompile). This classloader + // does not throw a ClassNotFoundException with a cause set (i.e. exception.getCause returns + // a null). This classloader is needed because janino will throw the exception directly if + // the parent classloader throws a ClassNotFoundException with cause set instead of trying to + // find other possible classes (see org.codehaus.janinoClassLoaderIClassLoader's + // findIClass method). Please also see https://issues.apache.org/jira/browse/SPARK-15622 and + // https://issues.apache.org/jira/browse/SPARK-11636. + val parentClassLoader = new ParentClassLoader(Utils.getContextOrSparkClassLoader) + evaluator.setParentClassLoader(parentClassLoader) // Cannot be under package codegen, or fail with java.lang.InstantiationException evaluator.setClassName("org.apache.spark.sql.catalyst.expressions.GeneratedClass") evaluator.setDefaultImports(Array( @@ -720,7 +875,7 @@ object CodeGenerator extends Logging { )) evaluator.setExtendedClass(classOf[GeneratedClass]) - def formatted = CodeFormatter.format(code) + lazy val formatted = CodeFormatter.format(code) logDebug({ // Only add extra debugging info to byte code when we are going to print the source code. @@ -729,7 +884,8 @@ object CodeGenerator extends Logging { }) try { - evaluator.cook("generated.java", code) + evaluator.cook("generated.java", code.body) + recordCompilationStats(evaluator) } catch { case e: Exception => val msg = s"failed to compile: $e\n$formatted" @@ -739,6 +895,43 @@ object CodeGenerator extends Logging { evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } + /** + * Records the generated class and method bytecode sizes by inspecting janino private fields. + */ + private def recordCompilationStats(evaluator: ClassBodyEvaluator): Unit = { + // First retrieve the generated classes. + val classes = { + val resultField = classOf[SimpleCompiler].getDeclaredField("result") + resultField.setAccessible(true) + val loader = resultField.get(evaluator).asInstanceOf[ByteArrayClassLoader] + val classesField = loader.getClass.getDeclaredField("classes") + classesField.setAccessible(true) + classesField.get(loader).asInstanceOf[JavaMap[String, Array[Byte]]].asScala + } + + // Then walk the classes to get at the method bytecode. + val codeAttr = Utils.classForName("org.codehaus.janino.util.ClassFile$CodeAttribute") + val codeAttrField = codeAttr.getDeclaredField("code") + codeAttrField.setAccessible(true) + classes.foreach { case (_, classBytes) => + CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) + try { + val cf = new ClassFile(new ByteArrayInputStream(classBytes)) + cf.methodInfos.asScala.foreach { method => + method.getAttributes().foreach { a => + if (a.getClass.getName == codeAttr.getName) { + CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( + codeAttrField.get(a).asInstanceOf[Array[Byte]].length) + } + } + } + } catch { + case NonFatal(e) => + logWarning("Error calculating stats of compiled class.", e) + } + } + } + /** * A cache of generated classes. * @@ -751,12 +944,14 @@ object CodeGenerator extends Logging { private val cache = CacheBuilder.newBuilder() .maximumSize(100) .build( - new CacheLoader[String, GeneratedClass]() { - override def load(code: String): GeneratedClass = { + new CacheLoader[CodeAndComment, GeneratedClass]() { + override def load(code: CodeAndComment): GeneratedClass = { val startTime = System.nanoTime() val result = doCompile(code) val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 + CodegenMetrics.METRIC_SOURCE_CODE_SIZE.update(code.body.length) + CodegenMetrics.METRIC_COMPILATION_TIME.update(timeMs.toLong) logInfo(s"Code generated in $timeMs ms") result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 2bd77c65c31cf..6a5a3e7933eea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} -import org.apache.spark.sql.catalyst.util.toCommentSafeString /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -36,9 +35,10 @@ trait CodegenFallback extends Expression { val idx = ctx.references.length ctx.references += this val objectTerm = ctx.freshName("obj") + val placeHolder = ctx.registerComment(this.toString) if (nullable) { ev.copy(code = s""" - /* expression: ${toCommentSafeString(this.toString)} */ + $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; @@ -47,7 +47,7 @@ trait CodegenFallback extends Expression { }""") } else { ev.copy(code = s""" - /* expression: ${toCommentSafeString(this.toString)} */ + $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; """, isNull = "false") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index f143b40443836..0f82d2e613c73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -94,7 +94,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) - val code = s""" + val codeBody = s""" public java.lang.Object generate(Object[] references) { return new SpecificMutableProjection(references); } @@ -133,6 +133,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP } """ + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = CodeGenerator.compile(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 5635c91830f4a..f4d35d232e691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.io.ObjectInputStream +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -113,7 +116,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR protected def create(ordering: Seq[SortOrder]): BaseOrdering = { val ctx = newCodeGenContext() val comparisons = genComparisons(ctx, ordering) - val code = s""" + val codeBody = s""" public SpecificOrdering generate(Object[] references) { return new SpecificOrdering(references); } @@ -136,7 +139,9 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } }""" - logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}") + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + logDebug(s"Generated Ordering by ${ordering.mkString(",")}:\n${CodeFormatter.format(code)}") CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } @@ -145,7 +150,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR /** * A lazily generated row ordering comparator. */ -class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) extends Ordering[InternalRow] { +class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) + extends Ordering[InternalRow] with KryoSerializable { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) @@ -161,6 +167,14 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) extends Ordering[Int in.defaultReadObject() generatedOrdering = GenerateOrdering.generate(ordering) } + + override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException { + kryo.writeObject(out, ordering.toArray) + } + + override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { + generatedOrdering = GenerateOrdering.generate(kryo.readObject(in, classOf[Array[SortOrder]])) + } } object LazilyGeneratedOrdering { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dd8e2a289a661..106bb27964cab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -40,7 +40,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() val eval = predicate.genCode(ctx) - val code = s""" + val codeBody = s""" public SpecificPredicate generate(Object[] references) { return new SpecificPredicate(references); } @@ -61,6 +61,8 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool } }""" + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index ee1a363145dad..b891f94673752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -48,7 +48,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") - // These expressions could be splitted into multiple functions + // These expressions could be split into multiple functions ctx.addMutableState("Object[]", values, s"this.$values = null;") val rowClass = classOf[GenericInternalRow].getName @@ -155,7 +155,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] """ } val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) - val code = s""" + val codeBody = s""" public java.lang.Object generate(Object[] references) { return new SpecificSafeProjection(references); } @@ -181,6 +181,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = CodeGenerator.compile(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 6aa9cbf08bdb9..5efba4b3a6087 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -362,7 +362,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ctx = newCodeGenContext() val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) - val code = s""" + val codeBody = s""" public java.lang.Object generate(Object[] references) { return new SpecificUnsafeProjection(references); } @@ -390,6 +390,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } """ + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = CodeGenerator.compile(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index b1ffbaa3e94ec..4aa5ec82471ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -157,7 +157,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U }.mkString("\n") // ------------------------ Finally, put everything together --------------------------- // - val code = s""" + val codeBody = s""" |public java.lang.Object generate(Object[] references) { | return new SpecificUnsafeRowJoiner(); |} @@ -193,7 +193,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | } |} """.stripMargin - + val code = CodeFormatter.stripOverlappingComments(new CodeAndComment(codeBody, Map.empty)) logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") val c = CodeGenerator.compile(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c71cb73d65bf6..1efe2cb336b77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -43,6 +43,54 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType } } +/** + * Returns an unordered array containing the keys of the map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array containing the keys of the map.", + extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [1,2]") +case class MapKeys(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].keyType) + + override def nullSafeEval(map: Any): Any = { + map.asInstanceOf[MapData].keyArray() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).keyArray();") + } + + override def prettyName: String = "map_keys" +} + +/** + * Returns an unordered array containing the values of the map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array containing the values of the map.", + extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [\"a\",\"b\"]") +case class MapValues(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].valueType) + + override def nullSafeEval(map: Any): Any = { + map.asInstanceOf[MapData].valueArray() + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).valueArray();") + } + + override def prettyName: String = "map_values" +} + /** * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. @@ -64,7 +112,13 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def checkInputDataTypes(): TypeCheckResult = base.dataType match { case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => - TypeCheckResult.TypeCheckSuccess + ascendingOrder match { + case Literal(_: Boolean, BooleanType) => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + "Sort order in second argument requires a boolean literal.") + } case ArrayType(dt, _) => TypeCheckResult.TypeCheckFailure( s"$prettyName does not support sorting array of type ${dt.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d986d9dca6ed4..09e22aaf3e3d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -84,14 +84,14 @@ case class CreateArray(children: Seq[Expression]) extends Expression { @ExpressionDescription( usage = "_FUNC_(key0, value0, key1, value1...) - Creates a map with the given key/value pairs.") case class CreateMap(children: Seq[Expression]) extends Expression { - private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children) - private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children) + lazy val keys = children.indices.filter(_ % 2 == 0).map(children) + lazy val values = children.indices.filter(_ % 2 != 0).map(children) override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects an positive even number of arguments.") + TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.") } else if (keys.map(_.dataType).distinct.length > 1) { TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) @@ -252,9 +252,13 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val names = nameExprs.map(_.eval(EmptyRow)) override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, Metadata.empty) + val fields = names.zip(valExprs).map { + case (name, valExpr: NamedExpression) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } @@ -365,8 +369,11 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + val fields = names.zip(valExprs).map { + case (name, valExpr: NamedExpression) => + StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } @@ -386,3 +393,53 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression override def prettyName: String = "named_struct_unsafe" } + +/** + * Creates a map after splitting the input text into key/value pairs using delimiters + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(text[, pairDelim, keyValueDelim]) - Creates a map after splitting the text into key/value pairs using delimiters. Default delimiters are ',' for pairDelim and ':' for keyValueDelim.", + extended = """ > SELECT _FUNC_('a:1,b:2,c:3',',',':');\n map("a":"1","b":"2","c":"3") """) +// scalastyle:on line.size.limit +case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression) + extends TernaryExpression with CodegenFallback with ExpectsInputTypes { + + def this(child: Expression, pairDelim: Expression) = { + this(child, pairDelim, Literal(":")) + } + + def this(child: Expression) = { + this(child, Literal(","), Literal(":")) + } + + override def children: Seq[Expression] = Seq(text, pairDelim, keyValueDelim) + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + + override def dataType: DataType = MapType(StringType, StringType, valueContainsNull = false) + + override def checkInputDataTypes(): TypeCheckResult = { + if (Seq(pairDelim, keyValueDelim).exists(! _.foldable)) { + TypeCheckResult.TypeCheckFailure(s"$prettyName's delimiters must be foldable.") + } else { + super.checkInputDataTypes() + } + } + + override def nullSafeEval(str: Any, delim1: Any, delim2: Any): Any = { + val array = str.asInstanceOf[UTF8String] + .split(delim1.asInstanceOf[UTF8String], -1) + .map { kv => + val arr = kv.split(delim2.asInstanceOf[UTF8String], 2) + if (arr.length < 2) { + Array(arr(0), null) + } else { + arr + } + } + ArrayBasedMapData(array.map(_ (0)), array.map(_ (1))) + } + + override def prettyName: String = "str_to_map" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3b4468f55ca73..abb5594bfa7f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -106,7 +106,7 @@ trait ExtractValue extends Expression case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression with ExtractValue { - private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] + lazy val childSchema = child.dataType.asInstanceOf[StructType] override def dataType: DataType = childSchema(ordinal).dataType override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index e97e08947a500..f9499cf78569e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -126,7 +126,8 @@ abstract class CaseWhenBase( override def eval(input: InternalRow): Any = { var i = 0 - while (i < branches.size) { + val size = branches.size + while (i < size) { if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) { return branches(i)._2.eval(input) } @@ -299,7 +300,7 @@ case class Least(children: Seq[Expression]) extends Expression { } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got LEAST (${children.map(_.dataType)}).") + s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) } @@ -359,7 +360,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + - s" got GREATEST (${children.map(_.dataType)}).") + s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).") } else { TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 69c32f447e867..7ab68a13e09cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -399,6 +399,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { override def nullable: Boolean = true private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + private lazy val formatter: SimpleDateFormat = + Try(new SimpleDateFormat(constFormat.toString)).getOrElse(null) override def eval(input: InternalRow): Any = { val t = left.eval(input) @@ -411,11 +413,11 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { case TimestampType => t.asInstanceOf[Long] / 1000000L case StringType if right.foldable => - if (constFormat != null) { - Try(new SimpleDateFormat(constFormat.toString).parse( - t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) - } else { + if (constFormat == null || formatter == null) { null + } else { + Try(formatter.parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) } case StringType => val f = right.eval(input) @@ -434,13 +436,10 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { left.dataType match { case StringType if right.foldable => val sdf = classOf[SimpleDateFormat].getName - val fString = if (constFormat == null) null else constFormat.toString - val formatter = ctx.freshName("formatter") - if (fString == null) { - ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + if (formatter == null) { + ExprCode("", "true", ctx.defaultValue(dataType)) } else { + val formatterName = ctx.addReferenceObj("formatter", formatter, sdf) val eval1 = left.genCode(ctx) ev.copy(code = s""" ${eval1.code} @@ -448,10 +447,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { try { - $sdf $formatter = new $sdf("$fString"); - ${ev.value} = - $formatter.parse(${eval1.value}.toString()).getTime() / 1000L; - } catch (java.lang.Throwable e) { + ${ev.value} = $formatterName.parse(${eval1.value}.toString()).getTime() / 1000L; + } catch (java.text.ParseException e) { ${ev.isNull} = true; } }""") @@ -463,7 +460,9 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { try { ${ev.value} = (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L; - } catch (java.lang.Throwable e) { + } catch (java.lang.IllegalArgumentException e) { + ${ev.isNull} = true; + } catch (java.text.ParseException e) { ${ev.isNull} = true; } """ @@ -520,6 +519,8 @@ case class FromUnixTime(sec: Expression, format: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + private lazy val formatter: SimpleDateFormat = + Try(new SimpleDateFormat(constFormat.toString)).getOrElse(null) override def eval(input: InternalRow): Any = { val time = left.eval(input) @@ -527,10 +528,10 @@ case class FromUnixTime(sec: Expression, format: Expression) null } else { if (format.foldable) { - if (constFormat == null) { + if (constFormat == null || formatter == null) { null } else { - Try(UTF8String.fromString(new SimpleDateFormat(constFormat.toString).format( + Try(UTF8String.fromString(formatter.format( new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null) } } else { @@ -549,11 +550,10 @@ case class FromUnixTime(sec: Expression, format: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val sdf = classOf[SimpleDateFormat].getName if (format.foldable) { - if (constFormat == null) { - ev.copy(code = s""" - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") + if (formatter == null) { + ExprCode("", "true", "(UTF8String) null") } else { + val formatterName = ctx.addReferenceObj("formatter", formatter, sdf) val t = left.genCode(ctx) ev.copy(code = s""" ${t.code} @@ -561,9 +561,9 @@ case class FromUnixTime(sec: Expression, format: Expression) ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { try { - ${ev.value} = UTF8String.fromString(new $sdf("${constFormat.toString}").format( + ${ev.value} = UTF8String.fromString($formatterName.format( new java.util.Date(${t.value} * 1000L))); - } catch (java.lang.Throwable e) { + } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; } }""") @@ -574,7 +574,7 @@ case class FromUnixTime(sec: Expression, format: Expression) try { ${ev.value} = UTF8String.fromString((new $sdf($f.toString())).format( new java.util.Date($seconds * 1000L))); - } catch (java.lang.Throwable e) { + } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; }""".stripMargin }) @@ -682,6 +682,7 @@ case class TimeAdd(start: Expression, interval: Expression) override def right: Expression = interval override def toString: String = s"$left + $right" + override def sql: String = s"${left.sql} + ${right.sql}" override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) override def dataType: DataType = TimestampType @@ -730,16 +731,17 @@ case class FromUTCTimestamp(left: Expression, right: Expression) """.stripMargin) } else { val tzTerm = ctx.freshName("tz") + val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") + ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $tzClass.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; |if (!${ev.isNull}) { - | ${ev.value} = ${eval.value} + - | ${tzTerm}.getOffset(${eval.value} / 1000) * 1000L; + | ${ev.value} = $dtu.convertTz(${eval.value}, $utcTerm, $tzTerm); |} """.stripMargin) } @@ -761,6 +763,7 @@ case class TimeSub(start: Expression, interval: Expression) override def right: Expression = interval override def toString: String = s"$left - $right" + override def sql: String = s"${left.sql} - ${right.sql}" override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) override def dataType: DataType = TimestampType @@ -869,16 +872,17 @@ case class ToUTCTimestamp(left: Expression, right: Expression) """.stripMargin) } else { val tzTerm = ctx.freshName("tz") + val utcTerm = ctx.freshName("utc") val tzClass = classOf[TimeZone].getName ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") + ctx.addMutableState(tzClass, utcTerm, s"""$utcTerm = $tzClass.getTimeZone("UTC");""") val eval = left.genCode(ctx) ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; |if (!${ev.isNull}) { - | ${ev.value} = ${eval.value} - - | ${tzTerm}.getOffset(${eval.value} / 1000) * 1000L; + | ${ev.value} = $dtu.convertTz(${eval.value}, $tzTerm, $utcTerm); |} """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 65d7a1d5a0904..9d5c856a23e2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -41,19 +41,16 @@ import org.apache.spark.sql.types._ */ trait Generator extends Expression { - // TODO ideally we should return the type of ArrayType(StructType), - // however, we don't keep the output field names in the Generator. - override def dataType: DataType = throw new UnsupportedOperationException + override def dataType: DataType = ArrayType(elementSchema) override def foldable: Boolean = false override def nullable: Boolean = false /** - * The output element data types in structure of Seq[(DataType, Nullable)] - * TODO we probably need to add more information like metadata etc. + * The output element schema. */ - def elementTypes: Seq[(DataType, Boolean, String)] + def elementSchema: StructType /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: InternalRow): TraversableOnce[InternalRow] @@ -69,7 +66,7 @@ trait Generator extends Expression { * A generator that produces its output using the provided lambda function. */ case class UserDefinedGenerator( - elementTypes: Seq[(DataType, Boolean, String)], + elementSchema: StructType, function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) extends Generator with CodegenFallback { @@ -97,13 +94,63 @@ case class UserDefinedGenerator( } /** - * Given an input array produces a sequence of rows for each value in the array. + * Separate v1, ..., vk into n rows. Each row will have k/n columns. n must be constant. + * {{{ + * SELECT stack(2, 1, 2, 3) -> + * 1 2 + * 3 NULL + * }}} */ -// scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.") -// scalastyle:on line.size.limit -case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback { + usage = "_FUNC_(n, v1, ..., vk) - Separate v1, ..., vk into n rows.", + extended = "> SELECT _FUNC_(2, 1, 2, 3);\n [1,2]\n [3,null]") +case class Stack(children: Seq[Expression]) + extends Expression with Generator with CodegenFallback { + + private lazy val numRows = children.head.eval().asInstanceOf[Int] + private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") + } else if (children.head.dataType != IntegerType || !children.head.foldable || numRows < 1) { + TypeCheckResult.TypeCheckFailure("The number of rows must be a positive constant integer.") + } else { + for (i <- 1 until children.length) { + val j = (i - 1) % numFields + if (children(i).dataType != elementSchema.fields(j).dataType) { + return TypeCheckResult.TypeCheckFailure( + s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " + + s"Argument $i (${children(i).dataType})") + } + } + TypeCheckResult.TypeCheckSuccess + } + } + + override def elementSchema: StructType = + StructType(children.tail.take(numFields).zipWithIndex.map { + case (e, index) => StructField(s"col$index", e.dataType) + }) + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val values = children.tail.map(_.eval(input)).toArray + for (row <- 0 until numRows) yield { + val fields = new Array[Any](numFields) + for (col <- 0 until numFields) { + val index = row * numFields + col + fields.update(col, if (index < values.length) values(index) else null) + } + InternalRow(fields: _*) + } + } +} + +/** + * A base class for Explode and PosExplode + */ +abstract class ExplodeBase(child: Expression, position: Boolean) + extends UnaryExpression with Generator with CodegenFallback with Serializable { override def children: Seq[Expression] = child :: Nil @@ -117,10 +164,27 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) - override def elementTypes: Seq[(DataType, Boolean, String)] = child.dataType match { - case ArrayType(et, containsNull) => (et, containsNull, "col") :: Nil + override def elementSchema: StructType = child.dataType match { + case ArrayType(et, containsNull) => + if (position) { + new StructType() + .add("pos", IntegerType, false) + .add("col", et, containsNull) + } else { + new StructType() + .add("col", et, containsNull) + } case MapType(kt, vt, valueContainsNull) => - (kt, false, "key") :: (vt, valueContainsNull, "value") :: Nil + if (position) { + new StructType() + .add("pos", IntegerType, false) + .add("key", kt, false) + .add("value", vt, valueContainsNull) + } else { + new StructType() + .add("key", kt, false) + .add("value", vt, valueContainsNull) + } } override def eval(input: InternalRow): TraversableOnce[InternalRow] = { @@ -132,7 +196,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit } else { val rows = new Array[InternalRow](inputArray.numElements()) inputArray.foreach(et, (i, e) => { - rows(i) = InternalRow(e) + rows(i) = if (position) InternalRow(i, e) else InternalRow(e) }) rows } @@ -144,7 +208,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit val rows = new Array[InternalRow](inputMap.numElements()) var i = 0 inputMap.foreach(kt, vt, (k, v) => { - rows(i) = InternalRow(k, v) + rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v) i += 1 }) rows @@ -152,3 +216,70 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit } } } + +/** + * Given an input array produces a sequence of rows for each value in the array. + * + * {{{ + * SELECT explode(array(10,20)) -> + * 10 + * 20 + * }}} + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of map a into multiple rows and columns.", + extended = "> SELECT _FUNC_(array(10,20));\n 10\n 20") +// scalastyle:on line.size.limit +case class Explode(child: Expression) extends ExplodeBase(child, position = false) + +/** + * Given an input array produces a sequence of rows for each position and value in the array. + * + * {{{ + * SELECT posexplode(array(10,20)) -> + * 0 10 + * 1 20 + * }}} + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a) - Separates the elements of array a into multiple rows with positions, or the elements of a map into multiple rows and columns with positions.", + extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20") +// scalastyle:on line.size.limit +case class PosExplode(child: Expression) extends ExplodeBase(child, position = true) + +/** + * Explodes an array of structs into a table. + */ +@ExpressionDescription( + usage = "_FUNC_(a) - Explodes an array of structs into a table.", + extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n [2,b]") +case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback { + + override def children: Seq[Expression] = child :: Nil + + override def checkInputDataTypes(): TypeCheckResult = child.dataType match { + case ArrayType(et, _) if et.isInstanceOf[StructType] => + TypeCheckResult.TypeCheckSuccess + case _ => + TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should be array of struct type, not ${child.dataType}") + } + + override def elementSchema: StructType = child.dataType match { + case ArrayType(et : StructType, _) => et + } + + private lazy val numFields = elementSchema.fields.length + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + val inputArray = child.eval(input).asInstanceOf[ArrayData] + if (inputArray == null) { + Nil + } else { + for (i <- 0 until inputArray.numElements()) + yield inputArray.getStruct(i, numFields) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index ecd09b7083f2e..c14a2fb122618 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -26,7 +26,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -356,9 +356,9 @@ case class JsonTuple(children: Seq[Expression]) // and count the number of foldable fields, we'll use this later to optimize evaluation @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) - override def elementTypes: Seq[(DataType, Boolean, String)] = fieldExpressions.zipWithIndex.map { - case (_, idx) => (StringType, true, s"c$idx") - } + override def elementSchema: StructType = StructType(fieldExpressions.zipWithIndex.map { + case (_, idx) => StructField(s"c$idx", StringType, nullable = true) + }) override def prettyName: String = "json_tuple" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 7e3683e482df1..41e3952f0e253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -163,8 +163,7 @@ object DecimalLiteral { /** * In order to do type checking, use Literal.create() instead of constructor */ -case class Literal protected (value: Any, dataType: DataType) - extends LeafExpression with CodegenFallback { +case class Literal (value: Any, dataType: DataType) extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = value == null @@ -182,7 +181,7 @@ case class Literal protected (value: Any, dataType: DataType) override protected def jsonFields: List[JField] = { // Turns all kinds of literal values to string in json field, as the type info is hard to - // retain in json format, e.g. {"a": 123} can be a int, or double, or decimal, etc. + // retain in json format, e.g. {"a": 123} can be an int, or double, or decimal, etc. val jsonValue = (value, dataType) match { case (null, _) => JNull case (i: Int, DateType) => JString(DateTimeUtils.toJavaDate(i).toString) @@ -246,15 +245,28 @@ case class Literal protected (value: Any, dataType: DataType) case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => "NULL" case _ if value == null => s"CAST(NULL AS ${dataType.sql})" case (v: UTF8String, StringType) => - // Escapes all backslashes and double quotes. - "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" + // Escapes all backslashes and single quotes. + "'" + v.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" case (v: Byte, ByteType) => v + "Y" case (v: Short, ShortType) => v + "S" case (v: Long, LongType) => v + "L" // Float type doesn't have a suffix - case (v: Float, FloatType) => s"CAST($v AS ${FloatType.sql})" - case (v: Double, DoubleType) => v + "D" - case (v: Decimal, t: DecimalType) => s"CAST($v AS ${t.sql})" + case (v: Float, FloatType) => + val castedValue = v match { + case _ if v.isNaN => "'NaN'" + case Float.PositiveInfinity => "'Infinity'" + case Float.NegativeInfinity => "'-Infinity'" + case _ => v + } + s"CAST($castedValue AS ${FloatType.sql})" + case (v: Double, DoubleType) => + v match { + case _ if v.isNaN => s"CAST('NaN' AS ${DoubleType.sql})" + case Double.PositiveInfinity => s"CAST('Infinity' AS ${DoubleType.sql})" + case Double.NegativeInfinity => s"CAST('-Infinity' AS ${DoubleType.sql})" + case _ => v + "D" + } + case (v: Decimal, t: DecimalType) => v + "BD" case (v: Int, DateType) => s"DATE '${DateTimeUtils.toJavaDate(v)}'" case (v: Long, TimestampType) => s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" case _ => value.toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 1c0787bf9227f..5c4436ffaab72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -175,11 +175,12 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val CRC32 = "java.util.zip.CRC32" + val checksum = ctx.freshName("checksum") nullSafeCodeGen(ctx, ev, value => { s""" - $CRC32 checksum = new $CRC32(); - checksum.update($value, 0, $value.length); - ${ev.value} = checksum.getValue(); + $CRC32 $checksum = new $CRC32(); + $checksum.update($value, 0, $value.length); + ${ev.value} = $checksum.getValue(); """ }) } @@ -476,10 +477,13 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = input + private val outputPrefix = s"Result of ${child.simpleString} is " + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val outputPrefixField = ctx.addReferenceObj("outputPrefix", outputPrefix) nullSafeCodeGen(ctx, ev, c => s""" - | System.err.println("Result of ${child.simpleString} is " + $c); + | System.err.println($outputPrefixField + $c); | ${ev.value} = $c; """.stripMargin) } @@ -500,10 +504,12 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa override def prettyName: String = "assert_true" + private val errMsg = s"'${child.simpleString}' is not true!" + override def eval(input: InternalRow) : Any = { val v = child.eval(input) if (v == null || java.lang.Boolean.FALSE.equals(v)) { - throw new RuntimeException(s"'${child.simpleString}' is not true!") + throw new RuntimeException(errMsg) } else { null } @@ -511,9 +517,10 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { - | throw new RuntimeException("'${child.simpleString}' is not true."); + | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = "true", value = "null") } @@ -553,8 +560,9 @@ object XxHash64Function extends InterpretedHashFunction { @ExpressionDescription( usage = "_FUNC_() - Returns the current database.", extended = "> SELECT _FUNC_()") -private[sql] case class CurrentDatabase() extends LeafExpression with Unevaluable { +case class CurrentDatabase() extends LeafExpression with Unevaluable { override def dataType: DataType = StringType override def foldable: Boolean = true override def nullable: Boolean = false + override def prettyName: String = "current_database" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 421200e147b7a..1c18265e0fed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -88,6 +88,82 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } +@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.") +case class IfNull(left: Expression, right: Expression) extends RuntimeReplaceable { + override def children: Seq[Expression] = Seq(left, right) + + override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right)) + + override def replaceForTypeCoercion(): Expression = { + if (left.dataType != right.dataType) { + TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype => + copy(left = Cast(left, dtype), right = Cast(right, dtype)) + }.getOrElse(this) + } else { + this + } + } +} + + +@ExpressionDescription(usage = "_FUNC_(a,b) - Returns null if a equals to b, or a otherwise.") +case class NullIf(left: Expression, right: Expression) extends RuntimeReplaceable { + override def children: Seq[Expression] = Seq(left, right) + + override def replaceForEvaluation(): Expression = { + If(EqualTo(left, right), Literal.create(null, left.dataType), left) + } + + override def replaceForTypeCoercion(): Expression = { + if (left.dataType != right.dataType) { + TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype => + copy(left = Cast(left, dtype), right = Cast(right, dtype)) + }.getOrElse(this) + } else { + this + } + } +} + + +@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.") +case class Nvl(left: Expression, right: Expression) extends RuntimeReplaceable { + override def children: Seq[Expression] = Seq(left, right) + + override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right)) + + override def replaceForTypeCoercion(): Expression = { + if (left.dataType != right.dataType) { + TypeCoercion.findTightestCommonTypeToString(left.dataType, right.dataType).map { dtype => + copy(left = Cast(left, dtype), right = Cast(right, dtype)) + }.getOrElse(this) + } else { + this + } + } +} + + +@ExpressionDescription(usage = "_FUNC_(a,b,c) - Returns b if a is not null, or c otherwise.") +case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression) + extends RuntimeReplaceable { + + override def replaceForEvaluation(): Expression = If(IsNotNull(expr1), expr2, expr3) + + override def children: Seq[Expression] = Seq(expr1, expr2, expr3) + + override def replaceForTypeCoercion(): Expression = { + if (expr2.dataType != expr3.dataType) { + TypeCoercion.findTightestCommonTypeOfTwo(expr2.dataType, expr3.dataType).map { dtype => + copy(expr2 = Cast(expr2, dtype), expr3 = Cast(expr3, dtype)) + }.getOrElse(this) + } else { + this + } + } +} + + /** * Evaluates to `true` iff it's NaN. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala similarity index 59% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1e418540a2624..691edd5c2e7be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -15,18 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions +package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier -import scala.annotation.tailrec import scala.language.existentials import scala.reflect.ClassTag -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -64,33 +65,29 @@ case class StaticInvoke( val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") - if (propagateNull) { - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" - } else { - "" - } - - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" - ev.copy(code = s""" - ${argGen.map(_.code).mkString("\n")} - - boolean ${ev.isNull} = !$argsNonNull; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + val callFunc = s"$objectName.$functionName($argString)" - if ($argsNonNull) { - ${ev.value} = $objectName.$functionName($argString); - $objNullCheck - } - """) + val setIsNull = if (propagateNull && arguments.nonEmpty) { + s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};" } else { - ev.copy(code = s""" - ${argGen.map(_.code).mkString("\n")} + s"boolean ${ev.isNull} = false;" + } - $javaType ${ev.value} = $objectName.$functionName($argString); - final boolean ${ev.isNull} = ${ev.value} == null; - """) + // If the function can return null, we do an extra check to make sure our null bit is still set + // correctly. + val postNullCheck = if (ctx.defaultValue(dataType) == "null") { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" } + + val code = s""" + ${argGen.map(_.code).mkString("\n")} + $setIsNull + final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc; + $postNullCheck + """ + ev.copy(code = code) } } @@ -111,7 +108,8 @@ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, - arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { + arguments: Seq[Expression] = Nil, + propagateNull: Boolean = true) extends Expression with NonSQLExpression { override def nullable: Boolean = true override def children: Seq[Expression] = targetObject +: arguments @@ -130,60 +128,53 @@ case class Invoke( case _ => None } - lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match { - case (IntegerType, "java.lang.Object") => (s: String) => - s"((java.lang.Integer)$s).intValue()" - case (LongType, "java.lang.Object") => (s: String) => - s"((java.lang.Long)$s).longValue()" - case (FloatType, "java.lang.Object") => (s: String) => - s"((java.lang.Float)$s).floatValue()" - case (ShortType, "java.lang.Object") => (s: String) => - s"((java.lang.Short)$s).shortValue()" - case (ByteType, "java.lang.Object") => (s: String) => - s"((java.lang.Byte)$s).byteValue()" - case (DoubleType, "java.lang.Object") => (s: String) => - s"((java.lang.Double)$s).doubleValue()" - case (BooleanType, "java.lang.Object") => (s: String) => - s"((java.lang.Boolean)$s).booleanValue()" - case _ => identity[String] _ - } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val obj = targetObject.genCode(ctx) val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") - // If the function can return null, we do an extra check to make sure our null bit is still set - // correctly. - val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"boolean ${ev.isNull} = ${ev.value} == null;" + val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) { + s"${obj.value}.$functionName($argString)" } else { - ev.isNull = obj.isNull - "" + s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)" } - val value = unboxer(s"${obj.value}.$functionName($argString)") + val setIsNull = if (propagateNull && arguments.nonEmpty) { + s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" + } else { + s"boolean ${ev.isNull} = ${obj.isNull};" + } val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { - s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;" + s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;" } else { s""" $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; try { - ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value; + ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc; } catch (Exception e) { org.apache.spark.unsafe.Platform.throwException(e); } """ } - ev.copy(code = s""" + // If the function can return null, we do an extra check to make sure our null bit is still set + // correctly. + val postNullCheck = if (ctx.defaultValue(dataType) == "null") { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } + + val code = s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} + $setIsNull $evaluate - $objNullCheck - """) + $postNullCheck + """ + ev.copy(code = code) } override def toString: String = s"$targetObject.$functionName" @@ -241,44 +232,52 @@ case class NewInstance( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") + val argIsNulls = ctx.freshName("argIsNulls") + ctx.addMutableState("boolean[]", argIsNulls, + s"$argIsNulls = new boolean[${arguments.size}];") + val argValues = arguments.zipWithIndex.map { case (e, i) => + val argValue = ctx.freshName("argValue") + ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + argValue + } + + val argCodes = arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + expr.code + s""" + $argIsNulls[$i] = ${expr.isNull}; + ${argValues(i)} = ${expr.value}; + """ + } + val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) - val setup = + var isNull = ev.isNull + val setIsNull = if (propagateNull && arguments.nonEmpty) { s""" - ${argGen.map(_.code).mkString("\n")} - ${outer.map(_.code).getOrElse("")} - """.stripMargin + boolean $isNull = false; + for (int idx = 0; idx < ${arguments.length}; idx++) { + if ($argIsNulls[idx]) { $isNull = true; break; } + } + """ + } else { + isNull = "false" + "" + } val constructorCall = outer.map { gen => - s"""${gen.value}.new ${cls.getSimpleName}($argString)""" + s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})""" }.getOrElse { - s"new $className($argString)" + s"new $className(${argValues.mkString(", ")})" } - if (propagateNull && argGen.nonEmpty) { - val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" - - ev.copy(code = s""" - $setup - - boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if ($argsNonNull) { - ${ev.value} = $constructorCall; - ${ev.isNull} = false; - } - """) - } else { - ev.copy(code = s""" - $setup - - final $javaType ${ev.value} = $constructorCall; - final boolean ${ev.isNull} = false; - """) - } + val code = s""" + $argCode + ${outer.map(_.code).getOrElse("")} + $setIsNull + final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; + """ + ev.copy(code = code, isNull = isNull) } override def toString: String = s"newInstance($cls)" @@ -306,13 +305,14 @@ case class UnwrapOption( val javaType = ctx.javaType(dataType) val inputObject = child.genCode(ctx) - ev.copy(code = s""" + val code = s""" ${inputObject.code} - boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty(); - $javaType ${ev.value} = - ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get(); - """) + final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); + $javaType ${ev.value} = ${ev.isNull} ? + ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get(); + """ + ev.copy(code = code) } } @@ -338,14 +338,14 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - ev.copy(code = s""" + val code = s""" ${inputObject.code} - boolean ${ev.isNull} = false; scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); - """) + """ + ev.copy(code = code, isNull = "false") } } @@ -366,6 +366,13 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext object MapObjects { private val curId = new java.util.concurrent.atomic.AtomicInteger() + /** + * Construct an instance of MapObjects case class. + * + * @param function The function applied on the collection elements. + * @param inputData An expression that when evaluated returns a collection object. + * @param elementType The data type of elements in the collection. + */ def apply( function: Expression => Expression, inputData: Expression, @@ -373,7 +380,7 @@ object MapObjects { val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopVar, function(loopVar), inputData) + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) } } @@ -385,56 +392,23 @@ object MapObjects { * The following collection ObjectTypes are currently supported: * Seq, Array, ArrayData, java.util.List * - * @param loopVar A place holder that used as the loop variable when iterate the collection, and - * used as input for the `lambdaFunction`. It also carries the element type info. + * @param loopValue the name of the loop variable that used when iterate the collection, and used + * as input for the `lambdaFunction` + * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and + * used as input for the `lambdaFunction` + * @param loopVarDataType the data type of the loop variable that used when iterate the collection, + * and used as input for the `lambdaFunction` * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. */ case class MapObjects private( - loopVar: LambdaVariable, + loopValue: String, + loopIsNull: String, + loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression) extends Expression with NonSQLExpression { - @tailrec - private def itemAccessorMethod(dataType: DataType): String => String = dataType match { - case NullType => - val nullTypeClassName = NullType.getClass.getName + ".MODULE$" - (i: String) => s".get($i, $nullTypeClassName)" - case IntegerType => (i: String) => s".getInt($i)" - case LongType => (i: String) => s".getLong($i)" - case FloatType => (i: String) => s".getFloat($i)" - case DoubleType => (i: String) => s".getDouble($i)" - case ByteType => (i: String) => s".getByte($i)" - case ShortType => (i: String) => s".getShort($i)" - case BooleanType => (i: String) => s".getBoolean($i)" - case StringType => (i: String) => s".getUTF8String($i)" - case s: StructType => (i: String) => s".getStruct($i, ${s.size})" - case a: ArrayType => (i: String) => s".getArray($i)" - case _: MapType => (i: String) => s".getMap($i)" - case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType) - case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)" - case DateType => (i: String) => s".getInt($i)" - } - - private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { - case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - (".size()", (i: String) => s".apply($i)", false) - case ObjectType(cls) if cls.isArray => - (".length", (i: String) => s"[$i]", false) - case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => - (".size()", (i: String) => s".get($i)", false) - case ArrayType(t, _) => - val (sqlType, primitiveElement) = t match { - case m: MapType => (m, false) - case s: StructType => (s, false) - case s: StringType => (s, false) - case udt: UserDefinedType[_] => (udt.sqlType, false) - case o => (o, true) - } - (".numElements()", itemAccessorMethod(sqlType), primitiveElement) - } - override def nullable: Boolean = true override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil @@ -445,10 +419,9 @@ case class MapObjects private( override def dataType: DataType = ArrayType(lambdaFunction.dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) - val elementJavaType = ctx.javaType(loopVar.dataType) - ctx.addMutableState("boolean", loopVar.isNull, "") - ctx.addMutableState(elementJavaType, loopVar.value, "") + val elementJavaType = ctx.javaType(loopVarDataType) + ctx.addMutableState("boolean", loopIsNull, "") + ctx.addMutableState(elementJavaType, loopValue, "") val genInputData = inputData.genCode(ctx) val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") @@ -468,43 +441,93 @@ case class MapObjects private( s"new $convertedType[$dataLength]" } - val loopNullCheck = if (primitiveElement) { - s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" - } else { - s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" + // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type + // of input collection at runtime for this case. + val seq = ctx.freshName("seq") + val array = ctx.freshName("array") + val determineCollectionType = inputData.dataType match { + case ObjectType(cls) if cls == classOf[Object] => + val seqClass = classOf[Seq[_]].getName + s""" + $seqClass $seq = null; + $elementJavaType[] $array = null; + if (${genInputData.value}.getClass().isArray()) { + $array = ($elementJavaType[]) ${genInputData.value}; + } else { + $seq = ($seqClass) ${genInputData.value}; + } + """ + case _ => "" } - ev.copy(code = s""" - ${genInputData.code} + // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. + // When we want to apply MapObjects on it, we have to use it. + val inputDataType = inputData.dataType match { + case p: PythonUserDefinedType => p.sqlType + case _ => inputData.dataType + } - boolean ${ev.isNull} = ${genInputData.value} == null; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + val (getLength, getLoopVar) = inputDataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" + case ObjectType(cls) if cls.isArray => + s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" + case ArrayType(et, _) => + s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) + case ObjectType(cls) if cls == classOf[Object] => + s"$seq == null ? $array.length : $seq.size()" -> + s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" + } + + // Make a copy of the data if it's unsafe-backed + def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = + s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value" + val genFunctionValue = lambdaFunction.dataType match { + case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) + case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) + case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) + case _ => genFunction.value + } + + val loopNullCheck = inputDataType match { + case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + // The element of primitive array will never be null. + case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => + s"$loopIsNull = false" + case _ => s"$loopIsNull = $loopValue == null;" + } + + val code = s""" + ${genInputData.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { + if (!${genInputData.isNull}) { + $determineCollectionType $convertedType[] $convertedArray = null; - int $dataLength = ${genInputData.value}$lengthFunction; + int $dataLength = $getLength; $convertedArray = $arrayConstructor; int $loopIndex = 0; while ($loopIndex < $dataLength) { - ${loopVar.value} = - ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; + $loopValue = ($elementJavaType) ($getLoopVar); $loopNullCheck ${genFunction.code} if (${genFunction.isNull}) { $convertedArray[$loopIndex] = null; } else { - $convertedArray[$loopIndex] = ${genFunction.value}; + $convertedArray[$loopIndex] = $genFunctionValue; } $loopIndex += 1; } - ${ev.isNull} = false; ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); } - """) + """ + ev.copy(code = code, isNull = genInputData.isNull) } } @@ -539,14 +562,16 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) } """ } + val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) val schemaField = ctx.addReferenceObj("schema", schema) - ev.copy(code = s""" - boolean ${ev.isNull} = false; + + val code = s""" $values = new Object[${children.size}]; $childrenCode - final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); - """) + final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); + """ + ev.copy(code = code, isNull = "false") } } @@ -571,22 +596,28 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) } } + // try conf from env, otherwise create a new one + val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addMutableState( - serializerInstanceClass, - serializer, - s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") + val serializerInit = s""" + if ($env == null) { + $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + } else { + $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + } + """ + ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) // Code to serialize. val input = child.genCode(ctx) - ev.copy(code = s""" + val javaType = ctx.javaType(dataType) + val serialize = s"$serializer.serialize(${input.value}, null).array()" + + val code = s""" ${input.code} - final boolean ${ev.isNull} = ${input.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $serializer.serialize(${input.value}, null).array(); - } - """) + final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $serialize; + """ + ev.copy(code = code, isNull = input.isNull) } override def dataType: DataType = BinaryType @@ -611,23 +642,29 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B (classOf[JavaSerializer].getName, classOf[JavaSerializerInstance].getName) } } + // try conf from env, otherwise create a new one + val env = s"${classOf[SparkEnv].getName}.get()" val sparkConf = s"new ${classOf[SparkConf].getName}()" - ctx.addMutableState( - serializerInstanceClass, - serializer, - s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") - - // Code to serialize. + val serializerInit = s""" + if ($env == null) { + $serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance(); + } else { + $serializer = ($serializerInstanceClass) new $serializerClass($env.conf()).newInstance(); + } + """ + ctx.addMutableState(serializerInstanceClass, serializer, serializerInit) + + // Code to deserialize. val input = child.genCode(ctx) - ev.copy(code = s""" + val javaType = ctx.javaType(dataType) + val deserialize = + s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" + + val code = s""" ${input.code} - final boolean ${ev.isNull} = ${input.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = (${ctx.javaType(dataType)}) - $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); - } - """) + final $javaType ${ev.value} = ${input.isNull} ? ${ctx.defaultValue(javaType)} : $deserialize; + """ + ev.copy(code = code, isNull = input.isNull) } override def dataType: DataType = ObjectType(tag.runtimeClass) @@ -658,15 +695,13 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp """ } - ev.isNull = instanceGen.isNull - ev.value = instanceGen.value - - ev.copy(code = s""" + val code = s""" ${instanceGen.code} if (!${instanceGen.isNull}) { ${initialize.mkString("\n")} } - """) + """ + ev.copy(code = code, isNull = instanceGen.isNull, value = instanceGen.value) } } @@ -682,7 +717,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) extends UnaryExpression with NonSQLExpression { override def dataType: DataType = child.dataType - + override def foldable: Boolean = false override def nullable: Boolean = false override def eval(input: InternalRow): Any = @@ -696,13 +731,105 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + "please try to use scala.Option[_] or other nullable types " + "(e.g. java.lang.Integer instead of int/scala.Int)." - val idx = ctx.references.length - ctx.references += errMsg - ExprCode(code = s""" + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + + val code = s""" ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException((String) references[$idx]); - }""", isNull = "false", value = childGen.value) + throw new RuntimeException($errMsgField); + } + """ + ev.copy(code = code, isNull = "false", value = childGen.value) + } +} + +/** + * Returns the value of field at index `index` from the external row `child`. + * This class can be viewed as [[GetStructField]] for [[Row]]s instead of [[InternalRow]]s. + * + * Note that the input row and the field we try to get are both guaranteed to be not null, if they + * are null, a runtime exception will be thrown. + */ +case class GetExternalRowField( + child: Expression, + index: Int, + fieldName: String) extends UnaryExpression with NonSQLExpression { + + override def nullable: Boolean = false + + override def dataType: DataType = ObjectType(classOf[Object]) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null." + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + val row = child.genCode(ctx) + val code = s""" + ${row.code} + + if (${row.isNull}) { + throw new RuntimeException("The input external row cannot be null."); + } + + if (${row.value}.isNullAt($index)) { + throw new RuntimeException($errMsgField); + } + + final Object ${ev.value} = ${row.value}.get($index); + """ + ev.copy(code = code, isNull = "false") + } +} + +/** + * Validates the actual data type of input expression at runtime. If it doesn't match the + * expectation, throw an exception. + */ +case class ValidateExternalType(child: Expression, expected: DataType) + extends UnaryExpression with NonSQLExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(ObjectType(classOf[Object])) + + override def nullable: Boolean = child.nullable + + override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected) + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val errMsgField = ctx.addReferenceObj("errMsg", errMsg) + val input = child.genCode(ctx) + val obj = input.value + + val typeCheck = expected match { + case _: DecimalType => + Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal]) + .map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ") + case _: ArrayType => + s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()" + case _ => + s"$obj instanceof ${ctx.boxedType(dataType)}" + } + + val code = s""" + ${input.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${input.isNull}) { + if ($typeCheck) { + ${ev.value} = (${ctx.boxedType(dataType)}) $obj; + } else { + throw new RuntimeException($obj.getClass().getName() + $errMsgField); + } + } + + """ + ev.copy(code = code, isNull = input.isNull) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index 6112259fed619..9a892905f5186 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -31,7 +31,8 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 - while (i < ordering.size) { + val size = ordering.size + while (i < size) { val order = ordering(i) val left = order.child.eval(a) val right = order.child.eval(b) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 23baa6f7837fb..a6125c61e508a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import com.google.common.collect.Maps + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -86,10 +88,40 @@ package object expressions { /** * Helper functions for working with `Seq[Attribute]`. */ - implicit class AttributeSeq(attrs: Seq[Attribute]) { + implicit class AttributeSeq(val attrs: Seq[Attribute]) extends Serializable { /** Creates a StructType with a schema matching this `Seq[Attribute]`. */ def toStructType: StructType = { - StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) + StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + } + + // It's possible that `attrs` is a linked list, which can lead to bad O(n^2) loops when + // accessing attributes by their ordinals. To avoid this performance penalty, convert the input + // to an array. + @transient private lazy val attrsArray = attrs.toArray + + @transient private lazy val exprIdToOrdinal = { + val arr = attrsArray + val map = Maps.newHashMapWithExpectedSize[ExprId, Int](arr.length) + // Iterate over the array in reverse order so that the final map value is the first attribute + // with a given expression id. + var index = arr.length - 1 + while (index >= 0) { + map.put(arr(index).exprId, index) + index -= 1 + } + map + } + + /** + * Returns the attribute at the given index. + */ + def apply(ordinal: Int): Attribute = attrsArray(ordinal) + + /** + * Returns the index of first attribute with a matching expression id, or -1 if no match exists. + */ + def indexOf(exprId: ExprId): Int = { + Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 8a6cf53782b91..abe0f08e135d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,8 +69,11 @@ trait PredicateHelper { protected def replaceAlias( condition: Expression, aliases: AttributeMap[Expression]): Expression = { - condition.transform { - case a: Attribute => aliases.getOrElse(a, a) + // Use transformUp to prevent infinite recursion when the replacement expression + // redefines the same ExprId, + condition.transformUp { + case a: Attribute => + aliases.getOrElse(a, a) } } @@ -81,8 +84,9 @@ trait PredicateHelper { * * For example consider a join between two relations R(a, b) and S(c, d). * - * `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns - * `false`. + * - `canEvaluate(EqualTo(a,b), R)` returns `true` + * - `canEvaluate(EqualTo(a,c), R)` returns `false` + * - `canEvaluate(Literal(1), R)` returns `true` as literals CAN be evaluated on any plan */ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = expr.references.subsetOf(plan.outputSet) @@ -390,13 +394,13 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } -private[sql] object BinaryComparison { +object BinaryComparison { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } /** An extractor that matches both standard 3VL equality and null-safe equality. */ -private[sql] object Equality { +object Equality { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { case EqualTo(l, r) => Some((l, r)) case EqualNullSafe(l, r) => Some((l, r)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 541b8601a344b..d25da3fd587b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -108,10 +108,11 @@ case class Like(left: Expression, right: Expression) """) } } else { + val rightStr = ctx.freshName("rightStr") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); + String $rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr)); ${ev.value} = $pattern.matcher(${eval1}.toString()).matches(); """ }) @@ -157,10 +158,11 @@ case class RLike(left: Expression, right: Expression) """) } } else { + val rightStr = ctx.freshName("rightStr") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String rightStr = ${eval2}.toString(); - ${patternClass} $pattern = ${patternClass}.compile(rightStr); + String $rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile($rightStr); ${ev.value} = $pattern.matcher(${eval1}.toString()).find(0); """ }) @@ -259,6 +261,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val classNamePattern = classOf[Pattern].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName + val matcher = ctx.freshName("matcher") + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") @@ -267,6 +271,12 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ctx.addMutableState(classNameStringBuffer, termResult, s"${termResult} = new $classNameStringBuffer();") + val setEvNotNull = if (nullable) { + s"${ev.isNull} = false;" + } else { + "" + } + nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" if (!$regexp.equals(${termLastRegex})) { @@ -280,14 +290,14 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); } ${termResult}.delete(0, ${termResult}.length()); - java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); + java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); - while (m.find()) { - m.appendReplacement(${termResult}, ${termLastReplacement}); + while (${matcher}.find()) { + ${matcher}.appendReplacement(${termResult}, ${termLastReplacement}); } - m.appendTail(${termResult}); + ${matcher}.appendTail(${termResult}); ${ev.value} = UTF8String.fromString(${termResult}.toString()); - ${ev.isNull} = false; + $setEvNotNull """ }) } @@ -319,7 +329,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val m = pattern.matcher(s.toString) if (m.find) { val mr: MatchResult = m.toMatchResult - UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + val group = mr.group(r.asInstanceOf[Int]) + if (group == null) { // Pattern matched, but not optional group + UTF8String.EMPTY_UTF8 + } else { + UTF8String.fromString(group) + } } else { UTF8String.EMPTY_UTF8 } @@ -334,10 +349,18 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName + val matcher = ctx.freshName("matcher") + val matchResult = ctx.freshName("matchResult") ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + val setEvNotNull = if (nullable) { + s"${ev.isNull} = false;" + } else { + "" + } + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { s""" if (!$regexp.equals(${termLastRegex})) { @@ -345,15 +368,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio ${termLastRegex} = $regexp.clone(); ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } - java.util.regex.Matcher m = + java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); - if (m.find()) { - java.util.regex.MatchResult mr = m.toMatchResult(); - ${ev.value} = UTF8String.fromString(mr.group($idx)); - ${ev.isNull} = false; + if (${matcher}.find()) { + java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult(); + if (${matchResult}.group($idx) == null) { + ${ev.value} = UTF8String.EMPTY_UTF8; + } else { + ${ev.value} = UTF8String.fromString(${matchResult}.group($idx)); + } + $setEvNotNull } else { ${ev.value} = UTF8String.EMPTY_UTF8; - ${ev.isNull} = false; + $setEvNotNull }""" }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 93a8278528697..73dceb35ac50e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -214,11 +214,11 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) } /** - * A internal row implementation that uses an array of objects as the underlying storage. + * An internal row implementation that uses an array of objects as the underlying storage. * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGenericInternalRow { +class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 78e846d3f580e..61549c9a23685 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -17,12 +17,17 @@ package org.apache.spark.sql.catalyst.expressions -import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.net.{MalformedURLException, URL} +import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols} import java.util.{HashMap, Locale, Map => JMap} +import java.util.regex.Pattern + +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -162,6 +167,46 @@ case class ConcatWs(children: Seq[Expression]) } } +@ExpressionDescription( + usage = "_FUNC_(n, str1, str2, ...) - returns the n-th string, e.g. returns str2 when n is 2", + extended = "> SELECT _FUNC_(1, 'scala', 'java') FROM src LIMIT 1;\n" + "'scala'") +case class Elt(children: Seq[Expression]) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + private lazy val indexExpr = children.head + private lazy val stringExprs = children.tail.toArray + + /** This expression is always nullable because it returns null if index is out of range. */ + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size < 2) { + TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") + } else { + super[ImplicitCastInputTypes].checkInputDataTypes() + } + } + + override def eval(input: InternalRow): Any = { + val indexObj = indexExpr.eval(input) + if (indexObj == null) { + null + } else { + val index = indexObj.asInstanceOf[Int] + if (index <= 0 || index > stringExprs.length) { + null + } else { + stringExprs(index - 1).eval(input) + } + } + } +} + + trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => @@ -494,7 +539,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) extends TernaryExpression with ImplicitCastInputTypes { def this(substr: Expression, str: Expression) = { - this(substr, str, Literal(0)) + this(substr, str, Literal(1)) } override def children: Seq[Expression] = substr :: str :: start :: Nil @@ -516,9 +561,14 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) if (l == null) { null } else { - l.asInstanceOf[UTF8String].indexOf( - r.asInstanceOf[UTF8String], - s.asInstanceOf[Int]) + 1 + val sVal = s.asInstanceOf[Int] + if (sVal < 1) { + 0 + } else { + l.asInstanceOf[UTF8String].indexOf( + r.asInstanceOf[UTF8String], + s.asInstanceOf[Int] - 1) + 1 + } } } } @@ -537,8 +587,10 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) if (!${substrGen.isNull}) { ${strGen.code} if (!${strGen.isNull}) { - ${ev.value} = ${strGen.value}.indexOf(${substrGen.value}, - ${startGen.value}) + 1; + if (${startGen.value} > 0) { + ${ev.value} = ${strGen.value}.indexOf(${substrGen.value}, + ${startGen.value} - 1) + 1; + } } else { ${ev.isNull} = true; } @@ -604,6 +656,154 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) override def prettyName: String = "rpad" } +object ParseUrl { + private val HOST = UTF8String.fromString("HOST") + private val PATH = UTF8String.fromString("PATH") + private val QUERY = UTF8String.fromString("QUERY") + private val REF = UTF8String.fromString("REF") + private val PROTOCOL = UTF8String.fromString("PROTOCOL") + private val FILE = UTF8String.fromString("FILE") + private val AUTHORITY = UTF8String.fromString("AUTHORITY") + private val USERINFO = UTF8String.fromString("USERINFO") + private val REGEXPREFIX = "(&|^)" + private val REGEXSUBFIX = "=([^&]*)" +} + +/** + * Extracts a part from a URL + */ +@ExpressionDescription( + usage = "_FUNC_(url, partToExtract[, key]) - extracts a part from a URL", + extended = """Parts: HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, USERINFO. + Key specifies which query to extract. + Examples: + > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'HOST') + 'spark.apache.org' + > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY') + 'query=1' + > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY', 'query') + '1'""") +case class ParseUrl(children: Seq[Expression]) + extends Expression with ExpectsInputTypes with CodegenFallback { + + override def nullable: Boolean = true + override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType) + override def dataType: DataType = StringType + override def prettyName: String = "parse_url" + + // If the url is a constant, cache the URL object so that we don't need to convert url + // from UTF8String to String to URL for every row. + @transient private lazy val cachedUrl = children(0) match { + case Literal(url: UTF8String, _) if url ne null => getUrl(url) + case _ => null + } + + // If the key is a constant, cache the Pattern object so that we don't need to convert key + // from UTF8String to String to StringBuilder to String to Pattern for every row. + @transient private lazy val cachedPattern = children(2) match { + case Literal(key: UTF8String, _) if key ne null => getPattern(key) + case _ => null + } + + // If the partToExtract is a constant, cache the Extract part function so that we don't need + // to check the partToExtract for every row. + @transient private lazy val cachedExtractPartFunc = children(1) match { + case Literal(part: UTF8String, _) => getExtractPartFunc(part) + case _ => null + } + + import ParseUrl._ + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size > 3 || children.size < 2) { + TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or three arguments") + } else { + super[ExpectsInputTypes].checkInputDataTypes() + } + } + + private def getPattern(key: UTF8String): Pattern = { + Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX) + } + + private def getUrl(url: UTF8String): URL = { + try { + new URL(url.toString) + } catch { + case e: MalformedURLException => null + } + } + + private def getExtractPartFunc(partToExtract: UTF8String): URL => String = { + partToExtract match { + case HOST => _.getHost + case PATH => _.getPath + case QUERY => _.getQuery + case REF => _.getRef + case PROTOCOL => _.getProtocol + case FILE => _.getFile + case AUTHORITY => _.getAuthority + case USERINFO => _.getUserInfo + case _ => (url: URL) => null + } + } + + private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = { + val m = pattern.matcher(query.toString) + if (m.find()) { + UTF8String.fromString(m.group(2)) + } else { + null + } + } + + private def extractFromUrl(url: URL, partToExtract: UTF8String): UTF8String = { + if (cachedExtractPartFunc ne null) { + UTF8String.fromString(cachedExtractPartFunc.apply(url)) + } else { + UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url)) + } + } + + private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = { + if (cachedUrl ne null) { + extractFromUrl(cachedUrl, partToExtract) + } else { + val currentUrl = getUrl(url) + if (currentUrl ne null) { + extractFromUrl(currentUrl, partToExtract) + } else { + null + } + } + } + + override def eval(input: InternalRow): Any = { + val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]} + if (evaluated.contains(null)) return null + if (evaluated.size == 2) { + parseUrlWithoutKey(evaluated(0), evaluated(1)) + } else { + // 3-arg, i.e. QUERY with key + assert(evaluated.size == 3) + if (evaluated(1) != QUERY) { + return null + } + + val query = parseUrlWithoutKey(evaluated(0), evaluated(1)) + if (query eq null) { + return null + } + + if (cachedPattern ne null) { + extractValueFromQuery(query, cachedPattern) + } else { + extractValueFromQuery(query, getPattern(evaluated(2))) + } + } + } +} + /** * Returns the input formatted according do printf-style format strings */ @@ -1140,3 +1340,65 @@ case class FormatNumber(x: Expression, d: Expression) override def prettyName: String = "format_number" } + +/** + * Splits a string into arrays of sentences, where each sentence is an array of words. + * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used. + */ +@ExpressionDescription( + usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of words.", + extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]") +case class Sentences( + str: Expression, + language: Expression = Literal(""), + country: Expression = Literal("")) + extends Expression with ImplicitCastInputTypes with CodegenFallback { + + def this(str: Expression) = this(str, Literal(""), Literal("")) + def this(str: Expression, language: Expression) = this(str, language, Literal("")) + + override def nullable: Boolean = true + override def dataType: DataType = + ArrayType(ArrayType(StringType, containsNull = false), containsNull = false) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = str :: language :: country :: Nil + + override def eval(input: InternalRow): Any = { + val string = str.eval(input) + if (string == null) { + null + } else { + val languageStr = language.eval(input).asInstanceOf[UTF8String] + val countryStr = country.eval(input).asInstanceOf[UTF8String] + val locale = if (languageStr != null && countryStr != null) { + new Locale(languageStr.toString, countryStr.toString) + } else { + Locale.getDefault + } + getSentences(string.asInstanceOf[UTF8String].toString, locale) + } + } + + private def getSentences(sentences: String, locale: Locale) = { + val bi = BreakIterator.getSentenceInstance(locale) + bi.setText(sentences) + var idx = 0 + val result = new ArrayBuffer[GenericArrayData] + while (bi.next != BreakIterator.DONE) { + val sentence = sentences.substring(idx, bi.current) + idx = bi.current + + val wi = BreakIterator.getWordInstance(locale) + var widx = 0 + wi.setText(sentence) + val words = new ArrayBuffer[UTF8String] + while (wi.next != BreakIterator.DONE) { + val word = sentence.substring(widx, wi.current) + widx = wi.current + if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word) + } + result += new GenericArrayData(words) + } + new GenericArrayData(result) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index cd6d3a00b7cfb..08cb6c0134e3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -44,6 +44,15 @@ abstract class SubqueryExpression extends Expression { protected def conditionString: String = children.mkString("[", " && ", "]") } +object SubqueryExpression { + def hasCorrelatedSubquery(e: Expression): Boolean = { + e.find { + case e: SubqueryExpression if e.children.nonEmpty => true + case _ => false + }.isDefined + } +} + /** * A subquery that will return only one row and one column. This will be converted into a physical * scalar subquery during planning. @@ -55,28 +64,26 @@ case class ScalarSubquery( children: Seq[Expression] = Seq.empty, exprId: ExprId = NamedExpression.newExprId) extends SubqueryExpression with Unevaluable { - - override def plan: LogicalPlan = SubqueryAlias(toString, query) - override lazy val resolved: Boolean = childrenResolved && query.resolved - - override def dataType: DataType = query.schema.fields.head.dataType - - override def checkInputDataTypes(): TypeCheckResult = { - if (query.schema.length != 1) { - TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + - query.schema.length.toString) - } else { - TypeCheckResult.TypeCheckSuccess - } + override lazy val references: AttributeSet = { + if (query.resolved) super.references -- query.outputSet + else super.references } - + override def dataType: DataType = query.schema.fields.head.dataType override def foldable: Boolean = false override def nullable: Boolean = true - + override def plan: LogicalPlan = SubqueryAlias(toString, query) override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan) + override def toString: String = s"scalar-subquery#${exprId.id} $conditionString" +} - override def toString: String = s"subquery#${exprId.id} $conditionString" +object ScalarSubquery { + def hasCorrelatedScalarSubquery(e: Expression): Boolean = { + e.find { + case e: ScalarSubquery if e.children.nonEmpty => true + case _ => false + }.isDefined + } } /** @@ -92,7 +99,7 @@ case class PredicateSubquery( extends SubqueryExpression with Predicate with Unevaluable { override lazy val resolved = childrenResolved && query.resolved override lazy val references: AttributeSet = super.references -- query.outputSet - override def nullable: Boolean = false + override def nullable: Boolean = nullAware override def plan: LogicalPlan = SubqueryAlias(toString, query) override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan) override def toString: String = s"predicate-subquery#${exprId.id} $conditionString" @@ -105,11 +112,24 @@ object PredicateSubquery { case _ => false }.isDefined } + + /** + * Returns whether there are any null-aware predicate subqueries inside Not. If not, we could + * turn the null-aware predicate into not-null-aware predicate. + */ + def hasNullAwarePredicateWithinNot(e: Expression): Boolean = { + e.find{ x => + x.isInstanceOf[Not] && e.find { + case p: PredicateSubquery => p.nullAware + case _ => false + }.isDefined + }.isDefined + } } /** * A [[ListQuery]] expression defines the query which we want to search in an IN subquery - * expression. It should and can only be used in conjunction with a IN expression. + * expression. It should and can only be used in conjunction with an IN expression. * * For example (SQL): * {{{ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index c0b453dccf5e9..6806591f68bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -82,16 +82,16 @@ case class WindowSpecDefinition( val partition = if (partitionSpec.isEmpty) { "" } else { - "PARTITION BY " + partitionSpec.map(_.sql).mkString(", ") + "PARTITION BY " + partitionSpec.map(_.sql).mkString(", ") + " " } val order = if (orderSpec.isEmpty) { "" } else { - "ORDER BY " + orderSpec.map(_.sql).mkString(", ") + "ORDER BY " + orderSpec.map(_.sql).mkString(", ") + " " } - s"($partition $order ${frameSpecification.toString})" + s"($partition$order${frameSpecification.toString})" } } @@ -321,8 +321,7 @@ abstract class OffsetWindowFunction val input: Expression /** - * Default result value for the function when the input expression returns NULL. The default will - * evaluated against the current row instead of the offset row. + * Default result value for the function when the 'offset'th row does not exist. */ val default: Expression @@ -348,7 +347,7 @@ abstract class OffsetWindowFunction */ override def foldable: Boolean = false - override def nullable: Boolean = default == null || default.nullable + override def nullable: Boolean = default == null || default.nullable || input.nullable override lazy val frame = { // This will be triggered by the Analyzer. @@ -373,20 +372,22 @@ abstract class OffsetWindowFunction } /** - * The Lead function returns the value of 'x' at 'offset' rows after the current row in the window. - * Offsets start at 0, which is the current row. The offset must be constant integer value. The - * default offset is 1. When the value of 'x' is null at the offset, or when the offset is larger - * than the window, the default expression is evaluated. - * - * This documentation has been based upon similar documentation for the Hive and Presto projects. + * The Lead function returns the value of 'x' at the 'offset'th row after the current row in + * the window. Offsets start at 0, which is the current row. The offset must be constant + * integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row, + * null is returned. If there is no such offset row, the default expression is evaluated. * * @param input expression to evaluate 'offset' rows after the current row. * @param offset rows to jump ahead in the partition. - * @param default to use when the input value is null or when the offset is larger than the window. + * @param default to use when the offset is larger than the window. The default value is null. */ @ExpressionDescription(usage = - """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows - after the current row in the window""") + """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at the 'offset'th row + after the current row in the window. + The default value of 'offset' is 1 and the default value of 'default' is null. + If the value of 'x' at the 'offset'th row is null, null is returned. + If there is no such offset row (e.g. when the offset is 1, the last row of the window + does not have any subsequent row), 'default' is returned.""") case class Lead(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { @@ -400,20 +401,22 @@ case class Lead(input: Expression, offset: Expression, default: Expression) } /** - * The Lag function returns the value of 'x' at 'offset' rows before the current row in the window. - * Offsets start at 0, which is the current row. The offset must be constant integer value. The - * default offset is 1. When the value of 'x' is null at the offset, or when the offset is smaller - * than the window, the default expression is evaluated. - * - * This documentation has been based upon similar documentation for the Hive and Presto projects. + * The Lag function returns the value of 'x' at the 'offset'th row before the current row in + * the window. Offsets start at 0, which is the current row. The offset must be constant + * integer value. The default offset is 1. When the value of 'x' is null at the 'offset'th row, + * null is returned. If there is no such offset row, the default expression is evaluated. * * @param input expression to evaluate 'offset' rows before the current row. * @param offset rows to jump back in the partition. - * @param default to use when the input value is null or when the offset is smaller than the window. + * @param default to use when the offset row does not exist. */ @ExpressionDescription(usage = - """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows - before the current row in the window""") + """_FUNC_(input, offset, default) - LAG returns the value of 'x' at the 'offset'th row + before the current row in the window. + The default value of 'offset' is 1 and the default value of 'default' is null. + If the value of 'x' at the 'offset'th row is null, null is returned. + If there is no such offset row (e.g. when the offset is 1, the first row of the window + does not have any previous row), 'default' is returned.""") case class Lag(input: Expression, offset: Expression, default: Expression) extends OffsetWindowFunction { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala new file mode 100644 index 0000000000000..47f039e6a4cc4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.xml + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Base class for xpath_boolean, xpath_double, xpath_int, etc. + * + * This is not the world's most efficient implementation due to type conversion, but works. + */ +abstract class XPathExtract extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + override def left: Expression = xml + override def right: Expression = path + + /** XPath expressions are always nullable, e.g. if the xml string is empty. */ + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (!path.foldable) { + TypeCheckFailure("path should be a string literal") + } else { + super.checkInputDataTypes() + } + } + + @transient protected lazy val xpathUtil = new UDFXPathUtil + @transient protected lazy val pathString: String = path.eval().asInstanceOf[UTF8String].toString + + /** Concrete implementations need to override the following three methods. */ + def xml: Expression + def path: Expression +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Evaluates a boolean xpath expression.", + extended = "> SELECT _FUNC_('1','a/b');\ntrue") +case class XPathBoolean(xml: Expression, path: Expression) extends XPathExtract { + + override def prettyName: String = "xpath_boolean" + override def dataType: DataType = BooleanType + + override def nullSafeEval(xml: Any, path: Any): Any = { + xpathUtil.evalBoolean(xml.asInstanceOf[UTF8String].toString, pathString) + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a short value that matches the xpath expression", + extended = "> SELECT _FUNC_('12','sum(a/b)');\n3") +case class XPathShort(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_int" + override def dataType: DataType = ShortType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.shortValue() + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns an integer value that matches the xpath expression", + extended = "> SELECT _FUNC_('12','sum(a/b)');\n3") +case class XPathInt(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_int" + override def dataType: DataType = IntegerType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.intValue() + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a long value that matches the xpath expression", + extended = "> SELECT _FUNC_('12','sum(a/b)');\n3") +case class XPathLong(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_long" + override def dataType: DataType = LongType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.longValue() + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a float value that matches the xpath expression", + extended = "> SELECT _FUNC_('12','sum(a/b)');\n3.0") +case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_float" + override def dataType: DataType = FloatType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.floatValue() + } +} + +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a double value that matches the xpath expression", + extended = "> SELECT _FUNC_('12','sum(a/b)');\n3.0") +case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_float" + override def dataType: DataType = DoubleType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, pathString) + if (ret eq null) null else ret.doubleValue() + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns the text contents of the first xml node that matches the xpath expression", + extended = "> SELECT _FUNC_('bcc','a/c');\ncc") +// scalastyle:on line.size.limit +case class XPathString(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath_string" + override def dataType: DataType = StringType + + override def nullSafeEval(xml: Any, path: Any): Any = { + val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, pathString) + UTF8String.fromString(ret) + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(xml, xpath) - Returns a string array of values within xml nodes that match the xpath expression", + extended = "> SELECT _FUNC_('b1b2b3c1c2','a/b/text()');\n['b1','b2','b3']") +// scalastyle:on line.size.limit +case class XPathList(xml: Expression, path: Expression) extends XPathExtract { + override def prettyName: String = "xpath" + override def dataType: DataType = ArrayType(StringType, containsNull = false) + + override def nullSafeEval(xml: Any, path: Any): Any = { + val nodeList = xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString) + if (nodeList ne null) { + val ret = new Array[UTF8String](nodeList.getLength) + var i = 0 + while (i < nodeList.getLength) { + ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue) + i += 1 + } + new GenericArrayData(ret) + } else { + null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 7d0584511fafe..834897b85023d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst - /** * An identifier that optionally specifies a database. * @@ -29,8 +28,16 @@ sealed trait IdentifierWithDatabase { def database: Option[String] + /* + * Escapes back-ticks within the identifier name with double-back-ticks. + */ + private def quoteIdentifier(name: String): String = name.replace("`", "``") + def quotedString: String = { - if (database.isDefined) s"`${database.get}`.`$identifier`" else s"`$identifier`" + val replacedId = quoteIdentifier(identifier) + val replacedDb = database.map(quoteIdentifier(_)) + + if (replacedDb.isDefined) s"`${replacedDb.get}`.`$replacedId`" else s"`$replacedId`" } def unquotedString: String = { @@ -44,7 +51,7 @@ sealed trait IdentifierWithDatabase { /** * Identifies a table in a database. * If `database` is not defined, the current database is used. - * When we register a permenent function in the FunctionRegistry, we use + * When we register a permanent function in the FunctionRegistry, we use * unquotedString as the function name. */ case class TableIdentifier(table: String, database: Option[String]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0b70edec8e37a..4c06038185c58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet +import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} -import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -48,9 +49,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) // we do not eliminate subqueries or compute current time in the analyzer. Batch("Finish Analysis", Once, EliminateSubqueryAliases, + ReplaceExpressions, ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), - DistinctAggregationRewriter) :: + RewriteDistinctAggregates) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// @@ -68,13 +70,13 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ReplaceExceptWithAntiJoin, ReplaceDistinctWithAggregate) :: Batch("Aggregate", fixedPoint, - RemoveLiteralFromGroupExpressions) :: + RemoveLiteralFromGroupExpressions, + RemoveRepetitionFromGroupExpressions) :: Batch("Operator Optimizations", fixedPoint, // Operator push down - SetOperationPushDown, - SamplePushDown, + PushThroughSetOperations, ReorderJoin, - OuterJoinElimination, + EliminateOuterJoin, PushPredicateThroughJoin, PushDownPredicate, LimitPushDown, @@ -88,27 +90,33 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) CombineUnions, // Constant folding and strength reduction NullPropagation, + FoldablePropagation, OptimizeIn(conf), ConstantFolding, LikeSimplification, BooleanSimplification, SimplifyConditionals, RemoveDispensableExpressions, - BinaryComparisonSimplification, + SimplifyBinaryComparison, PruneFilters, EliminateSorts, SimplifyCasts, SimplifyCaseConversionExpressions, + RewriteCorrelatedScalarSubquery, EliminateSerialization, - RewritePredicateSubquery) :: + RemoveAliasOnlyProject) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, - EmbedSerializerInFilter) :: + EmbedSerializerInFilter, + RemoveAliasOnlyProject) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation) :: Batch("OptimizeCodegen", Once, - OptimizeCodegen(conf)) :: Nil + OptimizeCodegen(conf)) :: + Batch("RewriteSubquery", Once, + RewritePredicateSubquery, + CollapseProject) :: Nil } /** @@ -138,38 +146,43 @@ class SimpleTestOptimizer extends Optimizer( new SimpleCatalystConf(caseSensitiveAnalysis = true)) /** - * Pushes operations down into a Sample. + * Removes the Project only conducting Alias of its child node. + * It is created mainly for removing extra Project added in EliminateSerialization rule, + * but can also benefit other operators. */ -object SamplePushDown extends Rule[LogicalPlan] { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Push down projection into sample - case Project(projectList, s @ Sample(lb, up, replace, seed, child)) => - Sample(lb, up, replace, seed, - Project(projectList, child))() +object RemoveAliasOnlyProject extends Rule[LogicalPlan] { + // Check if projectList in the Project node has the same attribute names and ordering + // as its child node. + private def isAliasOnly( + projectList: Seq[NamedExpression], + childOutput: Seq[Attribute]): Boolean = { + if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) { + false + } else { + projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) => + a.child match { + case attr: Attribute if a.name == attr.name && attr.semanticEquals(o) => true + case _ => false + } + } + } } -} -/** - * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) - * representation of data item. For example back to back map operations. - */ -object EliminateSerialization extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case d @ DeserializeToObject(_, _, s: SerializeFromObject) - if d.outputObjectType == s.inputObjectType => - // A workaround for SPARK-14803. Remove this after it is fixed. - if (d.outputObjectType.isInstanceOf[ObjectType] && - d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) { - s.child - } else { - // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. - val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) - Project(objAttr :: Nil, s.child) + def apply(plan: LogicalPlan): LogicalPlan = { + val aliasOnlyProject = plan.find { + case Project(pList, child) if isAliasOnly(pList, child.output) => true + case _ => false + } + + aliasOnlyProject.map { case p: Project => + val aliases = p.projectList.map(_.asInstanceOf[Alias]) + val attrMap = AttributeMap(aliases.map(a => (a.toAttribute, a.child))) + plan.transformAllExpressions { + case a: Attribute if attrMap.contains(a) => attrMap(a) + }.transform { + case op: Project if op.eq(p) => op.child } - case a @ AppendColumns(_, _, _, s: SerializeFromObject) - if a.deserializer.dataType == s.inputObjectType => - AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) + }.getOrElse(plan) } } @@ -180,7 +193,7 @@ object LimitPushDown extends Rule[LogicalPlan] { private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = { plan match { - case GlobalLimit(expr, child) => child + case GlobalLimit(_, child) => child case _ => plan } } @@ -213,7 +226,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. // - If neither side is limited, limit the side that is estimated to be bigger. - case LocalLimit(exp, join @ Join(left, right, joinType, condition)) => + case LocalLimit(exp, join @ Join(left, right, joinType, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLimit(exp, right)) case LeftOuter => join.copy(left = maybePushLimit(exp, left)) @@ -244,7 +257,7 @@ object LimitPushDown extends Rule[LogicalPlan] { * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, * we will not be able to pushdown Projections. */ -object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { +object PushThroughSetOperations extends Rule[LogicalPlan] with PredicateHelper { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. @@ -357,11 +370,12 @@ object ColumnPruning extends Rule[LogicalPlan] { g.copy(child = prunedChild(g.child, g.references)) // Turn off `join` for Generate if no column from it's child is used - case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => + case p @ Project(_, g: Generate) + if g.join && !g.outer && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) // Eliminate unneeded attributes from right side of a Left Existence Join. - case j @ Join(left, right, LeftExistence(_), condition) => + case j @ Join(_, right, LeftExistence(_), _) => j.copy(right = prunedChild(right, j.references)) // all the columns will be used to compare, so we can't prune them @@ -393,10 +407,10 @@ object ColumnPruning extends Rule[LogicalPlan] { case w: Window if w.windowExpressions.isEmpty => w.child // Eliminate no-op Projects - case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child + case p @ Project(_, child) if sameOutput(child.output, p.output) => child // Can't prune the columns on LeafNode - case p @ Project(_, l: LeafNode) => p + case p @ Project(_, _: LeafNode) => p // for all other logical plans that inherits the output from it's children case p @ Project(_, child) => @@ -479,7 +493,8 @@ object CollapseProject extends Rule[LogicalPlan] { // Substitute any attributes that are produced by the lower projection, so that we safely // eliminate it. // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - val rewrittenUpper = upper.map(_.transform { + // Use transformUp to prevent infinite recursion. + val rewrittenUpper = upper.map(_.transformUp { case a: Attribute => aliases.getOrElse(a, a) }) // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here. @@ -494,7 +509,7 @@ object CollapseProject extends Rule[LogicalPlan] { */ object CollapseRepartition extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case r @ Repartition(numPartitions, shuffle, Repartition(_, _, child)) => + case Repartition(numPartitions, shuffle, Repartition(_, _, child)) => Repartition(numPartitions, shuffle, child) } } @@ -548,6 +563,8 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { + case e @ WindowExpression(Cast(Literal(0L, _), _), _) => + Cast(Literal(0L), e.dataType) case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) @@ -568,7 +585,7 @@ object NullPropagation extends Rule[LogicalPlan] { // For Coalesce, remove null literals. case e @ Coalesce(children) => val newChildren = children.filter(nonNullLiteral) - if (newChildren.length == 0) { + if (newChildren.isEmpty) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { newChildren.head @@ -611,6 +628,66 @@ object NullPropagation extends Rule[LogicalPlan] { } } +/** + * Propagate foldable expressions: + * Replace attributes with aliases of the original foldable expressions if possible. + * Other optimizations will take advantage of the propagated foldable expressions. + * + * {{{ + * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3 + * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() + * }}} + */ +object FoldablePropagation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val foldableMap = AttributeMap(plan.flatMap { + case Project(projectList, _) => projectList.collect { + case a: Alias if a.child.foldable => (a.toAttribute, a) + } + case _ => Nil + }) + + if (foldableMap.isEmpty) { + plan + } else { + var stop = false + CleanupAliases(plan.transformUp { + case u: Union => + stop = true + u + case c: Command => + stop = true + c + // For outer join, although its output attributes are derived from its children, they are + // actually different attributes: the output of outer join is not always picked from its + // children, but can also be null. + // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes + // of outer join. + case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) => + stop = true + j + + // These 3 operators take attributes as constructor parameters, and these attributes + // can't be replaced by alias. + case m: MapGroups => + stop = true + m + case f: FlatMapGroupsInR => + stop = true + f + case c: CoGroup => + stop = true + c + + case p: LogicalPlan if !stop => p.transformExpressions { + case a: AttributeReference if foldableMap.contains(a) => + foldableMap(a) + } + }) + } + } +} + /** * Generate a list of additional filters from an operator's existing constraint but remove those * that are either already part of the operator's condition or are part of the operator's child @@ -791,7 +868,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable. * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable. */ -object BinaryComparisonSimplification extends Rule[LogicalPlan] with PredicateHelper { +object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { // True with equality @@ -839,7 +916,7 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case e @ CaseWhen(branches, _) if branches.headOption.map(_._1) == Some(TrueLiteral) => // If the first branch is a true literal, remove the entire CaseWhen and use the value // from that. Note that CaseWhen.branches should never be empty, and as a result the - // headOption (rather than head) added above is just a extra (and unnecessary) safeguard. + // headOption (rather than head) added above is just an extra (and unnecessary) safeguard. branches.head._2 } } @@ -850,8 +927,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { */ case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen => - e.toCodegen() + case e: CaseWhen if canCodegen(e) => e.toCodegen() + } + + private def canCodegen(e: CaseWhen): Boolean = { + val numBranches = e.branches.size + e.elseValue.size + numBranches <= conf.maxCaseBranchesForCodegen } } @@ -870,11 +951,11 @@ object CombineUnions extends Rule[LogicalPlan] { */ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case ff @ Filter(fc, nf @ Filter(nc, grandChild)) => + case Filter(fc, nf @ Filter(nc, grandChild)) => (ExpressionSet(splitConjunctivePredicates(fc)) -- ExpressionSet(splitConjunctivePredicates(nc))).reduceOption(And) match { case Some(ac) => - Filter(And(ac, nc), grandChild) + Filter(And(nc, ac), grandChild) case None => nf } @@ -950,19 +1031,23 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be - // pushed beneath must satisfy the following two conditions: + // pushed beneath must satisfy the following conditions: // 1. All the expressions are part of window partitioning key. The expressions can be compound. - // 2. Deterministic + // 2. Deterministic. + // 3. Placed before any non-deterministic predicates. case filter @ Filter(condition, w: Window) if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => - cond.references.subsetOf(partitionAttrs) && cond.deterministic && - // This is for ensuring all the partitioning expressions have been converted to alias - // in Analyzer. Thus, we do not need to check if the expressions in conditions are - // the same as the expressions used in partitioning columns. - partitionAttrs.forall(_.isInstanceOf[Attribute]) + + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(partitionAttrs) } + + val stayUp = rest ++ containingNonDeterministic + if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) @@ -981,11 +1066,16 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // For each filter, expand the alias and check if the filter can be evaluated using // attributes produced by the aggregate operator's child operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => val replaced = replaceAlias(cond, aliasMap) - replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic + cond.references.nonEmpty && replaced.references.subsetOf(aggregate.child.outputSet) } + val stayUp = rest ++ containingNonDeterministic + if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) @@ -999,9 +1089,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case filter @ Filter(condition, union: Union) => // Union could change the rows, so non-deterministic predicate can't be pushed down - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => - cond.deterministic - } + val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic) + if (pushDown.nonEmpty) { val pushDownCond = pushDown.reduceLeft(And) val output = union.output @@ -1023,17 +1112,28 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { filter } - // two filters should be combine together by other rules - case filter @ Filter(_, f: Filter) => filter - // should not push predicates through sample, or will generate different results. - case filter @ Filter(_, s: Sample) => filter - - case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => + case filter @ Filter(condition, u: UnaryNode) + if canPushThrough(u) && u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => u.withNewChildren(Seq(Filter(predicate, u.child))) } } + private def canPushThrough(p: UnaryNode): Boolean = p match { + // Note that some operators (e.g. project, aggregate, union) are being handled separately + // (earlier in this rule). + case _: AppendColumns => true + case _: BroadcastHint => true + case _: Distinct => true + case _: Generate => true + case _: Pivot => true + case _: RedistributeData => true + case _: Repartition => true + case _: ScriptTransformation => true + case _: Sort => true + case _ => false + } + private def pushDownPredicate( filter: Filter, grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { @@ -1041,9 +1141,15 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // come from grandchild. // TODO: non-deterministic predicates could be pushed through some operators that do not change // the rows. - val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond => - cond.deterministic && cond.references.subsetOf(grandchild.outputSet) + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(filter.condition).span(_.deterministic) + + val (pushDown, rest) = candidates.partition { cond => + cond.references.subsetOf(grandchild.outputSet) } + + val stayUp = rest ++ containingNonDeterministic + if (pushDown.nonEmpty) { val newChild = insertFilter(pushDown.reduceLeft(And)) if (stayUp.nonEmpty) { @@ -1057,110 +1163,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { } } -/** - * Reorder the joins and push all the conditions into join, so that the bottom ones have at least - * one condition. - * - * The order of joins will not be changed if all of them already have at least one condition. - */ -object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { - - /** - * Join a list of plans together and push down the conditions into them. - * - * The joined plan are picked from left to right, prefer those has at least one join condition. - * - * @param input a list of LogicalPlans to join. - * @param conditions a list of condition for join. - */ - @tailrec - def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { - assert(input.size >= 2) - if (input.size == 2) { - Join(input(0), input(1), Inner, conditions.reduceLeftOption(And)) - } else { - val left :: rest = input.toList - // find out the first join that have at least one join condition - val conditionalJoin = rest.find { plan => - val refs = left.outputSet ++ plan.outputSet - conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) - .exists(_.references.subsetOf(refs)) - } - // pick the next one if no condition left - val right = conditionalJoin.getOrElse(rest.head) - - val joinedRefs = left.outputSet ++ right.outputSet - val (joinConditions, others) = conditions.partition(_.references.subsetOf(joinedRefs)) - val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) - - // should not have reference to same logical plan - createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others) - } - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case j @ ExtractFiltersAndInnerJoins(input, conditions) - if input.size > 2 && conditions.nonEmpty => - createOrderedJoin(input, conditions) - } -} - -/** - * Elimination of outer joins, if the predicates can restrict the result sets so that - * all null-supplying rows are eliminated - * - * - full outer -> inner if both sides have such predicates - * - left outer -> inner if the right side has such predicates - * - right outer -> inner if the left side has such predicates - * - full outer -> left outer if only the left side has such predicates - * - full outer -> right outer if only the right side has such predicates - * - * This rule should be executed before pushing down the Filter - */ -object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { - - /** - * Returns whether the expression returns null or false when all inputs are nulls. - */ - private def canFilterOutNull(e: Expression): Boolean = { - if (!e.deterministic || PredicateSubquery.hasPredicateSubquery(e)) return false - val attributes = e.references.toSeq - val emptyRow = new GenericInternalRow(attributes.length) - val v = BindReferences.bindReference(e, attributes).eval(emptyRow) - v == null || v == false - } - - private def buildNewJoinType(filter: Filter, join: Join): JoinType = { - val splitConjunctiveConditions: Seq[Expression] = splitConjunctivePredicates(filter.condition) - val leftConditions = splitConjunctiveConditions - .filter(_.references.subsetOf(join.left.outputSet)) - val rightConditions = splitConjunctiveConditions - .filter(_.references.subsetOf(join.right.outputSet)) - - val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) || - filter.constraints.filter(_.isInstanceOf[IsNotNull]) - .exists(expr => join.left.outputSet.intersect(expr.references).nonEmpty) - val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) || - filter.constraints.filter(_.isInstanceOf[IsNotNull]) - .exists(expr => join.right.outputSet.intersect(expr.references).nonEmpty) - - join.joinType match { - case RightOuter if leftHasNonNullPredicate => Inner - case LeftOuter if rightHasNonNullPredicate => Inner - case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner - case FullOuter if leftHasNonNullPredicate => LeftOuter - case FullOuter if rightHasNonNullPredicate => RightOuter - case o => o - } - } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => - val newJoinType = buildNewJoinType(f, j) - if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) - } -} - /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other @@ -1173,18 +1175,25 @@ object OuterJoinElimination extends Rule[LogicalPlan] with PredicateHelper { */ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** - * Splits join condition expressions into three categories based on the attributes required - * to evaluate them. + * Splits join condition expressions or filter predicates (on a given join's output) into three + * categories based on the attributes required to evaluate them. Note that we explicitly exclude + * on-deterministic (i.e., stateful) condition expressions in canEvaluateInLeft or + * canEvaluateInRight to prevent pushing these predicates on either side of the join. * * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { + // Note: In order to ensure correctness, it's important to not change the relative ordering of + // any deterministic expression that follows a non-deterministic expression. To achieve this, + // we only consider pushing down those expressions that precede the first non-deterministic + // expression in the condition. + val (pushDownCandidates, containingNonDeterministic) = condition.span(_.deterministic) val (leftEvaluateCondition, rest) = - condition.partition(_.references subsetOf left.outputSet) + pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) val (rightEvaluateCondition, commonCondition) = - rest.partition(_.references subsetOf right.outputSet) + rest.partition(expr => expr.references.subsetOf(right.outputSet)) - (leftEvaluateCondition, rightEvaluateCondition, commonCondition) + (leftEvaluateCondition, rightEvaluateCondition, commonCondition ++ containingNonDeterministic) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -1192,7 +1201,6 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) - joinType match { case Inner => // push down the single side `where` condition into respective sides @@ -1200,9 +1208,16 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) - val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And) + val (newJoinConditions, others) = + commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e)) + val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) - Join(newLeft, newRight, Inner, newJoinCond) + val join = Join(newLeft, newRight, Inner, newJoinCond) + if (others.nonEmpty) { + Filter(others.reduceLeft(And), join) + } else { + join + } case RightOuter => // push down the right side only `where` condition val newLeft = left @@ -1229,7 +1244,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } // push down the join filter into sub query scanning if applicable - case f @ Join(left, right, joinType, joinCondition) => + case j @ Join(left, right, joinType, joinCondition) => val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) @@ -1259,7 +1274,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) Join(newLeft, newRight, LeftOuter, newJoinCond) - case FullOuter => f + case FullOuter => j case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") case UsingJoin(_, _) => sys.error("Untransformed Using join node") } @@ -1291,11 +1306,11 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] { */ object CombineLimits extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case ll @ GlobalLimit(le, nl @ GlobalLimit(ne, grandChild)) => + case GlobalLimit(le, GlobalLimit(ne, grandChild)) => GlobalLimit(Least(Seq(ne, le)), grandChild) - case ll @ LocalLimit(le, nl @ LocalLimit(ne, grandChild)) => + case LocalLimit(le, LocalLimit(ne, grandChild)) => LocalLimit(Least(Seq(ne, le)), grandChild) - case ll @ Limit(le, nl @ Limit(ne, grandChild)) => + case Limit(le, Limit(ne, grandChild)) => Limit(Least(Seq(ne, le)), grandChild) } } @@ -1367,10 +1382,15 @@ object DecimalAggregates extends Rule[LogicalPlan] { */ object ConvertToLocalRelation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(projectList, LocalRelation(output, data)) => + case Project(projectList, LocalRelation(output, data)) + if !projectList.exists(hasUnevaluableExpr) => val projection = new InterpretedProjection(projectList, output) LocalRelation(projectList.map(_.toAttribute), data.map(projection)) } + + private def hasUnevaluableExpr(expr: Expression): Boolean = { + expr.find(e => e.isInstanceOf[Unevaluable] && !e.isInstanceOf[AttributeReference]).isDefined + } } /** @@ -1433,63 +1453,28 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { */ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(grouping, _, _) => + case a @ Aggregate(grouping, _, _) if grouping.nonEmpty => val newGrouping = grouping.filter(!_.foldable) - a.copy(groupingExpressions = newGrouping) - } -} - -/** - * Computes the current date and time to make sure we return the same result in a single query. - */ -object ComputeCurrentTime extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val dateExpr = CurrentDate() - val timeExpr = CurrentTimestamp() - val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) - val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) - - plan transformAllExpressions { - case CurrentDate() => currentDate - case CurrentTimestamp() => currentTime - } - } -} - -/** Replaces the expression of CurrentDatabase with the current database name. */ -case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - plan transformAllExpressions { - case CurrentDatabase() => - Literal.create(sessionCatalog.getCurrentDatabase, StringType) - } + if (newGrouping.nonEmpty) { + a.copy(groupingExpressions = newGrouping) + } else { + // All grouping expressions are literals. We should not drop them all, because this can + // change the return semantics when the input of the Aggregate is empty (SPARK-17114). We + // instead replace this by single, easy to hash/sort, literal expression. + a.copy(groupingExpressions = Seq(Literal(0, IntegerType))) + } } } /** - * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a - * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed - * the deserializer in filter condition to save the extra serialization at last. + * Removes repetition from group expressions in [[Aggregate]], as they have no effect to the result + * but only makes the grouping key bigger. */ -object EmbedSerializerInFilter extends Rule[LogicalPlan] { +object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) => - val numObjects = condition.collect { - case a: Attribute if a == d.output.head => a - }.length - - if (numObjects > 1) { - // If the filter condition references the object more than one times, we should not embed - // deserializer in it as the deserialization will happen many times and slow down the - // execution. - // TODO: we can still embed it if we can make sure subexpression elimination works here. - s - } else { - val newCondition = condition transform { - case a: Attribute if a == d.output.head => d.deserializer - } - Filter(newCondition, d.child) - } + case a @ Aggregate(grouping, _, _) => + val newGrouping = ExpressionSet(grouping).toSeq + a.copy(groupingExpressions = newGrouping) } } @@ -1504,7 +1489,7 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] { */ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(condition, child) => + case Filter(condition, child) => val (withSubquery, withoutSubquery) = splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery) @@ -1517,19 +1502,301 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { // Filter the plan by applying left semi and left anti joins. withSubquery.foldLeft(newFilter) { case (p, PredicateSubquery(sub, conditions, _, _)) => - Join(p, sub, LeftSemi, conditions.reduceOption(And)) + val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) + Join(outerPlan, sub, LeftSemi, joinCond) case (p, Not(PredicateSubquery(sub, conditions, false, _))) => - Join(p, sub, LeftAnti, conditions.reduceOption(And)) + val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) + Join(outerPlan, sub, LeftAnti, joinCond) case (p, Not(PredicateSubquery(sub, conditions, true, _))) => - // This is a NULL-aware (left) anti join (NAAJ). + // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr // Construct the condition. A NULL in one of the conditions is regarded as a positive // result; such a row will be filtered out by the Anti-Join operator. - val anyNull = conditions.map(IsNull).reduceLeft(Or) - val condition = conditions.reduceLeft(And) - // Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS - // if performance matters to you. - Join(p, sub, LeftAnti, Option(Or(anyNull, condition))) + // Note that will almost certainly be planned as a Broadcast Nested Loop join. + // Use EXISTS if performance matters to you. + val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p) + val anyNull = splitConjunctivePredicates(joinCond.get).map(IsNull).reduceLeft(Or) + Join(outerPlan, sub, LeftAnti, Option(Or(anyNull, joinCond.get))) + case (p, predicate) => + val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p) + Project(p.output, Filter(newCond.get, inputPlan)) + } + } + + /** + * Given a predicate expression and an input plan, it rewrites + * any embedded existential sub-query into an existential join. + * It returns the rewritten expression together with the updated plan. + * Currently, it does not support null-aware joins. Embedded NOT IN predicates + * are blocked in the Analyzer. + */ + private def rewriteExistentialExpr( + exprs: Seq[Expression], + plan: LogicalPlan): (Option[Expression], LogicalPlan) = { + var newPlan = plan + val newExprs = exprs.map { e => + e transformUp { + case PredicateSubquery(sub, conditions, nullAware, _) => + // TODO: support null-aware join + val exists = AttributeReference("exists", BooleanType, nullable = false)() + newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)) + exists + } + } + (newExprs.reduceOption(And), newPlan) + } +} + +/** + * This rule rewrites correlated [[ScalarSubquery]] expressions into LEFT OUTER joins. + */ +object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { + /** + * Extract all correlated scalar subqueries from an expression. The subqueries are collected using + * the given collector. The expression is rewritten and returned. + */ + private def extractCorrelatedScalarSubqueries[E <: Expression]( + expression: E, + subqueries: ArrayBuffer[ScalarSubquery]): E = { + val newExpression = expression transform { + case s: ScalarSubquery if s.children.nonEmpty => + subqueries += s + s.query.output.head + } + newExpression.asInstanceOf[E] + } + + /** + * Statically evaluate an expression containing zero or more placeholders, given a set + * of bindings for placeholder values. + */ + private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { + val rewrittenExpr = expr transform { + case r: AttributeReference => + bindings(r.exprId) match { + case Some(v) => Literal.create(v, r.dataType) + case None => Literal.default(NullType) + } + } + Option(rewrittenExpr.eval()) + } + + /** + * Statically evaluate an expression containing one or more aggregates on an empty input. + */ + private def evalAggOnZeroTups(expr: Expression) : Option[Any] = { + // AggregateExpressions are Unevaluable, so we need to replace all aggregates + // in the expression with the value they would return for zero input tuples. + // Also replace attribute refs (for example, for grouping columns) with NULL. + val rewrittenExpr = expr transform { + case a @ AggregateExpression(aggFunc, _, _, resultId) => + aggFunc.defaultResult.getOrElse(Literal.default(NullType)) + + case _: AttributeReference => Literal.default(NullType) + } + Option(rewrittenExpr.eval()) + } + + /** + * Statically evaluate a scalar subquery on an empty input. + * + * WARNING: This method only covers subqueries that pass the checks under + * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in + * CheckAnalysis become less restrictive, this method will need to change. + */ + private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = { + // Inputs to this method will start with a chain of zero or more SubqueryAlias + // and Project operators, followed by an optional Filter, followed by an + // Aggregate. Traverse the operators recursively. + def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { + case SubqueryAlias(_, child) => evalPlan(child) + case Filter(condition, child) => + val bindings = evalPlan(child) + if (bindings.isEmpty) bindings + else { + val exprResult = evalExpr(condition, bindings).getOrElse(false) + .asInstanceOf[Boolean] + if (exprResult) bindings else Map.empty + } + + case Project(projectList, child) => + val bindings = evalPlan(child) + if (bindings.isEmpty) { + bindings + } else { + projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap + } + + case Aggregate(_, aggExprs, _) => + // Some of the expressions under the Aggregate node are the join columns + // for joining with the outer query block. Fill those expressions in with + // nulls and statically evaluate the remainder. + aggExprs.map { + case ref: AttributeReference => (ref.exprId, None) + case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None) + case ne => (ne.exprId, evalAggOnZeroTups(ne)) + }.toMap + + case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") + } + + val resultMap = evalPlan(plan) + + // By convention, the scalar subquery result is the leftmost field. + resultMap(plan.output.head.exprId) + } + + /** + * Split the plan for a scalar subquery into the parts above the innermost query block + * (first part of returned value), the HAVING clause of the innermost query block + * (optional second part) and the parts below the HAVING CLAUSE (third part). + */ + private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = { + val topPart = ArrayBuffer.empty[LogicalPlan] + var bottomPart: LogicalPlan = plan + while (true) { + bottomPart match { + case havingPart @ Filter(_, aggPart: Aggregate) => + return (topPart, Option(havingPart), aggPart) + + case aggPart: Aggregate => + // No HAVING clause + return (topPart, None, aggPart) + + case p @ Project(_, child) => + topPart += p + bottomPart = child + + case s @ SubqueryAlias(_, child) => + topPart += s + bottomPart = child + + case Filter(_, op) => + sys.error(s"Correlated subquery has unexpected operator $op below filter") + + case op @ _ => sys.error(s"Unexpected operator $op in correlated subquery") + } + } + + sys.error("This line should be unreachable") + } + + // Name of generated column used in rewrite below + val ALWAYS_TRUE_COLNAME = "alwaysTrue" + + /** + * Construct a new child plan by left joining the given subqueries to a base plan. + */ + private def constructLeftJoins( + child: LogicalPlan, + subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { + subqueries.foldLeft(child) { + case (currentChild, ScalarSubquery(query, conditions, _)) => + val origOutput = query.output.head + + val resultWithZeroTups = evalSubqueryOnZeroTups(query) + if (resultWithZeroTups.isEmpty) { + // CASE 1: Subquery guaranteed not to have the COUNT bug + Project( + currentChild.output :+ origOutput, + Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) + } else { + // Subquery might have the COUNT bug. Add appropriate corrections. + val (topPart, havingNode, aggNode) = splitSubquery(query) + + // The next two cases add a leading column to the outer join input to make it + // possible to distinguish between the case when no tuples join and the case + // when the tuple that joins contains null values. + // The leading column always has the value TRUE. + val alwaysTrueExprId = NamedExpression.newExprId + val alwaysTrueExpr = Alias(Literal.TrueLiteral, + ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId) + val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME, + BooleanType)(exprId = alwaysTrueExprId) + + val aggValRef = query.output.head + + if (havingNode.isEmpty) { + // CASE 2: Subquery with no HAVING clause + Project( + currentChild.output :+ + Alias( + If(IsNull(alwaysTrueRef), + Literal.create(resultWithZeroTups.get, origOutput.dataType), + aggValRef), origOutput.name)(exprId = origOutput.exprId), + Join(currentChild, + Project(query.output :+ alwaysTrueExpr, query), + LeftOuter, conditions.reduceOption(And))) + + } else { + // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. + // Need to modify any operators below the join to pass through all columns + // referenced in the HAVING clause. + var subqueryRoot: UnaryNode = aggNode + val havingInputs: Seq[NamedExpression] = aggNode.output + + topPart.reverse.foreach { + case Project(projList, _) => + subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) + case s @ SubqueryAlias(alias, _) => + subqueryRoot = SubqueryAlias(alias, subqueryRoot) + case op => sys.error(s"Unexpected operator $op in corelated subquery") + } + + // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups + // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) + // ELSE (aggregate value) END AS (original column name) + val caseExpr = Alias(CaseWhen(Seq( + (IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)), + (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), + aggValRef), + origOutput.name)(exprId = origOutput.exprId) + + Project( + currentChild.output :+ caseExpr, + Join(currentChild, + Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), + LeftOuter, conditions.reduceOption(And))) + + } + } + } + } + + /** + * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar + * subqueries. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(grouping, expressions, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + // We currently only allow correlated subqueries in an aggregate if they are part of the + // grouping expressions. As a result we need to replace all the scalar subqueries in the + // grouping expressions by their result. + val newGrouping = grouping.map { e => + subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e) + } + Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) + } else { + a + } + case p @ Project(expressions, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + Project(newExpressions, constructLeftJoins(child, subqueries)) + } else { + p + } + case f @ Filter(condition, child) => + val subqueries = ArrayBuffer.empty[ScalarSubquery] + val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) + if (subqueries.nonEmpty) { + Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) + } else { + f } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala similarity index 93% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 2e30d83a60970..d6a39ecf53b86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.analysis +package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} @@ -35,7 +35,7 @@ import org.apache.spark.sql.types.IntegerType * ("a", "ca1", "cb2", 5), * ("b", "ca1", "cb1", 13)) * .toDF("key", "cat1", "cat2", "value") - * data.registerTempTable("data") + * data.createOrReplaceTempView("data") * * val agg = data.groupBy($"key") * .agg( @@ -99,7 +99,7 @@ import org.apache.spark.sql.types.IntegerType * we could improve this in the current rule by applying more advanced expression canonicalization * techniques. */ -object DistinctAggregationRewriter extends Rule[LogicalPlan] { +object RewriteDistinctAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a: Aggregate => rewrite(a) @@ -119,14 +119,16 @@ object DistinctAggregationRewriter extends Rule[LogicalPlan] { .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - // Aggregation strategy can handle the query with single distinct - if (distinctAggGroups.size > 1) { + // Check if the aggregates contains functions that do not support partial aggregation. + val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial) + + // Aggregation strategy can handle queries with a single distinct group and partial aggregates. + if (distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && existsNonPartial)) { // Create the attributes for the grouping id and the group by clause. - val gid = - new AttributeReference("gid", IntegerType, false)(isGenerated = true) + val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true) val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)() + case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() } val groupByAttrs = groupByMap.map(_._2) @@ -135,9 +137,7 @@ object DistinctAggregationRewriter extends Rule[LogicalPlan] { def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Expression): AggregateFunction = { - af.withNewChildren(af.children.map { - case afc => attrs(afc) - }).asInstanceOf[AggregateFunction] + af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. @@ -265,5 +265,5 @@ object DistinctAggregationRewriter extends Rule[LogicalPlan] { // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.sql, e.dataType, true)() + e -> AttributeReference(e.sql, e.dataType, nullable = true)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala new file mode 100644 index 0000000000000..7c667315870f5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types._ + + +/** + * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can + * be evaluated. This is mainly used to provide compatibility with other databases. + * For example, we use this to support "nvl" by replacing it with "coalesce". + */ +object ReplaceExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case e: RuntimeReplaceable => e.replaced + } +} + + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} + + +/** Replaces the expression of CurrentDatabase with the current database name. */ +case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case CurrentDatabase() => + Literal.create(sessionCatalog.getCurrentDatabase, StringType) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala new file mode 100644 index 0000000000000..08062bdba6822 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.annotation.tailrec + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +/** + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. + */ +object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to join. + * @param conditions a list of condition for join. + */ + @tailrec + def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { + assert(input.size >= 2) + if (input.size == 2) { + val (joinConditions, others) = conditions.partition( + e => !SubqueryExpression.hasCorrelatedSubquery(e)) + val join = Join(input(0), input(1), Inner, joinConditions.reduceLeftOption(And)) + if (others.nonEmpty) { + Filter(others.reduceLeft(And), join) + } else { + join + } + } else { + val left :: rest = input.toList + // find out the first join that have at least one join condition + val conditionalJoin = rest.find { plan => + val refs = left.outputSet ++ plan.outputSet + conditions + .filterNot(l => l.references.nonEmpty && canEvaluate(l, left)) + .filterNot(r => r.references.nonEmpty && canEvaluate(r, plan)) + .exists(_.references.subsetOf(refs)) + } + // pick the next one if no condition left + val right = conditionalJoin.getOrElse(rest.head) + + val joinedRefs = left.outputSet ++ right.outputSet + val (joinConditions, others) = conditions.partition( + e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e)) + val joined = Join(left, right, Inner, joinConditions.reduceLeftOption(And)) + + // should not have reference to same logical plan + createOrderedJoin(Seq(joined) ++ rest.filterNot(_ eq right), others) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ ExtractFiltersAndInnerJoins(input, conditions) + if input.size > 2 && conditions.nonEmpty => + createOrderedJoin(input, conditions) + } +} + + +/** + * Elimination of outer joins, if the predicates can restrict the result sets so that + * all null-supplying rows are eliminated + * + * - full outer -> inner if both sides have such predicates + * - left outer -> inner if the right side has such predicates + * - right outer -> inner if the left side has such predicates + * - full outer -> left outer if only the left side has such predicates + * - full outer -> right outer if only the right side has such predicates + * + * This rule should be executed before pushing down the Filter + */ +object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Returns whether the expression returns null or false when all inputs are nulls. + */ + private def canFilterOutNull(e: Expression): Boolean = { + if (!e.deterministic || SubqueryExpression.hasCorrelatedSubquery(e)) return false + val attributes = e.references.toSeq + val emptyRow = new GenericInternalRow(attributes.length) + val boundE = BindReferences.bindReference(e, attributes) + if (boundE.find(_.isInstanceOf[Unevaluable]).isDefined) return false + val v = boundE.eval(emptyRow) + v == null || v == false + } + + private def buildNewJoinType(filter: Filter, join: Join): JoinType = { + val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints + val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) + val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) + + val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + + join.joinType match { + case RightOuter if leftHasNonNullPredicate => Inner + case LeftOuter if rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate && rightHasNonNullPredicate => Inner + case FullOuter if leftHasNonNullPredicate => LeftOuter + case FullOuter if rightHasNonNullPredicate => RightOuter + case o => o + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(condition, j @ Join(_, _, RightOuter | LeftOuter | FullOuter, _)) => + val newJoinType = buildNewJoinType(f, j) + if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala new file mode 100644 index 0000000000000..8a25cee614c6f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.StructType + +/* + * This file defines optimization rules related to object manipulation (for the Dataset API). + */ + + +/** + * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) + * representation of data item. For example back to back map operations. + */ +object EliminateSerialization extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case d @ DeserializeToObject(_, _, s: SerializeFromObject) + if d.outputObjectType == s.inputObjectType => + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + // We will remove it later in RemoveAliasOnlyProject rule. + val objAttr = + Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) + case a @ AppendColumns(_, _, _, s: SerializeFromObject) + if a.deserializer.dataType == s.inputObjectType => + AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) + } +} + + +/** + * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a + * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed + * the deserializer in filter condition to save the extra serialization at last. + */ +object EmbedSerializerInFilter extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) + // SPARK-15632: Conceptually, filter operator should never introduce schema change. This + // optimization rule also relies on this assumption. However, Dataset typed filter operator + // does introduce schema changes in some cases. Thus, we only enable this optimization when + // + // 1. either input and output schemata are exactly the same, or + // 2. both input and output schemata are single-field schema and share the same type. + // + // The 2nd case is included because encoders for primitive types always have only a single + // field with hard-coded field name "value". + // TODO Cleans this up after fixing SPARK-15632. + if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) => + + val numObjects = condition.collect { + case a: Attribute if a == d.output.head => a + }.length + + if (numObjects > 1) { + // If the filter condition references the object more than one times, we should not embed + // deserializer in it as the deserialization will happen many times and slow down the + // execution. + // TODO: we can still embed it if we can make sure subexpression elimination works here. + s + } else { + val newCondition = condition transform { + case a: Attribute if a == d.output.head => d.deserializer + } + val filter = Filter(newCondition, d.child) + + // Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`. + // We will remove it later in RemoveAliasOnlyProject rule. + val objAttrs = filter.output.zip(s.output).map { case (fout, sout) => + Alias(fout, fout.name)(exprId = sout.exprId) + } + Project(objAttrs, filter) + } + } + + def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = { + (lhs, rhs) match { + case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType + case _ => false + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 1f923f47dd0ba..e5ecefde41a49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.catalyst.parser import java.sql.{Date, Timestamp} @@ -25,7 +26,8 @@ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -81,49 +83,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * ******************************************************************************************** */ protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) - /** - * Make sure we do not try to create a plan for a native command. - */ - override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null - - /** - * Create a plan for a SHOW FUNCTIONS command. - */ - override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { - import ctx._ - if (qualifiedName != null) { - val names = qualifiedName().identifier().asScala.map(_.getText).toList - names match { - case db :: name :: Nil => - ShowFunctions(Some(db), Some(name)) - case name :: Nil => - ShowFunctions(None, Some(name)) - case _ => - throw new ParseException("SHOW FUNCTIONS unsupported name", ctx) - } - } else if (pattern != null) { - ShowFunctions(None, Some(string(pattern))) - } else { - ShowFunctions(None, None) - } - } - - /** - * Create a plan for a DESCRIBE FUNCTION command. - */ - override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { - import ctx._ - val functionName = - if (describeFuncName.STRING() != null) { - string(describeFuncName.STRING()) - } else if (describeFuncName.qualifiedName() != null) { - describeFuncName.qualifiedName().identifier().asScala.map(_.getText).mkString(".") - } else { - describeFuncName.getText - } - DescribeFunction(functionName, EXTENDED != null) - } - /** * Create a top-level plan with Common Table Expressions. */ @@ -132,19 +91,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Apply CTEs query.optional(ctx.ctes) { - val ctes = ctx.ctes.namedQuery.asScala.map { - case nCtx => - val namedQuery = visitNamedQuery(nCtx) - (namedQuery.alias, namedQuery) + val ctes = ctx.ctes.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) } - // Check for duplicate names. - ctes.groupBy(_._1).filter(_._2.size > 1).foreach { - case (name, _) => - throw new ParseException( - s"Name '$name' is used for multiple common table expressions", ctx) - } - + checkDuplicateKeys(ctes, ctx) With(query, ctes.toMap) } } @@ -180,7 +132,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Build the insert clauses. val inserts = ctx.multiInsertQueryBody.asScala.map { body => - assert(body.querySpecification.fromClause == null, + validate(body.querySpecification.fromClause == null, "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", body) @@ -219,6 +171,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val tableIdent = visitTableIdentifier(ctx.tableIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + val dynamicPartitionKeys = partitionKeys.filter(_._2.isEmpty) + if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { + throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) + } + InsertIntoTable( UnresolvedRelation(tableIdent, None), partitionKeys, @@ -232,11 +190,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitPartitionSpec( ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { - ctx.partitionVal.asScala.map { pVal => + val parts = ctx.partitionVal.asScala.map { pVal => val name = pVal.identifier.getText.toLowerCase val value = Option(pVal.constant).map(visitStringConstant) name -> value - }.toMap + } + // Check for duplicate partition columns in one spec. + checkDuplicateKeys(parts, ctx) + parts.toMap } /** @@ -438,7 +399,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * separated) relations here, these get converted into a single plan by condition-less inner join. */ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { - val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) + val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => + val right = plan(relation.relationPrimary) + val join = right.optionalMap(left)(Join(_, _, Inner, None)) + withJoinRelations(join, relation) + } ctx.lateralView.asScala.foldLeft(from)(withGenerate) } @@ -554,19 +519,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { query: LogicalPlan, ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { val expressions = expressionList(ctx.expression) - - // Create the generator. - val generator = ctx.qualifiedName.getText.toLowerCase match { - case "explode" if expressions.size == 1 => - Explode(expressions.head) - case "json_tuple" => - JsonTuple(expressions) - case name => - UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions) - } - Generate( - generator, + UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions), join = true, outer = ctx.OUTER != null, Some(ctx.tblName.getText.toLowerCase), @@ -575,54 +529,49 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a joins between two or more logical plans. + * Create a single relation referenced in a FROM claused. This method is used when a part of the + * join condition is nested, for example: + * {{{ + * select * from t1 join (t2 cross join t3) on col1 = col2 + * }}} */ - override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { - /** Build a join between two plans. */ - def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { - val baseJoinType = ctx.joinType match { - case null => Inner - case jt if jt.FULL != null => FullOuter - case jt if jt.SEMI != null => LeftSemi - case jt if jt.ANTI != null => LeftAnti - case jt if jt.LEFT != null => LeftOuter - case jt if jt.RIGHT != null => RightOuter - case _ => Inner - } + override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { + withJoinRelations(plan(ctx.relationPrimary), ctx) + } - // Resolve the join type and join condition - val (joinType, condition) = Option(ctx.joinCriteria) match { - case Some(c) if c.USING != null => - val columns = c.identifier.asScala.map { column => - UnresolvedAttribute.quoted(column.getText) - } - (UsingJoin(baseJoinType, columns), None) - case Some(c) if c.booleanExpression != null => - (baseJoinType, Option(expression(c.booleanExpression))) - case None if ctx.NATURAL != null => - (NaturalJoin(baseJoinType), None) - case None => - (baseJoinType, None) - } - Join(left, right, joinType, condition) - } + /** + * Join one more [[LogicalPlan]]s to the current logical plan. + */ + private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { + ctx.joinRelation.asScala.foldLeft(base) { (left, join) => + withOrigin(join) { + val baseJoinType = join.joinType match { + case null => Inner + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } - // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the - // first join clause is at the top. However fields of previously referenced tables can be used - // in following join clauses. The tree needs to be reversed in order to make this work. - var result = plan(ctx.left) - var current = ctx - while (current != null) { - current.right match { - case right: JoinRelationContext => - result = join(current, result, plan(right.left)) - current = right - case right => - result = join(current, result, plan(right)) - current = null + // Resolve the join type and join condition + val (joinType, condition) = Option(join.joinCriteria) match { + case Some(c) if c.USING != null => + val columns = c.identifier.asScala.map { column => + UnresolvedAttribute.quoted(column.getText) + } + (UsingJoin(baseJoinType, columns), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if join.NATURAL != null => + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, plan(join.right), joinType, condition) } } - result } /** @@ -641,7 +590,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // function takes X PERCENT as the input and the range of X is [0, 100], we need to // adjust the fraction. val eps = RandomSampler.roundingEpsilon - assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, s"Sampling fraction ($fraction) must be on interval [0, 1]", ctx) Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) @@ -655,8 +604,18 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val fraction = ctx.percentage.getText.toDouble sample(fraction / 100.0d) + case SqlBaseParser.BYTELENGTH_LITERAL => + throw new ParseException( + "TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + case SqlBaseParser.BUCKET if ctx.ON != null => - throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx) + if (ctx.identifier != null) { + throw new ParseException( + "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx) + } else { + throw new ParseException( + "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx) + } case SqlBaseParser.BUCKET => sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) @@ -688,48 +647,41 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { val table = UnresolvedRelation( visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.identifier).map(_.getText)) + Option(ctx.strictIdentifier).map(_.getText)) table.optionalMap(ctx.sample)(withSample) } + /** + * Create a table-valued function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) + } + /** * Create an inline table (a virtual table in Hive parlance). */ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { // Get the backing expressions. - val expressions = ctx.expression.asScala.map { eCtx => - val e = expression(eCtx) - assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) - e - } - - // Validate and evaluate the rows. - val (structType, structConstructor) = expressions.head.dataType match { - case st: StructType => - (st, (e: Expression) => e) - case dt => - val st = CreateStruct(Seq(expressions.head)).dataType - (st, (e: Expression) => CreateStruct(Seq(e))) - } - val rows = expressions.map { - case expression => - val safe = Cast(structConstructor(expression), structType) - safe.eval().asInstanceOf[InternalRow] + val rows = ctx.expression.asScala.map { e => + expression(e) match { + // inline table comes in two styles: + // style 1: values (1), (2), (3) -- multiple columns are supported + // style 2: values 1, 2, 3 -- only a single column is supported here + case CreateStruct(children) => children // style 1 + case child => Seq(child) // style 2 + } } - // Construct attributes. - val baseAttributes = structType.toAttributes.map(_.withNullability(true)) - val attributes = if (ctx.identifierList != null) { - val aliases = visitIdentifierList(ctx.identifierList) - assert(aliases.size == baseAttributes.size, - "Number of aliases must match the number of fields in an inline table.", ctx) - baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + val aliases = if (ctx.identifierList != null) { + visitIdentifierList(ctx.identifierList) } else { - baseAttributes + Seq.tabulate(rows.head.size)(i => s"col${i + 1}") } - // Create plan and add an alias if a name has been defined. - LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + val table = UnresolvedInlineTable(aliases, rows) + table.optionalMap(ctx.identifier)(aliasPlan) } /** @@ -738,7 +690,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * hooks. */ override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + plan(ctx.relation) + .optionalMap(ctx.sample)(withSample) + .optionalMap(ctx.strictIdentifier)(aliasPlan) } /** @@ -747,13 +701,15 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * hooks. */ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + plan(ctx.queryNoWith) + .optionalMap(ctx.sample)(withSample) + .optionalMap(ctx.strictIdentifier)(aliasPlan) } /** * Create an alias (SubqueryAlias) for a LogicalPlan. */ - private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = { + private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = { SubqueryAlias(alias.getText, plan) } @@ -787,7 +743,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * ******************************************************************************************** */ /** * Create an expression from the given context. This method just passes the context on to the - * vistor and only takes care of typing (We assume that the visitor returns an Expression here). + * visitor and only takes care of typing (We assume that the visitor returns an Expression here). */ protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) @@ -1058,6 +1014,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } } + /** + * Create a current timestamp/date expression. These are different from regular function because + * they do not require the user to specify braces when calling them. + */ + override def visitTimeFunctionCall(ctx: TimeFunctionCallContext): Expression = withOrigin(ctx) { + ctx.name.getType match { + case SqlBaseParser.CURRENT_DATE => + CurrentDate() + case SqlBaseParser.CURRENT_TIMESTAMP => + CurrentTimestamp() + } + } + /** * Create a function database (optional) and name pair. */ @@ -1112,7 +1081,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // We currently only allow foldable integers. def value: Int = { val e = expression(ctx.expression) - assert(e.resolved && e.foldable && e.dataType == IntegerType, + validate(e.resolved && e.foldable && e.dataType == IntegerType, "Frame bound value must be a constant integer.", ctx) e.eval().asInstanceOf[Int] @@ -1159,7 +1128,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * }}} */ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { - val e = expression(ctx.valueExpression) + val e = expression(ctx.value) val branches = ctx.whenClause.asScala.map { wCtx => (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) } @@ -1303,10 +1272,17 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** Create a numeric literal expression. */ - private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { - val raw = ctx.getText + private def numericLiteral + (ctx: NumberContext, minValue: BigDecimal, maxValue: BigDecimal, typeName: String) + (converter: String => Any): Literal = withOrigin(ctx) { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) try { - Literal(f(raw.substring(0, raw.length - 1))) + val rawBigDecimal = BigDecimal(rawStrippedQualifier) + if (rawBigDecimal < minValue || rawBigDecimal > maxValue) { + throw new ParseException(s"Numeric literal ${rawStrippedQualifier} does not " + + s"fit in range [${minValue}, ${maxValue}] for type ${typeName}", ctx) + } + Literal(converter(rawStrippedQualifier)) } catch { case e: NumberFormatException => throw new ParseException(e.getMessage, ctx) @@ -1316,29 +1292,42 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { /** * Create a Byte Literal expression. */ - override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { - _.toByte + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = { + numericLiteral(ctx, Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte) } /** * Create a Short Literal expression. */ - override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { - _.toShort + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = { + numericLiteral(ctx, Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort) } /** * Create a Long Literal expression. */ - override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { - _.toLong + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = { + numericLiteral(ctx, Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong) } /** * Create a Double Literal expression. */ - override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { - _.toDouble + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = { + numericLiteral(ctx, Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** + * Create a BigDecimal Literal expression. + */ + override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = { + val raw = ctx.getText.substring(0, ctx.getText.length - 2) + try { + Literal(BigDecimal(raw).underlying()) + } catch { + case e: AnalysisException => + throw new ParseException(e.message, ctx) + } } /** @@ -1365,7 +1354,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { val intervals = ctx.intervalField.asScala.map(visitIntervalField) - assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) + validate(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) Literal(intervals.reduce(_.add(_))) } @@ -1392,7 +1381,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case (from, Some(t)) => throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) } - assert(interval != null, "No interval can be constructed", ctx) + validate(interval != null, "No interval can be constructed", ctx) interval } catch { // Handle Exceptions thrown by CalendarInterval diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index d0132529f18ea..d687a85c18b63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -53,19 +53,15 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => astBuilder.visitSingleStatement(parser.singleStatement()) match { case plan: LogicalPlan => plan - case _ => nativeCommand(sqlText) + case _ => + val position = Origin(None, None) + throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) } } - /** Get the builder (visitor) which converts a ParseTree into a AST. */ + /** Get the builder (visitor) which converts a ParseTree into an AST. */ protected def astBuilder: AstBuilder - /** Create a native command, or fail when this is not supported. */ - protected def nativeCommand(sqlText: String): LogicalPlan = { - val position = Origin(None, None) - throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) - } - protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { logInfo(s"Parsing command: $command") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 64713cddf4e0d..9506de229b878 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -39,8 +39,15 @@ object ParserUtils { stream.getText(Interval.of(0, stream.size())) } - def parseException(message: String, ctx: ParserRuleContext): ParseException = { - new ParseException(s"Operation not allowed: $message", ctx) + def operationNotAllowed(message: String, ctx: ParserRuleContext): Nothing = { + throw new ParseException(s"Operation not allowed: $message", ctx) + } + + /** Check if duplicate keys exist in a set of key-value pairs. */ + def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = { + keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) => + throw new ParseException(s"Found duplicate keys '$key'.", ctx) + } } /** Get the code that creates the given node. */ @@ -70,8 +77,8 @@ object ParserUtils { Origin(Option(token.getLine), Option(token.getCharPositionInLine)) } - /** Assert if a condition holds. If it doesn't throw a parse exception. */ - def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { + /** Validate the condition. If it doesn't throw a parse exception. */ + def validate(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { if (!f) { throw new ParseException(message, ctx) } @@ -185,9 +192,7 @@ object ParserUtils { * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the * passed function. The original plan is returned when the context does not exist. */ - def optionalMap[C <: ParserRuleContext]( - ctx: C)( - f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { if (ctx != null) { f(ctx, plan) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 516b41cb138b4..327a0481acf7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -22,18 +22,27 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode /** - * Given a [[plans.logical.LogicalPlan LogicalPlan]], returns a list of `PhysicalPlan`s that can + * Given a [[LogicalPlan]], returns a list of `PhysicalPlan`s that can * be used for execution. If this strategy does not apply to the give logical operation then an * empty list should be returned. */ abstract class GenericStrategy[PhysicalPlan <: TreeNode[PhysicalPlan]] extends Logging { + + /** + * Returns a placeholder for a physical plan that executes `plan`. This placeholder will be + * filled in automatically by the QueryPlanner using the other execution strategies that are + * available. + */ + protected def planLater(plan: LogicalPlan): PhysicalPlan + def apply(plan: LogicalPlan): Seq[PhysicalPlan] } /** - * Abstract class for transforming [[plans.logical.LogicalPlan LogicalPlan]]s into physical plans. - * Child classes are responsible for specifying a list of [[Strategy]] objects that each of which - * can return a list of possible physical plan options. If a given strategy is unable to plan all + * Abstract class for transforming [[LogicalPlan]]s into physical plans. + * Child classes are responsible for specifying a list of [[GenericStrategy]] objects that + * each of which can return a list of possible physical plan options. + * If a given strategy is unable to plan all * of the remaining operators in the tree, it can call [[planLater]], which returns a placeholder * object that will be filled in using other available strategies. * @@ -46,13 +55,6 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { /** A list of execution strategies that can be used by the planner */ def strategies: Seq[GenericStrategy[PhysicalPlan]] - /** - * Returns a placeholder for a physical plan that executes `plan`. This placeholder will be - * filled in automatically by the QueryPlanner using the other execution strategies that are - * available. - */ - protected def planLater(plan: LogicalPlan): PhysicalPlan = this.plan(plan).next() - def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = { // Obviously a lot to do here still... val iter = strategies.view.flatMap(_(plan)).toIterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 00656191354f2..d952c9eeabaeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -112,6 +112,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // as join keys. val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) val joinKeys = predicates.flatMap { + case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) // Replace null with default value for joining key, then those rows with null in it could @@ -125,6 +126,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { case other => None } val otherPredicates = predicates.filterNot { + case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false case EqualTo(l, r) => canEvaluate(l, left) && canEvaluate(r, right) || canEvaluate(l, right) && canEvaluate(r, left) @@ -227,8 +229,8 @@ object IntegerIndex { * - Unnamed grouping expressions are named so that they can be referred to across phases of * aggregation * - Aggregations that appear multiple times are deduplicated. - * - The compution of the aggregations themselves is separated from the final result. For example, - * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final + * - The computation of the aggregations themselves is separated from the final result. For + * example, the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final * computation that computes `count.resultAttribute + 1`. */ object PhysicalAggregation { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d4447ca32d5a0..41c4e0085a93e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -35,7 +35,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => - constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) + constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) && + constraint.deterministic) } /** @@ -67,26 +68,104 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case _ => Seq.empty[Attribute] } + // Collect aliases from expressions, so we may avoid producing recursive constraints. + private lazy val aliasMap = AttributeMap( + (expressions ++ children.flatMap(_.expressions)).collect { + case a: Alias => (a.toAttribute, a.child) + }) + /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an - * additional constraint of the form `b = 5` + * additional constraint of the form `b = 5`. + * + * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)` + * as they are often useless and can lead to a non-converging set of constraints. */ private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + val constraintClasses = generateEquivalentConstraintClasses(constraints) + var inferredConstraints = Set.empty[Expression] constraints.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(l) => r + val candidateConstraints = constraints - eq + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(l) && + !isRecursiveDeduction(r, constraintClasses) => r }) - inferredConstraints ++= (constraints - eq).map(_ transform { - case a: Attribute if a.semanticEquals(r) => l + inferredConstraints ++= candidateConstraints.map(_ transform { + case a: Attribute if a.semanticEquals(r) && + !isRecursiveDeduction(l, constraintClasses) => l }) case _ => // No inference } inferredConstraints -- constraints } + /* + * Generate a sequence of expression sets from constraints, where each set stores an equivalence + * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following + * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal + * to an selected attribute. + */ + private def generateEquivalentConstraintClasses( + constraints: Set[Expression]): Seq[Set[Expression]] = { + var constraintClasses = Seq.empty[Set[Expression]] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + // Transform [[Alias]] to its child. + val left = aliasMap.getOrElse(l, l) + val right = aliasMap.getOrElse(r, r) + // Get the expression set for an equivalence constraint class. + val leftConstraintClass = getConstraintClass(left, constraintClasses) + val rightConstraintClass = getConstraintClass(right, constraintClasses) + if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) { + // Combine the two sets. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+ + (leftConstraintClass ++ rightConstraintClass) + } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty + // Update equivalence class of `left` expression. + constraintClasses = constraintClasses + .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right) + } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty + // Update equivalence class of `right` expression. + constraintClasses = constraintClasses + .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left) + } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty + // Create new equivalence constraint class since neither expression presents + // in any classes. + constraintClasses = constraintClasses :+ Set(left, right) + } + case _ => // Skip + } + + constraintClasses + } + + /* + * Get all expressions equivalent to the selected expression. + */ + private def getConstraintClass( + expr: Expression, + constraintClasses: Seq[Set[Expression]]): Set[Expression] = + constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression]) + + /* + * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it + * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function. + * Here we first get all expressions equal to `attr` and then check whether at least one of them + * is a child of the referenced expression. + */ + private def isRecursiveDeduction( + attr: Attribute, + constraintClasses: Seq[Set[Expression]]): Boolean = { + val expr = aliasMap.getOrElse(attr, attr) + getConstraintClass(expr, constraintClasses).exists { e => + expr.children.exists(_.semanticEquals(e)) + } + } + /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to @@ -172,7 +251,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case null => null } - val newArgs = productIterator.map(recursiveTransform).toArray + val newArgs = mapProductIterator(recursiveTransform) if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } @@ -206,7 +285,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case null => null } - val newArgs = productIterator.map(recursiveTransform).toArray + val newArgs = mapProductIterator(recursiveTransform) if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } @@ -257,6 +336,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT override def simpleString: String = statePrefix + super.simpleString + override def verboseString: String = simpleString + /** * All the subqueries of current plan. */ @@ -264,7 +345,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]}) } - override def innerChildren: Seq[PlanType] = subqueries + override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** * Canonicalized copy of this query plan. @@ -296,7 +377,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** * All the attributes that are used for this plan. */ - lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output) + lazy val allAttributes: AttributeSeq = children.flatMap(_.output) private def cleanExpression(e: Expression): Expression = e match { case a: Alias => @@ -318,7 +399,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case other => other } - productIterator.map { + mapProductIterator { case s: Option[_] => s.map(cleanArg) case s: Seq[_] => s.map(cleanArg) case m: Map[_, _] => m.mapValues(cleanArg) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 13f57c54a5623..80674d9b4bc9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Attribute object JoinType { def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { @@ -69,6 +70,14 @@ case object LeftAnti extends JoinType { override def sql: String = "LEFT ANTI" } +case class ExistenceJoin(exists: Attribute) extends JoinType { + override def sql: String = { + // This join type is only used in the end of optimizer and physical plans, we will not + // generate SQL for this join type + throw new UnsupportedOperationException + } +} + case class NaturalJoin(tpe: JoinType) extends JoinType { require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), "Unsupported natural join type " + tpe) @@ -84,6 +93,7 @@ case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) exte object LeftExistence { def unapply(joinType: JoinType): Option[JoinType] = joinType match { case LeftSemi | LeftAnti => Some(joinType) + case j: ExistenceJoin => Some(joinType) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala new file mode 100644 index 0000000000000..64f57835c8898 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +/** + * A logical node that represents a non-query command to be executed by the system. For example, + * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are + * eagerly executed. + */ +trait Command extends LeafNode { + final override def children: Seq[LogicalPlan] = Seq.empty + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 5813b74c770d8..890865d177845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -57,14 +58,34 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type] } - override protected def stringArgs = Iterator(output) + override protected def stringArgs: Iterator[Any] = { + if (data.isEmpty) { + Iterator("", output) + } else { + Iterator(output) + } + } - override def sameResult(plan: LogicalPlan): Boolean = plan match { - case LocalRelation(otherOutput, otherData) => - otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data - case _ => false + override def sameResult(plan: LogicalPlan): Boolean = { + plan.canonicalized match { + case LocalRelation(otherOutput, otherData) => + otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data + case _ => false + } } override lazy val statistics = Statistics(sizeInBytes = output.map(_.dataType.defaultSize).sum * data.length) + + def toSQL(inlineTableName: String): String = { + require(data.nonEmpty) + val types = output.map(_.dataType) + val rows = data.map { row => + val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql } + cells.mkString("(", ", ", ")") + } + "VALUES " + rows.mkString(", ") + + " AS " + inlineTableName + + output.map(_.name).mkString("(", ", ", ")") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 45ac126a72f5c..9c152fb88412a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -127,7 +127,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def resolve(schema: StructType, resolver: Resolver): Seq[Attribute] = { schema.map { field => - resolveQuoted(field.name, resolver).map { + resolve(field.name :: Nil, resolver).map { case a: AttributeReference => a case other => sys.error(s"can not handle nested schema yet... plan $this") }.getOrElse { @@ -265,6 +265,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { s"Reference '$name' is ambiguous, could be: $referenceNames.") } } + + /** + * Refreshes (or invalidates) any metadata/data cached in the plan recursively. + */ + def refresh(): Unit = children.foreach(_.refresh()) } /** @@ -288,15 +293,19 @@ abstract class UnaryNode extends LogicalPlan { * expressions with the corresponding alias */ protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { - projectList.flatMap { + var allConstraints = child.constraints.asInstanceOf[Set[Expression]] + projectList.foreach { case a @ Alias(e, _) => - child.constraints.map(_ transform { + // For every alias in `projectList`, replace the reference in constraints by its attribute. + allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => a.toAttribute - }).union(Set(EqualNullSafe(e, a.toAttribute))) - case _ => - Set.empty[Expression] - }.toSet + }) + allConstraints += EqualNullSafe(e, a.toAttribute) + case _ => // Don't change. + } + + allConstraints -- child.constraints } override protected def validConstraints: Set[Expression] = child.constraints @@ -313,7 +322,8 @@ abstract class UnaryNode extends LogicalPlan { // (product of children). sizeInBytes = 1 } - Statistics(sizeInBytes = sizeInBytes) + + child.statistics.copy(sizeInBytes = sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 9ac4c3a2a56c8..63f86ad09412c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -32,4 +32,4 @@ package org.apache.spark.sql.catalyst.plans.logical * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it * defaults to the product of children's `sizeInBytes`. */ -private[sql] case class Statistics(sizeInBytes: BigInt) +private[sql] case class Statistics(sizeInBytes: BigInt, isBroadcastable: Boolean = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b2297bbcaa9c5..2fdaa057bab18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -85,7 +85,7 @@ case class Generate( override lazy val resolved: Boolean = { generator.resolved && childrenResolved && - generator.elementTypes.length == generatorOutput.length && + generator.elementSchema.length == generatorOutput.length && generatorOutput.forall(_.resolved) } @@ -109,13 +109,15 @@ case class Filter(condition: Expression, child: LogicalPlan) override protected def validConstraints: Set[Expression] = { val predicates = splitConjunctivePredicates(condition) - .filterNot(PredicateSubquery.hasPredicateSubquery) + .filterNot(SubqueryExpression.hasCorrelatedSubquery) child.constraints.union(predicates.toSet) } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + protected def leftConstraints: Set[Expression] = left.constraints protected def rightConstraints: Set[Expression] = { @@ -125,16 +127,21 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar case a: Attribute => attributeRewrites(a) }) } + + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { + case (l, r) => l.dataType.asNullable == r.dataType.asNullable } && + duplicateResolved } -private[sql] object SetOperation { +object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) } case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) @@ -143,14 +150,6 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) - // Intersect are only resolved if they don't introduce ambiguous expression ids, - // since the Optimizer will convert Intersect to Join. - override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && - duplicateResolved - override def maxRows: Option[Long] = { if (children.exists(_.maxRows.isEmpty)) { None @@ -159,31 +158,25 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } - override def statistics: Statistics = { + override lazy val statistics: Statistics = { val leftSize = left.statistics.sizeInBytes val rightSize = right.statistics.sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize - Statistics(sizeInBytes = sizeInBytes) + val isBroadcastable = left.statistics.isBroadcastable || right.statistics.isBroadcastable + + Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable) } } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { - def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty - /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output override protected def validConstraints: Set[Expression] = leftConstraints - override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && - duplicateResolved - - override def statistics: Statistics = { - Statistics(sizeInBytes = left.statistics.sizeInBytes) + override lazy val statistics: Statistics = { + left.statistics.copy() } } @@ -216,13 +209,13 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { child.output.length == children.head.output.length && // compare the data types with the first child child.output.zip(children.head.output).forall { - case (l, r) => l.dataType == r.dataType } + case (l, r) => l.dataType.asNullable == r.dataType.asNullable } ) children.length > 1 && childrenResolved && allChildrenCompatible } - override def statistics: Statistics = { + override lazy val statistics: Statistics = { val sizeInBytes = children.map(_.statistics.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } @@ -273,6 +266,8 @@ case class Join( override def output: Seq[Attribute] = { joinType match { + case j: ExistenceJoin => + left.output :+ j.exists case LeftExistence(_) => left.output case LeftOuter => @@ -295,6 +290,8 @@ case class Join( case LeftSemi if condition.isDefined => left.constraints .union(splitConjunctivePredicates(condition.get).toSet) + case j: ExistenceJoin => + left.constraints case Inner => left.constraints.union(right.constraints) case LeftExistence(_) => @@ -326,6 +323,16 @@ case class Join( case UsingJoin(_, _) => false case _ => resolvedExceptNatural } + + override lazy val statistics: Statistics = joinType match { + case LeftAnti | LeftSemi => + // LeftSemi and LeftAnti won't ever be bigger than left + left.statistics.copy() + case _ => + // make sure we don't propagate isBroadcastable in other joins, because + // they could explode the size. + super.statistics.copy(isBroadcastable = false) + } } /** @@ -334,9 +341,8 @@ case class Join( case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - // We manually set statistics of BroadcastHint to smallest value to make sure - // the plan wrapped by BroadcastHint will be considered to broadcast later. - override def statistics: Statistics = Statistics(sizeInBytes = 1) + // set isBroadcastable to true so the child will be broadcasted + override lazy val statistics: Statistics = super.statistics.copy(isBroadcastable = true) } case class InsertIntoTable( @@ -350,10 +356,25 @@ case class InsertIntoTable( override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty + lazy val expectedColumns = { + if (table.output.isEmpty) { + None + } else { + // Note: The parser (visitPartitionSpec in AstBuilder) already turns + // keys in partition to their lowercase forms. + val staticPartCols = partition.filter(_._2.isDefined).keySet + Some(table.output.filterNot(a => staticPartCols.contains(a.name))) + } + } + assert(overwrite || !ifNotExists) - override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + assert(partition.values.forall(_.nonEmpty) || !ifNotExists) + override lazy val resolved: Boolean = + childrenResolved && table.resolved && expectedColumns.forall { expected => + child.output.size == expected.size && child.output.zip(expected).forall { + case (childAttr, tableAttr) => + DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + } } } @@ -392,19 +413,25 @@ case class Sort( /** Factory for constructing new `Range` nodes. */ object Range { - def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { + def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = { val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes new Range(start, end, step, numSlices, output) } + def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { + Range(start, end, step, Some(numSlices)) + } } case class Range( start: Long, end: Long, step: Long, - numSlices: Int, - output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { - require(step != 0, "step cannot be 0") + numSlices: Option[Int], + output: Seq[Attribute]) + extends LeafNode with MultiInstanceRelation { + + require(step != 0, s"step ($step) cannot be 0") + val numElements: BigInt = { val safeStart = BigInt(start) val safeEnd = BigInt(end) @@ -416,13 +443,24 @@ case class Range( } } - override def newInstance(): Range = - Range(start, end, step, numSlices, output.map(_.newInstance())) + def toSQL(): String = { + if (numSlices.isDefined) { + s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step, ${numSlices.get})" + } else { + s"SELECT id AS `${output.head.name}` FROM range($start, $end, $step)" + } + } + + override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def statistics: Statistics = { + override lazy val statistics: Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) } + + override def simpleString: String = { + s"Range ($start, $end, step=$step, splits=$numSlices)" + } } case class Aggregate( @@ -443,12 +481,14 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows - override def validConstraints: Set[Expression] = - child.constraints.union(getAliasedConstraints(aggregateExpressions)) + override def validConstraints: Set[Expression] = { + val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) + child.constraints.union(getAliasedConstraints(nonAgg)) + } - override def statistics: Statistics = { + override lazy val statistics: Statistics = { if (groupingExpressions.isEmpty) { - Statistics(sizeInBytes = 1) + super.statistics.copy(sizeInBytes = 1) } else { super.statistics } @@ -467,7 +507,7 @@ case class Window( def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) } -private[sql] object Expand { +object Expand { /** * Extract attribute set according to the grouping id. * @@ -491,7 +531,7 @@ private[sql] object Expand { /** * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. + * multiple output rows for an input row. * * @param bitmasks The bitmask set represents the grouping sets * @param groupByAliases The aliased original group by expressions @@ -533,7 +573,7 @@ private[sql] object Expand { /** * Apply a number of projections to every input row, hence we will get multiple output rows for - * a input row. + * an input row. * * @param projections to apply * @param output of all projections. @@ -546,13 +586,13 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def statistics: Statistics = { + override lazy val statistics: Statistics = { val sizeInBytes = super.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } // This operator can reuse attributes (for example making them null when doing a roll up) so - // the contraints of the child may no longer be valid. + // the constraints of the child may no longer be valid. override protected def validConstraints: Set[Expression] = Set.empty[Expression] } @@ -620,8 +660,14 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum - Statistics(sizeInBytes = sizeInBytes) + val sizeInBytes = if (limit == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + 1 + } else { + (limit: Long) * output.map(a => a.dataType.defaultSize).sum + } + child.statistics.copy(sizeInBytes = sizeInBytes) } } @@ -635,8 +681,14 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo } override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum - Statistics(sizeInBytes = sizeInBytes) + val sizeInBytes = if (limit == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + 1 + } else { + (limit: Long) * output.map(a => a.dataType.defaultSize).sum + } + child.statistics.copy(sizeInBytes = sizeInBytes) } } @@ -666,14 +718,14 @@ case class Sample( override def output: Seq[Attribute] = child.output - override def statistics: Statistics = { + override lazy val statistics: Statistics = { val ratio = upperBound - lowerBound // BigInt can't multiply with Double var sizeInBytes = child.statistics.sizeInBytes * (ratio * 100).toInt / 100 if (sizeInBytes == 0) { sizeInBytes = 1 } - Statistics(sizeInBytes = sizeInBytes) + child.statistics.copy(sizeInBytes = sizeInBytes) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil @@ -695,6 +747,7 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { */ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) extends UnaryNode { + require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") override def output: Seq[Attribute] = child.output } @@ -712,5 +765,5 @@ case object OneRowRelation extends LeafNode { * * [[LeafNode]]s must override this. */ - override def statistics: Statistics = Statistics(sizeInBytes = 1) + override lazy val statistics: Statistics = Statistics(sizeInBytes = 1) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala deleted file mode 100644 index fcffdbaaf07b7..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical - -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.types.StringType - -/** - * A logical node that represents a non-query command to be executed by the system. For example, - * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are - * eagerly executed. - */ -trait Command - -/** - * Returned for the "DESCRIBE FUNCTION [EXTENDED] functionName" command. - * @param functionName The function to be described. - * @param isExtended True if "DESCRIBE FUNCTION EXTENDED" is used. Otherwise, false. - */ -private[sql] case class DescribeFunction( - functionName: String, - isExtended: Boolean) extends LogicalPlan with Command { - - override def children: Seq[LogicalPlan] = Seq.empty - override val output: Seq[Attribute] = Seq( - AttributeReference("function_desc", StringType, nullable = false)()) -} - -/** - * Returned for the "SHOW FUNCTIONS" command, which will list all of the - * registered function list. - */ -private[sql] case class ShowFunctions( - db: Option[String], pattern: Option[String]) extends LogicalPlan with Command { - override def children: Seq[LogicalPlan] = Seq.empty - override val output: Seq[Attribute] = Seq( - AttributeReference("function", StringType, nullable = false)()) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 84339f439a666..7beeeb4f04bf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,26 +30,13 @@ object CatalystSerde { DeserializeToObject(deserializer, generateObjAttr[T], child) } - def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = { - val deserializer = UnresolvedDeserializer(encoder.deserializer) - DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child) - } - def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { SerializeFromObject(encoderFor[T].namedExpressions, child) } - def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = { - SerializeFromObject(encoder.namedExpressions, child) - } - def generateObjAttr[T : Encoder]: Attribute = { AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() } - - def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = { - AttributeReference("obj", encoder.deserializer.dataType, nullable = false)() - } } /** @@ -94,7 +81,7 @@ case class DeserializeToObject( */ case class SerializeFromObject( serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectConsumer { + child: LogicalPlan) extends ObjectConsumer { override def output: Seq[Attribute] = serializer.map(_.toAttribute) } @@ -118,7 +105,7 @@ object MapPartitions { case class MapPartitions( func: Iterator[Any] => Iterator[Any], outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer + child: LogicalPlan) extends ObjectConsumer with ObjectProducer object MapPartitionsInR { def apply( @@ -128,16 +115,16 @@ object MapPartitionsInR { schema: StructType, encoder: ExpressionEncoder[Row], child: LogicalPlan): LogicalPlan = { - val deserialized = CatalystSerde.deserialize(child, encoder) + val deserialized = CatalystSerde.deserialize(child)(encoder) val mapped = MapPartitionsInR( func, packageNames, broadcastVars, encoder.schema, schema, - CatalystSerde.generateObjAttrForRow(RowEncoder(schema)), + CatalystSerde.generateObjAttr(RowEncoder(schema)), deserialized) - CatalystSerde.serialize(mapped, RowEncoder(schema)) + CatalystSerde.serialize(mapped)(RowEncoder(schema)) } } @@ -152,8 +139,11 @@ case class MapPartitionsInR( inputSchema: StructType, outputSchema: StructType, outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer { + child: LogicalPlan) extends ObjectConsumer with ObjectProducer { override lazy val schema = outputSchema + + override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema, + outputObjAttr, child) } object MapElements { @@ -175,7 +165,7 @@ object MapElements { case class MapElements( func: AnyRef, outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer + child: LogicalPlan) extends ObjectConsumer with ObjectProducer /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { @@ -215,7 +205,7 @@ case class AppendColumnsWithObject( func: Any => Any, childSerializer: Seq[NamedExpression], newColumnsSerializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectConsumer { + child: LogicalPlan) extends ObjectConsumer { override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute) } @@ -256,6 +246,55 @@ case class MapGroups( outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectProducer +/** Factory for constructing new `FlatMapGroupsInR` nodes. */ +object FlatMapGroupsInR { + def apply( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + schema: StructType, + keyDeserializer: Expression, + valueDeserializer: Expression, + inputSchema: StructType, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: LogicalPlan): LogicalPlan = { + val mapped = FlatMapGroupsInR( + func, + packageNames, + broadcastVars, + inputSchema, + schema, + UnresolvedDeserializer(keyDeserializer, groupingAttributes), + UnresolvedDeserializer(valueDeserializer, dataAttributes), + groupingAttributes, + dataAttributes, + CatalystSerde.generateObjAttr(RowEncoder(schema)), + child) + CatalystSerde.serialize(mapped)(RowEncoder(schema)) + } +} + +case class FlatMapGroupsInR( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + inputSchema: StructType, + outputSchema: StructType, + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectProducer{ + + override lazy val schema = outputSchema + + override protected def stringArgs: Iterator[Any] = Iterator(inputSchema, outputSchema, + keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, + child) +} + /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index a5bdee1b854ce..28cbce8748fcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -43,7 +43,7 @@ case class RepartitionByExpression( child: LogicalPlan, numPartitions: Option[Int] = None) extends RedistributeData { numPartitions match { - case Some(n) => require(n > 0, "numPartitions must be greater than 0.") + case Some(n) => require(n > 0, s"Number of partitions ($n) must be positive.") case None => // Ok } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala index 42bdab42b79ff..b46f7a6d5a138 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/package.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst /** - * A a collection of common abstractions for query plans as well as + * A collection of common abstractions for query plans as well as * a base logical plan representation. */ package object plans diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index d449088498c8e..51d78dd1233fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -67,7 +67,7 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { require( ordering != Nil, - "The ordering expressions of a OrderedDistribution should not be Nil. " + + "The ordering expressions of an OrderedDistribution should not be Nil. " + "An AllTuples should be used to represent a distribution that only has " + "a single partition.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 5eb8fdf048e28..931d14dc18eea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -21,8 +21,9 @@ import java.util.UUID import scala.collection.Map import scala.collection.mutable.Stack +import scala.reflect.ClassTag -import org.apache.commons.lang.ClassUtils +import org.apache.commons.lang3.ClassUtils import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -104,7 +105,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def find(f: BaseType => Boolean): Option[BaseType] = f(this) match { case true => Some(this) - case false => children.foldLeft(None: Option[BaseType]) { (l, r) => l.orElse(r.find(f)) } + case false => children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) } } /** @@ -157,6 +158,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { ret } + /** + * Returns a Seq containing the leaves in this tree. + */ + def collectLeaves(): Seq[BaseType] = { + this.collect { case p if p.children.isEmpty => p } + } + /** * Finds and returns the first [[TreeNode]] of the tree for which the given partial function * is defined (pre-order), and applies the partial function to it. @@ -164,16 +172,29 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B] = { val lifted = pf.lift lifted(this).orElse { - children.foldLeft(None: Option[B]) { (l, r) => l.orElse(r.collectFirst(pf)) } + children.foldLeft(Option.empty[B]) { (l, r) => l.orElse(r.collectFirst(pf)) } } } + /** + * Efficient alternative to `productIterator.map(f).toArray`. + */ + protected def mapProductIterator[B: ClassTag](f: Any => B): Array[B] = { + val arr = Array.ofDim[B](productArity) + var i = 0 + while (i < arr.length) { + arr(i) = f(productElement(i)) + i += 1 + } + arr + } + /** * Returns a copy of this node where `f` has been applied to all the nodes children. */ def mapChildren(f: BaseType => BaseType): BaseType = { var changed = false - val newArgs = productIterator.map { + val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) if (newChild fastEquals arg) { @@ -184,7 +205,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } case nonChild: AnyRef => nonChild case null => null - }.toArray + } if (changed) makeCopy(newArgs) else this } @@ -197,7 +218,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { var changed = false val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer - val newArgs = productIterator.map { + val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. case s: Seq[_] => s.map { @@ -237,7 +258,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } case nonChild: AnyRef => nonChild case null => null - }.toArray + } if (changed) makeCopy(newArgs) else this } @@ -302,7 +323,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { rule: PartialFunction[BaseType, BaseType], nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = { var changed = false - val newArgs = productIterator.map { + val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) if (!(newChild fastEquals arg)) { @@ -353,7 +374,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } case nonChild: AnyRef => nonChild case null => null - }.toArray + } if (changed) makeCopy(newArgs) else this } @@ -424,23 +445,41 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ protected def stringArgs: Iterator[Any] = productIterator + private lazy val allChildren: Set[TreeNode[_]] = (children ++ innerChildren).toSet[TreeNode[_]] + /** Returns a string representing the arguments to this node, minus any children */ - def argString: String = productIterator.flatMap { - case tn: TreeNode[_] if containsChild(tn) => Nil - case tn: TreeNode[_] => s"${tn.simpleString}" :: Nil - case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil - case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil - case set: Set[_] => set.mkString("{", ",", "}") :: Nil + def argString: String = stringArgs.flatMap { + case tn: TreeNode[_] if allChildren.contains(tn) => Nil + case Some(tn: TreeNode[_]) if allChildren.contains(tn) => Nil + case Some(tn: TreeNode[_]) => tn.simpleString :: Nil + case tn: TreeNode[_] => tn.simpleString :: Nil + case seq: Seq[Any] if seq.toSet.subsetOf(allChildren.asInstanceOf[Set[Any]]) => Nil + case iter: Iterable[_] if iter.isEmpty => Nil + case seq: Seq[_] => Utils.truncatedString(seq, "[", ", ", "]") :: Nil + case set: Set[_] => Utils.truncatedString(set.toSeq, "{", ", ", "}") :: Nil + case array: Array[_] if array.isEmpty => Nil + case array: Array[_] => Utils.truncatedString(array, "[", ", ", "]") :: Nil + case null => Nil + case None => Nil + case Some(null) => Nil + case Some(any) => any :: Nil case other => other :: Nil }.mkString(", ") - /** String representation of this node without any children. */ + /** ONE line description of this node. */ def simpleString: String = s"$nodeName $argString".trim + /** ONE line description of this node with more information */ + def verboseString: String + override def toString: String = treeString /** Returns a string representation of the nodes in this tree */ - def treeString: String = generateTreeString(0, Nil, new StringBuilder).toString + def treeString: String = treeString(verbose = true) + + def treeString(verbose: Boolean): String = { + generateTreeString(0, Nil, new StringBuilder, verbose).toString + } /** * Returns a string representation of the nodes in this tree, where each operator is numbered. @@ -467,52 +506,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } /** - * All the nodes that will be used to generate tree string. - * - * For example: - * - * WholeStageCodegen - * +-- SortMergeJoin - * |-- InputAdapter - * | +-- Sort - * +-- InputAdapter - * +-- Sort - * - * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string - * like this: - * - * WholeStageCodegen - * : +- SortMergeJoin - * : :- INPUT - * : :- INPUT - * :- Sort - * :- Sort - */ - protected def treeChildren: Seq[BaseType] = children - - /** - * All the nodes that are parts of this node. - * - * For example: - * - * WholeStageCodegen - * +- SortMergeJoin - * |-- InputAdapter - * | +-- Sort - * +-- InputAdapter - * +-- Sort - * - * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree - * string like this: - * - * WholeStageCodegen - * : +- SortMergeJoin - * : :- INPUT - * : :- INPUT - * :- Sort - * :- Sort + * All the nodes that should be shown as a inner nested tree of this node. + * For example, this can be used to show sub-queries. */ - protected def innerChildren: Seq[BaseType] = Nil + protected def innerChildren: Seq[TreeNode[_]] = Seq.empty /** * Appends the string represent of this node and its children to the given StringBuilder. @@ -522,7 +519,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * `lastChildren` for the root node should be empty. */ def generateTreeString( - depth: Int, lastChildren: Seq[Boolean], builder: StringBuilder): StringBuilder = { + depth: Int, + lastChildren: Seq[Boolean], + builder: StringBuilder, + verbose: Boolean, + prefix: String = ""): StringBuilder = { if (depth > 0) { lastChildren.init.foreach { isLast => val prefixFragment = if (isLast) " " else ": " @@ -533,18 +534,22 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder.append(branch) } - builder.append(simpleString) + builder.append(prefix) + val headline = if (verbose) verboseString else simpleString + builder.append(headline) builder.append("\n") if (innerChildren.nonEmpty) { innerChildren.init.foreach(_.generateTreeString( - depth + 2, lastChildren :+ false :+ false, builder)) - innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) + depth + 2, lastChildren :+ false :+ false, builder, verbose)) + innerChildren.last.generateTreeString( + depth + 2, lastChildren :+ false :+ true, builder, verbose) } - if (treeChildren.nonEmpty) { - treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) - treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder) + if (children.nonEmpty) { + children.init.foreach( + _.generateTreeString(depth + 1, lastChildren :+ false, builder, verbose, prefix)) + children.last.generateTreeString(depth + 1, lastChildren :+ true, builder, verbose, prefix) } builder @@ -615,7 +620,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case s: String => JString(s) case u: UUID => JString(u.toString) case dt: DataType => dt.jsonValue - case m: Metadata => m.jsonValue + // SPARK-17356: In usage of mllib, Metadata may store a huge vector of data, transforming + // it to JSON may trigger OutOfMemoryError. + case m: Metadata => Metadata.empty.jsonValue case s: StorageLevel => ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~ ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala index 6d35f140cf23f..0c7205b3c6651 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/AbstractScalaRowIterator.scala @@ -23,7 +23,7 @@ package org.apache.spark.sql.catalyst.util * `Row` in order to work around a spurious IntelliJ compiler error. This cannot be an abstract * class because that leads to compilation errors under Scala 2.11. */ -private[spark] class AbstractScalaRowIterator[T] extends Iterator[T] { +class AbstractScalaRowIterator[T] extends Iterator[T] { override def hasNext: Boolean = throw new NotImplementedError override def next(): T = throw new NotImplementedError diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index d46f03ad8fbb3..4449da13c083c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -24,23 +24,6 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy()) - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[ArrayBasedMapData]) { - return false - } - - val other = o.asInstanceOf[ArrayBasedMapData] - if (other eq null) { - return false - } - - this.keyArray == other.keyArray && this.valueArray == other.valueArray - } - - override def hashCode: Int = { - keyArray.hashCode() * 37 + valueArray.hashCode() - } - override def toString: String = { s"keys: $keyArray, values: $valueArray" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f84c6592c6ae5..0b643a5b84268 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -58,6 +58,7 @@ object DateTimeUtils { final val YearZero = -17999 final val toYearZero = to2001 + 7304850 final val TimeZoneGMT = TimeZone.getTimeZone("GMT") + final val MonthOf31Days = Set(1, 3, 5, 7, 8, 10, 12) @transient lazy val defaultTimeZone = TimeZone.getDefault @@ -99,8 +100,8 @@ object DateTimeUtils { // reverse of millisToDays def daysToMillis(days: SQLDate): Long = { - val millisUtc = days.toLong * MILLIS_PER_DAY - millisUtc - threadLocalLocalTimeZone.get().getOffset(millisUtc) + val millisLocal = days.toLong * MILLIS_PER_DAY + millisLocal - getOffsetFromLocalMillis(millisLocal, threadLocalLocalTimeZone.get()) } def dateToString(days: SQLDate): String = @@ -333,8 +334,7 @@ object DateTimeUtils { digitsMilli += 1 } - if (!justTime && (segments(0) < 0 || segments(0) > 9999 || segments(1) < 1 || - segments(1) > 12 || segments(2) < 1 || segments(2) > 31)) { + if (!justTime && isInvalidDate(segments(0), segments(1), segments(2))) { return None } @@ -414,10 +414,10 @@ object DateTimeUtils { return None } segments(i) = currentSegmentValue - if (segments(0) < 0 || segments(0) > 9999 || segments(1) < 1 || segments(1) > 12 || - segments(2) < 1 || segments(2) > 31) { + if (isInvalidDate(segments(0), segments(1), segments(2))) { return None } + val c = threadLocalGmtCalendar.get() c.clear() c.set(segments(0), segments(1) - 1, segments(2), 0, 0, 0) @@ -425,6 +425,25 @@ object DateTimeUtils { Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) } + /** + * Return true if the date is invalid. + */ + private def isInvalidDate(year: Int, month: Int, day: Int): Boolean = { + if (year < 0 || year > 9999 || month < 1 || month > 12 || day < 1 || day > 31) { + return true + } + if (month == 2) { + if (isLeapYear(year) && day > 29) { + return true + } else if (!isLeapYear(year) && day > 28) { + return true + } + } else if (!MonthOf31Days.contains(month) && day > 30) { + return true + } + false + } + /** * Returns the microseconds since year zero (-17999) from microseconds since epoch. */ @@ -831,14 +850,75 @@ object DateTimeUtils { } } + /** + * Lookup the offset for given millis seconds since 1970-01-01 00:00:00 in given timezone. + * TODO: Improve handling of normalization differences. + * TODO: Replace with JSR-310 or similar system - see SPARK-16788 + */ + private[sql] def getOffsetFromLocalMillis(millisLocal: Long, tz: TimeZone): Long = { + var guess = tz.getRawOffset + // the actual offset should be calculated based on milliseconds in UTC + val offset = tz.getOffset(millisLocal - guess) + if (offset != guess) { + guess = tz.getOffset(millisLocal - offset) + if (guess != offset) { + // fallback to do the reverse lookup using java.sql.Timestamp + // this should only happen near the start or end of DST + val days = Math.floor(millisLocal.toDouble / MILLIS_PER_DAY).toInt + val year = getYear(days) + val month = getMonth(days) + val day = getDayOfMonth(days) + + var millisOfDay = (millisLocal % MILLIS_PER_DAY).toInt + if (millisOfDay < 0) { + millisOfDay += MILLIS_PER_DAY.toInt + } + val seconds = (millisOfDay / 1000L).toInt + val hh = seconds / 3600 + val mm = seconds / 60 % 60 + val ss = seconds % 60 + val ms = millisOfDay % 1000 + val calendar = Calendar.getInstance(tz) + calendar.set(year, month - 1, day, hh, mm, ss) + calendar.set(Calendar.MILLISECOND, ms) + guess = (millisLocal - calendar.getTimeInMillis()).toInt + } + } + guess + } + + /** + * Convert the timestamp `ts` from one timezone to another. + * + * TODO: Because of DST, the conversion between UTC and human time is not exactly one-to-one + * mapping, the conversion here may return wrong result, we should make the timestamp + * timezone-aware. + */ + def convertTz(ts: SQLTimestamp, fromZone: TimeZone, toZone: TimeZone): SQLTimestamp = { + // We always use local timezone to parse or format a timestamp + val localZone = threadLocalLocalTimeZone.get() + val utcTs = if (fromZone.getID == localZone.getID) { + ts + } else { + // get the human time using local time zone, that actually is in fromZone. + val localTs = ts + localZone.getOffset(ts / 1000L) * 1000L // in fromZone + localTs - getOffsetFromLocalMillis(localTs / 1000L, fromZone) * 1000L + } + if (toZone.getID == localZone.getID) { + utcTs + } else { + val localTs2 = utcTs + toZone.getOffset(utcTs / 1000L) * 1000L // in toZone + // treat it as local timezone, convert to UTC (we could get the expected human time back) + localTs2 - getOffsetFromLocalMillis(localTs2 / 1000L, localZone) * 1000L + } + } + /** * Returns a timestamp of given timezone from utc timestamp, with the same string * representation in their timezone. */ def fromUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = { - val tz = TimeZone.getTimeZone(timeZone) - val offset = tz.getOffset(time / 1000L) - time + offset * 1000L + convertTz(time, TimeZoneGMT, TimeZone.getTimeZone(timeZone)) } /** @@ -846,8 +926,16 @@ object DateTimeUtils { * string representation in their timezone. */ def toUTCTime(time: SQLTimestamp, timeZone: String): SQLTimestamp = { - val tz = TimeZone.getTimeZone(timeZone) - val offset = tz.getOffset(time / 1000L) - time - offset * 1000L + convertTz(time, TimeZone.getTimeZone(timeZone), TimeZoneGMT) + } + + /** + * Re-initialize the current thread's thread locals. Exposed for testing. + */ + private[util] def resetThreadLocals(): Unit = { + threadLocalGmtCalendar.remove() + threadLocalLocalTimeZone.remove() + threadLocalTimestampFormat.remove() + threadLocalDateFormat.remove() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 2b8cdc1e23ab3..7ee9581b63af5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -23,6 +23,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, Decimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +private object GenericArrayData { + + // SPARK-16634: Workaround for JVM bug present in some 1.7 versions. + def anyToSeq(seqOrArray: Any): Seq[Any] = seqOrArray match { + case seq: Seq[Any] => seq + case array: Array[_] => array.toSeq + } + +} + class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seq: Seq[Any]) = this(seq.toArray) @@ -37,6 +47,8 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq) def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq) + def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray)) + override def copy(): ArrayData = new GenericArrayData(array.clone()) override def numElements(): Int = array.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala index 40db6067adf71..94e8824cd18cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MapData.scala @@ -19,6 +19,11 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.types.DataType +/** + * This is an internal data representation for map type in Spark SQL. This should not implement + * `equals` and `hashCode` because the type cannot be used as join keys, grouping keys, or + * in equality tests. See SPARK-9415 and PR#13847 for the discussions. + */ abstract class MapData extends Serializable { def numElements(): Int diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index da90ddbd63afb..9c3f6b7c5d245 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -21,8 +21,6 @@ import org.apache.spark.unsafe.types.UTF8String object NumberConverter { - private val value = new Array[Byte](64) - /** * Divide x by m as if x is an unsigned 64-bit integer. Examples: * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2 @@ -49,7 +47,7 @@ object NumberConverter { * @param v is treated as an unsigned 64-bit integer * @param radix must be between MIN_RADIX and MAX_RADIX */ - private def decode(v: Long, radix: Int): Unit = { + private def decode(v: Long, radix: Int, value: Array[Byte]): Unit = { var tmpV = v java.util.Arrays.fill(value, 0.asInstanceOf[Byte]) var i = value.length - 1 @@ -69,11 +67,9 @@ object NumberConverter { * @param fromPos is the first element that should be considered * @return the result should be treated as an unsigned 64-bit integer. */ - private def encode(radix: Int, fromPos: Int): Long = { + private def encode(radix: Int, fromPos: Int, value: Array[Byte]): Long = { var v: Long = 0L val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once - // val - // exceeds this value var i = fromPos while (i < value.length && value(i) >= 0) { if (v >= bound) { @@ -94,7 +90,7 @@ object NumberConverter { * @param radix must be between MIN_RADIX and MAX_RADIX * @param fromPos is the first nonzero element */ - private def byte2char(radix: Int, fromPos: Int): Unit = { + private def byte2char(radix: Int, fromPos: Int, value: Array[Byte]): Unit = { var i = fromPos while (i < value.length) { value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte] @@ -109,9 +105,9 @@ object NumberConverter { * @param radix must be between MIN_RADIX and MAX_RADIX * @param fromPos is the first nonzero element */ - private def char2byte(radix: Int, fromPos: Int): Unit = { + private def char2byte(radix: Int, fromPos: Int, value: Array[Byte]): Unit = { var i = fromPos - while ( i < value.length) { + while (i < value.length) { value(i) = Character.digit(value(i), radix).asInstanceOf[Byte] i += 1 } @@ -124,8 +120,8 @@ object NumberConverter { */ def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX - || Math.abs(toBase) < Character.MIN_RADIX - || Math.abs(toBase) > Character.MAX_RADIX) { + || Math.abs(toBase) < Character.MIN_RADIX + || Math.abs(toBase) > Character.MAX_RADIX) { return null } @@ -136,15 +132,16 @@ object NumberConverter { var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) // Copy the digits in the right side of the array + val temp = new Array[Byte](64) var i = 1 while (i <= n.length - first) { - value(value.length - i) = n(n.length - i) + temp(temp.length - i) = n(n.length - i) i += 1 } - char2byte(fromBase, value.length - n.length + first) + char2byte(fromBase, temp.length - n.length + first, temp) // Do the conversion by going through a 64 bit integer - var v = encode(fromBase, value.length - n.length + first) + var v = encode(fromBase, temp.length - n.length + first, temp) if (negative && toBase > 0) { if (v < 0) { v = -1 @@ -156,21 +153,20 @@ object NumberConverter { v = -v negative = true } - decode(v, Math.abs(toBase)) + decode(v, Math.abs(toBase), temp) // Find the first non-zero digit or the last digits if all are zero. val firstNonZeroPos = { - val firstNonZero = value.indexWhere( _ != 0) - if (firstNonZero != -1) firstNonZero else value.length - 1 + val firstNonZero = temp.indexWhere( _ != 0) + if (firstNonZero != -1) firstNonZero else temp.length - 1 } - - byte2char(Math.abs(toBase), firstNonZeroPos) + byte2char(Math.abs(toBase), firstNonZeroPos, temp) var resultStartPos = firstNonZeroPos if (negative && toBase < 0) { resultStartPos = firstNonZeroPos - 1 - value(resultStartPos) = '-' + temp(resultStartPos) = '-' } - UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length)) + UTF8String.fromBytes(java.util.Arrays.copyOfRange(temp, resultStartPos, temp.length)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index f603cbfb0cc21..7101ca5a17de9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -42,11 +42,17 @@ object TypeUtils { } def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { - if (types.distinct.size > 1) { - TypeCheckResult.TypeCheckFailure( - s"input to $caller should all be the same type, but it's " + - types.map(_.simpleString).mkString("[", ", ", "]")) + if (types.size <= 1) { + TypeCheckResult.TypeCheckSuccess } else { + val firstType = types.head + types.foreach { t => + if (!t.sameType(firstType)) { + return TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's " + + types.map(_.simpleString).mkString("[", ", ", "]")) + } + } TypeCheckResult.TypeCheckSuccess } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index f879b34358a9b..4005087dad05a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -153,15 +153,7 @@ package object util { "`" + name.replace("`", "``") + "`" } - /** - * Returns the string representation of this expression that is safe to be put in - * code comments of generated code. The length is capped at 128 characters. - */ - def toCommentSafeString(str: String): String = { - val len = math.min(str.length, 128) - val suffix = if (str.length > len) "..." else "" - str.substring(0, len).replace("*/", "\\*\\/").replace("\\u", "\\\\u") + suffix - } + def toPrettySQL(e: Expression): String = usePrettyExpression(e).sql /* FIX ME implicit class debugLogging(a: Any) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 90af10f7a6b1e..1981fd8f0a1b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.types -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{runtimeMirror, TypeTag} +import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.util.Utils /** * A non-concrete data type, reserved for internal uses. @@ -129,11 +127,6 @@ protected[sql] abstract class AtomicType extends DataType { private[sql] type InternalType private[sql] val tag: TypeTag[InternalType] private[sql] val ordering: Ordering[InternalType] - - @transient private[sql] val classTag = ScalaReflectionLock.synchronized { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) - } } @@ -141,12 +134,13 @@ protected[sql] abstract class AtomicType extends DataType { * :: DeveloperApi :: * Numeric data types. */ +@DeveloperApi abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no - // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + // longer a no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 520e344361625..82a03b0afc002 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -77,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" + override def catalogString: String = s"array<${elementType.catalogString}>" + override def sql: String = s"ARRAY<${elementType.sql}>" override private[spark] def asNullable: ArrayType = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala index 1d73e40ffcd36..2c966230e447e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock * * Please use the singleton [[DataTypes.DateType]]. * - * Internally, this is represented as the number of days from epoch (1970-01-01 00:00:00 UTC). + * Internally, this is represented as the number of days from 1970-01-01. */ @DeveloperApi class DateType private() extends AtomicType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 6f4ec6b701919..70859052872dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.types -import java.math.{MathContext, RoundingMode} +import java.lang.{Long => JLong} +import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.DeveloperApi @@ -128,6 +129,21 @@ final class Decimal extends Ordered[Decimal] with Serializable { this } + /** + * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. + */ + def set(bigintval: BigInteger): Decimal = { + // TODO: Remove this once we migrate to java8 and use longValueExact() instead. + require( + bigintval.compareTo(LONG_MAX_BIG_INT) <= 0 && bigintval.compareTo(LONG_MIN_BIG_INT) >= 0, + s"BigInteger $bigintval too large for decimal") + this.decimalVal = null + this.longVal = bigintval.longValue() + this._precision = DecimalType.MAX_PRECISION + this._scale = 0 + this + } + /** * Set this Decimal to the given Decimal value. */ @@ -155,6 +171,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } + def toScalaBigInt: BigInt = BigInt(toLong) + + def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong) + def toUnscaledLong: Long = { if (decimalVal.ne(null)) { decimalVal.underlying().unscaledValue().longValue() @@ -222,10 +242,30 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (scale < _scale) { // Easier case: we just need to divide our scale down val diff = _scale - scale - val droppedDigits = longVal % POW_10(diff) - longVal /= POW_10(diff) - if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { - longVal += (if (longVal < 0) -1L else 1L) + val pow10diff = POW_10(diff) + // % and / always round to 0 + val droppedDigits = longVal % pow10diff + longVal /= pow10diff + roundMode match { + case ROUND_FLOOR => + if (droppedDigits < 0) { + longVal += -1L + } + case ROUND_CEILING => + if (droppedDigits > 0) { + longVal += 1L + } + case ROUND_HALF_UP => + if (math.abs(droppedDigits) * 2 >= pow10diff) { + longVal += (if (droppedDigits < 0) -1L else 1L) + } + case ROUND_HALF_EVEN => + val doubled = math.abs(droppedDigits) * 2 + if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { + longVal += (if (droppedDigits < 0) -1L else 1L) + } + case _ => + sys.error(s"Not supported rounding mode: $roundMode") } } else if (scale > _scale) { // We might be able to multiply longVal by a power of 10 and not overflow, but if not, @@ -302,7 +342,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - // HiveTypeCoercion will take care of the precision, scale of result + // TypeCoercion will take care of the precision, scale of result def * (that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) @@ -346,7 +386,7 @@ object Decimal { val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR - /** Maximum number of decimal digits a Int can represent */ + /** Maximum number of decimal digits an Int can represent */ val MAX_INT_DIGITS = 9 /** Maximum number of decimal digits a Long can represent */ @@ -361,6 +401,9 @@ object Decimal { private[sql] val ZERO = Decimal(0) private[sql] val ONE = Decimal(1) + private val LONG_MAX_BIG_INT = BigInteger.valueOf(JLong.MAX_VALUE) + private val LONG_MIN_BIG_INT = BigInteger.valueOf(JLong.MIN_VALUE) + def apply(value: Double): Decimal = new Decimal().set(value) def apply(value: Long): Decimal = new Decimal().set(value) @@ -371,6 +414,10 @@ object Decimal { def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value) + def apply(value: java.math.BigInteger): Decimal = new Decimal().set(value) + + def apply(value: scala.math.BigInt): Decimal = new Decimal().set(value.bigInteger) + def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) @@ -386,6 +433,9 @@ object Decimal { def fromDecimal(value: Any): Decimal = { value match { case j: java.math.BigDecimal => apply(j) + case d: BigDecimal => apply(d) + case k: scala.math.BigInt => apply(k) + case l: java.math.BigInteger => apply(l) case d: Decimal => d } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 9c1319c1c5e6f..6500875f95e54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType { private[sql] val LongDecimal = DecimalType(20, 0) private[sql] val FloatDecimal = DecimalType(14, 7) private[sql] val DoubleDecimal = DecimalType(30, 15) + private[sql] val BigIntDecimal = DecimalType(38, 0) private[sql] def forType(dataType: DataType): DecimalType = dataType match { case ByteType => ByteDecimal @@ -151,7 +152,7 @@ object DecimalType extends AbstractDataType { } /** - * Returns if dt is a DecimalType that fits inside a int + * Returns if dt is a DecimalType that fits inside an int */ def is32BitDecimalType(dt: DataType): Boolean = { dt match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 5474954af70e1..178960929bd83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ +import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: @@ -31,6 +32,7 @@ import org.json4s.JsonDSL._ * @param valueType The data type of map values. * @param valueContainsNull Indicates if map values have `null` values. */ +@DeveloperApi case class MapType( keyType: DataType, valueType: DataType, @@ -62,6 +64,8 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + override def catalogString: String = s"map<${keyType.catalogString},${valueType.catalogString}>" + override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" override private[spark] def asNullable: MapType = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala index 66f123682e117..657bd86ce17d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala @@ -84,25 +84,28 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any]) override def equals(obj: Any): Boolean = { obj match { - case that: Metadata => - if (map.keySet == that.map.keySet) { - map.keys.forall { k => - (map(k), that.map(k)) match { - case (v0: Array[_], v1: Array[_]) => - v0.view == v1.view - case (v0, v1) => - v0 == v1 - } + case that: Metadata if map.size == that.map.size => + map.keysIterator.forall { key => + that.map.get(key) match { + case Some(otherValue) => + val ourValue = map.get(key).get + (ourValue, otherValue) match { + case (v0: Array[Long], v1: Array[Long]) => java.util.Arrays.equals(v0, v1) + case (v0: Array[Double], v1: Array[Double]) => java.util.Arrays.equals(v0, v1) + case (v0: Array[Boolean], v1: Array[Boolean]) => java.util.Arrays.equals(v0, v1) + case (v0: Array[AnyRef], v1: Array[AnyRef]) => java.util.Arrays.equals(v0, v1) + case (v0, v1) => v0 == v1 + } + case None => false } - } else { - false } case other => false } } - override def hashCode: Int = Metadata.hash(this) + private lazy val _hashCode: Int = Metadata.hash(this) + override def hashCode: Int = _hashCode private def get[T](key: String): T = { map(key).asInstanceOf[T] @@ -113,8 +116,10 @@ sealed class Metadata private[types] (private[types] val map: Map[String, Any]) object Metadata { + private[this] val _empty = new Metadata(Map.empty) + /** Returns an empty Metadata. */ - def empty: Metadata = new Metadata(Map.empty) + def empty: Metadata = _empty /** Creates a Metadata instance from JSON. */ def fromJson(json: String): Metadata = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index b06aa7bc5245d..55fdfbe3e0464 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -22,11 +22,12 @@ import scala.util.Try import org.json4s.JsonDSL._ -import org.apache.spark.SparkException +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -103,6 +104,18 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + override def equals(that: Any): Boolean = { + that match { + case StructType(otherFields) => + java.util.Arrays.equals( + fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]]) + case _ => false + } + } + + private lazy val _hashCode: Int = java.util.Arrays.hashCode(fields.asInstanceOf[Array[AnyRef]]) + override def hashCode(): Int = _hashCode + /** * Creates a new [[StructType]] by adding a new field. * {{{ @@ -281,7 +294,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum override def simpleString: String = { - val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") + val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}") + Utils.truncatedString(fieldTypes, "struct<", ",", ">") + } + + override def catalogString: String = { + // in catalogString, we should not truncate + val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.catalogString}") s"struct<${fieldTypes.mkString(",")}>" } @@ -365,10 +384,10 @@ object StructType extends AbstractDataType { StructType(fields.asScala) } - protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = + private[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) - def removeMetadata(key: String, dt: DataType): DataType = + private[sql] def removeMetadata(key: String, dt: DataType): DataType = dt match { case StructType(fields) => val newFields = fields.map { f => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala index 0f24e51ed2b76..20ec75c70615b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.Utils * This object keeps the mappings between user classes and their User Defined Types (UDTs). * Previously we use the annotation `SQLUserDefinedType` to register UDTs for user classes. * However, by doing this, we add SparkSQL dependency on user classes. This object provides - * alterntive approach to register UDTs for user classes. + * alternative approach to register UDTs for user classes. */ private[spark] object UDTRegistration extends Serializable with Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index aa36121bdee7f..894631382f8ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -96,11 +96,12 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa } /** - * ::DeveloperApi:: + * :: DeveloperApi :: * The user defined type in Python. * * Note: This can only be accessed via Python UDF, or accessed as serialized object. */ +@DeveloperApi private[sql] class PythonUserDefinedType( val sqlType: DataType, override val pyUDT: String, diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java new file mode 100644 index 0000000000000..e0a54fe30ac7d --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/streaming/JavaOutputModeSuite.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming; + +import org.junit.Test; + +public class JavaOutputModeSuite { + + @Test + public void testOutputModes() { + OutputMode o1 = OutputMode.Append(); + assert(o1.toString().toLowerCase().contains("append")); + OutputMode o2 = OutputMode.Complete(); + assert (o2.toString().toLowerCase().contains("complete")); + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 711e8707116cf..850869799507f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -236,9 +236,8 @@ object RandomDataGenerator { // convert it to catalyst value to call udt's deserialize. val toCatalystType = CatalystTypeConverters.createToCatalystConverter(udt.sqlType) - if (maybeSqlTypeGenerator.isDefined) { - val sqlTypeGenerator = maybeSqlTypeGenerator.get - val generator = () => { + maybeSqlTypeGenerator.map { sqlTypeGenerator => + () => { val generatedScalaValue = sqlTypeGenerator.apply() if (generatedScalaValue == null) { null @@ -246,9 +245,6 @@ object RandomDataGenerator { udt.deserialize(toCatalystType(generatedScalaValue)) } } - Some(generator) - } else { - None } case unsupportedType => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 0672551b2972d..85563ddedc165 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -23,8 +23,10 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.typeOf import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils case class PrimitiveData( @@ -277,6 +279,18 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object])) } + test("SPARK-15062: Get correct serializer for List[_]") { + val list = List(1, 2, 3) + val serializer = serializerFor[List[Int]](BoundReference( + 0, ObjectType(list.getClass), nullable = false)) + assert(serializer.children.size == 2) + assert(serializer.children.head.isInstanceOf[Literal]) + assert(serializer.children.head.asInstanceOf[Literal].value === UTF8String.fromString("value")) + assert(serializer.children.last.isInstanceOf[NewInstance]) + assert(serializer.children.last.asInstanceOf[NewInstance] + .cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData])) + } + private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 1b08913ddd0e8..ff112c51697ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,8 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max} import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} @@ -111,7 +110,8 @@ class AnalysisErrorSuite extends AnalysisTest { "scalar subquery with 2 columns", testRelation.select( (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), - "Scalar subquery must return only one column, but got 2" :: Nil) + "The number of columns in the subquery (2)" :: + "does not match the required number of columns (1)":: Nil) errorTest( "scalar subquery with no column", @@ -162,6 +162,16 @@ class AnalysisErrorSuite extends AnalysisTest { UnspecifiedFrame)).as('window)), "Distinct window functions are not supported" :: Nil) + errorTest( + "nested aggregate functions", + testRelation.groupBy('a)( + AggregateExpression( + Max(AggregateExpression(Count(Literal(1)), Complete, isDistinct = false)), + Complete, + isDistinct = false)), + "not allowed to use an aggregate function in the argument of another aggregate function." :: Nil + ) + errorTest( "offset window function", testRelation2.select( @@ -329,6 +339,31 @@ class AnalysisErrorSuite extends AnalysisTest { "The start time" :: "must be greater than or equal to 0." :: Nil ) + errorTest( + "generator nested in expressions", + listRelation.select(Explode('list) + 1), + "Generators are not supported when it's nested in expressions, but got: (explode(list) + 1)" + :: Nil + ) + + errorTest( + "generator appears in operator which is not Project", + listRelation.sortBy(Explode('list).asc), + "Generators are not supported outside the SELECT clause, but got: Sort" :: Nil + ) + + errorTest( + "num_rows in limit clause must be equal to or greater than 0", + listRelation.limit(-1), + "The limit expression must be equal to or greater than 0, but got -1" :: Nil + ) + + errorTest( + "more than one generators in SELECT", + listRelation.select(Explode('list), Explode('list)), + "Only one generator allowed per select clause but found 2: explode(list), explode(list)" :: Nil + ) + test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) // Since we manually construct the logical plan at here and Sum only accept @@ -459,11 +494,14 @@ class AnalysisErrorSuite extends AnalysisTest { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() val c = AttributeReference("c", BooleanType)() - val plan1 = Filter(Cast(In(a, Seq(ListQuery(LocalRelation(b)))), BooleanType), LocalRelation(a)) - assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil) + val plan1 = Filter(Cast(Not(In(a, Seq(ListQuery(LocalRelation(b))))), BooleanType), + LocalRelation(a)) + assertAnalysisError(plan1, + "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) - val plan2 = Filter(Or(In(a, Seq(ListQuery(LocalRelation(b)))), c), LocalRelation(a, c)) - assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil) + val plan2 = Filter(Or(Not(In(a, Seq(ListQuery(LocalRelation(b))))), c), LocalRelation(a, c)) + assertAnalysisError(plan2, + "Null-aware predicate sub-queries cannot be used in nested conditions" :: Nil) } test("PredicateSubQuery correlated predicate is nested in an illegal plan") { @@ -496,12 +534,4 @@ class AnalysisErrorSuite extends AnalysisTest { LocalRelation(a)) assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) } - - test("Correlated Scalar Subquery") { - val a = AttributeReference("a", IntegerType)() - val b = AttributeReference("b", IntegerType)() - val sub = Project(Seq(b), Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b))) - val plan = Project(Seq(a, Alias(ScalarSubquery(sub), "b")()), LocalRelation(a)) - assertAnalysisError(plan, "Correlated scalar subqueries are not supported." :: Nil) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a63d1770f3255..1ba4f4a36d1a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -182,8 +182,7 @@ class AnalysisSuite extends AnalysisTest { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - // StringType will be promoted into Decimal(38, 18) - assert(pl(3).dataType == DecimalType(38, 22)) + assert(pl(3).dataType == DoubleType) assert(pl(4).dataType == DoubleType) } @@ -346,4 +345,50 @@ class AnalysisSuite extends AnalysisTest { assertAnalysisSuccess(query) } + + private def assertExpressionType( + expression: Expression, + expectedDataType: DataType): Unit = { + val afterAnalyze = + Project(Seq(Alias(expression, "a")()), OneRowRelation).analyze.expressions.head + if (!afterAnalyze.dataType.equals(expectedDataType)) { + fail( + s""" + |data type of expression $expression doesn't match expected: + |Actual data type: + |${afterAnalyze.dataType} + | + |Expected data type: + |${expectedDataType} + """.stripMargin) + } + } + + test("SPARK-15776: test whether Divide expression's data type can be deduced correctly by " + + "analyzer") { + assertExpressionType(sum(Divide(1, 2)), DoubleType) + assertExpressionType(sum(Divide(1.0, 2)), DoubleType) + assertExpressionType(sum(Divide(1, 2.0)), DoubleType) + assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) + assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) + assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) + assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) + assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) + } + + test("SPARK-18058: union and set operations shall not care about the nullability" + + " when comparing column types") { + val firstTable = LocalRelation( + AttributeReference("a", + StructType(Seq(StructField("a", IntegerType, nullable = true))), nullable = false)()) + val secondTable = LocalRelation( + AttributeReference("a", + StructType(Seq(StructField("a", IntegerType, nullable = false))), nullable = false)()) + + assertAnalysisSuccess(Union(firstTable, secondTable)) + assertAnalysisSuccess(Except(firstTable, secondTable)) + assertAnalysisSuccess(Intersect(firstTable, secondTable)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index b1fcf011f43e6..3acb261800c0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -31,7 +31,7 @@ trait AnalysisTest extends PlanTest { private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { val conf = new SimpleCatalystConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) - catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true) + catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index b3b1f5b920a53..66d9b4c8e351f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -52,7 +52,7 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { private val b: Expression = UnresolvedAttribute("b") before { - catalog.createTempTable("table", relation, overrideIfExists = true) + catalog.createTempView("table", relation, overrideIfExists = true) } private def checkType(expression: Expression, expectedType: DataType): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 660dc86c3e284..3aefb3cfc3336 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -85,7 +85,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Subtract('booleanField, 'booleanField), "requires (numeric or calendarinterval) type") assertError(Multiply('booleanField, 'booleanField), "requires numeric type") - assertError(Divide('booleanField, 'booleanField), "requires numeric type") + assertError(Divide('booleanField, 'booleanField), "requires (double or decimal) type") assertError(Remainder('booleanField, 'booleanField), "requires numeric type") assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral type") @@ -166,6 +166,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(Explode('intField), "input to function explode should be array or map type") + assertError(PosExplode('intField), + "input to function explode should be array or map type") } test("check types for CreateNamedStruct") { @@ -214,7 +216,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { assertError(operator(Seq('booleanField)), "requires at least 2 arguments") assertError(operator(Seq('intField, 'stringField)), "should all have the same type") - assertError(operator(Seq('intField, 'decimalField)), "should all have the same type") assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala new file mode 100644 index 0000000000000..920c6ea50f4ba --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types.{LongType, NullType} + +/** + * Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in + * end-to-end tests (in sql/core module) for verifying the correct error messages are shown + * in negative cases. + */ +class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { + + private def lit(v: Any): Literal = Literal(v) + + test("validate inputs are foldable") { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) + + // nondeterministic (rand) should not work + intercept[AnalysisException] { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) + } + + // aggregate should not work + intercept[AnalysisException] { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) + } + + // unresolved attribute should not work + intercept[AnalysisException] { + ResolveInlineTables.validateInputEvaluable( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) + } + } + + test("validate input dimensions") { + ResolveInlineTables.validateInputDimension( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) + + // num alias != data dimension + intercept[AnalysisException] { + ResolveInlineTables.validateInputDimension( + UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) + } + + // num alias == data dimension, but data themselves are inconsistent + intercept[AnalysisException] { + ResolveInlineTables.validateInputDimension( + UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) + } + } + + test("do not fire the rule if not all expressions are resolved") { + val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) + assert(ResolveInlineTables(table) == table) + } + + test("convert") { + val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) + val converted = ResolveInlineTables.convert(table) + + assert(converted.output.map(_.dataType) == Seq(LongType)) + assert(converted.data.size == 2) + assert(converted.data(0).getLong(0) == 1L) + assert(converted.data(1).getLong(0) == 2L) + } + + test("nullability inference in convert") { + val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) + val converted1 = ResolveInlineTables.convert(table1) + assert(!converted1.schema.fields(0).nullable) + + val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) + val converted2 = ResolveInlineTables.convert(table2) + assert(converted2.schema.fields(0).nullable) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index 1423a8705af27..100ec4d53fb81 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -100,7 +100,7 @@ class ResolveNaturalJoinSuite extends AnalysisTest { val naturalPlan = r3.join(r4, NaturalJoin(FullOuter), None) val usingPlan = r3.join(r4, UsingJoin(FullOuter, Seq(UnresolvedAttribute("b"))), None) val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select( - Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c) + Alias(Coalesce(Seq(b, b)), "b")(), a, c) checkAnalysis(naturalPlan, expected) checkAnalysis(usingPlan, expected) } @@ -113,4 +113,34 @@ class ResolveNaturalJoinSuite extends AnalysisTest { assert(error.message.contains( "using columns ['d] can not be resolved given input columns: [b, a, c]")) } + + test("using join with a case sensitive analyzer") { + val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) + + { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None) + checkAnalysis(usingPlan, expected, caseSensitive = true) + } + + { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("A"))), None) + assertAnalysisError( + usingPlan, + Seq("using columns ['A] can not be resolved given input columns: [b, a, c, a]")) + } + } + + test("using join with a case insensitive analyzer") { + val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) + + { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None) + checkAnalysis(usingPlan, expected, caseSensitive = false) + } + + { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("A"))), None) + checkAnalysis(usingPlan, expected, caseSensitive = false) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala similarity index 78% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index b591861ac094c..9560563a8ca56 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -19,18 +19,20 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp +import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -class HiveTypeCoercionSuite extends PlanTest { +class TypeCoercionSuite extends PlanTest { test("eligible implicit type cast") { def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { - val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) + val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) assert(got.map(_.dataType) == Option(expected), s"Failed to cast $from to $to") } @@ -101,7 +103,7 @@ class HiveTypeCoercionSuite extends PlanTest { test("ineligible implicit type cast") { def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { - val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) + val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got") } @@ -129,11 +131,11 @@ class HiveTypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) + var found = TypeCoercion.findTightestCommonTypeOfTwo(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t2, t1) + found = TypeCoercion.findTightestCommonTypeOfTwo(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } @@ -199,38 +201,49 @@ class HiveTypeCoercionSuite extends PlanTest { } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { + ruleTest(Seq(rule), initial, transformed) + } + + private def ruleTest( + rules: Seq[Rule[LogicalPlan]], + initial: Expression, + transformed: Expression): Unit = { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + val analyzer = new RuleExecutor[LogicalPlan] { + override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) + } + comparePlans( - rule(Project(Seq(Alias(initial, "a")()), testRelation)), + analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), Project(Seq(Alias(transformed, "a")()), testRelation)) } test("cast NullType for expressions that implement ExpectsInputTypes") { - import HiveTypeCoercionSuite._ + import TypeCoercionSuite._ - ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + ruleTest(TypeCoercion.ImplicitTypeCasts, AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + ruleTest(TypeCoercion.ImplicitTypeCasts, NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } test("cast NullType for binary operators") { - import HiveTypeCoercionSuite._ + import TypeCoercionSuite._ - ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + ruleTest(TypeCoercion.ImplicitTypeCasts, AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + ruleTest(TypeCoercion.ImplicitTypeCasts, NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -239,7 +252,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -251,7 +264,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("CreateArray casts") { - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -261,7 +274,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -270,11 +283,29 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), StringType) :: Cast(Literal("a"), StringType) :: Nil)) + + ruleTest(TypeCoercion.FunctionArgumentConversion, + CreateArray(Literal.create(null, DecimalType(5, 3)) + :: Literal(1) + :: Nil), + CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 3)) + :: Literal(1).cast(DecimalType(13, 3)) + :: Nil)) + + ruleTest(TypeCoercion.FunctionArgumentConversion, + CreateArray(Literal.create(null, DecimalType(5, 3)) + :: Literal.create(null, DecimalType(22, 10)) + :: Literal.create(null, DecimalType(38, 38)) + :: Nil), + CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38)) + :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Nil)) } test("CreateMap casts") { // type coercion for map keys - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -285,8 +316,19 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + CreateMap(Literal.create(null, DecimalType(5, 3)) + :: Literal("a") + :: Literal.create(2.0, FloatType) + :: Literal("b") + :: Nil), + CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType) + :: Literal("a") + :: Literal.create(2.0, FloatType).cast(DoubleType) + :: Literal("b") + :: Nil)) // type coercion for map values - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -297,8 +339,19 @@ class HiveTypeCoercionSuite extends PlanTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal.create(null, DecimalType(38, 0)) + :: Literal(2) + :: Literal.create(null, DecimalType(38, 38)) + :: Nil), + CreateMap(Literal(1) + :: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38)) + :: Literal(2) + :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) + :: Nil)) // type coercion for both map keys and values - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -313,7 +366,7 @@ class HiveTypeCoercionSuite extends PlanTest { test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -322,7 +375,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -331,23 +384,50 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal.create(null, DecimalType(10, 5)) + :: Literal(1) + :: Nil), + operator(Literal(1.0).cast(DoubleType) + :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) + :: Literal(1).cast(DoubleType) + :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + operator(Literal.create(null, DecimalType(15, 0)) + :: Literal.create(null, DecimalType(10, 5)) + :: Literal(1) + :: Nil), + operator(Literal.create(null, DecimalType(15, 0)).cast(DecimalType(20, 5)) + :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) + :: Literal(1).cast(DecimalType(20, 5)) + :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + operator(Literal.create(2L, LongType) + :: Literal(1) + :: Literal.create(null, DecimalType(10, 5)) + :: Nil), + operator(Literal.create(2L, LongType).cast(DecimalType(25, 5)) + :: Literal(1).cast(DecimalType(25, 5)) + :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(25, 5)) + :: Nil)) } } test("nanvl casts") { - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType))) - ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) } test("type coercion for If") { - val rule = HiveTypeCoercion.IfCoercion + val rule = TypeCoercion.IfCoercion ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), @@ -367,20 +447,20 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for CaseKeyWhen") { - ruleTest(HiveTypeCoercion.ImplicitTypeCasts, + ruleTest(TypeCoercion.ImplicitTypeCasts, CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(HiveTypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) - ruleTest(HiveTypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) - ruleTest(HiveTypeCoercion.CaseWhenCoercion, + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) @@ -388,7 +468,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("BooleanEquality type cast") { - val be = HiveTypeCoercion.BooleanEquality + val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. val one = Add(Literal(Decimal(1)), Literal(Decimal(0))) @@ -414,7 +494,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("BooleanEquality simplification") { - val be = HiveTypeCoercion.BooleanEquality + val be = TypeCoercion.BooleanEquality ruleTest(be, EqualTo(Literal(true), Literal(1)), @@ -473,7 +553,7 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) - val wt = HiveTypeCoercion.WidenSetOperationTypes + val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] @@ -512,7 +592,7 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("p", ByteType)(), AttributeReference("q", DoubleType)()) - val wt = HiveTypeCoercion.WidenSetOperationTypes + val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) val unionRelation = wt( @@ -536,7 +616,7 @@ class HiveTypeCoercionSuite extends PlanTest { } } - val dp = HiveTypeCoercion.WidenSetOperationTypes + val dp = TypeCoercion.WidenSetOperationTypes val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) @@ -584,7 +664,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("rule for date/timestamp operations") { - val dateTimeOperations = HiveTypeCoercion.DateTimeOperations + val dateTimeOperations = TypeCoercion.DateTimeOperations val date = Literal(new java.sql.Date(0L)) val timestamp = Literal(new Timestamp(0L)) val interval = Literal(new CalendarInterval(0, 0)) @@ -615,7 +695,7 @@ class HiveTypeCoercionSuite extends PlanTest { */ test("make sure rules do not fire early") { // InConversion - val inConversion = HiveTypeCoercion.InConversion + val inConversion = TypeCoercion.InConversion ruleTest(inConversion, In(UnresolvedAttribute("a"), Seq(Literal(1))), In(UnresolvedAttribute("a"), Seq(Literal(1))) @@ -630,10 +710,37 @@ class HiveTypeCoercionSuite extends PlanTest { Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } + + test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + + "in aggregation function like sum") { + val rules = Seq(FunctionArgumentConversion, Division) + // Casts Integer to Double + ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) + // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will + // cast the right expression to Double. + ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3))) + // Left expression is Int, right expression is Double + ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType), Cast(3.0, DoubleType)))) + // Casts Float to Double + ruleTest( + rules, + sum(Divide(4.0f, 3)), + sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType)))) + // Left expression is Decimal, right expression is Int. Another rule DecimalPrecision will cast + // the right expression to Decimal. + ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3))) + } + + test("SPARK-17117 null type coercion in divide") { + val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val nullLit = Literal.create(null, NullType) + ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) + ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) + } } -object HiveTypeCoercionSuite { +object TypeCoercionSuite { case class AnyTypeUnaryExpression(child: Expression) extends UnaryExpression with ExpectsInputTypes with Unevaluable { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index ce00a03e764fd..ff1bb126f463d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -19,14 +19,20 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.InternalOutputModes._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.IntegerType +/** A dummy command for testing unsupported operations. */ +case class DummyCommand() extends Command + class UnsupportedOperationsSuite extends SparkFunSuite { val attribute = AttributeReference("a", IntegerType, nullable = true)() @@ -44,12 +50,12 @@ class UnsupportedOperationsSuite extends SparkFunSuite { assertNotSupportedInBatchPlan( "streaming source", streamRelation, - Seq("with streaming source", "startStream")) + Seq("with streaming source", "start")) assertNotSupportedInBatchPlan( "select on streaming source", streamRelation.select($"count(*)"), - Seq("with streaming source", "startStream")) + Seq("with streaming source", "start")) /* @@ -61,38 +67,36 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // Batch plan in streaming query testError( "streaming plan - no streaming source", - Seq("without streaming source", "startStream")) { + Seq("without streaming source", "start")) { UnsupportedOperationChecker.checkForStreaming(batchRelation.select($"count(*)"), Append) } // Commands assertNotSupportedInStreamingPlan( "commmands", - DescribeFunction("func", true), + DummyCommand(), outputMode = Append, expectedMsgs = "commands" :: Nil) - // Aggregates: Not supported on streams in Append mode - assertSupportedInStreamingPlan( - "aggregate - batch with update output mode", - batchRelation.groupBy("a")("count(*)"), - outputMode = Update) + // Aggregation: Multiple streaming aggregations not supported + def aggExprs(name: String): Seq[NamedExpression] = Seq(Count("*").as(name)) assertSupportedInStreamingPlan( - "aggregate - batch with append output mode", - batchRelation.groupBy("a")("count(*)"), - outputMode = Append) + "aggregate - multiple batch aggregations", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), batchRelation)), + Append) assertSupportedInStreamingPlan( - "aggregate - stream with update output mode", - streamRelation.groupBy("a")("count(*)"), - outputMode = Update) + "aggregate - multiple aggregations but only one streaming aggregation", + Aggregate(Nil, aggExprs("c"), batchRelation).join( + Aggregate(Nil, aggExprs("d"), streamRelation), joinType = Inner), + Update) assertNotSupportedInStreamingPlan( - "aggregate - stream with append output mode", - streamRelation.groupBy("a")("count(*)"), - outputMode = Append, - Seq("aggregation", "append output mode")) + "aggregate - multiple streaming aggregations", + Aggregate(Nil, aggExprs("c"), Aggregate(Nil, aggExprs("d"), streamRelation)), + outputMode = Update, + expectedMsgs = Seq("multiple streaming aggregations")) // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( @@ -182,15 +186,30 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.intersect(_), streamStreamSupported = false) - - // Unary operations + // Sort: supported only on batch subplans and on aggregation + complete output mode testUnaryOperatorInStreamingPlan("sort", Sort(Nil, true, _)) + assertSupportedInStreamingPlan( + "sort - sort over aggregated data in Complete output mode", + streamRelation.groupBy()(Count("*")).sortBy(), + Complete) + assertNotSupportedInStreamingPlan( + "sort - sort over aggregated data in Update output mode", + streamRelation.groupBy()(Count("*")).sortBy(), + Update, + Seq("sort", "aggregat", "complete")) // sort on aggregations is supported on Complete mode only + + + // Other unary operations testUnaryOperatorInStreamingPlan("sort partitions", SortPartitions(Nil, _), expectedMsg = "sort") testUnaryOperatorInStreamingPlan( "sample", Sample(0.1, 1, true, 1L, _)(), expectedMsg = "sampling") testUnaryOperatorInStreamingPlan( "window", Window(Nil, Nil, Nil, _), expectedMsg = "non-time-based windows") + // Output modes with aggregation and non-aggregation plans + testOutputMode(Append, shouldSupportAggregation = false) + testOutputMode(Update, shouldSupportAggregation = true) + testOutputMode(Complete, shouldSupportAggregation = true) /* ======================================================================================= @@ -289,6 +308,38 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode) } + /** Test output mode with and without aggregation in the streaming plan */ + def testOutputMode( + outputMode: OutputMode, + shouldSupportAggregation: Boolean): Unit = { + + // aggregation + if (shouldSupportAggregation) { + assertNotSupportedInStreamingPlan( + s"$outputMode output mode - no aggregation", + streamRelation.where($"a" > 1), + outputMode = outputMode, + Seq("aggregation", s"$outputMode output mode")) + + assertSupportedInStreamingPlan( + s"$outputMode output mode - aggregation", + streamRelation.groupBy("a")("count(*)"), + outputMode = outputMode) + + } else { + assertSupportedInStreamingPlan( + s"$outputMode output mode - no aggregation", + streamRelation.where($"a" > 1), + outputMode = outputMode) + + assertNotSupportedInStreamingPlan( + s"$outputMode output mode - aggregation", + streamRelation.groupBy("a")("count(*)"), + outputMode = outputMode, + Seq("aggregation", s"$outputMode output mode")) + } + } + /** * Assert that the logical plan is supported as subplan insider a streaming plan. * @@ -353,17 +404,11 @@ class UnsupportedOperationsSuite extends SparkFunSuite { val e = intercept[AnalysisException] { testBody } - - if (!expectedMsgs.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { - fail( - s"""Exception message should contain the following substrings: - | - | ${expectedMsgs.mkString("\n ")} - | - |Actual exception message: - | - | ${e.getMessage} - """.stripMargin) + expectedMsgs.foreach { m => + if (!e.getMessage.toLowerCase.contains(m.toLowerCase)) { + fail(s"Exception message should contain: '$m', " + + s"actual exception message:\n\t'${e.getMessage}'") + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index d739b177430c7..31e422e91aea3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.catalyst.catalog +import java.io.File +import java.net.URI + import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException} import org.apache.spark.util.Utils @@ -198,6 +202,13 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac } } + test("rename table when destination table already exists") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.renameTable("db2", "tbl1", "tbl2") + } + } + test("alter table") { val catalog = newBasicCatalog() val tbl1 = catalog.getTable("db2", "tbl1") @@ -356,6 +367,13 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac } } + test("rename partitions when the new partition already exists") { + val catalog = newBasicCatalog() + intercept[AnalysisException] { + catalog.renamePartitions("db2", "tbl2", Seq(part1.spec), Seq(part2.spec)) + } + } + test("alter partitions") { val catalog = newBasicCatalog() try { @@ -365,6 +383,8 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac // See HIVE-2742 for more detail. catalog.setCurrentDatabase("db2") val newLocation = newUriForDatabase() + val newSerde = "com.sparkbricks.text.EasySerde" + val newSerdeProps = Map("spark" -> "bricks", "compressed" -> "false") // alter but keep spec the same val oldPart1 = catalog.getPartition("db2", "tbl2", part1.spec) val oldPart2 = catalog.getPartition("db2", "tbl2", part2.spec) @@ -377,6 +397,14 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(newPart2.storage.locationUri == Some(newLocation)) assert(oldPart1.storage.locationUri != Some(newLocation)) assert(oldPart2.storage.locationUri != Some(newLocation)) + // alter other storage information + catalog.alterPartitions("db2", "tbl2", Seq( + oldPart1.copy(storage = storageFormat.copy(serde = Some(newSerde))), + oldPart2.copy(storage = storageFormat.copy(serdeProperties = newSerdeProps)))) + val newPart1b = catalog.getPartition("db2", "tbl2", part1.spec) + val newPart2b = catalog.getPartition("db2", "tbl2", part2.spec) + assert(newPart1b.storage.serde == Some(newSerde)) + assert(newPart2b.storage.serdeProperties == newSerdeProps) // alter but change spec, should fail because new partition specs do not exist yet val badPart1 = part1.copy(spec = Map("a" -> "v1", "b" -> "v2")) val badPart2 = part2.copy(spec = Map("a" -> "v3", "b" -> "v4")) @@ -412,14 +440,14 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("create function when database does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.createFunction("does_not_exist", newFunc()) } } test("create function that already exists") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[FunctionAlreadyExistsException] { catalog.createFunction("db2", newFunc("func1")) } } @@ -433,14 +461,14 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac test("drop function when database does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.dropFunction("does_not_exist", "something") } } test("drop function that does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchFunctionException] { catalog.dropFunction("db2", "does_not_exist") } } @@ -449,15 +477,15 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val catalog = newBasicCatalog() assert(catalog.getFunction("db2", "func1") == CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, - Seq.empty[(String, String)])) - intercept[AnalysisException] { + Seq.empty[FunctionResource])) + intercept[NoSuchFunctionException] { catalog.getFunction("db2", "does_not_exist") } } test("get function when database does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.getFunction("does_not_exist", "func1") } } @@ -467,19 +495,27 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val newName = "funcky" assert(catalog.getFunction("db2", "func1").className == funcClass) catalog.renameFunction("db2", "func1", newName) - intercept[AnalysisException] { catalog.getFunction("db2", "func1") } + intercept[NoSuchFunctionException] { catalog.getFunction("db2", "func1") } assert(catalog.getFunction("db2", newName).identifier.funcName == newName) assert(catalog.getFunction("db2", newName).className == funcClass) - intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") } + intercept[NoSuchFunctionException] { catalog.renameFunction("db2", "does_not_exist", "me") } } test("rename function when database does not exist") { val catalog = newBasicCatalog() - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.renameFunction("does_not_exist", "func1", "func5") } } + test("rename function when new function already exists") { + val catalog = newBasicCatalog() + catalog.createFunction("db2", newFunc("func2", Some("db2"))) + intercept[FunctionAlreadyExistsException] { + catalog.renameFunction("db2", "func1", "func2") + } + } + test("list functions") { val catalog = newBasicCatalog() catalog.createFunction("db2", newFunc("func2")) @@ -488,6 +524,96 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac assert(catalog.listFunctions("db2", "func*").toSet == Set("func1", "func2")) } + // -------------------------------------------------------------------------- + // File System operations + // -------------------------------------------------------------------------- + + private def exists(uri: String, children: String*): Boolean = { + val base = new File(new URI(uri)) + children.foldLeft(base) { + case (parent, child) => new File(parent, child) + }.exists() + } + + test("create/drop database should create/delete the directory") { + val catalog = newBasicCatalog() + val db = newDb("mydb") + catalog.createDatabase(db, ignoreIfExists = false) + assert(exists(db.locationUri)) + + catalog.dropDatabase("mydb", ignoreIfNotExists = false, cascade = false) + assert(!exists(db.locationUri)) + } + + test("create/drop/rename table should create/delete/rename the directory") { + val catalog = newBasicCatalog() + val db = catalog.getDatabase("db1") + val table = CatalogTable( + identifier = TableIdentifier("my_table", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat(None, None, None, None, false, Map.empty), + schema = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string")) + ) + + catalog.createTable("db1", table, ignoreIfExists = false) + assert(exists(db.locationUri, "my_table")) + + catalog.renameTable("db1", "my_table", "your_table") + assert(!exists(db.locationUri, "my_table")) + assert(exists(db.locationUri, "your_table")) + + catalog.dropTable("db1", "your_table", ignoreIfNotExists = false) + assert(!exists(db.locationUri, "your_table")) + + val externalTable = CatalogTable( + identifier = TableIdentifier("external_table", Some("db1")), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat( + Some(Utils.createTempDir().getAbsolutePath), + None, None, None, false, Map.empty), + schema = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string")) + ) + catalog.createTable("db1", externalTable, ignoreIfExists = false) + assert(!exists(db.locationUri, "external_table")) + } + + test("create/drop/rename partitions should create/delete/rename the directory") { + val catalog = newBasicCatalog() + val databaseDir = catalog.getDatabase("db1").locationUri + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat(None, None, None, None, false, Map.empty), + schema = Seq( + CatalogColumn("col1", "int"), + CatalogColumn("col2", "string"), + CatalogColumn("a", "int"), + CatalogColumn("b", "string")), + partitionColumnNames = Seq("a", "b") + ) + catalog.createTable("db1", table, ignoreIfExists = false) + + catalog.createPartitions("db1", "tbl", Seq(part1, part2), ignoreIfExists = false) + assert(exists(databaseDir, "tbl", "a=1", "b=2")) + assert(exists(databaseDir, "tbl", "a=3", "b=4")) + + catalog.renamePartitions("db1", "tbl", Seq(part1.spec), Seq(part3.spec)) + assert(!exists(databaseDir, "tbl", "a=1", "b=2")) + assert(exists(databaseDir, "tbl", "a=5", "b=6")) + + catalog.dropPartitions("db1", "tbl", Seq(part2.spec, part3.spec), ignoreIfNotExists = false) + assert(!exists(databaseDir, "tbl", "a=3", "b=4")) + assert(!exists(databaseDir, "tbl", "a=5", "b=6")) + + val externalPartition = CatalogTablePartition( + Map("a" -> "7", "b" -> "8"), + CatalogStorageFormat( + Some(Utils.createTempDir().getAbsolutePath), + None, None, None, false, Map.empty) + ) + catalog.createPartitions("db1", "tbl", Seq(externalPartition), ignoreIfExists = false) + assert(!exists(databaseDir, "tbl", "a=7", "b=8")) + } } @@ -507,10 +633,17 @@ abstract class CatalogTestUtils { inputFormat = Some(tableInputFormat), outputFormat = Some(tableOutputFormat), serde = None, + compressed = false, serdeProperties = Map.empty) lazy val part1 = CatalogTablePartition(Map("a" -> "1", "b" -> "2"), storageFormat) lazy val part2 = CatalogTablePartition(Map("a" -> "3", "b" -> "4"), storageFormat) lazy val part3 = CatalogTablePartition(Map("a" -> "5", "b" -> "6"), storageFormat) + lazy val partWithMixedOrder = CatalogTablePartition(Map("b" -> "6", "a" -> "6"), storageFormat) + lazy val partWithLessColumns = CatalogTablePartition(Map("a" -> "1"), storageFormat) + lazy val partWithMoreColumns = + CatalogTablePartition(Map("a" -> "5", "b" -> "6", "c" -> "7"), storageFormat) + lazy val partWithUnknownColumns = + CatalogTablePartition(Map("a" -> "5", "unknown" -> "6"), storageFormat) lazy val funcClass = "org.apache.spark.myFunc" /** @@ -540,7 +673,7 @@ abstract class CatalogTestUtils { def newFunc(): CatalogFunction = newFunc("funcName") - def newUriForDatabase(): String = Utils.createTempDir().getAbsolutePath + def newUriForDatabase(): String = Utils.createTempDir().toURI.toString.stripSuffix("/") def newDb(name: String): CatalogDatabase = { CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty) @@ -563,7 +696,7 @@ abstract class CatalogTestUtils { } def newFunc(name: String, database: Option[String] = None): CatalogFunction = { - CatalogFunction(FunctionIdentifier(name, database), funcClass, Seq.empty[(String, String)]) + CatalogFunction(FunctionIdentifier(name, database), funcClass, Seq.empty[FunctionResource]) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index ba5d8ce0f48f8..399b7067f4a09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias} @@ -69,7 +70,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("get database should throw exception when the database does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.getDatabaseMetadata("db_that_does_not_exist") } } @@ -120,7 +121,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop database when the database does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = false, cascade = false) } catalog.dropDatabase("db_that_does_not_exist", ignoreIfNotExists = true, cascade = false) @@ -140,8 +141,8 @@ class SessionCatalogSuite extends SparkFunSuite { test("alter database should throw exception when the database does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { - catalog.alterDatabase(newDb("does_not_exist")) + intercept[NoSuchDatabaseException] { + catalog.alterDatabase(newDb("unknown_db")) } } @@ -150,7 +151,7 @@ class SessionCatalogSuite extends SparkFunSuite { assert(catalog.getCurrentDatabase == "default") catalog.setCurrentDatabase("db2") assert(catalog.getCurrentDatabase == "db2") - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.setCurrentDatabase("deebo") } catalog.createDatabase(newDb("deebo"), ignoreIfExists = false) @@ -181,14 +182,14 @@ class SessionCatalogSuite extends SparkFunSuite { test("create table when database does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) // Creating table in non-existent database should always fail - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = false) } - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.createTable(newTable("tbl1", "does_not_exist"), ignoreIfExists = true) } // Table already exists - intercept[AnalysisException] { + intercept[TableAlreadyExistsException] { catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) } catalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = true) @@ -196,20 +197,20 @@ class SessionCatalogSuite extends SparkFunSuite { test("create temp table") { val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable1 = Range(1, 10, 1, 10, Seq()) - val tempTable2 = Range(1, 20, 2, 10, Seq()) - catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false) - catalog.createTempTable("tbl2", tempTable2, overrideIfExists = false) - assert(catalog.getTempTable("tbl1") == Some(tempTable1)) - assert(catalog.getTempTable("tbl2") == Some(tempTable2)) - assert(catalog.getTempTable("tbl3") == None) + val tempTable1 = Range(1, 10, 1, 10) + val tempTable2 = Range(1, 20, 2, 10) + catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) + catalog.createTempView("tbl2", tempTable2, overrideIfExists = false) + assert(catalog.getTempView("tbl1") == Option(tempTable1)) + assert(catalog.getTempView("tbl2") == Option(tempTable2)) + assert(catalog.getTempView("tbl3").isEmpty) // Temporary table already exists - intercept[AnalysisException] { - catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false) + intercept[TempTableAlreadyExistsException] { + catalog.createTempView("tbl1", tempTable1, overrideIfExists = false) } // Temporary table already exists but we override it - catalog.createTempTable("tbl1", tempTable2, overrideIfExists = true) - assert(catalog.getTempTable("tbl1") == Some(tempTable2)) + catalog.createTempView("tbl1", tempTable2, overrideIfExists = true) + assert(catalog.getTempView("tbl1") == Option(tempTable2)) } test("drop table") { @@ -227,38 +228,38 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop table when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) // Should always throw exception when the database does not exist - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = false) } - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true) } - // If the table does not exist, we do not issue an exception. Instead, we output an error log - // message to console when ignoreIfNotExists is set to false. - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false) + intercept[NoSuchTableException] { + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false) + } catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true) } test("drop temp table") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable = Range(1, 10, 2, 10, Seq()) - sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + val tempTable = Range(1, 10, 2, 10) + sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is not specified, temp table should be dropped first sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false) - assert(sessionCatalog.getTempTable("tbl1") == None) + assert(sessionCatalog.getTempView("tbl1") == None) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If temp table does not exist, the table in the current database should be dropped sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false) assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) // If database is specified, temp tables are never dropped - sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false) - assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(sessionCatalog.getTempView("tbl1") == Some(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) } @@ -281,15 +282,20 @@ class SessionCatalogSuite extends SparkFunSuite { sessionCatalog.renameTable( TableIdentifier("tblone", Some("db2")), TableIdentifier("tblones", Some("db1"))) } + // The new table already exists + intercept[TableAlreadyExistsException] { + sessionCatalog.renameTable( + TableIdentifier("tblone", Some("db2")), TableIdentifier("table_two", Some("db2"))) + } } test("rename table when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.renameTable( TableIdentifier("tbl1", Some("unknown_db")), TableIdentifier("tbl2", Some("unknown_db"))) } - intercept[AnalysisException] { + intercept[NoSuchTableException] { catalog.renameTable( TableIdentifier("unknown_table", Some("db2")), TableIdentifier("tbl2", Some("db2"))) } @@ -298,21 +304,21 @@ class SessionCatalogSuite extends SparkFunSuite { test("rename temp table") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable = Range(1, 10, 2, 10, Seq()) - sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + val tempTable = Range(1, 10, 2, 10) + sessionCatalog.createTempView("tbl1", tempTable, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) + assert(sessionCatalog.getTempView("tbl1") == Option(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is not specified, temp table should be renamed first sessionCatalog.renameTable(TableIdentifier("tbl1"), TableIdentifier("tbl3")) - assert(sessionCatalog.getTempTable("tbl1") == None) - assert(sessionCatalog.getTempTable("tbl3") == Some(tempTable)) + assert(sessionCatalog.getTempView("tbl1").isEmpty) + assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) // If database is specified, temp tables are never renamed sessionCatalog.renameTable( TableIdentifier("tbl2", Some("db2")), TableIdentifier("tbl4", Some("db2"))) - assert(sessionCatalog.getTempTable("tbl3") == Some(tempTable)) - assert(sessionCatalog.getTempTable("tbl4") == None) + assert(sessionCatalog.getTempView("tbl3") == Option(tempTable)) + assert(sessionCatalog.getTempView("tbl4").isEmpty) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl4")) } @@ -334,10 +340,10 @@ class SessionCatalogSuite extends SparkFunSuite { test("alter table when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.alterTable(newTable("tbl1", "unknown_db")) } - intercept[AnalysisException] { + intercept[NoSuchTableException] { catalog.alterTable(newTable("unknown_table", "db2")) } } @@ -355,20 +361,31 @@ class SessionCatalogSuite extends SparkFunSuite { test("get table when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) } - intercept[AnalysisException] { + intercept[NoSuchTableException] { catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) } } + test("get option of table metadata") { + val externalCatalog = newBasicCatalog() + val catalog = new SessionCatalog(externalCatalog) + assert(catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("db2"))) + == Option(externalCatalog.getTable("db2", "tbl1"))) + assert(catalog.getTableMetadataOption(TableIdentifier("unknown_table", Some("db2"))).isEmpty) + intercept[NoSuchDatabaseException] { + catalog.getTableMetadataOption(TableIdentifier("tbl1", Some("unknown_db"))) + } + } + test("lookup table relation") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) - val tempTable1 = Range(1, 10, 1, 10, Seq()) + val tempTable1 = Range(1, 10, 1, 10) val metastoreTable1 = externalCatalog.getTable("db2", "tbl1") - sessionCatalog.createTempTable("tbl1", tempTable1, overrideIfExists = false) + sessionCatalog.createTempView("tbl1", tempTable1, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") // If we explicitly specify the database, we'll look up the relation in that database assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))) @@ -405,21 +422,45 @@ class SessionCatalogSuite extends SparkFunSuite { assert(!catalog.tableExists(TableIdentifier("tbl1", Some("db1")))) assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) // If database is explicitly specified, do not check temporary tables - val tempTable = Range(1, 10, 1, 10, Seq()) - catalog.createTempTable("tbl3", tempTable, overrideIfExists = false) + val tempTable = Range(1, 10, 1, 10) assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) // If database is not explicitly specified, check the current database catalog.setCurrentDatabase("db2") assert(catalog.tableExists(TableIdentifier("tbl1"))) assert(catalog.tableExists(TableIdentifier("tbl2"))) - assert(catalog.tableExists(TableIdentifier("tbl3"))) + + catalog.createTempView("tbl3", tempTable, overrideIfExists = false) + // tableExists should not check temp view. + assert(!catalog.tableExists(TableIdentifier("tbl3"))) + } + + test("getTempViewOrPermanentTableMetadata on temporary views") { + val catalog = new SessionCatalog(newBasicCatalog()) + val tempTable = Range(1, 10, 2, 10) + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1")) + }.getMessage + + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage + + catalog.createTempView("view1", tempTable, overrideIfExists = false) + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).identifier.table == "view1") + assert(catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier("view1")).schema(0).name == "id") + + intercept[NoSuchTableException] { + catalog.getTempViewOrPermanentTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage } test("list tables without pattern") { val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10, Seq()) - catalog.createTempTable("tbl1", tempTable, overrideIfExists = false) - catalog.createTempTable("tbl4", tempTable, overrideIfExists = false) + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTempView("tbl4", tempTable, overrideIfExists = false) assert(catalog.listTables("db1").toSet == Set(TableIdentifier("tbl1"), TableIdentifier("tbl4"))) assert(catalog.listTables("db2").toSet == @@ -427,16 +468,16 @@ class SessionCatalogSuite extends SparkFunSuite { TableIdentifier("tbl4"), TableIdentifier("tbl1", Some("db2")), TableIdentifier("tbl2", Some("db2")))) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.listTables("unknown_db") } } test("list tables with pattern") { val catalog = new SessionCatalog(newBasicCatalog()) - val tempTable = Range(1, 10, 2, 10, Seq()) - catalog.createTempTable("tbl1", tempTable, overrideIfExists = false) - catalog.createTempTable("tbl4", tempTable, overrideIfExists = false) + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("tbl1", tempTable, overrideIfExists = false) + catalog.createTempView("tbl4", tempTable, overrideIfExists = false) assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet) assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet) assert(catalog.listTables("db2", "tbl*").toSet == @@ -446,7 +487,7 @@ class SessionCatalogSuite extends SparkFunSuite { TableIdentifier("tbl2", Some("db2")))) assert(catalog.listTables("db2", "*1").toSet == Set(TableIdentifier("tbl1"), TableIdentifier("tbl1", Some("db2")))) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.listTables("unknown_db", "*") } } @@ -465,17 +506,19 @@ class SessionCatalogSuite extends SparkFunSuite { assert(catalogPartitionsEqual(externalCatalog, "mydb", "tbl", Seq(part1, part2))) // Create partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("mydb") - sessionCatalog.createPartitions(TableIdentifier("tbl"), Seq(part3), ignoreIfExists = false) - assert(catalogPartitionsEqual(externalCatalog, "mydb", "tbl", Seq(part1, part2, part3))) + sessionCatalog.createPartitions( + TableIdentifier("tbl"), Seq(partWithMixedOrder), ignoreIfExists = false) + assert(catalogPartitionsEqual( + externalCatalog, "mydb", "tbl", Seq(part1, part2, partWithMixedOrder))) } test("create partitions when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.createPartitions( - TableIdentifier("tbl1", Some("does_not_exist")), Seq(), ignoreIfExists = false) + TableIdentifier("tbl1", Some("unknown_db")), Seq(), ignoreIfExists = false) } - intercept[AnalysisException] { + intercept[NoSuchTableException] { catalog.createPartitions( TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfExists = false) } @@ -491,6 +534,31 @@ class SessionCatalogSuite extends SparkFunSuite { TableIdentifier("tbl2", Some("db2")), Seq(part1), ignoreIfExists = true) } + test("create partitions with invalid part spec") { + val catalog = new SessionCatalog(newBasicCatalog()) + var e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1, partWithLessColumns), ignoreIfExists = false) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(part1, partWithMoreColumns), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.createPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithUnknownColumns, part1), ignoreIfExists = true) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + } + test("drop partitions") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) @@ -520,13 +588,13 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop partitions when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.dropPartitions( - TableIdentifier("tbl1", Some("does_not_exist")), + TableIdentifier("tbl1", Some("unknown_db")), Seq(), ignoreIfNotExists = false) } - intercept[AnalysisException] { + intercept[NoSuchTableException] { catalog.dropPartitions( TableIdentifier("does_not_exist", Some("db2")), Seq(), @@ -548,6 +616,28 @@ class SessionCatalogSuite extends SparkFunSuite { ignoreIfNotExists = true) } + test("drop partitions with invalid partition spec") { + val catalog = new SessionCatalog(newBasicCatalog()) + var e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithMoreColumns.spec), + ignoreIfNotExists = false) + } + assert(e.getMessage.contains( + "Partition spec is invalid. The spec (a, b, c) must be contained within " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + e = intercept[AnalysisException] { + catalog.dropPartitions( + TableIdentifier("tbl2", Some("db2")), + Seq(partWithUnknownColumns.spec), + ignoreIfNotExists = false) + } + assert(e.getMessage.contains( + "Partition spec is invalid. The spec (a, unknown) must be contained within " + + "the partition spec (a, b) defined in table '`db2`.`tbl2`'")) + } + test("get partition") { val catalog = new SessionCatalog(newBasicCatalog()) assert(catalog.getPartition( @@ -566,14 +656,33 @@ class SessionCatalogSuite extends SparkFunSuite { test("get partition when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { - catalog.getPartition(TableIdentifier("tbl1", Some("does_not_exist")), part1.spec) + intercept[NoSuchDatabaseException] { + catalog.getPartition(TableIdentifier("tbl1", Some("unknown_db")), part1.spec) } - intercept[AnalysisException] { + intercept[NoSuchTableException] { catalog.getPartition(TableIdentifier("does_not_exist", Some("db2")), part1.spec) } } + test("get partition with invalid partition spec") { + val catalog = new SessionCatalog(newBasicCatalog()) + var e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithLessColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithMoreColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.getPartition(TableIdentifier("tbl1", Some("db2")), partWithUnknownColumns.spec) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + } + test("rename partitions") { val catalog = new SessionCatalog(newBasicCatalog()) val newPart1 = part1.copy(spec = Map("a" -> "100", "b" -> "101")) @@ -606,16 +715,41 @@ class SessionCatalogSuite extends SparkFunSuite { test("rename partitions when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.renamePartitions( - TableIdentifier("tbl1", Some("does_not_exist")), Seq(part1.spec), Seq(part2.spec)) + TableIdentifier("tbl1", Some("unknown_db")), Seq(part1.spec), Seq(part2.spec)) } - intercept[AnalysisException] { + intercept[NoSuchTableException] { catalog.renamePartitions( TableIdentifier("does_not_exist", Some("db2")), Seq(part1.spec), Seq(part2.spec)) } } + test("rename partition with invalid partition spec") { + val catalog = new SessionCatalog(newBasicCatalog()) + var e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithLessColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithMoreColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.renamePartitions( + TableIdentifier("tbl1", Some("db2")), + Seq(part1.spec), Seq(partWithUnknownColumns.spec)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + } + test("alter partitions") { val catalog = new SessionCatalog(newBasicCatalog()) val newLocation = newUriForDatabase() @@ -648,14 +782,33 @@ class SessionCatalogSuite extends SparkFunSuite { test("alter partitions when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { - catalog.alterPartitions(TableIdentifier("tbl1", Some("does_not_exist")), Seq(part1)) + intercept[NoSuchDatabaseException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("unknown_db")), Seq(part1)) } - intercept[AnalysisException] { + intercept[NoSuchTableException] { catalog.alterPartitions(TableIdentifier("does_not_exist", Some("db2")), Seq(part1)) } } + test("alter partition with invalid partition spec") { + val catalog = new SessionCatalog(newBasicCatalog()) + var e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithLessColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithMoreColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, b, c) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + e = intercept[AnalysisException] { + catalog.alterPartitions(TableIdentifier("tbl1", Some("db2")), Seq(partWithUnknownColumns)) + } + assert(e.getMessage.contains("Partition spec is invalid. The spec (a, unknown) must match " + + "the partition spec (a, b) defined in table '`db2`.`tbl1`'")) + } + test("list partitions") { val catalog = new SessionCatalog(newBasicCatalog()) assert(catalog.listPartitions(TableIdentifier("tbl2", Some("db2"))).toSet == Set(part1, part2)) @@ -664,6 +817,16 @@ class SessionCatalogSuite extends SparkFunSuite { assert(catalog.listPartitions(TableIdentifier("tbl2")).toSet == Set(part1, part2)) } + test("list partitions when database/table does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[NoSuchDatabaseException] { + catalog.listPartitions(TableIdentifier("tbl1", Some("unknown_db"))) + } + intercept[NoSuchTableException] { + catalog.listPartitions(TableIdentifier("does_not_exist", Some("db2"))) + } + } + // -------------------------------------------------------------------------- // Functions // -------------------------------------------------------------------------- @@ -682,7 +845,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("create function when database does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.createFunction( newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) } @@ -690,7 +853,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("create function that already exists") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[FunctionAlreadyExistsException] { catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) } catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) @@ -708,13 +871,13 @@ class SessionCatalogSuite extends SparkFunSuite { assert(catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(1)) assert(catalog.lookupFunction(FunctionIdentifier("temp2"), arguments) === Literal(3)) // Temporary function does not exist. - intercept[AnalysisException] { + intercept[NoSuchFunctionException] { catalog.lookupFunction(FunctionIdentifier("temp3"), arguments) } val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) val info3 = new ExpressionInfo("tempFunc3", "temp1") // Temporary function already exists - intercept[AnalysisException] { + intercept[TempFunctionAlreadyExistsException] { catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) } // Temporary function is overridden @@ -740,11 +903,11 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop function when database/function does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { + intercept[NoSuchDatabaseException] { catalog.dropFunction( - FunctionIdentifier("something", Some("does_not_exist")), ignoreIfNotExists = false) + FunctionIdentifier("something", Some("unknown_db")), ignoreIfNotExists = false) } - intercept[AnalysisException] { + intercept[NoSuchFunctionException] { catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) } catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) @@ -758,10 +921,10 @@ class SessionCatalogSuite extends SparkFunSuite { val arguments = Seq(Literal(1), Literal(2), Literal(3)) assert(catalog.lookupFunction(FunctionIdentifier("func1"), arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) - intercept[AnalysisException] { + intercept[NoSuchFunctionException] { catalog.lookupFunction(FunctionIdentifier("func1"), arguments) } - intercept[AnalysisException] { + intercept[NoSuchTempFunctionException] { catalog.dropTempFunction("func1", ignoreIfNotExists = false) } catalog.dropTempFunction("func1", ignoreIfNotExists = true) @@ -771,7 +934,7 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) val expected = CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, - Seq.empty[(String, String)]) + Seq.empty[FunctionResource]) assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) // Get function without explicitly specifying database catalog.setCurrentDatabase("db2") @@ -780,10 +943,10 @@ class SessionCatalogSuite extends SparkFunSuite { test("get function when database/function does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { - catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("does_not_exist"))) + intercept[NoSuchDatabaseException] { + catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("unknown_db"))) } - intercept[AnalysisException] { + intercept[NoSuchFunctionException] { catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) } } @@ -796,7 +959,7 @@ class SessionCatalogSuite extends SparkFunSuite { assert(catalog.lookupFunction( FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) - intercept[AnalysisException] { + intercept[NoSuchFunctionException] { catalog.lookupFunction(FunctionIdentifier("func1"), Seq(Literal(1), Literal(2), Literal(3))) } } @@ -811,19 +974,26 @@ class SessionCatalogSuite extends SparkFunSuite { catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) - assert(catalog.listFunctions("db1", "*").toSet == + assert(catalog.listFunctions("db1", "*").map(_._1).toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"))) - assert(catalog.listFunctions("db2", "*").toSet == + assert(catalog.listFunctions("db2", "*").map(_._1).toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"), FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func2", Some("db2")), FunctionIdentifier("not_me", Some("db2")))) - assert(catalog.listFunctions("db2", "func*").toSet == + assert(catalog.listFunctions("db2", "func*").map(_._1).toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func2", Some("db2")))) } + test("list functions when database does not exist") { + val catalog = new SessionCatalog(newBasicCatalog()) + intercept[NoSuchDatabaseException] { + catalog.listFunctions("unknown_db", "func*") + } + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 3ad0dae767be3..802397d50e85c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -41,17 +41,17 @@ class EncoderResolutionSuite extends PlanTest { // int type can be up cast to long type val attrs1 = Seq('a.string, 'b.int) - encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1)) + encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1)) // int type can be up cast to string type val attrs2 = Seq('a.int, 'b.long) - encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L)) + encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L)) } test("real type doesn't match encoder schema but they are compatible: nested product") { val encoder = ExpressionEncoder[ComplexClass] val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) - encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) } test("real type doesn't match encoder schema but they are compatible: tupled encoder") { @@ -59,7 +59,7 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[StringLongClass], ExpressionEncoder[Long]) val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) - encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) + encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } test("nullability of array type element should not fail analysis") { @@ -67,7 +67,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = 'a.array(IntegerType) :: Nil // It should pass analysis - val bound = encoder.resolve(attrs, null).bind(attrs) + val bound = encoder.resolveAndBind(attrs) // If no null values appear, it should works fine bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) @@ -84,20 +84,16 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.long, 'c.int) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct\n" + - " - Target schema: struct<_1:string,_2:bigint>") + "but failed as the number of fields does not line up.") } { val attrs = Seq('a.string) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct\n" + - " - Target schema: struct<_1:string,_2:bigint>") + "but failed as the number of fields does not line up.") } } @@ -106,26 +102,28 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct>\n" + - " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + "but failed as the number of fields does not line up.") } { val attrs = Seq('a.string, 'b.struct('x.long)) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct>\n" + - " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + "but failed as the number of fields does not line up.") } } + test("nested case class can have different number of fields from the real schema") { + val encoder = ExpressionEncoder[(String, StringIntClass)] + val attrs = Seq('a.string, 'b.struct('a.string, 'b.int, 'c.int)) + encoder.resolveAndBind(attrs) + } + test("throw exception if real type is not compatible with encoder schema") { val msg1 = intercept[AnalysisException] { - ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + ExpressionEncoder[StringIntClass].resolveAndBind(Seq('a.string, 'b.long)) }.message assert(msg1 == s""" @@ -138,7 +136,7 @@ class EncoderResolutionSuite extends PlanTest { val msg2 = intercept[AnalysisException] { val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) - ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + ExpressionEncoder[ComplexClass].resolveAndBind(Seq('a.long, 'b.struct(structType))) }.message assert(msg2 == s""" @@ -171,7 +169,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { - to.resolve(from.schema.toAttributes, null) + to.resolveAndBind(from.schema.toAttributes) } } @@ -180,7 +178,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { - intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) + intercept[AnalysisException](to.resolveAndBind(from.schema.toAttributes)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index c3b20e2cc00a2..a1f9259f139ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.encoders +import java.math.BigInteger import java.sql.{Date, Timestamp} import java.util.Arrays @@ -26,11 +27,13 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -85,6 +88,25 @@ class JavaSerializable(val value: Int) extends Serializable { } } +/** For testing UDT for a case class */ +@SQLUserDefinedType(udt = classOf[UDTForCaseClass]) +case class UDTCaseClass(uri: java.net.URI) + +class UDTForCaseClass extends UserDefinedType[UDTCaseClass] { + + override def sqlType: DataType = StringType + + override def serialize(obj: UDTCaseClass): UTF8String = { + UTF8String.fromString(obj.uri.toString) + } + + override def userClass: Class[UDTCaseClass] = classOf[UDTCaseClass] + + override def deserialize(datum: Any): UDTCaseClass = datum match { + case uri: UTF8String => UDTCaseClass(new java.net.URI(uri.toString)) + } +} + class ExpressionEncoderSuite extends PlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -108,13 +130,15 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(new java.lang.Double(-3.7), "boxed double") encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") - // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") - + encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + encodeDecodeTest(BigInt("23134123123"), "scala biginteger") + encodeDecodeTest(new BigInteger("23134123123"), "java BigInteger") encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal") encodeDecodeTest("hello", "string") encodeDecodeTest(Date.valueOf("2012-12-23"), "date") encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp") + encodeDecodeTest(Array(Timestamp.valueOf("2016-01-29 10:00:00")), "array of timestamp") encodeDecodeTest(Array[Byte](13, 21, -23), "binary") encodeDecodeTest(Seq(31, -123, 4), "seq of int") @@ -144,6 +168,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple") encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple") + encodeDecodeTest(List(1, 2), "list of int") + encodeDecodeTest(List("a", null), "list with String and null") + + encodeDecodeTest( + UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class") + // Kryo encoders encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) encodeDecodeTest(new KryoSerializable(15), "kryo object")( @@ -305,7 +335,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.defaultBinding + val boundEncoder = encoder.resolveAndBind() val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( @@ -321,12 +351,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } // Test the correct resolution of serialization / deserialization. - val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))() - val inputPlan = LocalRelation(attr) - val plan = - Project(Alias(encoder.deserializer, "obj")() :: Nil, - Project(encoder.namedExpressions, - inputPlan)) + val attr = AttributeReference("obj", encoder.deserializer.dataType)() + val plan = LocalRelation(attr).serialize[T].deserialize[T] assertAnalysisSuccess(plan) val isCorrect = (input, convertedBack) match { @@ -336,6 +362,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) case (b1: Array[_], b2: Array[_]) => Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case (left: Comparable[_], right: Comparable[_]) => + left.asInstanceOf[Comparable[Any]].compareTo(right) == 0 case _ => input == convertedBack } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index a8fa372b1ee3d..2e513ea22c151 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -127,42 +127,129 @@ class RowEncoderSuite extends SparkFunSuite { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) - test(s"encode/decode: Product") { + test("encode/decode decimal type") { val schema = new StructType() - .add("structAsProduct", - new StructType() - .add("int", IntegerType) - .add("string", StringType) - .add("double", DoubleType)) + .add("int", IntegerType) + .add("string", StringType) + .add("double", DoubleType) + .add("java_decimal", DecimalType.SYSTEM_DEFAULT) + .add("scala_decimal", DecimalType.SYSTEM_DEFAULT) + .add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() + + val javaDecimal = new java.math.BigDecimal("1234.5678") + val scalaDecimal = BigDecimal("1234.5678") + val catalystDecimal = Decimal("1234.5678") - val input: Row = Row((100, "test", 0.123)) + val input = Row(100, "test", 0.123, javaDecimal, scalaDecimal, catalystDecimal) val row = encoder.toRow(input) val convertedBack = encoder.fromRow(row) - assert(input.getStruct(0) == convertedBack.getStruct(0)) + // Decimal will be converted back to Java BigDecimal when decoding. + assert(convertedBack.getDecimal(3).compareTo(javaDecimal) == 0) + assert(convertedBack.getDecimal(4).compareTo(scalaDecimal.bigDecimal) == 0) + assert(convertedBack.getDecimal(5).compareTo(catalystDecimal.toJavaBigDecimal) == 0) } - test("encode/decode Decimal") { - val schema = new StructType() - .add("int", IntegerType) - .add("string", StringType) - .add("double", DoubleType) - .add("decimal", DecimalType.SYSTEM_DEFAULT) + test("RowEncoder should preserve decimal precision and scale") { + val schema = new StructType().add("decimal", DecimalType(10, 5), false) + val encoder = RowEncoder(schema).resolveAndBind() + val decimal = Decimal("67123.45") + val input = Row(decimal) + val row = encoder.toRow(input) - val encoder = RowEncoder(schema) + assert(row.toSeq(schema).head == decimal) + } - val input: Row = Row(100, "test", 0.123, Decimal(1234.5678)) + test("RowEncoder should preserve schema nullability") { + val schema = new StructType().add("int", IntegerType, nullable = false) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == IntegerType) + assert(encoder.serializer.head.nullable == false) + } + + test("RowEncoder should preserve nested column name") { + val schema = new StructType().add( + "struct", + new StructType() + .add("i", IntegerType, nullable = false) + .add( + "s", + new StructType().add("int", IntegerType, nullable = false), + nullable = false), + nullable = false) + val encoder = RowEncoder(schema).resolveAndBind() + assert(encoder.serializer.length == 1) + assert(encoder.serializer.head.dataType == + new StructType() + .add("i", IntegerType, nullable = false) + .add( + "s", + new StructType().add("int", IntegerType, nullable = false), + nullable = false)) + assert(encoder.serializer.head.nullable == false) + } + + test("RowEncoder should support array as the external type for ArrayType") { + val schema = new StructType() + .add("array", ArrayType(IntegerType)) + .add("nestedArray", ArrayType(ArrayType(StringType))) + .add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType)))) + val encoder = RowEncoder(schema).resolveAndBind() + val input = Row( + Array(1, 2, null), + Array(Array("abc", null), null), + Array(Seq(Array(0L, null), null), null)) val row = encoder.toRow(input) val convertedBack = encoder.fromRow(row) - // Decimal inside external row will be converted back to Java BigDecimal when decoding. - assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal - .compareTo(convertedBack.getDecimal(3)) == 0) + assert(convertedBack.getSeq(0) == Seq(1, 2, null)) + assert(convertedBack.getSeq(1) == Seq(Seq("abc", null), null)) + assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null)) + } + + test("RowEncoder should throw RuntimeException if input row object is null") { + val schema = new StructType().add("int", IntegerType) + val encoder = RowEncoder(schema) + val e = intercept[RuntimeException](encoder.toRow(null)) + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + assert(e.getMessage.contains("top level row object")) + } + + test("RowEncoder should validate external type") { + val e1 = intercept[RuntimeException] { + val schema = new StructType().add("a", IntegerType) + val encoder = RowEncoder(schema) + encoder.toRow(Row(1.toShort)) + } + assert(e1.getMessage.contains("java.lang.Short is not a valid external type")) + + val e2 = intercept[RuntimeException] { + val schema = new StructType().add("a", StringType) + val encoder = RowEncoder(schema) + encoder.toRow(Row(1)) + } + assert(e2.getMessage.contains("java.lang.Integer is not a valid external type")) + + val e3 = intercept[RuntimeException] { + val schema = new StructType().add("a", + new StructType().add("b", IntegerType).add("c", StringType)) + val encoder = RowEncoder(schema) + encoder.toRow(Row(1 -> "a")) + } + assert(e3.getMessage.contains("scala.Tuple2 is not a valid external type")) + + val e4 = intercept[RuntimeException] { + val schema = new StructType().add("a", ArrayType(TimestampType)) + val encoder = RowEncoder(schema) + encoder.toRow(Row(Array("a"))) + } + assert(e4.getMessage.contains("java.lang.String is not a valid external type")) } private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get var input: Row = null diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 72285c6a24199..069c3b37ffa96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -117,8 +117,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toDouble) + testFunc(Decimal(_)) + } + test("/ (Divide) basic") { - testNumericDataTypes { convert => + testDecimalAndDoubleType { convert => val left = Literal(convert(2)) val right = Literal(convert(1)) val dataType = left.dataType @@ -128,12 +133,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } - DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + Seq(DoubleType, DecimalType.SYSTEM_DEFAULT).foreach { tpe => checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe) } } - test("/ (Divide) for integral type") { + // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType. + // TODO: in future release, we should add a IntegerDivide to support integral types. + ignore("/ (Divide) for integral type") { checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) checkEvaluation(Divide(Literal(1), Literal(2)), 0) @@ -143,12 +150,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L) } - test("/ (Divide) for floating point") { - checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) - } - test("% (Remainder)") { testNumericDataTypes { convert => val left = Literal(convert(1)) @@ -172,6 +173,17 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper // } } + test("SPARK-17617: % (Remainder) double % double on super big double") { + val leftDouble = Literal(-5083676433652386516D) + val rightDouble = Literal(10D) + checkEvaluation(Remainder(leftDouble, rightDouble), -6.0D) + + // Float has smaller precision + val leftFloat = Literal(-5083676433652386516F) + val rightFloat = Literal(10F) + checkEvaluation(Remainder(leftFloat, rightFloat), -2.0F) + } + test("Abs") { testNumericDataTypes { convert => val input = Literal(convert(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala new file mode 100644 index 0000000000000..43367c7e14c34 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflectionSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.types.{IntegerType, StringType} + +/** A static class for testing purpose. */ +object ReflectStaticClass { + def method1(): String = "m1" + def method2(v1: Int): String = "m" + v1 + def method3(v1: java.lang.Integer): String = "m" + v1 + def method4(v1: Int, v2: String): String = "m" + v1 + v2 +} + +/** A non-static class for testing purpose. */ +class ReflectDynamicClass { + def method1(): String = "m1" +} + +/** + * Test suite for [[CallMethodViaReflection]] and its companion object. + */ +class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelper { + + import CallMethodViaReflection._ + + // Get rid of the $ so we are getting the companion object's name. + private val staticClassName = ReflectStaticClass.getClass.getName.stripSuffix("$") + private val dynamicClassName = classOf[ReflectDynamicClass].getName + + test("findMethod via reflection for static methods") { + assert(findMethod(staticClassName, "method1", Seq.empty).exists(_.getName == "method1")) + assert(findMethod(staticClassName, "method2", Seq(IntegerType)).isDefined) + assert(findMethod(staticClassName, "method3", Seq(IntegerType)).isDefined) + assert(findMethod(staticClassName, "method4", Seq(IntegerType, StringType)).isDefined) + } + + test("findMethod for a JDK library") { + assert(findMethod(classOf[java.util.UUID].getName, "randomUUID", Seq.empty).isDefined) + } + + test("class not found") { + val ret = createExpr("some-random-class", "method").checkInputDataTypes() + assert(ret.isFailure) + val errorMsg = ret.asInstanceOf[TypeCheckFailure].message + assert(errorMsg.contains("not found") && errorMsg.contains("class")) + } + + test("method not found because name does not match") { + val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes() + assert(ret.isFailure) + val errorMsg = ret.asInstanceOf[TypeCheckFailure].message + assert(errorMsg.contains("cannot find a static method")) + } + + test("method not found because there is no static method") { + val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes() + assert(ret.isFailure) + val errorMsg = ret.asInstanceOf[TypeCheckFailure].message + assert(errorMsg.contains("cannot find a static method")) + } + + test("input type checking") { + assert(CallMethodViaReflection(Seq.empty).checkInputDataTypes().isFailure) + assert(CallMethodViaReflection(Seq(Literal(staticClassName))).checkInputDataTypes().isFailure) + assert(CallMethodViaReflection( + Seq(Literal(staticClassName), Literal(1))).checkInputDataTypes().isFailure) + assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess) + } + + test("invoking methods using acceptable types") { + checkEvaluation(createExpr(staticClassName, "method1"), "m1") + checkEvaluation(createExpr(staticClassName, "method2", 2), "m2") + checkEvaluation(createExpr(staticClassName, "method3", 3), "m3") + checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four") + } + + private def createExpr(className: String, methodName: String, args: Any*) = { + CallMethodViaReflection( + Literal.create(className, StringType) +: + Literal.create(methodName, StringType) +: + args.map(Literal.apply) + ) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 43af3592070fe..b748595fc4f2d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -70,7 +70,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkNullCast(DateType, TimestampType) numericTypes.foreach(dt => checkNullCast(dt, TimestampType)) - atomicTypes.foreach(dt => checkNullCast(dt, DateType)) + checkNullCast(StringType, DateType) + checkNullCast(TimestampType, DateType) checkNullCast(StringType, CalendarIntervalType) numericTypes.foreach(dt => checkNullCast(StringType, dt)) @@ -366,7 +367,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast("2012-12-11", DoubleType), null) checkEvaluation(cast(123, IntegerType), 123) - checkEvaluation(cast(Literal.create(null, IntegerType), ShortType), null) } @@ -548,7 +548,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) assert(ret.resolved === false) - checkEvaluation(ret, Seq(null, true, false)) } { @@ -607,7 +606,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) assert(ret.resolved === false) - checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true)) @@ -714,7 +712,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) assert(ret.resolved === false) - checkEvaluation(ret, InternalRow(null, true, false)) } { @@ -730,6 +727,16 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("cast struct with a timestamp field") { + val originalSchema = new StructType().add("tsField", TimestampType, nullable = false) + // nine out of ten times I'm casting a struct, it's to normalize its fields nullability + val targetSchema = new StructType().add("tsField", TimestampType, nullable = true) + + val inp = Literal.create(InternalRow(0L), originalSchema) + val expected = InternalRow(0L) + checkEvaluation(cast(inp, targetSchema), expected) + } + test("complex casting") { val complex = Literal.create( Row( @@ -755,15 +762,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("l", LongType, nullable = true))))))) assert(ret.resolved === false) - checkEvaluation(ret, Row( - Seq(123, null, null), - Map("a" -> null, "b" -> true, "c" -> false), - Row(0L))) } test("cast between string and interval") { import org.apache.spark.unsafe.types.CalendarInterval + checkEvaluation(Cast(Literal(""), CalendarIntervalType), null) checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), new CalendarInterval(-3, 7 * CalendarInterval.MICROS_PER_HOUR)) checkEvaluation(Cast(Literal.create( @@ -790,4 +794,16 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast("abc", BooleanType), null) checkEvaluation(cast("", BooleanType), null) } + + test("SPARK-16729 type checking for casting to date type") { + assert(cast("1234", DateType).checkInputDataTypes().isSuccess) + assert(cast(new Timestamp(1), DateType).checkInputDataTypes().isSuccess) + assert(cast(false, DateType).checkInputDataTypes().isFailure) + assert(cast(1.toByte, DateType).checkInputDataTypes().isFailure) + assert(cast(1.toShort, DateType).checkInputDataTypes().isFailure) + assert(cast(1, DateType).checkInputDataTypes().isFailure) + assert(cast(1L, DateType).checkInputDataTypes().isFailure) + assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure) + assert(cast(1.0, DateType).checkInputDataTypes().isFailure) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 2082cea0f60f7..45dcfcaf23132 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -48,6 +50,18 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { futures.foreach(ThreadUtils.awaitResult(_, 10.seconds)) } + test("metrics are recorded on compile") { + val startCount1 = CodegenMetrics.METRIC_COMPILATION_TIME.getCount() + val startCount2 = CodegenMetrics.METRIC_SOURCE_CODE_SIZE.getCount() + val startCount3 = CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount() + val startCount4 = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount() + GenerateOrdering.generate(Add(Literal(123), Literal(1)).asc :: Nil) + assert(CodegenMetrics.METRIC_COMPILATION_TIME.getCount() == startCount1 + 1) + assert(CodegenMetrics.METRIC_SOURCE_CODE_SIZE.getCount() == startCount2 + 1) + assert(CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount() > startCount3) + assert(CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount() > startCount4) + } + test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) @@ -100,10 +114,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { case (expr, i) => Seq(Literal(i), expr) })) val plan = GenerateMutableProjection.generate(expressions) - val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) - val expected = Seq(new ArrayBasedMapData( - new GenericArrayData(0 until length), - new GenericArrayData(Seq.fill(length)(true)))) + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)).map { + case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m) + } + val expected = (0 until length).map((_, true)).toMap :: Nil if (!checkResult(actual, expected)) { fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") @@ -137,6 +151,19 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-14224: split wide external row creation into blocks due to JVM code size limit") { + val length = 5000 + val schema = StructType(Seq.fill(length)(StructField("int", IntegerType))) + val expressions = Seq(CreateExternalRow(Seq.fill(length)(Literal(1)), schema)) + val plan = GenerateMutableProjection.generate(expressions) + val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) + val expected = Seq(Row.fromSeq(Seq.fill(length)(1))) + + if (!checkResult(actual, expected)) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } + test("test generated safe and unsafe projection") { val schema = new StructType(Array( StructField("a", StringType, true), @@ -194,4 +221,59 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { true, InternalRow(UTF8String.fromString("\\u"))) } + + test("check compilation error doesn't occur caused by specific literal") { + // The end of comment (*/) should be escaped. + GenerateUnsafeProjection.generate( + Literal.create("*/Compilation error occurs/*", StringType) :: Nil) + + // `\u002A` is `*` and `\u002F` is `/` + // so if the end of comment consists of those characters in queries, we need to escape them. + GenerateUnsafeProjection.generate( + Literal.create("\\u002A/Compilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\u002A/Compilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\u002a/Compilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\u002a/Compilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("*\\u002FCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("*\\\\u002FCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("*\\002fCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("*\\\\002fCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\002A\\002FCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\002A\\002FCompilation error occurs/*", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\002A\\\\002FCompilation error occurs/*", StringType) :: Nil) + + // \ u002X is an invalid unicode literal so it should be escaped. + GenerateUnsafeProjection.generate( + Literal.create("\\u002X/Compilation error occurs", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\u002X/Compilation error occurs", StringType) :: Nil) + + // \ u001 is an invalid unicode literal so it should be escaped. + GenerateUnsafeProjection.generate( + Literal.create("\\u001/Compilation error occurs", StringType) :: Nil) + GenerateUnsafeProjection.generate( + Literal.create("\\\\u001/Compilation error occurs", StringType) :: Nil) + + } + + test("SPARK-17160: field names are properly escaped by GetExternalRowField") { + val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true) + GenerateUnsafeProjection.generate( + ValidateExternalType( + GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"), IntegerType) :: Nil) + } + + test("SPARK-17160: field names are properly escaped by AssertTrue") { + GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 1aae4678d6278..a5f784fdcc13c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -44,6 +44,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, ArrayType(StringType)), null) } + test("MapKeys/MapValues") { + val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapKeys(m0), Seq("a", "b")) + checkEvaluation(MapValues(m0), Seq("1", "2")) + checkEvaluation(MapKeys(m1), Seq()) + checkEvaluation(MapValues(m1), Seq()) + checkEvaluation(MapKeys(m2), null) + checkEvaluation(MapValues(m2), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 7c009a7360b6f..0c307b2b8576b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -228,4 +228,58 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkErrorMessage(structType, IntegerType, "Field name should be String Literal") checkErrorMessage(otherType, StringType, "Can't extract value from") } + + test("ensure to preserve metadata") { + val metadata = new MetadataBuilder() + .putString("key", "value") + .build() + + def checkMetadata(expr: Expression): Unit = { + assert(expr.dataType.asInstanceOf[StructType]("a").metadata === metadata) + assert(expr.dataType.asInstanceOf[StructType]("b").metadata === Metadata.empty) + } + + val a = AttributeReference("a", IntegerType, metadata = metadata)() + val b = AttributeReference("b", IntegerType)() + checkMetadata(CreateStruct(Seq(a, b))) + checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) + checkMetadata(CreateStructUnsafe(Seq(a, b))) + checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) + } + + test("StringToMap") { + val s0 = Literal("a:1,b:2,c:3") + val m0 = Map("a" -> "1", "b" -> "2", "c" -> "3") + checkEvaluation(new StringToMap(s0), m0) + + val s1 = Literal("a: ,b:2") + val m1 = Map("a" -> " ", "b" -> "2") + checkEvaluation(new StringToMap(s1), m1) + + val s2 = Literal("a=1,b=2,c=3") + val m2 = Map("a" -> "1", "b" -> "2", "c" -> "3") + checkEvaluation(StringToMap(s2, Literal(","), Literal("=")), m2) + + val s3 = Literal("") + val m3 = Map[String, String]("" -> null) + checkEvaluation(StringToMap(s3, Literal(","), Literal("=")), m3) + + val s4 = Literal("a:1_b:2_c:3") + val m4 = Map("a" -> "1", "b" -> "2", "c" -> "3") + checkEvaluation(new StringToMap(s4, Literal("_")), m4) + + // arguments checking + assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess) + assert(new StringToMap(Literal(null)).checkInputDataTypes().isFailure) + assert(new StringToMap(Literal("a:1,b:2,c:3"), Literal(null)).checkInputDataTypes().isFailure) + assert(StringToMap(Literal("a:1,b:2,c:3"), Literal(null), Literal(null)) + .checkInputDataTypes().isFailure) + assert(new StringToMap(Literal(null), Literal(null)).checkInputDataTypes().isFailure) + + assert(new StringToMap(Literal("a:1_b:2_c:3"), NonFoldableLiteral("_")) + .checkInputDataTypes().isFailure) + assert( + new StringToMap(Literal("a=1_b=2_c=3"), Literal("_"), NonFoldableLiteral("=")) + .checkInputDataTypes().isFailure) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 3c581ecdaf068..36185b8c637a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -181,6 +182,12 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) + // Type checking error + assert( + Least(Seq(Literal(1), Literal("1"))).checkInputDataTypes() == + TypeCheckFailure("The expressions should all have the same type, " + + "got LEAST(int, string).")) + DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Least, dt, 2) } @@ -227,6 +234,12 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Timestamp.valueOf("2015-07-01 10:00:00")))), Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) + // Type checking error + assert( + Greatest(Seq(Literal(1), Literal("1"))).checkInputDataTypes() == + TypeCheckFailure("The expressions should all have the same type, " + + "got GREATEST(int, string).")) + DataTypeTestUtils.ordered.foreach { dt => checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 53c66d8a754ed..6118a34d29eaa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -143,7 +143,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Seconds") { - checkEvaluation(Second(Literal.create(null, DateType)), null) + assert(Second(Literal.create(null, DateType)).resolved === false) checkEvaluation(Second(Cast(Literal(d), TimestampType)), 0) checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType)), 15) checkEvaluation(Second(Literal(ts)), 15) @@ -176,7 +176,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Hour") { - checkEvaluation(Hour(Literal.create(null, DateType)), null) + assert(Hour(Literal.create(null, DateType)).resolved === false) checkEvaluation(Hour(Cast(Literal(d), TimestampType)), 0) checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType)), 13) checkEvaluation(Hour(Literal(ts)), 13) @@ -195,7 +195,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Minute") { - checkEvaluation(Minute(Literal.create(null, DateType)), null) + assert(Minute(Literal.create(null, DateType)).resolved === false) checkEvaluation(Minute(Cast(Literal(d), TimestampType)), 0) checkEvaluation(Minute(Cast(Literal(sdf.format(d)), TimestampType)), 10) checkEvaluation(Minute(Literal(ts)), 10) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 8a9617cfbf5df..668543a28bd30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} +import org.apache.spark.sql.catalyst.util.MapData import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -52,7 +53,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { /** * Check the equality between result of expression and expected value, it will handle - * Array[Byte] and Spread[Double]. + * Array[Byte], Spread[Double], and MapData. */ protected def checkResult(result: Any, expected: Any): Boolean = { (result, expected) match { @@ -60,7 +61,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { java.util.Arrays.equals(result, expected) case (result: Double, expected: Spread[Double @unchecked]) => expected.asInstanceOf[Spread[Double]].isWithin(result) - case _ => result == expected + case (result: MapData, expected: MapData) => + result.keyArray() == expected.keyArray() && result.valueArray() == expected.valueArray() + case (result: Double, expected: Double) => + if (expected.isNaN) result.isNaN else expected == result + case (result: Float, expected: Float) => + if (expected.isNaN) result.isNaN else expected == result + case _ => + result == expected } } @@ -124,9 +132,13 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - + // SPARK-16489 Explicitly doing code generation twice so code gen will fail if + // some expression is reusing variable names across different instances. + // This behavior is tested in ExpressionEvalHelperSuite. val plan = generateProject( - GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + UnsafeProjection.create( + Alias(expression, s"Optimized($expression)1")() :: + Alias(expression, s"Optimized($expression)2")() :: Nil), expression) val unsafeRow = plan(inputRow) @@ -134,13 +146,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected == null) { if (!unsafeRow.isNullAt(0)) { - val expectedRow = InternalRow(expected) + val expectedRow = InternalRow(expected, expected) fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") } } else { - val lit = InternalRow(expected) - val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit) + val lit = InternalRow(expected, expected) + val expectedRow = + UnsafeProjection.create(Array(expression.dataType, expression.dataType)).apply(lit) if (unsafeRow != expectedRow) { fail("Incorrect evaluation in unsafe mode: " + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala new file mode 100644 index 0000000000000..64b65e2070ed6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.{DataType, IntegerType} + +/** + * A test suite for testing [[ExpressionEvalHelper]]. + * + * Yes, we should write test cases for test harnesses, in case + * they have behaviors that are easy to break. + */ +class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("SPARK-16489 checkEvaluation should fail if expression reuses variable names") { + val e = intercept[RuntimeException] { checkEvaluation(BadCodegenExpression(), 10) } + assert(e.getMessage.contains("some_variable")) + } +} + +/** + * An expression that generates bad code (variable name "some_variable" is not unique across + * instances of the expression. + */ +case class BadCodegenExpression() extends LeafExpression { + override def nullable: Boolean = false + override def eval(input: InternalRow): Any = 10 + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ev.copy(code = + s""" + |int some_variable = 11; + |int ${ev.value} = 10; + """.stripMargin) + } + override def dataType: DataType = IntegerType +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala new file mode 100644 index 0000000000000..e29dfa41f1cc5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + private def checkTuple(actual: Expression, expected: Seq[InternalRow]): Unit = { + assert(actual.eval(null).asInstanceOf[TraversableOnce[InternalRow]].toSeq === expected) + } + + private final val empty_array = CreateArray(Seq.empty) + private final val int_array = CreateArray(Seq(1, 2, 3).map(Literal(_))) + private final val str_array = CreateArray(Seq("a", "b", "c").map(Literal(_))) + + test("explode") { + val int_correct_answer = Seq(create_row(1), create_row(2), create_row(3)) + val str_correct_answer = Seq(create_row("a"), create_row("b"), create_row("c")) + + checkTuple(Explode(empty_array), Seq.empty) + checkTuple(Explode(int_array), int_correct_answer) + checkTuple(Explode(str_array), str_correct_answer) + } + + test("posexplode") { + val int_correct_answer = Seq(create_row(0, 1), create_row(1, 2), create_row(2, 3)) + val str_correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c")) + + checkTuple(PosExplode(CreateArray(Seq.empty)), Seq.empty) + checkTuple(PosExplode(int_array), int_correct_answer) + checkTuple(PosExplode(str_array), str_correct_answer) + } + + test("inline") { + val correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c")) + + checkTuple( + Inline(Literal.create(Array(), ArrayType(new StructType().add("id", LongType)))), + Seq.empty) + + checkTuple( + Inline(CreateArray(Seq( + CreateStruct(Seq(Literal(0), Literal("a"))), + CreateStruct(Seq(Literal(1), Literal("b"))), + CreateStruct(Seq(Literal(2), Literal("c"))) + ))), + correct_answer) + } + + test("stack") { + checkTuple(Stack(Seq(1, 1).map(Literal(_))), Seq(create_row(1))) + checkTuple(Stack(Seq(1, 1, 2).map(Literal(_))), Seq(create_row(1, 2))) + checkTuple(Stack(Seq(2, 1, 2).map(Literal(_))), Seq(create_row(1), create_row(2))) + checkTuple(Stack(Seq(2, 1, 2, 3).map(Literal(_))), Seq(create_row(1, 2), create_row(3, null))) + checkTuple(Stack(Seq(3, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3).map(create_row(_))) + checkTuple(Stack(Seq(4, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3, null).map(create_row(_))) + + checkTuple( + Stack(Seq(3, 1, 1.0, "a", 2, 2.0, "b", 3, 3.0, "c").map(Literal(_))), + Seq(create_row(1, 1.0, "a"), create_row(2, 2.0, "b"), create_row(3, 3.0, "c"))) + + assert(Stack(Seq(Literal(1))).checkInputDataTypes().isFailure) + assert(Stack(Seq(Literal(1.0))).checkInputDataTypes().isFailure) + assert(Stack(Seq(Literal(1), Literal(1), Literal(1.0))).checkInputDataTypes().isSuccess) + assert(Stack(Seq(Literal(2), Literal(1), Literal(1.0))).checkInputDataTypes().isFailure) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala new file mode 100644 index 0000000000000..0f1264c7c3269 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class MapDataSuite extends SparkFunSuite { + + test("inequality tests") { + def u(str: String): UTF8String = UTF8String.fromString(str) + + // test data + val testMap1 = Map(u("key1") -> 1) + val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) + val testMap3 = Map(u("key1") -> 1) + val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) + + // ArrayBasedMapData + val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) + val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) + val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) + val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) + assert(testArrayMap1 !== testArrayMap3) + assert(testArrayMap2 !== testArrayMap4) + + // UnsafeMapData + val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) + val row = new GenericMutableRow(1) + def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { + row.update(0, map) + val unsafeRow = unsafeConverter.apply(row) + unsafeRow.getMap(0).copy + } + assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) + assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index ace6c15dc8418..712fe35f477b3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -77,6 +77,21 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-16602 Nvl should support numeric-string cases") { + val intLit = Literal.create(1, IntegerType) + val doubleLit = Literal.create(2.2, DoubleType) + val stringLit = Literal.create("c", StringType) + val nullLit = Literal.create(null, NullType) + + assert(Nvl(intLit, doubleLit).replaceForTypeCoercion().dataType == DoubleType) + assert(Nvl(intLit, stringLit).replaceForTypeCoercion().dataType == StringType) + assert(Nvl(stringLit, doubleLit).replaceForTypeCoercion().dataType == StringType) + + assert(Nvl(nullLit, intLit).replaceForTypeCoercion().dataType == IntegerType) + assert(Nvl(doubleLit, nullLit).replaceForTypeCoercion().dataType == DoubleType) + assert(Nvl(nullLit, stringLit).replaceForTypeCoercion().dataType == StringType) + } + test("AtLeastNNonNulls") { val mix = Seq(Literal("x"), Literal.create(null, StringType), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala new file mode 100644 index 0000000000000..b6263e77c141f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} + +class ObjectExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + test("MapObjects should make copies of unsafe-backed data") { + // test UnsafeRow-backed data + val structEncoder = ExpressionEncoder[Array[(java.lang.Integer, java.lang.Integer)]]() + val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4)))) + val structExpected = new GenericArrayData( + Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4)))) + checkEvalutionWithUnsafeProjection( + structEncoder.serializer.head, structExpected, structInputRow) + + // test UnsafeArray-backed data + val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]() + val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4)))) + val arrayExpected = new GenericArrayData( + Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4)))) + checkEvalutionWithUnsafeProjection( + arrayEncoder.serializer.head, arrayExpected, arrayInputRow) + + // test UnsafeMap-backed data + val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]() + val mapInputRow = InternalRow.fromSeq(Seq(Array( + Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400)))) + val mapExpected = new GenericArrayData(Seq( + new ArrayBasedMapData( + new GenericArrayData(Array(1, 2)), + new GenericArrayData(Array(100, 200))), + new ArrayBasedMapData( + new GenericArrayData(Array(3, 4)), + new GenericArrayData(Array(300, 400))))) + checkEvalutionWithUnsafeProjection( + mapEncoder.serializer.head, mapExpected, mapInputRow) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index b190d3a00dfb8..8cc2ab46c0c85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import scala.math._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, LazilyGeneratedOrdering} import org.apache.spark.sql.types._ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -44,9 +45,14 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper { case Ascending => signum(expected) case Descending => -1 * signum(expected) } + + val kryo = new KryoSerializer(new SparkConf).newInstance() val intOrdering = new InterpretedOrdering(sortOrder :: Nil) - val genOrdering = GenerateOrdering.generate(sortOrder :: Nil) - Seq(intOrdering, genOrdering).foreach { ordering => + val genOrdering = new LazilyGeneratedOrdering(sortOrder :: Nil) + val kryoIntOrdering = kryo.deserialize[InterpretedOrdering](kryo.serialize(intOrdering)) + val kryoGenOrdering = kryo.deserialize[LazilyGeneratedOrdering](kryo.serialize(genOrdering)) + + Seq(intOrdering, genOrdering, kryoIntOrdering, kryoGenOrdering).foreach { ordering => assert(ordering.compare(rowA, rowA) === 0) assert(ordering.compare(rowB, rowB) === 0) assert(signum(ordering.compare(rowA, rowB)) === expectedCompareResult) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala new file mode 100644 index 0000000000000..5299549e7b4da --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.StringType + +/** + * Unit tests for regular expression (regexp) related SQL expressions. + */ +class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("LIKE literal Regular Expression") { + checkEvaluation(Literal.create(null, StringType).like("a"), null) + checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) + checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create("a", StringType)), true) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create(null, StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create("a", StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null) + + checkEvaluation("abdef" like "abdef", true) + checkEvaluation("a_%b" like "a\\__b", true) + checkEvaluation("addb" like "a_%b", true) + checkEvaluation("addb" like "a\\__b", false) + checkEvaluation("addb" like "a%\\%b", false) + checkEvaluation("a_%b" like "a%\\%b", true) + checkEvaluation("addb" like "a%", true) + checkEvaluation("addb" like "**", false) + checkEvaluation("abc" like "a%", true) + checkEvaluation("abc" like "b%", false) + checkEvaluation("abc" like "bc%", false) + checkEvaluation("a\nb" like "a_b", true) + checkEvaluation("ab" like "a%b", true) + checkEvaluation("a\nb" like "a%b", true) + } + + test("LIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) + } + + test("RLIKE literal Regular Expression") { + checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) + checkEvaluation("abdef" rlike Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) + checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true) + checkEvaluation("abdef" rlike NonFoldableLiteral.create(null, StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create("abdef", StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null) + + checkEvaluation("abdef" rlike "abdef", true) + checkEvaluation("abbbbc" rlike "a.*c", true) + + checkEvaluation("fofo" rlike "^fo", true) + checkEvaluation("fo\no" rlike "^fo\no$", true) + checkEvaluation("Bn" rlike "^Ba*n", true) + checkEvaluation("afofo" rlike "fo", true) + checkEvaluation("afofo" rlike "^fo", false) + checkEvaluation("Baan" rlike "^Ba?n", false) + checkEvaluation("axe" rlike "pi|apa", false) + checkEvaluation("pip" rlike "^(pi)*$", false) + + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + checkEvaluation("abc" rlike "^ab", true) + checkEvaluation("abc" rlike "^bc", false) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike "**") + } + } + + test("RLIKE Non-literal Regular Expression") { + val regEx = 'a.string.at(0) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) + + intercept[java.util.regex.PatternSyntaxException] { + evaluate("abbbbc" rlike regEx, create_row("**")) + } + } + + + test("RegexReplace") { + val row1 = create_row("100-200", "(\\d+)", "num") + val row2 = create_row("100-200", "(\\d+)", "###") + val row3 = create_row("100-200", "(-)", "###") + val row4 = create_row(null, "(\\d+)", "###") + val row5 = create_row("100-200", null, "###") + val row6 = create_row("100-200", "(-)", null) + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.string.at(2) + + val expr = RegExpReplace(s, p, r) + checkEvaluation(expr, "num-num", row1) + checkEvaluation(expr, "###-###", row2) + checkEvaluation(expr, "100###200", row3) + checkEvaluation(expr, null, row4) + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) + + val nonNullExpr = RegExpReplace(Literal("100-200"), Literal("(\\d+)"), Literal("num")) + checkEvaluation(nonNullExpr, "num-num", row1) + } + + test("RegexExtract") { + val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1) + val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) + val row3 = create_row("100-200", "(\\d+).*", 1) + val row4 = create_row("100-200", "([a-z])", 1) + val row5 = create_row(null, "([a-z])", 1) + val row6 = create_row("100-200", null, 1) + val row7 = create_row("100-200", "([a-z])", null) + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.int.at(2) + + val expr = RegExpExtract(s, p, r) + checkEvaluation(expr, "100", row1) + checkEvaluation(expr, "200", row2) + checkEvaluation(expr, "100", row3) + checkEvaluation(expr, "", row4) // will not match anything, empty string get + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) + checkEvaluation(expr, null, row7) + + val expr1 = new RegExpExtract(s, p) + checkEvaluation(expr1, "100", row1) + + val nonNullExpr = RegExpExtract(Literal("100-200"), Literal("(\\d+)-(\\d+)"), Literal(1)) + checkEvaluation(nonNullExpr, "100", row1) + } + + test("SPLIT") { + val s1 = 'a.string.at(0) + val s2 = 'b.string.at(1) + val row1 = create_row("aa2bb3cc", "[1-9]+") + val row2 = create_row(null, "[1-9]+") + val row3 = create_row("aa2bb3cc", null) + + checkEvaluation( + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) + checkEvaluation( + StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) + checkEvaluation(StringSplit(s1, s2), null, row2) + checkEvaluation(StringSplit(s1, s2), null, row3) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala new file mode 100644 index 0000000000000..7e45028653e36 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.types.{IntegerType, StringType} + +class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("basic") { + val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil) + checkEvaluation(intUdf, 2) + + val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil) + checkEvaluation(stringUdf, "ax") + } + + test("better error message for NPE") { + val udf = ScalaUDF( + (s: String) => s.toLowerCase, + StringType, + Literal.create(null, StringType) :: Nil) + + val e1 = intercept[SparkException](udf.eval()) + assert(e1.getMessage.contains("Failed to execute user defined function")) + + val e2 = intercept[SparkException] { + checkEvalutionWithUnsafeProjection(udf, null) + } + assert(e2.getMessage.contains("Failed to execute user defined function")) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 2cf8ca7000edc..fdb9fa31f09c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -75,6 +75,29 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } + test("elt") { + def testElt(result: String, n: java.lang.Integer, args: String*): Unit = { + checkEvaluation( + Elt(Literal.create(n, IntegerType) +: args.map(Literal.create(_, StringType))), + result) + } + + testElt("hello", 1, "hello", "world") + testElt(null, 1, null, "world") + testElt(null, null, "hello", "world") + + // Invalid ranages + testElt(null, 3, "hello", "world") + testElt(null, 0, "hello", "world") + testElt(null, -1, "hello", "world") + + // type checking + assert(Elt(Seq.empty).checkInputDataTypes().isFailure) + assert(Elt(Seq(Literal(1))).checkInputDataTypes().isFailure) + assert(Elt(Seq(Literal(1), Literal("A"))).checkInputDataTypes().isSuccess) + assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure) + } + test("StringComparison") { val row = create_row("abc", null) val c1 = 'a.string.at(0) @@ -231,102 +254,6 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { SubstringIndex(Literal("www||apache||org"), Literal( "||"), Literal(2)), "www||apache") } - test("LIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType).like("a"), null) - checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) - checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) - checkEvaluation( - Literal.create("a", StringType).like(NonFoldableLiteral.create("a", StringType)), true) - checkEvaluation( - Literal.create("a", StringType).like(NonFoldableLiteral.create(null, StringType)), null) - checkEvaluation( - Literal.create(null, StringType).like(NonFoldableLiteral.create("a", StringType)), null) - checkEvaluation( - Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null) - - checkEvaluation("abdef" like "abdef", true) - checkEvaluation("a_%b" like "a\\__b", true) - checkEvaluation("addb" like "a_%b", true) - checkEvaluation("addb" like "a\\__b", false) - checkEvaluation("addb" like "a%\\%b", false) - checkEvaluation("a_%b" like "a%\\%b", true) - checkEvaluation("addb" like "a%", true) - checkEvaluation("addb" like "**", false) - checkEvaluation("abc" like "a%", true) - checkEvaluation("abc" like "b%", false) - checkEvaluation("abc" like "bc%", false) - checkEvaluation("a\nb" like "a_b", true) - checkEvaluation("ab" like "a%b", true) - checkEvaluation("a\nb" like "a%b", true) - } - - test("LIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, create_row(null)) - checkEvaluation("abdef" like regEx, true, create_row("abdef")) - checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) - checkEvaluation("addb" like regEx, true, create_row("a_%b")) - checkEvaluation("addb" like regEx, false, create_row("a\\__b")) - checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) - checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) - checkEvaluation("addb" like regEx, true, create_row("a%")) - checkEvaluation("addb" like regEx, false, create_row("**")) - checkEvaluation("abc" like regEx, true, create_row("a%")) - checkEvaluation("abc" like regEx, false, create_row("b%")) - checkEvaluation("abc" like regEx, false, create_row("bc%")) - checkEvaluation("a\nb" like regEx, true, create_row("a_b")) - checkEvaluation("ab" like regEx, true, create_row("a%b")) - checkEvaluation("a\nb" like regEx, true, create_row("a%b")) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) - } - - test("RLIKE literal Regular Expression") { - checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) - checkEvaluation("abdef" rlike Literal.create(null, StringType), null) - checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) - checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true) - checkEvaluation("abdef" rlike NonFoldableLiteral.create(null, StringType), null) - checkEvaluation( - Literal.create(null, StringType) rlike NonFoldableLiteral.create("abdef", StringType), null) - checkEvaluation( - Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null) - - checkEvaluation("abdef" rlike "abdef", true) - checkEvaluation("abbbbc" rlike "a.*c", true) - - checkEvaluation("fofo" rlike "^fo", true) - checkEvaluation("fo\no" rlike "^fo\no$", true) - checkEvaluation("Bn" rlike "^Ba*n", true) - checkEvaluation("afofo" rlike "fo", true) - checkEvaluation("afofo" rlike "^fo", false) - checkEvaluation("Baan" rlike "^Ba?n", false) - checkEvaluation("axe" rlike "pi|apa", false) - checkEvaluation("pip" rlike "^(pi)*$", false) - - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - checkEvaluation("abc" rlike "^ab", true) - checkEvaluation("abc" rlike "^bc", false) - - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike "**") - } - } - - test("RLIKE Non-literal Regular Expression") { - val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) - checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) - checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) - checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) - checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) - - intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, create_row("**")) - } - } - test("ascii for string") { val a = 'a.string.at(0) checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) @@ -508,16 +435,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s2 = 'b.string.at(1) val s3 = 'c.string.at(2) val s4 = 'd.int.at(3) - val row1 = create_row("aaads", "aa", "zz", 1) - val row2 = create_row(null, "aa", "zz", 0) - val row3 = create_row("aaads", null, "zz", 0) - val row4 = create_row(null, null, null, 0) + val row1 = create_row("aaads", "aa", "zz", 2) + val row2 = create_row(null, "aa", "zz", 1) + val row3 = create_row("aaads", null, "zz", 1) + val row4 = create_row(null, null, null, 1) checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) - checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) - checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 0, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(0)), 0, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 1, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 2, row1) + checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(3)), 0, row1) checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1) - checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1) + checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 2), 0, row1) checkEvaluation(new StringLocate(s2, s1), 1, row1) checkEvaluation(StringLocate(s2, s1, s4), 2, row1) @@ -587,68 +516,6 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringSpace(s1), null, row2) } - test("RegexReplace") { - val row1 = create_row("100-200", "(\\d+)", "num") - val row2 = create_row("100-200", "(\\d+)", "###") - val row3 = create_row("100-200", "(-)", "###") - val row4 = create_row(null, "(\\d+)", "###") - val row5 = create_row("100-200", null, "###") - val row6 = create_row("100-200", "(-)", null) - - val s = 's.string.at(0) - val p = 'p.string.at(1) - val r = 'r.string.at(2) - - val expr = RegExpReplace(s, p, r) - checkEvaluation(expr, "num-num", row1) - checkEvaluation(expr, "###-###", row2) - checkEvaluation(expr, "100###200", row3) - checkEvaluation(expr, null, row4) - checkEvaluation(expr, null, row5) - checkEvaluation(expr, null, row6) - } - - test("RegexExtract") { - val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1) - val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) - val row3 = create_row("100-200", "(\\d+).*", 1) - val row4 = create_row("100-200", "([a-z])", 1) - val row5 = create_row(null, "([a-z])", 1) - val row6 = create_row("100-200", null, 1) - val row7 = create_row("100-200", "([a-z])", null) - - val s = 's.string.at(0) - val p = 'p.string.at(1) - val r = 'r.int.at(2) - - val expr = RegExpExtract(s, p, r) - checkEvaluation(expr, "100", row1) - checkEvaluation(expr, "200", row2) - checkEvaluation(expr, "100", row3) - checkEvaluation(expr, "", row4) // will not match anything, empty string get - checkEvaluation(expr, null, row5) - checkEvaluation(expr, null, row6) - checkEvaluation(expr, null, row7) - - val expr1 = new RegExpExtract(s, p) - checkEvaluation(expr1, "100", row1) - } - - test("SPLIT") { - val s1 = 'a.string.at(0) - val s2 = 'b.string.at(1) - val row1 = create_row("aa2bb3cc", "[1-9]+") - val row2 = create_row(null, "[1-9]+") - val row3 = create_row("aa2bb3cc", null) - - checkEvaluation( - StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) - checkEvaluation( - StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) - checkEvaluation(StringSplit(s1, s2), null, row2) - checkEvaluation(StringSplit(s1, s2), null, row3) - } - test("length for string / binary") { val a = 'a.string.at(0) val b = 'b.binary.at(0) @@ -688,7 +555,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal(Decimal(123123324123L) * Decimal(123123.21234d)), Literal(4)), "15,159,339,180,002,773.2778") checkEvaluation(FormatNumber(Literal.create(null, IntegerType), Literal(3)), null) - checkEvaluation(FormatNumber(Literal.create(null, NullType), Literal(3)), null) + assert(FormatNumber(Literal.create(null, NullType), Literal(3)).resolved === false) } test("find in set") { @@ -700,4 +567,78 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0) checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0) } + + test("ParseUrl") { + def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = { + checkEvaluation( + ParseUrl(Seq(Literal(urlStr), Literal(partToExtract))), expected) + } + def checkParseUrlWithKey( + expected: String, + urlStr: String, + partToExtract: String, + key: String): Unit = { + checkEvaluation( + ParseUrl(Seq(Literal(urlStr), Literal(partToExtract), Literal(key))), expected) + } + + checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST") + checkParseUrl("/path", "http://spark.apache.org/path?query=1", "PATH") + checkParseUrl("query=1", "http://spark.apache.org/path?query=1", "QUERY") + checkParseUrl("Ref", "http://spark.apache.org/path?query=1#Ref", "REF") + checkParseUrl("http", "http://spark.apache.org/path?query=1", "PROTOCOL") + checkParseUrl("/path?query=1", "http://spark.apache.org/path?query=1", "FILE") + checkParseUrl("spark.apache.org:8080", "http://spark.apache.org:8080/path?query=1", "AUTHORITY") + checkParseUrl("userinfo", "http://userinfo@spark.apache.org/path?query=1", "USERINFO") + checkParseUrlWithKey("1", "http://spark.apache.org/path?query=1", "QUERY", "query") + + // Null checking + checkParseUrl(null, null, "HOST") + checkParseUrl(null, "http://spark.apache.org/path?query=1", null) + checkParseUrl(null, null, null) + checkParseUrl(null, "test", "HOST") + checkParseUrl(null, "http://spark.apache.org/path?query=1", "NO") + checkParseUrl(null, "http://spark.apache.org/path?query=1", "USERINFO") + checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "HOST", "query") + checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "quer") + checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", null) + checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "") + + // exceptional cases + intercept[java.util.regex.PatternSyntaxException] { + evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"), + Literal("QUERY"), Literal("???")))) + } + + // arguments checking + assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes().isFailure) + assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), Literal("4"))) + .checkInputDataTypes().isFailure) + assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes().isFailure) + assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes().isFailure) + assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal(3))).checkInputDataTypes().isFailure) + } + + test("Sentences") { + val nullString = Literal.create(null, StringType) + checkEvaluation(Sentences(nullString, nullString, nullString), null) + checkEvaluation(Sentences(nullString, nullString), null) + checkEvaluation(Sentences(nullString), null) + checkEvaluation(Sentences(Literal.create(null, NullType)), null) + checkEvaluation(Sentences("", nullString, nullString), Seq.empty) + checkEvaluation(Sentences("", nullString), Seq.empty) + checkEvaluation(Sentences(""), Seq.empty) + + val answer = Seq( + Seq("Hi", "there"), + Seq("The", "price", "was"), + Seq("But", "not", "now")) + + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now."), answer) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), answer) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"), + answer) + checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXX", "YYY"), + answer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 90e97d718a9fc..1e39b24fe8770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -21,8 +21,12 @@ import org.apache.spark.sql.types.IntegerType class SubexpressionEliminationSuite extends SparkFunSuite { test("Semantic equals and hash") { - val id = ExprId(1) val a: AttributeReference = AttributeReference("name", IntegerType)() + val id = { + // Make sure we use a "ExprId" different from "a.exprId" + val _id = ExprId(1) + if (a.exprId == _id) ExprId(2) else _id + } val b1 = a.withName("name2").withExprId(id) val b2 = a.withExprId(id) val b3 = a.withQualifier(Some("qualifierName")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index b82cf8d1693e2..d6c8fcf291842 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -108,4 +108,16 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with Priva TimeWindow.invokePrivate(parseExpression(Rand(123))) } } + + test("SPARK-16837: TimeWindow.apply equivalent to TimeWindow constructor") { + val slideLength = "1 second" + for (windowLength <- Seq("10 second", "1 minute", "2 hours")) { + val applyValue = TimeWindow(Literal(10L), windowLength, slideLength, "0 seconds") + val constructed = new TimeWindow(Literal(10L), + Literal(windowLength), + Literal(slideLength), + Literal("0 seconds")) + assert(applyValue == constructed) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala new file mode 100644 index 0000000000000..614f24db0aafb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DeclarativeAggregateEvaluator.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection + +/** + * Evaluator for a [[DeclarativeAggregate]]. + */ +case class DeclarativeAggregateEvaluator(function: DeclarativeAggregate, input: Seq[Attribute]) { + + lazy val initializer = GenerateSafeProjection.generate(function.initialValues) + + lazy val updater = GenerateSafeProjection.generate( + function.updateExpressions, + function.aggBufferAttributes ++ input) + + lazy val merger = GenerateSafeProjection.generate( + function.mergeExpressions, + function.aggBufferAttributes ++ function.inputAggBufferAttributes) + + lazy val evaluator = GenerateSafeProjection.generate( + function.evaluateExpression :: Nil, + function.aggBufferAttributes) + + def initialize(): InternalRow = initializer.apply(InternalRow.empty).copy() + + def update(values: InternalRow*): InternalRow = { + val joiner = new JoinedRow + val buffer = values.foldLeft(initialize()) { (buffer, input) => + updater(joiner(buffer, input)) + } + buffer.copy() + } + + def merge(buffers: InternalRow*): InternalRow = { + val joiner = new JoinedRow + val buffer = buffers.foldLeft(initialize()) { (left, right) => + merger(joiner(left, right)) + } + buffer.copy() + } + + def eval(buffer: InternalRow): InternalRow = evaluator(buffer).copy() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala new file mode 100644 index 0000000000000..ba36bc074e154 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.IntegerType + +class LastTestSuite extends SparkFunSuite { + val input = AttributeReference("input", IntegerType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input)) + val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(null, false)) + } + + test("update") { + val result = evaluator.update( + InternalRow(1), + InternalRow(9), + InternalRow(-1)) + assert(result === InternalRow(-1, true)) + } + + test("update - ignore nulls") { + val result1 = evaluatorIgnoreNulls.update( + InternalRow(null), + InternalRow(9), + InternalRow(null)) + assert(result1 === InternalRow(9, true)) + + val result2 = evaluatorIgnoreNulls.update( + InternalRow(null), + InternalRow(null)) + assert(result2 === InternalRow(null, false)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(null, false)) + + // Single merge + val p1 = evaluator.update(InternalRow(1), InternalRow(-99)) + assert(evaluator.merge(p1) === p1) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(2), InternalRow(10)) + assert(evaluator.merge(p1, p2) === p2) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p1, p0, p2) === p2) + assert(evaluator.merge(p2, p1, p0) === p1) + } + + test("merge - ignore nulls") { + // Multi merges + val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null)) + val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null)) + assert(evaluatorIgnoreNulls.merge(p1, p2) === p1) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(null, true)) === InternalRow(null)) + assert(evaluator.eval(InternalRow(null, false)) === InternalRow(null)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(null)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(1), InternalRow(-99)) + assert(evaluator.eval(p1) === InternalRow(-99)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(2), InternalRow(10)) + val m1 = evaluator.merge(p1, p0, p2) + assert(evaluator.eval(m1) === InternalRow(10)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(-99)) + } + + test("eval - ignore nulls") { + // Update - Merge - Eval + val p1 = evaluatorIgnoreNulls.update(InternalRow(1), InternalRow(null)) + val p2 = evaluatorIgnoreNulls.update(InternalRow(null), InternalRow(null)) + val m1 = evaluatorIgnoreNulls.merge(p1, p2) + assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala new file mode 100644 index 0000000000000..c7c386b5b838a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class BufferHolderSuite extends SparkFunSuite { + + test("SPARK-16071 Check the size limit to avoid integer overflow") { + var e = intercept[UnsupportedOperationException] { + new BufferHolder(new UnsafeRow(Int.MaxValue / 8)) + } + assert(e.getMessage.contains("too many fields")) + + val holder = new BufferHolder(new UnsafeRow(1000)) + holder.reset() + holder.grow(1000) + e = intercept[UnsupportedOperationException] { + holder.grow(Integer.MAX_VALUE) + } + assert(e.getMessage.contains("exceeds size limitation")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index f57b82bb96399..bc5a8f078234a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -23,22 +23,42 @@ import org.apache.spark.sql.catalyst.util._ class CodeFormatterSuite extends SparkFunSuite { - def testCase(name: String)(input: String)(expected: String): Unit = { + def testCase(name: String)( + input: String, comment: Map[String, String] = Map.empty)(expected: String): Unit = { test(name) { - if (CodeFormatter.format(input).trim !== expected.trim) { + val sourceCode = new CodeAndComment(input.trim, comment) + if (CodeFormatter.format(sourceCode).trim !== expected.trim) { fail( s""" |== FAIL: Formatted code doesn't match === - |${sideBySide(CodeFormatter.format(input).trim, expected.trim).mkString("\n")} + |${sideBySide(CodeFormatter.format(sourceCode).trim, expected.trim).mkString("\n")} """.stripMargin) } } } + test("removing overlapping comments") { + val code = new CodeAndComment( + """/*project_c4*/ + |/*project_c3*/ + |/*project_c2*/ + """.stripMargin, + Map( + "project_c4" -> "// (((input[0, bigint, false] + 1) + 2) + 3))", + "project_c3" -> "// ((input[0, bigint, false] + 1) + 2)", + "project_c2" -> "// (input[0, bigint, false] + 1)" + )) + + val reducedCode = CodeFormatter.stripOverlappingComments(code) + assert(reducedCode.body === "/*project_c4*/") + } + testCase("basic example") { - """class A { + """ + |class A { |blahblah; - |}""".stripMargin + |} + """.stripMargin }{ """ |/* 001 */ class A { @@ -48,11 +68,13 @@ class CodeFormatterSuite extends SparkFunSuite { } testCase("nested example") { - """class A { + """ + |class A { | if (c) { |duh; |} - |}""".stripMargin + |} + """.stripMargin } { """ |/* 001 */ class A { @@ -64,9 +86,11 @@ class CodeFormatterSuite extends SparkFunSuite { } testCase("single line") { - """class A { + """ + |class A { | if (c) {duh;} - |}""".stripMargin + |} + """.stripMargin }{ """ |/* 001 */ class A { @@ -76,9 +100,11 @@ class CodeFormatterSuite extends SparkFunSuite { } testCase("if else on the same line") { - """class A { + """ + |class A { | if (c) {duh;} else {boo;} - |}""".stripMargin + |} + """.stripMargin }{ """ |/* 001 */ class A { @@ -88,10 +114,12 @@ class CodeFormatterSuite extends SparkFunSuite { } testCase("function calls") { - """foo( + """ + |foo( |a, |b, - |c)""".stripMargin + |c) + """.stripMargin }{ """ |/* 001 */ foo( @@ -102,10 +130,12 @@ class CodeFormatterSuite extends SparkFunSuite { } testCase("single line comments") { - """// This is a comment about class A { { { ( ( + """ + |// This is a comment about class A { { { ( ( |class A { |class body; - |}""".stripMargin + |} + """.stripMargin }{ """ |/* 001 */ // This is a comment about class A { { { ( ( @@ -116,10 +146,12 @@ class CodeFormatterSuite extends SparkFunSuite { } testCase("single line comments /* */ ") { - """/** This is a comment about class A { { { ( ( */ + """ + |/** This is a comment about class A { { { ( ( */ |class A { |class body; - |}""".stripMargin + |} + """.stripMargin }{ """ |/* 001 */ /** This is a comment about class A { { { ( ( */ @@ -130,12 +162,14 @@ class CodeFormatterSuite extends SparkFunSuite { } testCase("multi-line comments") { - """ /* This is a comment about + """ + | /* This is a comment about |class A { |class body; ...*/ |class A { |class body; - |}""".stripMargin + |} + """.stripMargin }{ """ |/* 001 */ /* This is a comment about @@ -146,4 +180,57 @@ class CodeFormatterSuite extends SparkFunSuite { |/* 006 */ } """.stripMargin } + + testCase("reduce empty lines") { + CodeFormatter.stripExtraNewLines( + """ + |class A { + | + | + | /* + | * multi + | * line + | * comment + | */ + | + | class body; + | + | + | if (c) {duh;} + | else {boo;} + |} + """.stripMargin.trim) + }{ + """ + |/* 001 */ class A { + |/* 002 */ /* + |/* 003 */ * multi + |/* 004 */ * line + |/* 005 */ * comment + |/* 006 */ */ + |/* 007 */ class body; + |/* 008 */ + |/* 009 */ if (c) {duh;} + |/* 010 */ else {boo;} + |/* 011 */ } + """.stripMargin + } + + testCase("comment place holder")( + """ + |/*c1*/ + |class A + |/*c2*/ + |class B + |/*c1*//*c2*/ + """.stripMargin, Map("c1" -> "/*abc*/", "c2" -> "/*xyz*/") + ) { + """ + |/* 001 */ /*abc*/ + |/* 002 */ class A + |/* 003 */ /*xyz*/ + |/* 004 */ class B + |/* 005 */ /*abc*//*xyz*/ + """.stripMargin + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/ReusableStringReaderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/ReusableStringReaderSuite.scala new file mode 100644 index 0000000000000..e06d209c474be --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/ReusableStringReaderSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.xml + +import java.io.IOException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.xml.UDFXPathUtil.ReusableStringReader + +/** + * Unit tests for [[UDFXPathUtil.ReusableStringReader]]. + * + * Loosely based on Hive's TestReusableStringReader.java. + */ +class ReusableStringReaderSuite extends SparkFunSuite { + + private val fox = "Quick brown fox jumps over the lazy dog." + + test("empty reader") { + val reader = new ReusableStringReader + + intercept[IOException] { + reader.read() + } + + intercept[IOException] { + reader.ready() + } + + reader.close() + } + + test("mark reset") { + val reader = new ReusableStringReader + + if (reader.markSupported()) { + reader.asInstanceOf[ReusableStringReader].set(fox) + assert(reader.ready()) + + val cc = new Array[Char](6) + var read = reader.read(cc) + assert(read == 6) + assert("Quick " == new String(cc)) + + reader.mark(100) + + read = reader.read(cc) + assert(read == 6) + assert("brown " == new String(cc)) + + reader.reset() + read = reader.read(cc) + assert(read == 6) + assert("brown " == new String(cc)) + } + reader.close() + } + + test("skip") { + val reader = new ReusableStringReader + reader.asInstanceOf[ReusableStringReader].set(fox) + + // skip entire the data: + var skipped = reader.skip(fox.length() + 1) + assert(fox.length() == skipped) + assert(-1 == reader.read()) + + reader.asInstanceOf[ReusableStringReader].set(fox) // reset the data + val cc = new Array[Char](6) + var read = reader.read(cc) + assert(read == 6) + assert("Quick " == new String(cc)) + + // skip some piece of data: + skipped = reader.skip(30) + assert(skipped == 30) + read = reader.read(cc) + assert(read == 4) + assert("dog." == new String(cc, 0, read)) + + // skip when already at EOF: + skipped = reader.skip(300) + assert(skipped == 0, skipped) + assert(reader.read() == -1) + + reader.close() + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala new file mode 100644 index 0000000000000..c4cde7091154b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.xml + +import javax.xml.xpath.XPathConstants.STRING + +import org.w3c.dom.Node +import org.w3c.dom.NodeList + +import org.apache.spark.SparkFunSuite + +/** + * Unit tests for [[UDFXPathUtil]]. Loosely based on Hive's TestUDFXPathUtil.java. + */ +class UDFXPathUtilSuite extends SparkFunSuite { + + private lazy val util = new UDFXPathUtil + + test("illegal arguments") { + // null args + assert(util.eval(null, "a/text()", STRING) == null) + assert(util.eval("b1b2b3c1c2", null, STRING) == null) + assert( + util.eval("b1b2b3c1c2", "a/text()", null) == null) + + // empty String args + assert(util.eval("", "a/text()", STRING) == null) + assert(util.eval("b1b2b3c1c2", "", STRING) == null) + + // wrong expression: + intercept[RuntimeException] { + util.eval("b1b2b3c1c2", "a/text(", STRING) + } + } + + test("generic eval") { + val ret = + util.eval("b1b2b3c1c2", "a/c[2]/text()", STRING) + assert(ret == "c2") + } + + test("boolean eval") { + var ret = + util.evalBoolean("truefalseb3c1c2", "a/b[1]/text()") + assert(ret == true) + + ret = util.evalBoolean("truefalseb3c1c2", "a/b[4]") + assert(ret == false) + } + + test("string eval") { + var ret = + util.evalString("truefalseb3c1c2", "a/b[3]/text()") + assert(ret == "b3") + + ret = + util.evalString("truefalseb3c1c2", "a/b[4]/text()") + assert(ret == "") + + ret = util.evalString( + "trueFALSEb3c1c2", "a/b[2]/@k") + assert(ret == "foo") + } + + test("number eval") { + var ret = + util.evalNumber("truefalseb3c1-77", "a/c[2]") + assert(ret == -77.0d) + + ret = util.evalNumber( + "trueFALSEb3c1c2", "a/b[2]/@k") + assert(ret.isNaN) + } + + test("node eval") { + val ret = util.evalNode("truefalseb3c1-77", "a/c[2]") + assert(ret != null && ret.isInstanceOf[Node]) + } + + test("node list eval") { + val ret = util.evalNodeList("truefalseb3c1-77", "a/*") + assert(ret != null && ret.isInstanceOf[NodeList]) + assert(ret.asInstanceOf[NodeList].getLength == 5) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala new file mode 100644 index 0000000000000..bfa18a0919e45 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.xml + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.StringType + +/** + * Test suite for various xpath functions. + */ +class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + + /** A helper function that tests null and error behaviors for xpath expressions. */ + private def testNullAndErrorBehavior[T <: AnyRef](testExpr: (String, String, T) => Unit): Unit = { + // null input should lead to null output + testExpr("b1b2", null, null.asInstanceOf[T]) + testExpr(null, "a", null.asInstanceOf[T]) + testExpr(null, null, null.asInstanceOf[T]) + + // Empty input should also lead to null output + testExpr("", "a", null.asInstanceOf[T]) + testExpr("", "", null.asInstanceOf[T]) + testExpr("", "", null.asInstanceOf[T]) + + // Test error message for invalid XML document + val e1 = intercept[RuntimeException] { testExpr("/a>", "a", null.asInstanceOf[T]) } + assert(e1.getCause.getMessage.contains("Invalid XML document") && + e1.getCause.getMessage.contains("/a>")) + + // Test error message for invalid xpath + val e2 = intercept[RuntimeException] { testExpr("", "!#$", null.asInstanceOf[T]) } + assert(e2.getCause.getMessage.contains("Invalid XPath") && + e2.getCause.getMessage.contains("!#$")) + } + + test("xpath_boolean") { + def testExpr[T](xml: String, path: String, expected: java.lang.Boolean): Unit = { + checkEvaluation( + XPathBoolean(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("b", "a/b", true) + testExpr("b", "a/c", false) + testExpr("b", "a/b = \"b\"", true) + testExpr("b", "a/b = \"c\"", false) + testExpr("10", "a/b < 10", false) + testExpr("10", "a/b = 10", true) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_short") { + def testExpr[T](xml: String, path: String, expected: java.lang.Short): Unit = { + checkEvaluation( + XPathShort(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("this is not a number", "a", 0.toShort) + testExpr("try a boolean", "a = 10", 0.toShort) + testExpr( + "10000248", + "sum(a/b[@class=\"odd\"])", + 10004.toShort) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_int") { + def testExpr[T](xml: String, path: String, expected: java.lang.Integer): Unit = { + checkEvaluation( + XPathInt(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("this is not a number", "a", 0) + testExpr("try a boolean", "a = 10", 0) + testExpr( + "100000248", + "sum(a/b[@class=\"odd\"])", + 100004) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_long") { + def testExpr[T](xml: String, path: String, expected: java.lang.Long): Unit = { + checkEvaluation( + XPathLong(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("this is not a number", "a", 0L) + testExpr("try a boolean", "a = 10", 0L) + testExpr( + "9000000000248", + "sum(a/b[@class=\"odd\"])", + 9000000004L) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_float") { + def testExpr[T](xml: String, path: String, expected: java.lang.Float): Unit = { + checkEvaluation( + XPathFloat(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("this is not a number", "a", Float.NaN) + testExpr("try a boolean", "a = 10", 0.0F) + testExpr("1248", + "sum(a/b[@class=\"odd\"])", + 5.0F) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_double") { + def testExpr[T](xml: String, path: String, expected: java.lang.Double): Unit = { + checkEvaluation( + XPathDouble(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("this is not a number", "a", Double.NaN) + testExpr("try a boolean", "a = 10", 0.0) + testExpr("1248", + "sum(a/b[@class=\"odd\"])", + 5.0) + + testNullAndErrorBehavior(testExpr) + } + + test("xpath_string") { + def testExpr[T](xml: String, path: String, expected: String): Unit = { + checkEvaluation( + XPathString(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("bbcc", "a", "bbcc") + testExpr("bbcc", "a/b", "bb") + testExpr("bbcc", "a/c", "cc") + testExpr("bbcc", "a/d", "") + testExpr("b1b2", "//b", "b1") + testExpr("b1b2", "a/b[1]", "b1") + testExpr("b1b2", "a/b[@id='b_2']", "b2") + + testNullAndErrorBehavior(testExpr) + } + + test("xpath") { + def testExpr[T](xml: String, path: String, expected: Seq[String]): Unit = { + checkEvaluation( + XPathList(Literal.create(xml, StringType), Literal.create(path, StringType)), + expected) + } + + testExpr("b1b2b3c1c2", "a/text()", Seq.empty[String]) + testExpr("b1b2b3c1c2", "a/*/text()", + Seq("b1", "b2", "b3", "c1", "c2")) + testExpr("b1b2b3c1c2", "a/b/text()", + Seq("b1", "b2", "b3")) + testExpr("b1b2b3c1c2", "a/c/text()", Seq("c1", "c2")) + testExpr("b1b2b3c1c2", + "a/*[@class='bb']/text()", Seq("b1", "c1")) + + testNullAndErrorBehavior(testExpr) + } + + test("accept only literal path") { + def testExpr(exprCtor: (Expression, Expression) => Expression): Unit = { + // Validate that literal (technically this is foldable) paths are supported + val litPath = exprCtor(Literal("abcd"), Concat(Literal("/") :: Literal("/") :: Nil)) + assert(litPath.checkInputDataTypes().isSuccess) + + // Validate that non-foldable paths are not supported. + val nonLitPath = exprCtor(Literal("abcd"), NonFoldableLiteral("/")) + assert(nonLitPath.checkInputDataTypes().isFailure) + } + + testExpr(XPathBoolean) + testExpr(XPathShort) + testExpr(XPathInt) + testExpr(XPathLong) + testExpr(XPathFloat) + testExpr(XPathDouble) + testExpr(XPathString) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index e458eb8a1d362..aecf59aee6a9b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal @@ -25,20 +28,48 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { + val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), - RemoveLiteralFromGroupExpressions) :: Nil + FoldablePropagation, + RemoveLiteralFromGroupExpressions, + RemoveRepetitionFromGroupExpressions) :: Nil } + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + test("remove literals in grouping expression") { - val input = LocalRelation('a.int, 'b.int) + val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b)) + val optimized = Optimize.execute(analyzer.execute(query)) + val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not remove all grouping expressions if they are all literals") { + val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) + val optimized = Optimize.execute(analyzer.execute(query)) + val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) - val query = - input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b)) - val optimized = Optimize.execute(query) + comparePlans(optimized, correctAnswer) + } + + test("Remove aliased literals") { + val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) + val optimized = Optimize.execute(analyzer.execute(query)) + val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze + + comparePlans(optimized, correctAnswer) + } - val correctAnswer = input.groupBy('a)(sum('b)) + test("remove repetition in grouping expression") { + val input = LocalRelation('a.int, 'b.int, 'c.int) + val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) + val optimized = Optimize.execute(analyzer.execute(query)) + val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index 7cd038570bbdf..a313681eeb8f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -36,7 +36,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper NullPropagation, ConstantFolding, BooleanSimplification, - BinaryComparisonSimplification, + SimplifyBinaryComparison, PruneFilters) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index b5664a5e699e8..589607e3ad5cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -346,5 +346,20 @@ class ColumnPruningSuite extends PlanTest { comparePlans(Optimize.execute(plan1.analyze), correctAnswer1) } + test("push project down into sample") { + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val x = testRelation.subquery('x) + + val query1 = Sample(0.0, 0.6, false, 11L, x)().select('a) + val optimized1 = Optimize.execute(query1.analyze) + val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a))() + comparePlans(optimized1, expected1.analyze) + + val query2 = Sample(0.0, 0.6, false, 11L, x)().select('a as 'aa) + val optimized2 = Optimize.execute(query2.analyze) + val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a))().select('a as 'aa) + comparePlans(optimized2, expected2.analyze) + } + // todo: add more tests for column pruning } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 8c92ad82ac5be..7402918c1bbba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -34,7 +34,8 @@ class EliminateSortsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("Eliminate Sorts", Once, + Batch("Eliminate Sorts", FixedPoint(10), + FoldablePropagation, EliminateSorts) :: Nil } @@ -69,4 +70,16 @@ class EliminateSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("Remove no-op alias") { + val x = testRelation + + val query = x.select('a.as('x), Year(CurrentDate()).as('y), 'b) + .orderBy('x.asc, 'y.asc, 'b.desc) + val optimized = Optimize.execute(analyzer.execute(query)) + val correctAnswer = analyzer.execute( + x.select('a.as('x), Year(CurrentDate()).as('y), 'b).orderBy('x.asc, 'b.desc)) + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index fcc14a803beaa..019f132d94cb2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -34,7 +34,6 @@ class FilterPushdownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Filter Pushdown", FixedPoint(10), - SamplePushDown, CombineFilters, PushDownPredicate, BooleanSimplification, @@ -94,6 +93,30 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-16164: Filter pushdown should keep the ordering in the logical plan") { + val originalQuery = + testRelation + .where('a === 1) + .select('a, 'b) + .where('b === 1) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where('a === 1 && 'b === 1) + .select('a, 'b) + .analyze + + // We can not use comparePlans here because it normalized the plan. + assert(optimized == correctAnswer) + } + + test("SPARK-16994: filter should not be pushed through limit") { + val originalQuery = testRelation.limit(10).where('a === 1).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + test("can't push without rewrite") { val originalQuery = testRelation @@ -513,14 +536,14 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { testRelationWithArrayType .generate(Explode('c_arr), true, false, Some("arr")) - .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6)) + .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('c > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where('b >= 5) .generate(Explode('c_arr), true, false, Some("arr")) - .where('a + Rand(10).as("rnd") > 6) + .where('a + Rand(10).as("rnd") > 6 && 'c > 6) .analyze } @@ -567,22 +590,6 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } - test("push project and filter down into sample") { - val x = testRelation.subquery('x) - val originalQuery = - Sample(0.0, 0.6, false, 11L, x)().select('a) - - val originalQueryAnalyzed = - EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(originalQuery)) - - val optimized = Optimize.execute(originalQueryAnalyzed) - - val correctAnswer = - Sample(0.0, 0.6, false, 11L, x.select('a))() - - comparePlans(optimized, correctAnswer.analyze) - } - test("aggregate: push down filter when filter on group by expression") { val originalQuery = testRelation .groupBy('a)('a, count('b) as 'c) @@ -680,6 +687,23 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-17712: aggregate: don't push down filters that are data-independent") { + val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) + .select('a, 'b) + .groupBy('a)(count('a)) + .where(false) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .groupBy('a)(count('a)) + .where(false) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("broadcast hint") { val originalQuery = BroadcastHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) @@ -697,14 +721,14 @@ class FilterPushdownSuite extends PlanTest { val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val originalQuery = Union(Seq(testRelation, testRelation2)) - .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3 && 'c > 5L) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Union(Seq( testRelation.where('a === 2L), testRelation2.where('d === 2L))) - .where('b + Rand(10).as("rnd") === 3) + .where('b + Rand(10).as("rnd") === 3 && 'c > 5L) .analyze comparePlans(optimized, correctAnswer) @@ -980,4 +1004,18 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } + + test("join condition pushdown: deterministic and non-deterministic") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + // Verify that all conditions preceding the first non-deterministic condition are pushed down + // by the optimizer and others are not. + val originalQuery = x.join(y, condition = Some("x.a".attr === 5 && "y.a".attr === 5 && + "x.a".attr === Rand(10) && "y.b".attr === 5)) + val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5), + condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5)) + + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala new file mode 100644 index 0000000000000..355b3fc4aa637 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class FoldablePropagationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Foldable Propagation", FixedPoint(20), + FoldablePropagation) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int) + + test("Propagate from subquery") { + val query = OneRowRelation + .select(Literal(1).as('a), Literal(2).as('b)) + .subquery('T) + .select('a, 'b) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = OneRowRelation + .select(Literal(1).as('a), Literal(2).as('b)) + .subquery('T) + .select(Literal(1).as('a), Literal(2).as('b)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Propagate to select clause") { + val query = testRelation + .select('a.as('x), "str".as('y), 'b.as('z)) + .select('x, 'y, 'z) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation + .select('a.as('x), "str".as('y), 'b.as('z)) + .select('x, "str".as('y), 'z).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Propagate to where clause") { + val query = testRelation + .select("str".as('y)) + .where('y === "str" && "str" === 'y) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation + .select("str".as('y)) + .where("str".as('y) === "str" && "str" === "str".as('y)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Propagate to orderBy clause") { + val query = testRelation + .select('a.as('x), Year(CurrentDate()).as('y), 'b) + .orderBy('x.asc, 'y.asc, 'b.desc) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation + .select('a.as('x), Year(CurrentDate()).as('y), 'b) + .orderBy('x.asc, SortOrder(Year(CurrentDate()), Ascending), 'b.desc).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Propagate to groupBy clause") { + val query = testRelation + .select('a.as('x), Year(CurrentDate()).as('y), 'b) + .groupBy('x, 'y, 'b)(sum('x), avg('y).as('AVG), count('b)) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation + .select('a.as('x), Year(CurrentDate()).as('y), 'b) + .groupBy('x, Year(CurrentDate()).as('y), 'b)(sum('x), avg(Year(CurrentDate())).as('AVG), + count('b)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Propagate in a complex query") { + val query = testRelation + .select('a.as('x), Year(CurrentDate()).as('y), 'b) + .where('x > 1 && 'y === 2016 && 'b > 1) + .groupBy('x, 'y, 'b)(sum('x), avg('y).as('AVG), count('b)) + .orderBy('x.asc, 'AVG.asc) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation + .select('a.as('x), Year(CurrentDate()).as('y), 'b) + .where('x > 1 && Year(CurrentDate()).as('y) === 2016 && 'b > 1) + .groupBy('x, Year(CurrentDate()).as("y"), 'b)(sum('x), avg(Year(CurrentDate())).as('AVG), + count('b)) + .orderBy('x.asc, 'AVG.asc).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Propagate in subqueries of Union queries") { + val query = Union( + Seq( + testRelation.select(Literal(1).as('x), 'a).select('x + 'a), + testRelation.select(Literal(2).as('x), 'a).select('x + 'a))) + .select('x) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = Union( + Seq( + testRelation.select(Literal(1).as('x), 'a).select((Literal(1).as('x) + 'a).as("(x + a)")), + testRelation.select(Literal(2).as('x), 'a).select((Literal(2).as('x) + 'a).as("(x + a)")))) + .select('x).analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e7fdd5a6202b6..9f57f66a2ea20 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -27,9 +27,12 @@ import org.apache.spark.sql.catalyst.rules._ class InferFiltersFromConstraintsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("InferFilters", FixedPoint(5), InferFiltersFromConstraints) :: - Batch("PredicatePushdown", FixedPoint(5), PushPredicateThroughJoin) :: - Batch("CombineFilters", FixedPoint(5), CombineFilters) :: Nil + val batches = + Batch("InferAndPushDownFilters", FixedPoint(100), + PushPredicateThroughJoin, + PushDownPredicate, + InferFiltersFromConstraints, + CombineFilters) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -120,4 +123,82 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } + + test("inner join with alias: alias contains multiple attributes") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))) + .select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2.where(IsNotNull('a)), Inner, + Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("inner join with alias: alias contains single attributes") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, 'b.as('d)).as("t") + .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b) + .select('a, 'b.as('d)).as("t") + .join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner, + Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("inner join with alias: don't generate constraints for recursive functions") { + val t1 = testRelation.subquery('t1) + val t2 = testRelation.subquery('t2) + + val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2, Inner, + Some("t.a".attr === "t2.a".attr + && "t.d".attr === "t2.a".attr + && "t.int_col".attr === "t2.a".attr)) + .analyze + val correctAnswer = t1 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a)) + && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) + && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b + && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) + && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) + && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) + .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") + .join(t2 + .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a + && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner, + Some("t.a".attr === "t2.a".attr + && "t.d".attr === "t2.a".attr + && "t.int_col".attr === "t2.a".attr + && Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("generate correct filters for alias that don't produce recursive constraints") { + val t1 = testRelation.subquery('t1) + + val originalQuery = t1.select('a.as('x), 'b.as('y)).where('x === 1 && 'x === 'y).analyze + val correctAnswer = + t1.where('a === 1 && 'b === 1 && 'a === 'b && IsNotNull('a) && IsNotNull('b)) + .select('a.as('x), 'b.as('y)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index 5e6e54dc741f3..c168a55e40c54 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Coalesce, IsNotNull} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -30,7 +31,7 @@ class OuterJoinEliminationSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Outer Join Elimination", Once, - OuterJoinElimination, + EliminateOuterJoin, PushPredicateThroughJoin) :: Nil } @@ -192,4 +193,42 @@ class OuterJoinEliminationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("joins: no outer join elimination if the filter is not NULL eliminated") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where(Coalesce("y.e".attr :: "x.a".attr :: Nil)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val left = testRelation + val right = testRelation1 + val correctAnswer = + left.join(right, FullOuter, Option("a".attr === "d".attr)) + .where(Coalesce("e".attr :: "a".attr :: Nil)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: no outer join elimination if the filter's constraints are not NULL eliminated") { + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + + val originalQuery = + x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) + .where(IsNotNull(Coalesce("y.e".attr :: "x.a".attr :: Nil))) + + val optimized = Optimize.execute(originalQuery.analyze) + + val left = testRelation + val right = testRelation1 + val correctAnswer = + left.join(right, FullOuter, Option("a".attr === "d".attr)) + .where(IsNotNull(Coalesce("e".attr :: "a".attr :: Nil))).analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala new file mode 100644 index 0000000000000..0b973c3b659cf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.types.{IntegerType, StringType} + +class RewriteDistinctAggregatesSuite extends PlanTest { + val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) + + val nullInt = Literal(null, IntegerType) + val nullString = Literal(null, StringType) + val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) + + private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { + case Aggregate(_, _, Aggregate(_, _, _: Expand)) => + case _ => fail(s"Plan is not rewritten:\n$rewrite") + } + + test("single distinct group") { + val input = testRelation + .groupBy('a)(countDistinct('e)) + .analyze + val rewrite = RewriteDistinctAggregates(input) + comparePlans(input, rewrite) + } + + test("single distinct group with partial aggregates") { + val input = testRelation + .groupBy('a, 'd)( + countDistinct('e, 'c).as('agg1), + max('b).as('agg2)) + .analyze + val rewrite = RewriteDistinctAggregates(input) + comparePlans(input, rewrite) + } + + test("single distinct group with non-partial aggregates") { + val input = testRelation + .groupBy('a, 'd)( + countDistinct('e, 'c).as('agg1), + CollectSet('b).toAggregateExpression().as('agg2)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups") { + val input = testRelation + .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups with partial aggregates") { + val input = testRelation + .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups with non-partial aggregates") { + val input = testRelation + .groupBy('a)( + countDistinct('b, 'c), + countDistinct('d), + CollectSet('b).toAggregateExpression()) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 83ca9d5ec9f80..dab45a6b166be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -31,7 +31,7 @@ class SetOperationSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Union Pushdown", Once, CombineUnions, - SetOperationPushDown, + PushThroughSetOperations, PruneFilters) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 1fae64e3bc6b1..63d87bfb6d24d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -57,7 +57,9 @@ class TypedFilterOptimizationSuite extends PlanTest { comparePlans(optimized, expected) } - test("embed deserializer in filter condition if there is only one filter") { + // TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules + // for typed filters. + ignore("embed deserializer in typed filter condition if there is only one filter") { val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 @@ -67,7 +69,7 @@ class TypedFilterOptimizationSuite extends PlanTest { val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer) val condition = callFunction(f, BooleanType, deserializer) - val expected = input.where(condition).analyze + val expected = input.where(condition).select('_1.as("_1"), '_2.as("_2")).analyze comparePlans(optimized, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 40782978a7299..020fb16f6f3d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -class CatalystQlDataTypeParserSuite extends SparkFunSuite { +class DataTypeParserSuite extends SparkFunSuite { def parse(sql: String): DataType = CatalystSqlParser.parseDataType(sql) @@ -133,4 +133,8 @@ class CatalystQlDataTypeParserSuite extends SparkFunSuite { checkDataType( "struct<`x``y` int>", (new StructType).add("x`y", IntegerType)) + + // Use SQL keywords. + checkDataType("struct", + (new StructType).add("end", LongType).add("select", IntegerType).add("from", StringType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index 6da3eaea3d850..f67697eb86c26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -39,8 +39,6 @@ class ErrorParserSuite extends SparkFunSuite { } test("no viable input") { - intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") - intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index e73592c7afa28..15ebc196e0f2b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -292,6 +292,10 @@ class ExpressionParserSuite extends PlanTest { test("case when") { assertEqual("case a when 1 then b when 2 then c else d end", CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case (a or b) when true then c when false then d else e end", + CaseKeyWhen('a || 'b, Seq(true, 'c, false, 'd, 'e))) + assertEqual("case 'a'='a' when true then 1 end", + CaseKeyWhen("a" === "a", Seq(true, 1))) assertEqual("case when a = 1 then b when a = 2 then c else d end", CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) } @@ -375,20 +379,30 @@ class ExpressionParserSuite extends PlanTest { // Tiny Int Literal assertEqual("10Y", Literal(10.toByte)) - intercept("-1000Y") + intercept("-1000Y", s"does not fit in range [${Byte.MinValue}, ${Byte.MaxValue}]") // Small Int Literal assertEqual("10S", Literal(10.toShort)) - intercept("40000S") + intercept("40000S", s"does not fit in range [${Short.MinValue}, ${Short.MaxValue}]") // Long Int Literal assertEqual("10L", Literal(10L)) - intercept("78732472347982492793712334L") + intercept("78732472347982492793712334L", + s"does not fit in range [${Long.MinValue}, ${Long.MaxValue}]") // Double Literal assertEqual("10.0D", Literal(10.0D)) + intercept("-1.8E308D", s"does not fit in range") + intercept("1.8E308D", s"does not fit in range") // TODO we need to figure out if we should throw an exception here! assertEqual("1E309", Literal(Double.PositiveInfinity)) + + // BigDecimal Literal + assertEqual("90912830918230182310293801923652346786BD", + Literal(BigDecimal("90912830918230182310293801923652346786").underlying())) + assertEqual("123.0E-28BD", Literal(BigDecimal("123.0E-28").underlying())) + assertEqual("123.08BD", Literal(BigDecimal("123.08").underlying())) + intercept("1.20E-38BD", "DecimalType can only support precision up to 38") } test("strings") { @@ -502,4 +516,31 @@ class ExpressionParserSuite extends PlanTest { assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") } + + test("current date/timestamp braceless expressions") { + assertEqual("current_date", CurrentDate()) + assertEqual("current_timestamp", CurrentTimestamp()) + } + + test("SPARK-17364, fully qualified column name which starts with number") { + assertEqual("123_", UnresolvedAttribute("123_")) + assertEqual("1a.123_", UnresolvedAttribute("1a.123_")) + // ".123" should not be treated as token of type DECIMAL_VALUE + assertEqual("a.123A", UnresolvedAttribute("a.123A")) + // ".123E3" should not be treated as token of type SCIENTIFIC_DECIMAL_VALUE + assertEqual("a.123E3_column", UnresolvedAttribute("a.123E3_column")) + // ".123D" should not be treated as token of type DOUBLE_LITERAL + assertEqual("a.123D_column", UnresolvedAttribute("a.123D_column")) + // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL + assertEqual("a.123BD_column", UnresolvedAttribute("a.123BD_column")) + } + + test("SPARK-17832 function identifier contains backtick") { + val complexName = FunctionIdentifier("`ba`r", Some("`fo`o")) + assertEqual(complexName.quotedString, UnresolvedAttribute("`fo`o.`ba`r")) + intercept(complexName.unquotedString, "mismatched input") + // Function identifier contains countious backticks should be treated correctly. + val complexName2 = FunctionIdentifier("ba``r", Some("fo``o")) + assertEqual(complexName2.quotedString, UnresolvedAttribute("fo``o.ba``r")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 5e896a33bd3bf..ac9c494a0b433 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -14,27 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedInlineTable, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType - +/** + * Parser test cases for rules defined in [[CatalystSqlParser]] / [[AstBuilder]]. + * + * There is also SparkSqlParserSuite in sql/core module for parser rules defined in sql/core module. + */ class PlanParserSuite extends PlanTest { import CatalystSqlParser._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ - def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { comparePlans(parsePlan(sqlCommand), plan) } - def intercept(sqlCommand: String, messages: String*): Unit = { + private def intercept(sqlCommand: String, messages: String*): Unit = { val e = intercept[ParseException](parsePlan(sqlCommand)) messages.foreach { message => assert(e.message.contains(message)) @@ -48,19 +52,9 @@ class PlanParserSuite extends PlanTest { assertEqual("SELECT * FROM a", plan) } - test("show functions") { - assertEqual("show functions", ShowFunctions(None, None)) - assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) - assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) - assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) - intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") - } - - test("describe function") { - assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) - assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) - assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) - assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) + test("explain") { + intercept("EXPLAIN logical SELECT 1", "Unsupported SQL statement") + intercept("EXPLAIN formatted SELECT 1", "Unsupported SQL statement") } test("set operations") { @@ -99,7 +93,7 @@ class PlanParserSuite extends PlanTest { "cte2" -> table("cte1").select(star()))) intercept( "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", - "Name 'cte1' is used for multiple common table expressions") + "Found duplicate keys 'cte1'") } test("simple select query") { @@ -112,6 +106,7 @@ class PlanParserSuite extends PlanTest { table("db", "c").select('a, 'b).where('x < 1)) assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) + assertEqual("select from tbl", OneRowRelation.select('from.as("tbl"))) } test("reverse select query") { @@ -187,14 +182,12 @@ class PlanParserSuite extends PlanTest { // Single inserts assertEqual(s"insert overwrite table s $sql", insert(Map.empty, overwrite = true)) - assertEqual(s"insert overwrite table s if not exists $sql", - insert(Map.empty, overwrite = true, ifNotExists = true)) + assertEqual(s"insert overwrite table s partition (e = 1) if not exists $sql", + insert(Map("e" -> Option("1")), overwrite = true, ifNotExists = true)) assertEqual(s"insert into s $sql", insert(Map.empty)) assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", insert(Map("c" -> Option("d"), "e" -> Option("1")))) - assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", - insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) // Multi insert val plan2 = table("t").where('x > 5).select(star()) @@ -205,6 +198,13 @@ class PlanParserSuite extends PlanTest { table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) } + test ("insert with if not exists") { + val sql = "select * from t" + intercept(s"insert overwrite table s partition (e = 1, x) if not exists $sql", + "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [x]") + intercept[ParseException](parsePlan(s"insert overwrite table s if not exists $sql")) + } + test("aggregation") { val sql = "select a, b, sum(c) as c from d group by a, b" @@ -263,11 +263,14 @@ class PlanParserSuite extends PlanTest { } test("lateral view") { + val explode = UnresolvedGenerator(FunctionIdentifier("explode"), Seq('x)) + val jsonTuple = UnresolvedGenerator(FunctionIdentifier("json_tuple"), Seq('x, 'y)) + // Single lateral view assertEqual( "select * from t lateral view explode(x) expl as x", table("t") - .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + .generate(explode, join = true, outer = false, Some("expl"), Seq("x")) .select(star())) // Multiple lateral views @@ -277,12 +280,12 @@ class PlanParserSuite extends PlanTest { |lateral view explode(x) expl |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, table("t") - .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) - .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) + .generate(explode, join = true, outer = false, Some("expl"), Seq.empty) + .generate(jsonTuple, join = true, outer = true, Some("jtup"), Seq("q", "z")) .select(star())) // Multi-Insert lateral views. - val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + val from = table("t1").generate(explode, join = true, outer = false, Some("expl"), Seq("x")) assertEqual( """from t1 |lateral view explode(x) expl as x @@ -294,7 +297,7 @@ class PlanParserSuite extends PlanTest { |where s < 10 """.stripMargin, Union(from - .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) + .generate(jsonTuple, join = true, outer = false, Some("jtup"), Seq("q", "z")) .select(star()) .insertInto("t2"), from.where('s < 10).select(star()).insertInto("t3"))) @@ -354,10 +357,54 @@ class PlanParserSuite extends PlanTest { test("left anti join", LeftAnti, testExistence) test("anti join", LeftAnti, testExistence) + // Test natural cross join + intercept("select * from a natural cross join b") + + // Test natural join with a condition + intercept("select * from a natural join b on a.id = b.id") + // Test multiple consecutive joins assertEqual( "select * from a join b join c right join d", table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + + // SPARK-17296 + assertEqual( + "select * from t1 cross join t2 join t3 on t3.id = t1.id join t4 on t4.id = t1.id", + table("t1") + .join(table("t2"), Inner) + .join(table("t3"), Inner, Option(Symbol("t3.id") === Symbol("t1.id"))) + .join(table("t4"), Inner, Option(Symbol("t4.id") === Symbol("t1.id"))) + .select(star())) + + // Test multiple on clauses. + intercept("select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1") + + // Parenthesis + assertEqual( + "select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1", + table("t1") + .join(table("t2") + .join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1)) + .select(star())) + assertEqual( + "select * from t1 inner join (t2 inner join t3) on col3 = col2", + table("t1") + .join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2)) + .select(star())) + assertEqual( + "select * from t1 inner join (t2 inner join t3 on col3 = col2)", + table("t1") + .join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None) + .select(star())) + + // Implicit joins. + assertEqual( + "select * from t1, t3 join t2 on t1.col1 = t2.col2", + table("t1") + .join(table("t3")) + .join(table("t2"), Inner, Option(Symbol("t1.col1") === Symbol("t2.col2"))) + .select(star())) } test("sampled relations") { @@ -369,9 +416,13 @@ class PlanParserSuite extends PlanTest { assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", - "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") + "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported") intercept(s"$sql tablesample(bucket 11 out of 10) as x", s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") + intercept("SELECT * FROM parquet_t0 TABLESAMPLE(300M) s", + "TABLESAMPLE(byteLengthLiteral) is not supported") + intercept("SELECT * FROM parquet_t0 TABLESAMPLE(BUCKET 3 OUT OF 32 ON rand()) s", + "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported") } test("sub-query") { @@ -415,20 +466,21 @@ class PlanParserSuite extends PlanTest { assertEqual("table d.t", table("d", "t")) } + test("table valued function") { + assertEqual( + "select * from range(2)", + UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star())) + } + test("inline table") { - assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( - Seq('col1.int), - Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual("values 1, 2, 3, 4", + UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x))))) + assertEqual( - "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", - LocalRelation.fromExternalRows( - Seq('a.int, 'b.string), - Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) - intercept("values (a, 'a'), (b, 'b')", - "All expressions in an inline table must be constants.") - intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", - "Number of aliases must match the number of fields in an inline table.") - intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + "values (1, 'a'), (2, 'b') as tbl(a, b)", + UnresolvedInlineTable( + Seq("a", "b"), + Seq(Literal(1), Literal("a")) :: Seq(Literal(2), Literal("b")) :: Nil).as("tbl")) } test("simple select query with !> and !<") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index bef7d38f1a40e..cb57fb69b2555 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -53,8 +53,9 @@ class TableIdentifierParserSuite extends SparkFunSuite { "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", "int", "smallint", "timestamp", "at") - val hiveNonReservedRegression = Seq("left", "right", "left", "right", "full", "inner", "semi", - "union", "except", "intersect", "schema", "database") + val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", + "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", + "where", "having", "from", "to", "table", "with", "not") test("table identifier") { // Regular names. @@ -67,11 +68,10 @@ class TableIdentifierParserSuite extends SparkFunSuite { } } - test("table identifier - keywords") { + test("table identifier - strict keywords") { // SQL Keywords. - val keywords = Seq("select", "from", "where") ++ hiveNonReservedRegression - keywords.foreach { keyword => - intercept[ParseException](parseTableIdentifier(keyword)) + hiveStrictNonReservedKeyword.foreach { keyword => + assert(TableIdentifier(keyword) === parseTableIdentifier(keyword)) assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) } @@ -83,4 +83,27 @@ class TableIdentifierParserSuite extends SparkFunSuite { assert(TableIdentifier(nonReserved) === parseTableIdentifier(nonReserved)) } } + + test("SPARK-17364 table identifier - contains number") { + assert(parseTableIdentifier("123_") == TableIdentifier("123_")) + assert(parseTableIdentifier("1a.123_") == TableIdentifier("123_", Some("1a"))) + // ".123" should not be treated as token of type DECIMAL_VALUE + assert(parseTableIdentifier("a.123A") == TableIdentifier("123A", Some("a"))) + // ".123E3" should not be treated as token of type SCIENTIFIC_DECIMAL_VALUE + assert(parseTableIdentifier("a.123E3_LIST") == TableIdentifier("123E3_LIST", Some("a"))) + // ".123D" should not be treated as token of type DOUBLE_LITERAL + assert(parseTableIdentifier("a.123D_LIST") == TableIdentifier("123D_LIST", Some("a"))) + // ".123BD" should not be treated as token of type BIGDECIMAL_LITERAL + assert(parseTableIdentifier("a.123BD_LIST") == TableIdentifier("123BD_LIST", Some("a"))) + } + + test("SPARK-17832 table identifier - contains backtick") { + val complexName = TableIdentifier("`weird`table`name", Some("`d`b`1")) + assert(complexName === parseTableIdentifier("```d``b``1`.```weird``table``name`")) + assert(complexName === parseTableIdentifier(complexName.quotedString)) + intercept[ParseException](parseTableIdentifier(complexName.unquotedString)) + // Table identifier contains countious backticks should be treated correctly. + val complexName2 = TableIdentifier("x``y", Some("d``b")) + assert(complexName2 === parseTableIdentifier(complexName2.quotedString)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 81cc6b123cdd4..8068ce922e636 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -79,13 +79,15 @@ class ConstraintPropagationSuite extends SparkFunSuite { assert(tr.analyze.constraints.isEmpty) val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3).analyze + // SPARK-16644: aggregate expression count(a) should not appear in the constraints. verifyConstraints(aliasedRelation.analyze.constraints, ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "c1") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")), resolveColumn(aliasedRelation.analyze, "a") < 5, - IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))) + IsNotNull(resolveColumn(aliasedRelation.analyze, "a")), + IsNotNull(resolveColumn(aliasedRelation.analyze, "a3"))))) } test("propagating constraints in expand") { @@ -126,8 +128,16 @@ class ConstraintPropagationSuite extends SparkFunSuite { ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) + + val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y)) + verifyConstraints(multiAlias.analyze.constraints, + ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")), + IsNotNull(resolveColumn(multiAlias.analyze, "y")), + resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10)) + ) } test("propagating constraints in union") { @@ -298,7 +308,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === Cast(resolveColumn(tr, "c"), LongType), Cast(resolveColumn(tr, "d"), DoubleType) / - Cast(Cast(10, LongType), DoubleType) === + Cast(10, DoubleType) === Cast(resolveColumn(tr, "e"), DoubleType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), @@ -312,7 +322,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= Cast(resolveColumn(tr, "c"), LongType), Cast(resolveColumn(tr, "d"), DoubleType) / - Cast(Cast(10, LongType), DoubleType) < + Cast(10, DoubleType) < Cast(resolveColumn(tr, "e"), DoubleType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), @@ -350,4 +360,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { verifyConstraints(tr.analyze.constraints, ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c"))))) } + + test("not infer non-deterministic constraints") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + verifyConstraints(tr + .where('a.attr === Rand(0)) + .analyze.constraints, + ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "a"))))) + + verifyConstraints(tr + .where('a.attr === InputFileName()) + .where('a.attr =!= 'c.attr) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr, "a") =!= resolveColumn(tr, "c"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c"))))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 6310f0c2bc0ed..64e268703bf5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ /** @@ -56,16 +56,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) * etc., will all now be equivalent. * - Sample the seed will replaced by 0L. + * - Join conditions will be resorted by hashCode. */ private def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child) + Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And), child) case sample: Sample => sample.copy(seed = 0L)(true) + case join @ Join(left, right, joinType, condition) if condition.isDefined => + val newCondition = + splitConjunctivePredicates(condition.get).map(rewriteEqual(_)).sortBy(_.hashCode()) + .reduce(And) + Join(left, right, joinType, Some(newCondition)) } } + /** + * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be + * equivalent: + * 1. (a = b), (b = a); + * 2. (a <=> b), (b <=> a). + */ + private def rewriteEqual(condition: Expression): Expression = condition match { + case eq @ EqualTo(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + case eq @ EqualNullSafe(l: Expression, r: Expression) => + Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) + case _ => condition // Don't reorder. + } + /** Fails the test if the two plans do not match */ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { val normalized1 = normalizePlan(normalizeExprIds(plan1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala new file mode 100644 index 0000000000000..0c1feb3aa0882 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.TimeZone + +/** + * Helper functions for testing date and time functionality. + */ +object DateTimeTestUtils { + + val ALL_TIMEZONES: Seq[TimeZone] = TimeZone.getAvailableIDs.toSeq.map(TimeZone.getTimeZone) + + def withDefaultTimeZone[T](newDefaultTimeZone: TimeZone)(block: => T): T = { + val originalDefaultTimeZone = TimeZone.getDefault + try { + DateTimeUtils.resetThreadLocals() + TimeZone.setDefault(newDefaultTimeZone) + block + } finally { + TimeZone.setDefault(originalDefaultTimeZone) + DateTimeUtils.resetThreadLocals() + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 6745b4b6c3c67..4f516d006458e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -353,6 +353,25 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.getTimeInMillis * 1000 + 123456) } + test("SPARK-15379: special invalid date string") { + // Test stringToDate + assert(stringToDate( + UTF8String.fromString("2015-02-29 00:00:00")).isEmpty) + assert(stringToDate( + UTF8String.fromString("2015-04-31 00:00:00")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-02-29")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-04-31")).isEmpty) + + + // Test stringToTimestamp + assert(stringToTimestamp( + UTF8String.fromString("2015-02-29 00:00:00")).isEmpty) + assert(stringToTimestamp( + UTF8String.fromString("2015-04-31 00:00:00")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-02-29")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-04-31")).isEmpty) + } + test("hours") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) @@ -469,10 +488,23 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(toJavaTimestamp(fromUTCTime(fromJavaTimestamp(Timestamp.valueOf(utc)), tz)).toString === expected) } - test("2011-12-25 09:00:00.123456", "UTC", "2011-12-25 09:00:00.123456") - test("2011-12-25 09:00:00.123456", "JST", "2011-12-25 18:00:00.123456") - test("2011-12-25 09:00:00.123456", "PST", "2011-12-25 01:00:00.123456") - test("2011-12-25 09:00:00.123456", "Asia/Shanghai", "2011-12-25 17:00:00.123456") + for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { + DateTimeTestUtils.withDefaultTimeZone(tz) { + test("2011-12-25 09:00:00.123456", "UTC", "2011-12-25 09:00:00.123456") + test("2011-12-25 09:00:00.123456", "JST", "2011-12-25 18:00:00.123456") + test("2011-12-25 09:00:00.123456", "PST", "2011-12-25 01:00:00.123456") + test("2011-12-25 09:00:00.123456", "Asia/Shanghai", "2011-12-25 17:00:00.123456") + } + } + + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("PST")) { + // Daylight Saving Time + test("2016-03-13 09:59:59.0", "PST", "2016-03-13 01:59:59.0") + test("2016-03-13 10:00:00.0", "PST", "2016-03-13 03:00:00.0") + test("2016-11-06 08:59:59.0", "PST", "2016-11-06 01:59:59.0") + test("2016-11-06 09:00:00.0", "PST", "2016-11-06 01:00:00.0") + test("2016-11-06 10:00:00.0", "PST", "2016-11-06 02:00:00.0") + } } test("to UTC timestamp") { @@ -480,9 +512,50 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(toJavaTimestamp(toUTCTime(fromJavaTimestamp(Timestamp.valueOf(utc)), tz)).toString === expected) } - test("2011-12-25 09:00:00.123456", "UTC", "2011-12-25 09:00:00.123456") - test("2011-12-25 18:00:00.123456", "JST", "2011-12-25 09:00:00.123456") - test("2011-12-25 01:00:00.123456", "PST", "2011-12-25 09:00:00.123456") - test("2011-12-25 17:00:00.123456", "Asia/Shanghai", "2011-12-25 09:00:00.123456") + + for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { + DateTimeTestUtils.withDefaultTimeZone(tz) { + test("2011-12-25 09:00:00.123456", "UTC", "2011-12-25 09:00:00.123456") + test("2011-12-25 18:00:00.123456", "JST", "2011-12-25 09:00:00.123456") + test("2011-12-25 01:00:00.123456", "PST", "2011-12-25 09:00:00.123456") + test("2011-12-25 17:00:00.123456", "Asia/Shanghai", "2011-12-25 09:00:00.123456") + } + } + + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("PST")) { + // Daylight Saving Time + test("2016-03-13 01:59:59", "PST", "2016-03-13 09:59:59.0") + // 2016-03-13 02:00:00 PST does not exists + test("2016-03-13 02:00:00", "PST", "2016-03-13 10:00:00.0") + test("2016-03-13 03:00:00", "PST", "2016-03-13 10:00:00.0") + test("2016-11-06 00:59:59", "PST", "2016-11-06 07:59:59.0") + // 2016-11-06 01:00:00 PST could be 2016-11-06 08:00:00 UTC or 2016-11-06 09:00:00 UTC + test("2016-11-06 01:00:00", "PST", "2016-11-06 09:00:00.0") + test("2016-11-06 01:59:59", "PST", "2016-11-06 09:59:59.0") + test("2016-11-06 02:00:00", "PST", "2016-11-06 10:00:00.0") + } + } + + test("daysToMillis and millisToDays") { + // There are some days are skipped entirely in some timezone, skip them here. + val skipped_days = Map[String, Int]( + "Kwajalein" -> 8632, + "Pacific/Apia" -> 15338, + "Pacific/Enderbury" -> 9131, + "Pacific/Fakaofo" -> 15338, + "Pacific/Kiritimati" -> 9131, + "Pacific/Kwajalein" -> 8632, + "MIT" -> 15338) + for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { + DateTimeTestUtils.withDefaultTimeZone(tz) { + val skipped = skipped_days.getOrElse(tz.getID, Int.MinValue) + (-20000 to 20000).foreach { d => + if (d != skipped) { + assert(millisToDays(daysToMillis(d)) === d, + s"Round trip of ${d} did not work in tz ${tz}") + } + } + } + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala new file mode 100644 index 0000000000000..bc6852ca7e1fd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/TypeUtilsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.types._ + +class TypeUtilsSuite extends SparkFunSuite { + + private def typeCheckPass(types: Seq[DataType]): Unit = { + assert(TypeUtils.checkForSameTypeInputExpr(types, "a") == TypeCheckSuccess) + } + + private def typeCheckFail(types: Seq[DataType]): Unit = { + assert(TypeUtils.checkForSameTypeInputExpr(types, "a").isInstanceOf[TypeCheckFailure]) + } + + test("checkForSameTypeInputExpr") { + typeCheckPass(Nil) + typeCheckPass(StringType :: Nil) + typeCheckPass(StringType :: StringType :: Nil) + + typeCheckFail(StringType :: IntegerType :: Nil) + typeCheckFail(StringType :: IntegerType :: Nil) + + // Should also work on arrays. See SPARK-14990 + typeCheckPass(ArrayType(StringType, containsNull = true) :: + ArrayType(StringType, containsNull = false) :: Nil) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 6b85f12521c2a..569230accfd74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser class DataTypeSuite extends SparkFunSuite { @@ -342,4 +343,33 @@ class DataTypeSuite extends SparkFunSuite { StructField("a", StringType, nullable = false) :: StructField("b", StringType, nullable = false) :: Nil), expected = false) + + def checkCatalogString(dt: DataType): Unit = { + test(s"catalogString: $dt") { + val dt2 = CatalystSqlParser.parseDataType(dt.catalogString) + assert(dt === dt2) + } + } + def createStruct(n: Int): StructType = new StructType(Array.tabulate(n) { + i => StructField(s"col$i", IntegerType, nullable = true) + }) + + checkCatalogString(BooleanType) + checkCatalogString(ByteType) + checkCatalogString(ShortType) + checkCatalogString(IntegerType) + checkCatalogString(LongType) + checkCatalogString(FloatType) + checkCatalogString(DoubleType) + checkCatalogString(DecimalType(10, 5)) + checkCatalogString(BinaryType) + checkCatalogString(StringType) + checkCatalogString(DateType) + checkCatalogString(TimestampType) + checkCatalogString(createStruct(4)) + checkCatalogString(createStruct(40)) + checkCatalogString(ArrayType(IntegerType)) + checkCatalogString(ArrayType(createStruct(40))) + checkCatalogString(MapType(IntegerType, StringType)) + checkCatalogString(MapType(IntegerType, createStruct(40))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index e1675c95907af..4cf329ddee213 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -22,6 +22,7 @@ import scala.language.postfixOps import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.Decimal._ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { /** Check that a Decimal has the given string representation, precision and scale */ @@ -193,4 +194,18 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L) assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } + + test("changePrecision() on compact decimal should respect rounding mode") { + Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => + Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => + Seq("", "-").foreach { sign => + val bd = BigDecimal(sign + n) + val unscaled = (bd * 10).toLongExact + val d = Decimal(unscaled, 8, 1) + assert(d.changePrecision(10, 0, mode)) + assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") + } + } + } + } } diff --git a/sql/core/benchmarks/WideSchemaBenchmark-results.txt b/sql/core/benchmarks/WideSchemaBenchmark-results.txt new file mode 100644 index 0000000000000..0b9f791ac85e4 --- /dev/null +++ b/sql/core/benchmarks/WideSchemaBenchmark-results.txt @@ -0,0 +1,117 @@ +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + +parsing large select: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +1 select expressions 2 / 4 0.0 2050147.0 1.0X +100 select expressions 6 / 7 0.0 6123412.0 0.3X +2500 select expressions 135 / 141 0.0 134623148.0 0.0X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + +many column field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +1 cols x 100000 rows (read in-mem) 16 / 18 6.3 158.6 1.0X +1 cols x 100000 rows (exec in-mem) 17 / 19 6.0 166.7 1.0X +1 cols x 100000 rows (read parquet) 24 / 26 4.3 235.1 0.7X +1 cols x 100000 rows (write parquet) 81 / 85 1.2 811.3 0.2X +100 cols x 1000 rows (read in-mem) 17 / 19 6.0 166.2 1.0X +100 cols x 1000 rows (exec in-mem) 25 / 27 4.0 249.2 0.6X +100 cols x 1000 rows (read parquet) 23 / 25 4.4 226.0 0.7X +100 cols x 1000 rows (write parquet) 83 / 87 1.2 831.0 0.2X +2500 cols x 40 rows (read in-mem) 132 / 137 0.8 1322.9 0.1X +2500 cols x 40 rows (exec in-mem) 326 / 330 0.3 3260.6 0.0X +2500 cols x 40 rows (read parquet) 831 / 839 0.1 8305.8 0.0X +2500 cols x 40 rows (write parquet) 237 / 245 0.4 2372.6 0.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + +wide shallowly nested struct field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +1 wide x 100000 rows (read in-mem) 15 / 17 6.6 151.0 1.0X +1 wide x 100000 rows (exec in-mem) 20 / 22 5.1 196.6 0.8X +1 wide x 100000 rows (read parquet) 59 / 63 1.7 592.8 0.3X +1 wide x 100000 rows (write parquet) 81 / 87 1.2 814.6 0.2X +100 wide x 1000 rows (read in-mem) 21 / 25 4.8 208.7 0.7X +100 wide x 1000 rows (exec in-mem) 72 / 81 1.4 718.5 0.2X +100 wide x 1000 rows (read parquet) 75 / 85 1.3 752.6 0.2X +100 wide x 1000 rows (write parquet) 88 / 95 1.1 876.7 0.2X +2500 wide x 40 rows (read in-mem) 28 / 34 3.5 282.2 0.5X +2500 wide x 40 rows (exec in-mem) 1269 / 1284 0.1 12688.1 0.0X +2500 wide x 40 rows (read parquet) 549 / 578 0.2 5493.4 0.0X +2500 wide x 40 rows (write parquet) 96 / 104 1.0 959.1 0.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + +deeply nested struct field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +1 deep x 100000 rows (read in-mem) 14 / 16 7.0 143.8 1.0X +1 deep x 100000 rows (exec in-mem) 17 / 19 5.9 169.7 0.8X +1 deep x 100000 rows (read parquet) 33 / 35 3.1 327.0 0.4X +1 deep x 100000 rows (write parquet) 79 / 84 1.3 786.9 0.2X +100 deep x 1000 rows (read in-mem) 21 / 24 4.7 211.3 0.7X +100 deep x 1000 rows (exec in-mem) 221 / 235 0.5 2214.5 0.1X +100 deep x 1000 rows (read parquet) 1928 / 1952 0.1 19277.1 0.0X +100 deep x 1000 rows (write parquet) 91 / 96 1.1 909.5 0.2X +250 deep x 400 rows (read in-mem) 57 / 61 1.8 567.1 0.3X +250 deep x 400 rows (exec in-mem) 1329 / 1385 0.1 13291.8 0.0X +250 deep x 400 rows (read parquet) 36563 / 36750 0.0 365630.2 0.0X +250 deep x 400 rows (write parquet) 126 / 130 0.8 1262.0 0.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + +bushy struct field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +1 x 1 deep x 100000 rows (read in-mem) 13 / 15 7.8 127.7 1.0X +1 x 1 deep x 100000 rows (exec in-mem) 15 / 17 6.6 151.5 0.8X +1 x 1 deep x 100000 rows (read parquet) 20 / 23 5.0 198.3 0.6X +1 x 1 deep x 100000 rows (write parquet) 77 / 82 1.3 770.4 0.2X +128 x 8 deep x 1000 rows (read in-mem) 12 / 14 8.2 122.5 1.0X +128 x 8 deep x 1000 rows (exec in-mem) 124 / 140 0.8 1241.2 0.1X +128 x 8 deep x 1000 rows (read parquet) 69 / 74 1.4 693.9 0.2X +128 x 8 deep x 1000 rows (write parquet) 78 / 83 1.3 777.7 0.2X +1024 x 11 deep x 100 rows (read in-mem) 25 / 29 4.1 246.1 0.5X +1024 x 11 deep x 100 rows (exec in-mem) 1197 / 1223 0.1 11974.6 0.0X +1024 x 11 deep x 100 rows (read parquet) 426 / 433 0.2 4263.7 0.0X +1024 x 11 deep x 100 rows (write parquet) 91 / 98 1.1 913.5 0.1X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + +wide array field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +1 wide x 100000 rows (read in-mem) 14 / 16 7.0 143.2 1.0X +1 wide x 100000 rows (exec in-mem) 17 / 19 5.9 170.9 0.8X +1 wide x 100000 rows (read parquet) 43 / 46 2.3 434.1 0.3X +1 wide x 100000 rows (write parquet) 78 / 83 1.3 777.6 0.2X +100 wide x 1000 rows (read in-mem) 11 / 13 9.0 111.5 1.3X +100 wide x 1000 rows (exec in-mem) 13 / 15 7.8 128.3 1.1X +100 wide x 1000 rows (read parquet) 24 / 27 4.1 245.0 0.6X +100 wide x 1000 rows (write parquet) 74 / 80 1.4 740.5 0.2X +2500 wide x 40 rows (read in-mem) 11 / 13 9.1 109.5 1.3X +2500 wide x 40 rows (exec in-mem) 13 / 15 7.7 129.4 1.1X +2500 wide x 40 rows (read parquet) 24 / 26 4.1 241.3 0.6X +2500 wide x 40 rows (write parquet) 75 / 81 1.3 751.8 0.2X + +Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.11.6 +Intel(R) Core(TM) i7-4980HQ CPU @ 2.80GHz + +wide map field r/w: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------ +1 wide x 100000 rows (read in-mem) 16 / 18 6.2 162.6 1.0X +1 wide x 100000 rows (exec in-mem) 21 / 23 4.8 208.2 0.8X +1 wide x 100000 rows (read parquet) 54 / 59 1.8 543.6 0.3X +1 wide x 100000 rows (write parquet) 80 / 86 1.2 804.5 0.2X +100 wide x 1000 rows (read in-mem) 11 / 13 8.7 114.5 1.4X +100 wide x 1000 rows (exec in-mem) 14 / 16 7.0 143.5 1.1X +100 wide x 1000 rows (read parquet) 30 / 32 3.3 300.4 0.5X +100 wide x 1000 rows (write parquet) 75 / 80 1.3 749.9 0.2X +2500 wide x 40 rows (read in-mem) 13 / 15 7.8 128.1 1.3X +2500 wide x 40 rows (exec in-mem) 15 / 18 6.5 153.6 1.1X +2500 wide x 40 rows (read parquet) 30 / 33 3.3 304.4 0.5X +2500 wide x 40 rows (write parquet) 77 / 83 1.3 768.5 0.2X + diff --git a/sql/core/pom.xml b/sql/core/pom.xml index e1071ebfb5a61..e208fd355b9a7 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -39,7 +39,7 @@ com.univocity univocity-parsers - 2.0.2 + 2.1.1 jar @@ -73,7 +73,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} org.apache.parquet @@ -133,6 +133,33 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + + + test-jar-on-test-compile + test-compile + + test-jar + + + + org.codehaus.mojo build-helper-maven-plugin diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 1f1b5389aa7d4..cd521c52d1b21 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -29,6 +29,7 @@ import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -246,6 +247,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti SparkEnv.get().blockManager(), SparkEnv.get().serializerManager(), map.getPageSizeBytes(), + SparkEnv.get().conf().getLong("spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 38dbfef76caee..82ee5b0d77713 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -54,8 +54,10 @@ public UnsafeKVExternalSorter( StructType valueSchema, BlockManager blockManager, SerializerManager serializerManager, - long pageSizeBytes) throws IOException { - this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, null); + long pageSizeBytes, + long numElementsForSpillThreshold) throws IOException { + this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, + numElementsForSpillThreshold, null); } public UnsafeKVExternalSorter( @@ -64,6 +66,7 @@ public UnsafeKVExternalSorter( BlockManager blockManager, SerializerManager serializerManager, long pageSizeBytes, + long numElementsForSpillThreshold, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; @@ -73,6 +76,8 @@ public UnsafeKVExternalSorter( PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema); BaseOrdering ordering = GenerateOrdering.create(keySchema); KVComparator recordComparator = new KVComparator(ordering, keySchema.length()); + boolean canUseRadixSort = keySchema.length() == 1 && + SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)); TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager(); @@ -86,14 +91,17 @@ public UnsafeKVExternalSorter( prefixComparator, /* initialSize */ 4096, pageSizeBytes, - keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0))); + numElementsForSpillThreshold, + canUseRadixSort); } else { + // The array will be used to do in-place sort, which require half of the space to be empty. + assert(map.numKeys() <= map.getArray().size() / 2); // During spilling, the array in map will not be used, so we can borrow that and use it // as the underline array for in-memory sorter (it's always large enough). // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(), - false /* TODO(ekl) we can only radix sort if the BytesToBytes load factor is <= 0.5 */); + canUseRadixSort); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. @@ -128,6 +136,7 @@ public UnsafeKVExternalSorter( prefixComparator, /* initialSize */ 4096, pageSizeBytes, + numElementsForSpillThreshold, inMemSorter); // reset the map, so we can re-use it to insert new records. the inMemSorter will not used diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index cbe8f78164ae7..b903aeeb5c9ea 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -31,6 +31,8 @@ import java.util.Map; import java.util.Set; +import scala.Option; + import static org.apache.parquet.filter2.compat.RowGroupFilter.filterRowGroups; import static org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER; import static org.apache.parquet.format.converter.ParquetMetadataConverter.range; @@ -59,7 +61,12 @@ import org.apache.parquet.hadoop.util.ConfigurationUtil; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.Types; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.StructType$; +import org.apache.spark.util.AccumulatorV2; +import org.apache.spark.util.LongAccumulator; /** * Base class for custom RecordReaders for Parquet that directly materialize to `T`. @@ -136,11 +143,25 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont ReadSupport.ReadContext readContext = readSupport.init(new InitContext( taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); this.requestedSchema = readContext.getRequestedSchema(); - this.sparkSchema = new CatalystSchemaConverter(configuration).convert(requestedSchema); + String sparkRequestedSchemaString = + configuration.get(ParquetReadSupport$.MODULE$.SPARK_ROW_REQUESTED_SCHEMA()); + this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); for (BlockMetaData block : blocks) { this.totalRowCount += block.getRowCount(); } + + // For test purpose. + // If the predefined accumulator exists, the row group number to read will be updated + // to the accumulator. So we can check if the row groups are filtered or not in test case. + TaskContext taskContext = TaskContext$.MODULE$.get(); + if (taskContext != null) { + Option> accu = (Option>) taskContext.taskMetrics() + .lookForAccumulatorByName("numRowGroups"); + if (accu.isDefined()) { + ((LongAccumulator)accu.get()).add((long)blocks.size()); + } + } } /** @@ -196,7 +217,7 @@ protected void initialize(String path, List columns) throws IOException } this.requestedSchema = builder.named("spark_schema"); } - this.sparkSchema = new CatalystSchemaConverter(config).convert(requestedSchema); + this.sparkSchema = new ParquetSchemaConverter(config).convert(requestedSchema); this.reader = new ParquetFileReader(config, file, blocks, requestedSchema.getColumns()); for (BlockMetaData block : blocks) { this.totalRowCount += block.getRowCount(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index ea37a08ab5f55..3141edd06879c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -19,7 +19,6 @@ import java.io.IOException; -import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; @@ -60,7 +59,7 @@ public class VectorizedColumnReader { /** * If true, the current page is dictionary encoded. */ - private boolean useDictionary; + private boolean isCurrentPageDictionaryEncoded; /** * Maximum definition level for this column. @@ -101,13 +100,13 @@ public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader if (dictionaryPage != null) { try { this.dictionary = dictionaryPage.getEncoding().initDictionary(descriptor, dictionaryPage); - this.useDictionary = true; + this.isCurrentPageDictionaryEncoded = true; } catch (IOException e) { throw new IOException("could not decode the dictionary for " + descriptor, e); } } else { this.dictionary = null; - this.useDictionary = false; + this.isCurrentPageDictionaryEncoded = false; } this.totalValueCount = pageReader.getTotalValueCount(); if (totalValueCount == 0) { @@ -137,6 +136,13 @@ private boolean next() throws IOException { */ void readBatch(int total, ColumnVector column) throws IOException { int rowId = 0; + ColumnVector dictionaryIds = null; + if (dictionary != null) { + // SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to + // decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded + // page. + dictionaryIds = column.reserveDictionaryIds(total); + } while (total > 0) { // Compute the number of values we want to read in this page. int leftInPage = (int) (endOfPageValueCount - valuesRead); @@ -145,12 +151,10 @@ void readBatch(int total, ColumnVector column) throws IOException { leftInPage = (int) (endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); - if (useDictionary) { + if (isCurrentPageDictionaryEncoded) { // Read and decode dictionary ids. - ColumnVector dictionaryIds = column.reserveDictionaryIds(total); defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - if (column.hasDictionary() || (rowId == 0 && (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 || @@ -217,18 +221,24 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, if (column.dataType() == DataTypes.IntegerType || DecimalType.is32BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + } } } else if (column.dataType() == DataTypes.ByteType) { for (int i = rowId; i < rowId + num; ++i) { - column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } } } else if (column.dataType() == DataTypes.ShortType) { for (int i = rowId; i < rowId + num; ++i) { - column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } } } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); + throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } break; @@ -236,33 +246,41 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } } } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); + throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } break; case FLOAT: for (int i = rowId; i < rowId + num; ++i) { - column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); + } } break; case DOUBLE: for (int i = rowId; i < rowId + num; ++i) { - column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); + if (!column.isNullAt(i)) { + column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); + } } break; case INT96: if (column.dataType() == DataTypes.TimestampType) { for (int i = rowId; i < rowId + num; ++i) { // TODO: Convert dictionary of Binaries to dictionary of Longs - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putLong(i, CatalystRowConverter.binaryToSQLTimestamp(v)); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putLong(i, ParquetRowConverter.binaryToSQLTimestamp(v)); + } } } else { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } break; case BINARY: @@ -271,34 +289,42 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, // and reuse it across batches. This should mean adding a ByteArray would just update // the length and offset. for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putByteArray(i, v.getBytes()); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } } break; case FIXED_LEN_BYTE_ARRAY: // DecimalType written in the legacy mode if (DecimalType.is32BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v)); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putInt(i, (int) ParquetRowConverter.binaryToUnscaledLong(v)); + } } } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v)); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putLong(i, ParquetRowConverter.binaryToUnscaledLong(v)); + } } } else if (DecimalType.isByteArrayDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putByteArray(i, v.getBytes()); + if (!column.isNullAt(i)) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } } } else { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } break; default: - throw new NotImplementedException("Unsupported type: " + descriptor.getType()); + throw new UnsupportedOperationException("Unsupported type: " + descriptor.getType()); } } @@ -327,7 +353,7 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); + throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } } @@ -360,7 +386,7 @@ private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOE defColumn.readDoubles( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); + throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } } @@ -375,13 +401,13 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, // Read 12 bytes for INT96 - CatalystRowConverter.binaryToSQLTimestamp(data.readBinary(12))); + ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12))); } else { column.putNull(rowId + i); } } } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); + throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } } @@ -394,7 +420,7 @@ private void readFixedLenByteArrayBatch(int rowId, int num, for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { column.putInt(rowId + i, - (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + (int) ParquetRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); } else { column.putNull(rowId + i); } @@ -403,7 +429,7 @@ private void readFixedLenByteArrayBatch(int rowId, int num, for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, - CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + ParquetRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); } else { column.putNull(rowId + i); } @@ -417,7 +443,7 @@ private void readFixedLenByteArrayBatch(int rowId, int num, } } } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); + throw new UnsupportedOperationException("Unimplemented type: " + column.dataType()); } } @@ -459,16 +485,16 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr @SuppressWarnings("deprecation") Encoding plainDict = Encoding.PLAIN_DICTIONARY; // var to allow warning suppression if (dataEncoding != plainDict && dataEncoding != Encoding.RLE_DICTIONARY) { - throw new NotImplementedException("Unsupported encoding: " + dataEncoding); + throw new UnsupportedOperationException("Unsupported encoding: " + dataEncoding); } this.dataColumn = new VectorizedRleValuesReader(); - this.useDictionary = true; + this.isCurrentPageDictionaryEncoded = true; } else { if (dataEncoding != Encoding.PLAIN) { - throw new NotImplementedException("Unsupported encoding: " + dataEncoding); + throw new UnsupportedOperationException("Unsupported encoding: " + dataEncoding); } this.dataColumn = new VectorizedPlainValuesReader(); - this.useDictionary = false; + this.isCurrentPageDictionaryEncoded = false; } try { @@ -485,7 +511,7 @@ private void readPageV1(DataPageV1 page) throws IOException { // Initialize the decoders. if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) { - throw new NotImplementedException("Unsupported encoding: " + page.getDlEncoding()); + throw new UnsupportedOperationException("Unsupported encoding: " + page.getDlEncoding()); } int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); this.defColumn = new VectorizedRleValuesReader(bitWidth); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 2672e0453b392..9def4559d214e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.unsafe.Platform; @@ -31,6 +33,10 @@ public class VectorizedPlainValuesReader extends ValuesReader implements Vectori private byte[] buffer; private int offset; private int bitOffset; // Only used for booleans. + private ByteBuffer byteBuffer; // used to wrap the byte array buffer + + private static final boolean bigEndianPlatform = + ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); public VectorizedPlainValuesReader() { } @@ -39,6 +45,9 @@ public VectorizedPlainValuesReader() { public void initFromPage(int valueCount, byte[] bytes, int offset) throws IOException { this.buffer = bytes; this.offset = offset + Platform.BYTE_ARRAY_OFFSET; + if (bigEndianPlatform) { + byteBuffer = ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN); + } } @Override @@ -103,6 +112,9 @@ public final boolean readBoolean() { @Override public final int readInteger() { int v = Platform.getInt(buffer, offset); + if (bigEndianPlatform) { + v = java.lang.Integer.reverseBytes(v); + } offset += 4; return v; } @@ -110,6 +122,9 @@ public final int readInteger() { @Override public final long readLong() { long v = Platform.getLong(buffer, offset); + if (bigEndianPlatform) { + v = java.lang.Long.reverseBytes(v); + } offset += 8; return v; } @@ -121,14 +136,24 @@ public final byte readByte() { @Override public final float readFloat() { - float v = Platform.getFloat(buffer, offset); + float v; + if (!bigEndianPlatform) { + v = Platform.getFloat(buffer, offset); + } else { + v = byteBuffer.getFloat(offset - Platform.BYTE_ARRAY_OFFSET); + } offset += 4; return v; } @Override public final double readDouble() { - double v = Platform.getDouble(buffer, offset); + double v; + if (!bigEndianPlatform) { + v = Platform.getDouble(buffer, offset); + } else { + v = byteBuffer.getDouble(offset - Platform.BYTE_ARRAY_OFFSET); + } offset += 8; return v; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index 69ce54390fead..25a565d32638d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -30,7 +30,7 @@ * This is an illustrative implementation of an append-only single-key/single value aggregate hash * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates * (and fall back to the `BytesToBytesMap` if a given key isn't found). This can be potentially - * 'codegened' in TungstenAggregate to speed up aggregates w/ key. + * 'codegened' in HashAggregate to speed up aggregates w/ key. * * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the * key-value pairs. The index lookups in the array rely on linear probing (with a small number of diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index e7dccd1b2261f..9b0ed802fc9d2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -19,7 +19,7 @@ import java.math.BigDecimal; import java.math.BigInteger; -import org.apache.commons.lang.NotImplementedException; +import com.google.common.annotations.VisibleForTesting; import org.apache.parquet.column.Dictionary; import org.apache.parquet.io.api.Binary; @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -98,7 +99,7 @@ protected Array(ColumnVector data) { @Override public ArrayData copy() { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } // TODO: this is extremely expensive. @@ -169,7 +170,7 @@ public Object[] array() { } } } else { - throw new NotImplementedException("Type " + dt); + throw new UnsupportedOperationException("Type " + dt); } return list; } @@ -179,7 +180,7 @@ public Object[] array() { @Override public boolean getBoolean(int ordinal) { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } @Override @@ -187,7 +188,7 @@ public boolean getBoolean(int ordinal) { @Override public short getShort(int ordinal) { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } @Override @@ -198,7 +199,7 @@ public short getShort(int ordinal) { @Override public float getFloat(int ordinal) { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } @Override @@ -238,12 +239,12 @@ public ArrayData getArray(int ordinal) { @Override public MapData getMap(int ordinal) { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } @Override public Object get(int ordinal, DataType dataType) { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } } @@ -277,11 +278,39 @@ public void reset() { */ public abstract void close(); - /* + public void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) { + int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); + if (requiredCapacity <= newCapacity) { + try { + reserveInternal(newCapacity); + } catch (OutOfMemoryError outOfMemoryError) { + throwUnsupportedException(requiredCapacity, outOfMemoryError); + } + } else { + throwUnsupportedException(requiredCapacity, null); + } + } + } + + private void throwUnsupportedException(int requiredCapacity, Throwable cause) { + String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + + "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + + " to false."; + + if (cause != null) { + throw new RuntimeException(message, cause); + } else { + throw new RuntimeException(message); + } + } + + /** * Ensures that there is enough storage to store capcity elements. That is, the put() APIs * must work for all rowIds < capcity. */ - public abstract void reserve(int capacity); + protected abstract void reserveInternal(int capacity); /** * Returns the number of nulls in this column. @@ -504,7 +533,7 @@ public ColumnarBatch.Row getStruct(int rowId) { /** * Returns a utility object to get structs. - * provided to keep API compabilitity with InternalRow for code generation + * provided to keep API compatibility with InternalRow for code generation */ public ColumnarBatch.Row getStruct(int rowId, int size) { resultStruct.rowId = rowId; @@ -546,7 +575,7 @@ private Array getByteArray(int rowId) { * Returns the value for rowId. */ public MapData getMap(int ordinal) { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } /** @@ -846,6 +875,12 @@ public final int appendStruct(boolean isNull) { */ protected int capacity; + /** + * Upper limit for the maximum capacity for this column. + */ + @VisibleForTesting + protected int MAX_CAPACITY = Integer.MAX_VALUE; + /** * Data type for this column. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index f50c35fc64a75..900d7c431e723 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -23,8 +23,6 @@ import java.util.Iterator; import java.util.List; -import org.apache.commons.lang.NotImplementedException; - import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; @@ -88,8 +86,9 @@ public static void populate(ColumnVector col, InternalRow row, int fieldIdx) { col.getChildColumn(0).putInts(0, capacity, c.months); col.getChildColumn(1).putLongs(0, capacity, c.microseconds); } else if (t instanceof DateType) { - Date date = (Date)row.get(fieldIdx, t); - col.putInts(0, capacity, DateTimeUtils.fromJavaDate(date)); + col.putInts(0, capacity, row.getInt(fieldIdx)); + } else if (t instanceof TimestampType) { + col.putLongs(0, capacity, row.getLong(fieldIdx)); } } } @@ -112,7 +111,7 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) { } return result; } else { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } } @@ -161,7 +160,7 @@ private static void appendValue(ColumnVector dst, DataType t, Object o) { } else if (t instanceof DateType) { dst.appendInt(DateTimeUtils.fromJavaDate((Date)o)); } else { - throw new NotImplementedException("Type " + t); + throw new UnsupportedOperationException("Type " + t); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 8cece73faa4b9..62abc2a821a3a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -19,8 +19,6 @@ import java.math.BigDecimal; import java.util.*; -import org.apache.commons.lang.NotImplementedException; - import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; @@ -139,6 +137,10 @@ public InternalRow copy() { DataType dt = columns[i].dataType(); if (dt instanceof BooleanType) { row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); } else if (dt instanceof IntegerType) { row.setInt(i, getInt(i)); } else if (dt instanceof LongType) { @@ -156,6 +158,8 @@ public InternalRow copy() { row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); } else if (dt instanceof DateType) { row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); } else { throw new RuntimeException("Not implemented. " + dt); } @@ -166,7 +170,7 @@ public InternalRow copy() { @Override public boolean anyNull() { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } @Override @@ -227,12 +231,12 @@ public ArrayData getArray(int ordinal) { @Override public MapData getMap(int ordinal) { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } @Override public Object get(int ordinal, DataType dataType) { - throw new NotImplementedException(); + throw new UnsupportedOperationException(); } @Override @@ -258,7 +262,7 @@ public void update(int ordinal, Object value) { setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), t.precision()); } else { - throw new NotImplementedException("Datatype not supported " + dt); + throw new UnsupportedOperationException("Datatype not supported " + dt); } } } @@ -430,7 +434,7 @@ public int numValidRows() { */ public void setColumn(int ordinal, ColumnVector column) { if (column instanceof OffHeapColumnVector) { - throw new NotImplementedException("Need to ref count columns."); + throw new UnsupportedOperationException("Need to ref count columns."); } columns[ordinal] = column; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index b1901411351a2..913a05a0aa0ec 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -16,10 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.nio.ByteBuffer; import java.nio.ByteOrder; -import org.apache.commons.lang.NotImplementedException; - import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -28,6 +27,10 @@ * Column data backed using offheap memory. */ public final class OffHeapColumnVector extends ColumnVector { + + private static final boolean bigEndianPlatform = + ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + // The data stored in these two allocations need to maintain binary compatible. We can // directly pass this buffer to external components. private long nulls; @@ -39,9 +42,7 @@ public final class OffHeapColumnVector extends ColumnVector { protected OffHeapColumnVector(int capacity, DataType type) { super(capacity, type, MemoryMode.OFF_HEAP); - if (!ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN)) { - throw new NotImplementedException("Only little endian is supported."); - } + nulls = 0; data = 0; lengthData = 0; @@ -221,8 +222,17 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 4 * rowId, count * 4); + if (!bigEndianPlatform) { + Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, + null, data + 4 * rowId, count * 4); + } else { + int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; + long offset = data + 4 * rowId; + for (int i = 0; i < count; ++i, offset += 4, srcOffset += 4) { + Platform.putInt(null, offset, + java.lang.Integer.reverseBytes(Platform.getInt(src, srcOffset))); + } + } } @Override @@ -259,8 +269,17 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, - null, data + 8 * rowId, count * 8); + if (!bigEndianPlatform) { + Platform.copyMemory(src, srcIndex + Platform.BYTE_ARRAY_OFFSET, + null, data + 8 * rowId, count * 8); + } else { + int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; + long offset = data + 8 * rowId; + for (int i = 0; i < count; ++i, offset += 8, srcOffset += 8) { + Platform.putLong(null, offset, + java.lang.Long.reverseBytes(Platform.getLong(src, srcOffset))); + } + } } @Override @@ -297,8 +316,16 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - null, data + rowId * 4, count * 4); + if (!bigEndianPlatform) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } else { + ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); + long offset = data + 4 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putFloat(null, offset, bb.getFloat(srcIndex + (4 * i))); + } + } } @Override @@ -336,8 +363,16 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + if (!bigEndianPlatform) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId * 8, count * 8); + } else { + ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); + long offset = data + 8 * rowId; + for (int i = 0; i < count; ++i, offset += 8) { + Platform.putDouble(null, offset, bb.getDouble(srcIndex + (8 * i))); + } + } } @Override @@ -387,13 +422,9 @@ public void loadBytes(ColumnVector.Array array) { array.byteArrayOffset = 0; } - @Override - public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); - } - // Split out the slow path. - private void reserveInternal(int newCapacity) { + @Override + protected void reserveInternal(int newCapacity) { if (this.resultArray != null) { this.lengthData = Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index e97276800daa8..85067df4ebf91 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.Arrays; import org.apache.spark.memory.MemoryMode; @@ -27,6 +29,10 @@ * and a java array for the values. */ public final class OnHeapColumnVector extends ColumnVector { + + private static final boolean bigEndianPlatform = + ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + // The data stored in these arrays need to maintain binary compatible. We can // directly pass this buffer to external components. @@ -211,10 +217,11 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - for (int i = 0; i < count; ++i) { + for (int i = 0; i < count; ++i, srcOffset += 4) { intData[i + rowId] = Platform.getInt(src, srcOffset); - srcIndex += 4; - srcOffset += 4; + if (bigEndianPlatform) { + intData[i + rowId] = java.lang.Integer.reverseBytes(intData[i + rowId]); + } } } @@ -251,10 +258,11 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; - for (int i = 0; i < count; ++i) { + for (int i = 0; i < count; ++i, srcOffset += 8) { longData[i + rowId] = Platform.getLong(src, srcOffset); - srcIndex += 8; - srcOffset += 8; + if (bigEndianPlatform) { + longData[i + rowId] = java.lang.Long.reverseBytes(longData[i + rowId]); + } } } @@ -286,8 +294,15 @@ public void putFloats(int rowId, int count, float[] src, int srcIndex) { @Override public void putFloats(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, - floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + if (!bigEndianPlatform) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, floatData, + Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + } else { + ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < count; ++i) { + floatData[i + rowId] = bb.getFloat(srcIndex + (4 * i)); + } + } } @Override @@ -320,8 +335,15 @@ public void putDoubles(int rowId, int count, double[] src, int srcIndex) { @Override public void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, - Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + if (!bigEndianPlatform) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, doubleData, + Platform.DOUBLE_ARRAY_OFFSET + rowId * 8, count * 8); + } else { + ByteBuffer bb = ByteBuffer.wrap(src).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < count; ++i) { + doubleData[i + rowId] = bb.getDouble(srcIndex + (8 * i)); + } + } } @Override @@ -370,13 +392,9 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { return result; } - @Override - public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) reserveInternal(requiredCapacity * 2); - } - // Spilt this function out since it is the slow path. - private void reserveInternal(int newCapacity) { + @Override + protected void reserveInternal(int newCapacity) { if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java similarity index 97% rename from sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java rename to sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java index c7c6e3868f9bb..247e94b86c349 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java +++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.expressions.java; +package org.apache.spark.sql.expressions.javalang; import org.apache.spark.annotation.Experimental; import org.apache.spark.api.java.function.MapFunction; @@ -29,7 +29,7 @@ * :: Experimental :: * Type-safe functions available for {@link org.apache.spark.sql.Dataset} operations in Java. * - * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}. + * Scala users should use {@link org.apache.spark.sql.expressions.scalalang.typed}. * * @since 2.0.0 */ diff --git a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 226d59d0eae88..27d32b5dca431 100644 --- a/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -1,5 +1,7 @@ -org.apache.spark.sql.execution.datasources.csv.DefaultSource -org.apache.spark.sql.execution.datasources.jdbc.DefaultSource -org.apache.spark.sql.execution.datasources.json.DefaultSource -org.apache.spark.sql.execution.datasources.parquet.DefaultSource -org.apache.spark.sql.execution.datasources.text.DefaultSource +org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider +org.apache.spark.sql.execution.datasources.json.JsonFileFormat +org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +org.apache.spark.sql.execution.datasources.text.TextFileFormat +org.apache.spark.sql.execution.streaming.ConsoleSinkProvider +org.apache.spark.sql.execution.streaming.TextSocketSourceProvider diff --git a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css index 303f8ebb8814c..594e747a8d3a5 100644 --- a/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css +++ b/sql/core/src/main/resources/org/apache/spark/sql/execution/ui/static/spark-sql-viz.css @@ -41,3 +41,8 @@ stroke: #444; stroke-width: 1.5px; } + +/* Breaks the long string like file path when showing tooltips */ +.tooltip-inner { + word-wrap:break-word; +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c58addaf903fd..a46d1949e94ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -24,9 +24,11 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ @@ -37,6 +39,14 @@ private[sql] object Column { def apply(expr: Expression): Column = new Column(expr) def unapply(col: Column): Option[Expression] = Some(col.expr) + + private[sql] def generateAlias(e: Expression): String = { + e match { + case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => + a.aggregateFunction.toString + case expr => usePrettyExpression(expr).sql + } + } } /** @@ -68,10 +78,21 @@ class TypedColumn[-T, U]( } new TypedColumn[T, U](newExpr, encoder) } + + /** + * Gives the TypedColumn a name (alias). + * If the current TypedColumn has metadata associated with it, this metadata will be propagated + * to the new column. + * + * @group expr_ops + * @since 2.0.0 + */ + override def name(alias: String): TypedColumn[T, U] = + new TypedColumn[T, U](super.name(alias).expr, encoder) + } /** - * :: Experimental :: * A column that will be computed based on the data in a [[DataFrame]]. * * A new column is constructed based on the input columns present in a dataframe: @@ -100,7 +121,6 @@ class TypedColumn[-T, U]( * * @since 1.3.0 */ -@Experimental class Column(protected[sql] val expr: Expression) extends Logging { def this(name: String) = this(name match { @@ -111,6 +131,15 @@ class Column(protected[sql] val expr: Expression) extends Logging { case _ => UnresolvedAttribute.quotedString(name) }) + override def toString: String = usePrettyExpression(expr).sql + + override def equals(that: Any): Boolean = that match { + case that: Column => that.expr.equals(this.expr) + case _ => false + } + + override def hashCode: Int = this.expr.hashCode() + /** Creates a column based on the given expression. */ private def withExpr(newExpr: Expression): Column = new Column(newExpr) @@ -130,32 +159,28 @@ class Column(protected[sql] val expr: Expression) extends Logging { // Leave an unaliased generator with an empty list of names since the analyzer will generate // the correct defaults after the nested expression's type has been resolved. case explode: Explode => MultiAlias(explode, Nil) + case explode: PosExplode => MultiAlias(explode, Nil) case jt: JsonTuple => MultiAlias(jt, Nil) - case func: UnresolvedFunction => UnresolvedAlias(func, Some(usePrettyExpression(func).sql)) + case func: UnresolvedFunction => UnresolvedAlias(func, Some(Column.generateAlias)) // If we have a top level Cast, there is a chance to give it a better alias, if there is a // NamedExpression under this Cast. - case c: Cast => c.transformUp { - case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) - } match { - case ne: NamedExpression => ne - case other => Alias(expr, usePrettyExpression(expr).sql)() - } - - case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() - } + case c: Cast => + c.transformUp { + case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) + } match { + case ne: NamedExpression => ne + case other => Alias(expr, usePrettyExpression(expr).sql)() + } - override def toString: String = usePrettyExpression(expr).sql + case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => + UnresolvedAlias(a, Some(Column.generateAlias)) - override def equals(that: Any): Boolean = that match { - case that: Column => that.expr.equals(this.expr) - case _ => false + case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } - override def hashCode: Int = this.expr.hashCode - /** * Provides a type hint about the expected return value of this column. This information can * be used by operations such as `select` on a [[Dataset]] to automatically convert the @@ -910,12 +935,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: Symbol): Column = withExpr { - expr match { - case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata)) - case other => Alias(other, alias.name)() - } - } + def as(alias: Symbol): Column = name(alias.name) /** * Gives the column an alias with metadata. @@ -1076,6 +1096,22 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def over(window: expressions.WindowSpec): Column = window.withAggregate(this) + /** + * Define a empty analytic clause. In this case the analytic function is applied + * and presented for all rows in the result set. + * + * {{{ + * df.select( + * sum("price").over(), + * avg("price").over() + * ) + * }}} + * + * @group expr_ops + * @since 2.0.0 + */ + def over(): Column = over(Window.spec) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3d43f2022f669..fe3da25a4c4e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,7 +22,6 @@ import java.util.Properties import scala.collection.JavaConverters._ import org.apache.spark.Partition -import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -30,17 +29,14 @@ import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.execution.datasources.json.{InferSchema, JacksonParser, JSONOptions} -import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.types.StructType /** - * :: Experimental :: - * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, - * key-value stores, etc) or data streams. Use [[SparkSession.read]] to access this. + * Interface used to load a [[Dataset]] from external storage systems (e.g. file systems, + * key-value stores, etc). Use [[SparkSession.read]] to access this. * * @since 1.4.0 */ -@Experimental class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** @@ -123,13 +119,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def load(): DataFrame = { - val dataSource = - DataSource( - sparkSession, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap) - Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation())) + load(Seq.empty: _*) // force invocation of `load(...varargs...)` } /** @@ -139,7 +129,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def load(path: String): DataFrame = { - option("path", path).load() + option("path", path).load(Seq.empty: _*) // force invocation of `load(...varargs...)` } /** @@ -150,42 +140,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ @scala.annotation.varargs def load(paths: String*): DataFrame = { - if (paths.isEmpty) { - sparkSession.emptyDataFrame - } else { - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) - } - } - - /** - * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path - * (e.g. external key-value stores). - * - * @since 2.0.0 - */ - def stream(): DataFrame = { - val dataSource = - DataSource( + sparkSession.baseRelationToDataFrame( + DataSource.apply( sparkSession, + paths = paths, userSpecifiedSchema = userSpecifiedSchema, className = source, - options = extraOptions.toMap) - Dataset.ofRows(sparkSession, StreamingRelation(dataSource)) - } - - /** - * Loads input in as a [[DataFrame]], for data streams that read from some path. - * - * @since 2.0.0 - */ - def stream(path: String): DataFrame = { - option("path", path).stream() + options = extraOptions.toMap).resolveRelation()) } /** @@ -216,7 +177,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * clause expressions used to split the column `columnName` evenly. * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property - * should be included. + * should be included. "fetchsize" can be used to control the + * number of rows per fetch. * @since 1.4.0 */ def jdbc( @@ -246,7 +208,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param predicates Condition in the where clause for each partition. * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property - * should be included. + * should be included. "fetchsize" can be used to control the + * number of rows per fetch. * @since 1.4.0 */ def jdbc( @@ -277,35 +240,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can set the following JSON-specific options to deal with non-standard JSON files: - *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • - *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • - *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • - *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes - *
  • - *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers - * (e.g. 00012)
  • - *
  • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
  • - *
      - *
    • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the - * malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
    • - *
    • `DROPMALFORMED` : ignores the whole corrupted records.
    • - *
    • `FAILFAST` : throws an exception when it meets corrupted records.
    • - *
    - *
  • `columnNameOfCorruptRecord` (default `_corrupt_record`): allows renaming the new field - * having malformed string created by `PERMISSIVE` mode. This overrides - * `spark.sql.columnNameOfCorruptRecord`.
  • + * See the documentation on the overloaded `json()` method with varargs for more details. * * @since 1.4.0 */ - // TODO: Remove this one in Spark 2.0. - def json(path: String): DataFrame = format("json").load(path) + def json(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + json(Seq(path): _*) + } /** * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. @@ -314,6 +256,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * schema in advance, use the version that specifies the schema to avoid the extra scan. * * You can set the following JSON-specific options to deal with non-standard JSON files: + *
      *
    • `primitivesAsString` (default `false`): infers all primitive values as a string type
    • *
    • `prefersDecimal` (default `false`): infers all floating-point values as a decimal * type. If the values do not fit in decimal, then it infers them as doubles.
    • @@ -326,24 +269,32 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { *
    • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all * character using backslash quoting mechanism
    • *
    • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records - * during parsing.
    • - *
        - *
      • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts the - * malformed string into a new field configured by `columnNameOfCorruptRecord`. When - * a schema is set by user, it sets `null` for extra fields.
      • - *
      • `DROPMALFORMED` : ignores the whole corrupted records.
      • - *
      • `FAILFAST` : throws an exception when it meets corrupted records.
      • + * during parsing. + *
          + *
        • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
        • + *
        • `DROPMALFORMED` : ignores the whole corrupted records.
        • + *
        • `FAILFAST` : throws an exception when it meets corrupted records.
        • + *
        + * + *
      • `columnNameOfCorruptRecord` (default is the value specified in + * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string + * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
      • + *
      • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
      • + *
      • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
      • *
      - *
    • `columnNameOfCorruptRecord` (default `_corrupt_record`): allows renaming the new field - * having malformed string created by `PERMISSIVE` mode. This overrides - * `spark.sql.columnNameOfCorruptRecord`.
    • - * - * @since 1.6.0 + * @since 2.0.0 */ + @scala.annotation.varargs def json(paths: String*): DataFrame = format("json").load(paths : _*) /** - * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and + * Loads a `JavaRDD[String]` storing JSON objects (one object per record) and * returns the result as a [[DataFrame]]. * * Unless the schema is specified using [[schema]] function, this function goes through the @@ -387,21 +338,100 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { parsedOptions))(sparkSession)) } + /** + * Loads a CSV file and returns the result as a [[DataFrame]]. See the documentation on the + * other overloaded `csv()` method for more details. + * + * @since 2.0.0 + */ + def csv(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + csv(Seq(path): _*) + } + /** * Loads a CSV file and returns the result as a [[DataFrame]]. * - * This function goes through the input once to determine the input schema. To avoid going - * through the entire data once, specify the schema explicitly using [[schema]]. + * This function will go through the input once to determine the input schema if `inferSchema` + * is enabled. To avoid going through the entire data once, disable `inferSchema` option or + * specify the schema explicitly using [[schema]]. * + * You can set the following CSV-specific options to deal with CSV files: + *
        + *
      • `sep` (default `,`): sets the single character as a separator for each + * field and value.
      • + *
      • `encoding` (default `UTF-8`): decodes the CSV files by the given encoding + * type.
      • + *
      • `quote` (default `"`): sets the single character used for escaping quoted values where + * the separator can be part of the value. If you would like to turn off quotations, you need to + * set not `null` but an empty string. This behaviour is different form + * `com.databricks.spark.csv`.
      • + *
      • `escape` (default `\`): sets the single character used for escaping quotes inside + * an already quoted value.
      • + *
      • `comment` (default empty string): sets the single character used for skipping lines + * beginning with this character. By default, it is disabled.
      • + *
      • `header` (default `false`): uses the first line as names of columns.
      • + *
      • `inferSchema` (default `false`): infers the input schema automatically from data. It + * requires one extra pass over the data.
      • + *
      • `ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces + * from values being read should be skipped.
      • + *
      • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing + * whitespaces from values being read should be skipped.
      • + *
      • `nullValue` (default empty string): sets the string representation of a null value. Since + * 2.0.1, this applies to all supported types including the string type.
      • + *
      • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
      • + *
      • `positiveInf` (default `Inf`): sets the string representation of a positive infinity + * value.
      • + *
      • `negativeInf` (default `-Inf`): sets the string representation of a negative infinity + * value.
      • + *
      • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
      • + *
      • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
      • + * `java.sql.Timestamp.valueOf()` and `java.sql.Date.valueOf()` or ISO 8601 format. + *
      • `maxColumns` (default `20480`): defines a hard limit of how many columns + * a record can have.
      • + *
      • `maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed + * for any given value being read.
      • + *
      • `maxMalformedLogPerPartition` (default `10`): sets the maximum number of malformed rows + * Spark will log for each partition. Malformed records beyond this number will be ignored.
      • + *
      • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records + * during parsing. + *
          + *
        • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When + * a schema is set by user, it sets `null` for extra fields.
        • + *
        • `DROPMALFORMED` : ignores the whole corrupted records.
        • + *
        • `FAILFAST` : throws an exception when it meets corrupted records.
        • + *
        + *
      • + *
      * @since 2.0.0 */ @scala.annotation.varargs def csv(paths: String*): DataFrame = format("csv").load(paths : _*) /** - * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty - * [[DataFrame]] if no paths are passed in. + * Loads a Parquet file, returning the result as a [[DataFrame]]. See the documentation + * on the other overloaded `parquet()` method for more details. + * + * @since 2.0.0 + */ + def parquet(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + parquet(Seq(path): _*) + } + + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. * + * You can set the following Parquet-specific option(s) for reading Parquet files: + *
        + *
      • `mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets + * whether we should merge schemas collected from all Parquet part-files. This will override + * `spark.sql.parquet.mergeSchema`.
      • + *
      * @since 1.4.0 */ @scala.annotation.varargs @@ -414,9 +444,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * * @param path input path * @since 1.5.0 - * @note Currently, this method can only be used together with `HiveContext`. + * @note Currently, this method can only be used after enabling Hive support. */ - def orc(path: String): DataFrame = format("orc").load(path) + def orc(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + orc(Seq(path): _*) + } + + /** + * Loads an ORC file and returns the result as a [[DataFrame]]. + * + * @param paths input paths + * @since 2.0.0 + * @note Currently, this method can only be used after enabling Hive support. + */ + @scala.annotation.varargs + def orc(paths: String*): DataFrame = format("orc").load(paths: _*) /** * Returns the specified table as a [[DataFrame]]. @@ -430,24 +473,71 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a text file and returns a [[Dataset]] of String. The underlying schema of the Dataset + * Loads text files and returns a [[DataFrame]] whose schema starts with a string column named + * "value", and followed by partitioned columns if there are any. See the documentation on + * the other overloaded `text()` method for more details. + * + * @since 2.0.0 + */ + def text(path: String): DataFrame = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + text(Seq(path): _*) + } + + /** + * Loads text files and returns a [[DataFrame]] whose schema starts with a string column named + * "value", and followed by partitioned columns if there are any. + * + * Each line in the text files is a new row in the resulting DataFrame. For example: + * {{{ + * // Scala: + * spark.read.text("/path/to/spark/README.md") + * + * // Java: + * spark.read().text("/path/to/spark/README.md") + * }}} + * + * @param paths input paths + * @since 1.6.0 + */ + @scala.annotation.varargs + def text(paths: String*): DataFrame = format("text").load(paths : _*) + + /** + * Loads text files and returns a [[Dataset]] of String. See the documentation on the + * other overloaded `textFile()` method for more details. + * @since 2.0.0 + */ + def textFile(path: String): Dataset[String] = { + // This method ensures that calls that explicit need single argument works, see SPARK-16009 + textFile(Seq(path): _*) + } + + /** + * Loads text files and returns a [[Dataset]] of String. The underlying schema of the Dataset * contains a single string column named "value". * - * Each line in the text file is a new row in the resulting Dataset. For example: + * If the directory structure of the text files contains partitioning information, those are + * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. + * + * Each line in the text files is a new element in the resulting Dataset. For example: * {{{ * // Scala: - * sqlContext.read.text("/path/to/spark/README.md") + * spark.read.textFile("/path/to/spark/README.md") * * // Java: - * sqlContext.read().text("/path/to/spark/README.md") + * spark.read().textFile("/path/to/spark/README.md") * }}} * * @param paths input path * @since 2.0.0 */ @scala.annotation.varargs - def text(paths: String*): Dataset[String] = { - format("text").load(paths : _*).as[String](sparkSession.implicits.newStringEncoder) + def textFile(paths: String*): Dataset[String] = { + if (userSpecifiedSchema.nonEmpty) { + throw new AnalysisException("User specified schema not supported with `textFile`") + } + text(paths : _*).select("value").as[String](sparkSession.implicits.newStringEncoder) } /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 3eb1f0f0d58ff..1855eab96eaa5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -160,8 +160,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return A DataFrame containing for the contingency table. * * {{{ - * val df = sqlContext.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), - * (3, 3))).toDF("key", "value") + * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), (3, 3))) + * .toDF("key", "value") * val ct = df.stat.crosstab("key", "value") * ct.show() * +---------+---+---+---+ @@ -197,7 +197,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * val rows = Seq.tabulate(100) { i => * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) * } - * val df = sqlContext.createDataFrame(rows).toDF("a", "b") + * val df = spark.createDataFrame(rows).toDF("a", "b") * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns * // "a" and "b" * val freqSingles = df.stat.freqItems(Array("a", "b"), 0.4) @@ -258,7 +258,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * val rows = Seq.tabulate(100) { i => * if (i % 2 == 0) (1, -1.0) else (i, i * -1.0) * } - * val df = sqlContext.createDataFrame(rows).toDF("a", "b") + * val df = spark.createDataFrame(rows).toDF("a", "b") * // find the items with a frequency greater than 0.4 (observed 40% of the time) for columns * // "a" and "b" * val freqSingles = df.stat.freqItems(Seq("a", "b"), 0.4) @@ -314,7 +314,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return a new [[DataFrame]] that represents the stratified sample * * {{{ - * val df = sqlContext.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), + * val df = spark.createDataFrame(Seq((1, 1), (1, 2), (2, 1), (2, 1), (2, 3), (3, 2), * (3, 3))).toDF("key", "value") * val fractions = Map(1 -> 1.0, 3 -> 0.5) * df.stat.sampleBy("key", fractions, 36L).show() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 28f5ccd26bc52..a4c4a5defa1b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -21,27 +21,21 @@ import java.util.Properties import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.Path - -import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource, HadoopFsRelation} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.Utils /** - * :: Experimental :: - * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, - * key-value stores, etc) or data streams. Use [[DataFrame.write]] to access this. + * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, + * key-value stores, etc). Use [[Dataset.write]] to access this. * * @since 1.4.0 */ -@Experimental -final class DataFrameWriter private[sql](df: DataFrame) { +final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { + + private val df = ds.toDF() /** * Specifies the behavior when data or table already exists. Options include: @@ -52,7 +46,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def mode(saveMode: SaveMode): DataFrameWriter = { + def mode(saveMode: SaveMode): DataFrameWriter[T] = { this.mode = saveMode this } @@ -66,53 +60,24 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def mode(saveMode: String): DataFrameWriter = { + def mode(saveMode: String): DataFrameWriter[T] = { this.mode = saveMode.toLowerCase match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append case "ignore" => SaveMode.Ignore case "error" | "default" => SaveMode.ErrorIfExists case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + - "Accepted modes are 'overwrite', 'append', 'ignore', 'error'.") + "Accepted save modes are 'overwrite', 'append', 'ignore', 'error'.") } this } - /** - * :: Experimental :: - * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run - * the query as fast as possible. - * - * Scala Example: - * {{{ - * df.write.trigger(ProcessingTime("10 seconds")) - * - * import scala.concurrent.duration._ - * df.write.trigger(ProcessingTime(10.seconds)) - * }}} - * - * Java Example: - * {{{ - * df.write.trigger(ProcessingTime.create("10 seconds")) - * - * import java.util.concurrent.TimeUnit - * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) - * }}} - * - * @since 2.0.0 - */ - @Experimental - def trigger(trigger: Trigger): DataFrameWriter = { - this.trigger = trigger - this - } - /** * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. * * @since 1.4.0 */ - def format(source: String): DataFrameWriter = { + def format(source: String): DataFrameWriter[T] = { this.source = source this } @@ -122,7 +87,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def option(key: String, value: String): DataFrameWriter = { + def option(key: String, value: String): DataFrameWriter[T] = { this.extraOptions += (key -> value) this } @@ -132,28 +97,28 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 2.0.0 */ - def option(key: String, value: Boolean): DataFrameWriter = option(key, value.toString) + def option(key: String, value: Boolean): DataFrameWriter[T] = option(key, value.toString) /** * Adds an output option for the underlying data source. * * @since 2.0.0 */ - def option(key: String, value: Long): DataFrameWriter = option(key, value.toString) + def option(key: String, value: Long): DataFrameWriter[T] = option(key, value.toString) /** * Adds an output option for the underlying data source. * * @since 2.0.0 */ - def option(key: String, value: Double): DataFrameWriter = option(key, value.toString) + def option(key: String, value: Double): DataFrameWriter[T] = option(key, value.toString) /** * (Scala-specific) Adds output options for the underlying data source. * * @since 1.4.0 */ - def options(options: scala.collection.Map[String, String]): DataFrameWriter = { + def options(options: scala.collection.Map[String, String]): DataFrameWriter[T] = { this.extraOptions ++= options this } @@ -163,7 +128,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def options(options: java.util.Map[String, String]): DataFrameWriter = { + def options(options: java.util.Map[String, String]): DataFrameWriter[T] = { this.options(options.asScala) this } @@ -186,7 +151,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ @scala.annotation.varargs - def partitionBy(colNames: String*): DataFrameWriter = { + def partitionBy(colNames: String*): DataFrameWriter[T] = { this.partitioningColumns = Option(colNames) this } @@ -200,7 +165,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0 */ @scala.annotation.varargs - def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { + def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter[T] = { this.numBuckets = Option(numBuckets) this.bucketColumnNames = Option(colName +: colNames) this @@ -214,7 +179,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0 */ @scala.annotation.varargs - def sortBy(colName: String, colNames: String*): DataFrameWriter = { + def sortBy(colName: String, colNames: String*): DataFrameWriter[T] = { this.sortColumnNames = Option(colName +: colNames) this } @@ -235,7 +200,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def save(): Unit = { - assertNotBucketed() + assertNotBucketed("save") val dataSource = DataSource( df.sparkSession, className = source, @@ -245,99 +210,27 @@ final class DataFrameWriter private[sql](df: DataFrame) { dataSource.write(mode, df) } - - /** - * Specifies the name of the [[ContinuousQuery]] that can be started with `startStream()`. - * This name must be unique among all the currently active queries in the associated SQLContext. - * - * @since 2.0.0 - */ - def queryName(queryName: String): DataFrameWriter = { - this.extraOptions += ("queryName" -> queryName) - this - } - - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ - def startStream(path: String): ContinuousQuery = { - option("path", path).startStream() - } - - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[ContinuousQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ - def startStream(): ContinuousQuery = { - if (source == "memory") { - val queryName = - extraOptions.getOrElse( - "queryName", throw new AnalysisException("queryName must be specified for memory sink")) - val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified => - new Path(userSpecified).toUri.toString - }.orElse { - val checkpointConfig: Option[String] = - df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION, None) - - checkpointConfig.map { location => - new Path(location, queryName).toUri.toString - } - }.getOrElse { - Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath - } - - // If offsets have already been created, we trying to resume a query. - val checkpointPath = new Path(checkpointLocation, "offsets") - val fs = checkpointPath.getFileSystem(df.sparkSession.sessionState.newHadoopConf()) - if (fs.exists(checkpointPath)) { - throw new AnalysisException( - s"Unable to resume query written to memory sink. Delete $checkpointPath to start over.") - } else { - checkpointPath.toUri.toString - } - - val sink = new MemorySink(df.schema) - val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) - resultDf.registerTempTable(queryName) - val continuousQuery = df.sparkSession.sessionState.continuousQueryManager.startQuery( - queryName, - checkpointLocation, - df, - sink, - trigger) - continuousQuery - } else { - val dataSource = - DataSource( - df.sparkSession, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) - - val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) - val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { - new Path(df.sparkSession.sessionState.conf.checkpointLocation, queryName).toUri.toString - }) - df.sparkSession.sessionState.continuousQueryManager.startQuery( - queryName, - checkpointLocation, - df, - dataSource.createSink(), - trigger) - } - } - /** * Inserts the content of the [[DataFrame]] to the specified table. It requires that * the schema of the [[DataFrame]] is the same as the schema of the table. * + * Note: Unlike `saveAsTable`, `insertInto` ignores the column names and just uses position-based + * resolution. For example: + * + * {{{ + * scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1") + * scala> Seq((3, 4)).toDF("j", "i").write.insertInto("t1") + * scala> Seq((5, 6)).toDF("a", "b").write.insertInto("t1") + * scala> sql("select * from t1").show + * +---+---+ + * | i| j| + * +---+---+ + * | 5| 6| + * | 3| 4| + * | 1| 2| + * +---+---+ + * }}} + * * Because it inserts data to an existing table, format or options will be ignored. * * @since 1.4.0 @@ -347,26 +240,22 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def insertInto(tableIdent: TableIdentifier): Unit = { - assertNotBucketed() - val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) - val overwrite = mode == SaveMode.Overwrite - - // A partitioned relation's schema can be different from the input logicalPlan, since - // partition columns are all moved after data columns. We Project to adjust the ordering. - // TODO: this belongs to the analyzer. - val input = normalizedParCols.map { parCols => - val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { attr => - parCols.contains(attr.name) - } - Project(inputDataCols ++ inputPartCols, df.logicalPlan) - }.getOrElse(df.logicalPlan) + assertNotBucketed("insertInto") + + if (partitioningColumns.isDefined) { + throw new AnalysisException( + "insertInto() can't be used together with partitionBy(). " + + "Partition columns have already be defined for the table. " + + "It is not necessary to use partitionBy()." + ) + } - df.sparkSession.executePlan( + df.sparkSession.sessionState.executePlan( InsertIntoTable( - UnresolvedRelation(tableIdent), - partitions.getOrElse(Map.empty[String, Option[String]]), - input, - overwrite, + table = UnresolvedRelation(tableIdent), + partition = Map.empty[String, Option[String]], + child = df.logicalPlan, + overwrite = mode == SaveMode.Overwrite, ifNotExists = false)).toRdd } @@ -416,10 +305,15 @@ final class DataFrameWriter private[sql](df: DataFrame) { s"existing columns (${validColumnNames.mkString(", ")})")) } - private def assertNotBucketed(): Unit = { + private def assertNotBucketed(operation: String): Unit = { if (numBuckets.isDefined || sortColumnNames.isDefined) { - throw new IllegalArgumentException( - "Currently we don't support writing bucketed data to this data source.") + throw new AnalysisException(s"'$operation' does not support bucketing right now") + } + } + + private def assertNotPartitioned(operation: String): Unit = { + if (partitioningColumns.isDefined) { + throw new AnalysisException( s"'$operation' does not support partitioning") } } @@ -430,8 +324,23 @@ final class DataFrameWriter private[sql](df: DataFrame) { * save mode, specified by the `mode` function (default to throwing an exception). * When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be * the same as that of the existing table. - * When `mode` is `Append`, the schema of the [[DataFrame]] need to be - * the same as that of the existing table, and format or options will be ignored. + * + * When `mode` is `Append`, if there is an existing table, we will use the format and options of + * the existing table. The column order in the schema of the [[DataFrame]] doesn't need to be same + * as that of the existing table. Unlike `insertInto`, `saveAsTable` will use the column names to + * find the correct column positions. For example: + * + * {{{ + * scala> Seq((1, 2)).toDF("i", "j").write.mode("overwrite").saveAsTable("t1") + * scala> Seq((3, 4)).toDF("j", "i").write.mode("append").saveAsTable("t1") + * scala> sql("select * from t1").show + * +---+---+ + * | i| j| + * +---+---+ + * | 1| 2| + * | 4| 3| + * +---+---+ + * }}} * * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC @@ -446,6 +355,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def saveAsTable(tableIdent: TableIdentifier): Unit = { + val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent) (tableExists, mode) match { @@ -460,18 +370,17 @@ final class DataFrameWriter private[sql](df: DataFrame) { CreateTableUsingAsSelect( tableIdent, source, - temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), getBucketSpec, mode, extraOptions.toMap, df.logicalPlan) - df.sparkSession.executePlan(cmd).toRdd + df.sparkSession.sessionState.executePlan(cmd).toRdd } } /** - * Saves the content of the [[DataFrame]] to a external database table via JDBC. In the case the + * Saves the content of the [[DataFrame]] to an external database table via JDBC. In the case the * table already exists in the external database, behavior of this function depends on the * save mode, specified by the `mode` function (default to throwing an exception). * @@ -482,10 +391,14 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @param table Name of the table in the external database. * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property - * should be included. + * should be included. "batchsize" can be used to control the + * number of rows per insert. * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + assertNotPartitioned("jdbc") + assertNotBucketed("jdbc") + val props = new Properties() extraOptions.foreach { case (key, value) => props.put(key, value) @@ -536,13 +449,23 @@ final class DataFrameWriter private[sql](df: DataFrame) { * }}} * * You can set the following JSON-specific option(s) for writing JSON files: + *
        *
      • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
      • + *
      • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
      • + *
      • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
      • + *
      * * @since 1.4.0 */ - def json(path: String): Unit = format("json").save(path) + def json(path: String): Unit = { + format("json").save(path) + } /** * Saves the content of the [[DataFrame]] in Parquet format at the specified path. @@ -552,13 +475,18 @@ final class DataFrameWriter private[sql](df: DataFrame) { * }}} * * You can set the following Parquet-specific option(s) for writing Parquet files: - *
    • `compression` (default `null`): compression codec to use when saving to file. This can be - * one of the known case-insensitive shorten names(`none`, `snappy`, `gzip`, and `lzo`). - * This will overwrite `spark.sql.parquet.compression.codec`.
    • + *
        + *
      • `compression` (default is the value specified in `spark.sql.parquet.compression.codec`): + * compression codec to use when saving to file. This can be one of the known case-insensitive + * shorten names(none, `snappy`, `gzip`, and `lzo`). This will override + * `spark.sql.parquet.compression.codec`.
      • + *
      * * @since 1.4.0 */ - def parquet(path: String): Unit = format("parquet").save(path) + def parquet(path: String): Unit = { + format("parquet").save(path) + } /** * Saves the content of the [[DataFrame]] in ORC format at the specified path. @@ -568,14 +496,18 @@ final class DataFrameWriter private[sql](df: DataFrame) { * }}} * * You can set the following ORC-specific option(s) for writing ORC files: - *
    • `compression` (default `null`): compression codec to use when saving to file. This can be + *
        + *
      • `compression` (default `snappy`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names(`none`, `snappy`, `zlib`, and `lzo`). - * This will overwrite `orc.compress`.
      • + * This will override `orc.compress`. + *
      * * @since 1.5.0 - * @note Currently, this method can only be used together with `HiveContext`. + * @note Currently, this method can only be used after enabling Hive support */ - def orc(path: String): Unit = format("orc").save(path) + def orc(path: String): Unit = { + format("orc").save(path) + } /** * Saves the content of the [[DataFrame]] in a text file at the specified path. @@ -590,13 +522,17 @@ final class DataFrameWriter private[sql](df: DataFrame) { * }}} * * You can set the following option(s) for writing text files: + *
        *
      • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
      • + *
      * * @since 1.6.0 */ - def text(path: String): Unit = format("text").save(path) + def text(path: String): Unit = { + format("text").save(path) + } /** * Saves the content of the [[DataFrame]] in CSV format at the specified path. @@ -606,13 +542,36 @@ final class DataFrameWriter private[sql](df: DataFrame) { * }}} * * You can set the following CSV-specific option(s) for writing CSV files: + *
        + *
      • `sep` (default `,`): sets the single character as a separator for each + * field and value.
      • + *
      • `quote` (default `"`): sets the single character used for escaping quoted values where + * the separator can be part of the value.
      • + *
      • `escape` (default `\`): sets the single character used for escaping quotes inside + * an already quoted value.
      • + *
      • `escapeQuotes` (default `true`): a flag indicating whether values containing + * quotes should always be enclosed in quotes. Default is to escape all values containing + * a quote character.
      • + *
      • `quoteAll` (default `false`): A flag indicating whether all values should always be + * enclosed in quotes. Default is to only escape values containing a quote character.
      • + *
      • `header` (default `false`): writes the names of columns as the first line.
      • + *
      • `nullValue` (default empty string): sets the string representation of a null value.
      • *
      • `compression` (default `null`): compression codec to use when saving to file. This can be * one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`, * `snappy` and `deflate`).
      • + *
      • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
      • + *
      • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
      • + *
      * * @since 2.0.0 */ - def csv(path: String): Unit = format("csv").save(path) + def csv(path: String): Unit = { + format("csv").save(path) + } /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options @@ -622,8 +581,6 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var mode: SaveMode = SaveMode.ErrorIfExists - private var trigger: Trigger = ProcessingTime(0L) - private var extraOptions = new scala.collection.mutable.HashMap[String, String] private var partitioningColumns: Option[Seq[String]] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 08be94e8d4f12..4946bbe634d98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql @@ -30,23 +30,26 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ -import org.apache.spark.api.python.PythonRDD +import org.apache.spark.api.python.{PythonRDD, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} -import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -57,14 +60,14 @@ private[sql] object Dataset { } def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { - val qe = sparkSession.executePlan(logicalPlan) + val qe = sparkSession.sessionState.executePlan(logicalPlan) qe.assertAnalyzed() new Dataset[Row](sparkSession, logicalPlan, RowEncoder(qe.analyzed.schema)) } } /** - * A [[Dataset]] is a strongly typed collection of domain-specific objects that can be transformed + * A Dataset is a strongly typed collection of domain-specific objects that can be transformed * in parallel using functional or relational operations. Each Dataset also has an untyped view * called a [[DataFrame]], which is a Dataset of [[Row]]. * @@ -91,24 +94,24 @@ private[sql] object Dataset { * to some files on storage systems, using the `read` function available on a `SparkSession`. * {{{ * val people = spark.read.parquet("...").as[Person] // Scala - * Dataset people = spark.read().parquet("...").as(Encoders.bean(Person.class) // Java + * Dataset people = spark.read().parquet("...").as(Encoders.bean(Person.class)); // Java * }}} * * Datasets can also be created through transformations available on existing Datasets. For example, * the following creates a new Dataset by applying a filter on the existing one: * {{{ * val names = people.map(_.name) // in Scala; names is a Dataset[String] - * Dataset names = people.map((Person p) -> p.name, Encoders.STRING)) // in Java 8 + * Dataset names = people.map((Person p) -> p.name, Encoders.STRING)); // in Java 8 * }}} * * Dataset operations can also be untyped, through various domain-specific-language (DSL) - * functions defined in: [[Dataset]] (this class), [[Column]], and [[functions]]. These operations + * functions defined in: Dataset (this class), [[Column]], and [[functions]]. These operations * are very similar to the operations available in the data frame abstraction in R or Python. * * To select a column from the Dataset, use `apply` method in Scala and `col` in Java. * {{{ * val ageCol = people("age") // in Scala - * Column ageCol = people.col("age") // in Java + * Column ageCol = people.col("age"); // in Java * }}} * * Note that the [[Column]] type can also be manipulated through its various functions. @@ -120,7 +123,7 @@ private[sql] object Dataset { * * A more concrete example in Scala: * {{{ - * // To create Dataset[Row] using SQLContext + * // To create Dataset[Row] using SparkSession * val people = spark.read.parquet("...") * val department = spark.read.parquet("...") * @@ -132,7 +135,7 @@ private[sql] object Dataset { * * and in Java: * {{{ - * // To create Dataset using SQLContext + * // To create Dataset using SparkSession * Dataset people = spark.read().parquet("..."); * Dataset department = spark.read().parquet("..."); * @@ -144,11 +147,8 @@ private[sql] object Dataset { * * @groupname basic Basic Dataset functions * @groupname action Actions - * @groupname untypedrel Untyped Language Integrated Relational Queries - * @groupname typedrel Typed Language Integrated Relational Queries - * @groupname func Functional Transformations - * @groupname rdd RDD Operations - * @groupname output Output Operations + * @groupname untypedrel Untyped transformations + * @groupname typedrel Typed transformations * * @since 1.6.0 */ @@ -164,14 +164,14 @@ class Dataset[T] private[sql]( // you wrap it with `withNewExecutionId` if this actions doesn't call other action. def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { - this(sparkSession, sparkSession.executePlan(logicalPlan), encoder) + this(sparkSession, sparkSession.sessionState.executePlan(logicalPlan), encoder) } def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { this(sqlContext.sparkSession, logicalPlan, encoder) } - @transient protected[sql] val logicalPlan: LogicalPlan = { + @transient private[sql] val logicalPlan: LogicalPlan = { def hasSideEffects(plan: LogicalPlan): Boolean = plan match { case _: Command | _: InsertIntoTable | @@ -179,7 +179,7 @@ class Dataset[T] private[sql]( case _ => false } - queryExecution.logical match { + queryExecution.analyzed match { // For various commands (like DDL) and queries with side effects, we force query execution // to happen right away to let these side effects take place eagerly. case p if hasSideEffects(p) => @@ -192,28 +192,29 @@ class Dataset[T] private[sql]( } /** - * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is - * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the - * same object type (that will be possibly resolved to a different schema). + * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the + * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use + * it when constructing new Dataset objects that have the same object type (that will be + * possibly resolved to a different schema). */ - private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder) - unresolvedTEncoder.validate(logicalPlan.output) - - /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) /** - * The encoder where the expressions used to construct an object from an input row have been - * bound to the ordinals of this [[Dataset]]'s output schema. + * Encoder is used mostly as a container of serde expressions in Dataset. We build logical + * plans by these serde expressions and execute it within the query framework. However, for + * performance reasons we may want to use encoder as a function to deserialize internal rows to + * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its + * `fromRow` method later. */ - private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) + private val boundEnc = + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) - private implicit def classTag = unresolvedTEncoder.clsTag + private implicit def classTag = exprEnc.clsTag - def sqlContext: SQLContext = sparkSession.wrapped + // sqlContext must be val because a stable identifier is expected when you import implicits + @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext - protected[sql] def resolve(colName: String): NamedExpression = { + private[sql] def resolve(colName: String): NamedExpression = { queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver) .getOrElse { throw new AnalysisException( @@ -221,7 +222,7 @@ class Dataset[T] private[sql]( } } - protected[sql] def numericColumns: Seq[Expression] = { + private[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get } @@ -235,17 +236,13 @@ class Dataset[T] private[sql]( */ private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { val numRows = _numRows.max(0) - val takeResult = take(numRows + 1) + val takeResult = toDF().take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) // For array values, replace Seq and Array with square brackets // For cells that are beyond 20 characters, replace it with the first 17 and "..." - val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { - case r: Row => r - case tuple: Product => Row.fromTuple(tuple) - case o => Row(o) - }.map { row => + val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row => row.toSeq.map { cell => val str = cell match { case null => "null" @@ -343,7 +340,7 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Returns a new [[Dataset]] where each record has been mapped on to the specified type. The + * Returns a new Dataset where each record has been mapped on to the specified type. The * method used to map columns depend on the type of `U`: * - When `U` is a class, fields for the class will be mapped to columns of the same name * (case sensitivity is determined by `spark.sql.caseSensitive`). @@ -352,7 +349,7 @@ class Dataset[T] private[sql]( * - When `U` is a primitive type (i.e. String, Int, etc), then the first column of the * [[DataFrame]] will be used. * - * If the schema of the [[Dataset]] does not match the desired `U` type, you can use `select` + * If the schema of the Dataset does not match the desired `U` type, you can use `select` * along with `alias` or `as` to rearrange or rename as required. * * @group basic @@ -388,7 +385,7 @@ class Dataset[T] private[sql]( } /** - * Returns the schema of this [[Dataset]]. + * Returns the schema of this Dataset. * * @group basic * @since 1.6.0 @@ -413,7 +410,7 @@ class Dataset[T] private[sql]( */ def explain(extended: Boolean): Unit = { val explain = ExplainCommand(queryExecution.logical, extended = extended) - sparkSession.executePlan(explain).executedPlan.executeCollect().foreach { + sparkSession.sessionState.executePlan(explain).executedPlan.executeCollect().foreach { // scalastyle:off println r => println(r.getString(0)) // scalastyle:on println @@ -456,10 +453,10 @@ class Dataset[T] private[sql]( def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] /** - * Returns true if this [[Dataset]] contains one or more sources that continuously - * return data as it arrives. A [[Dataset]] that reads data from a streaming source - * must be executed as a [[ContinuousQuery]] using the `startStream()` method in - * [[DataFrameWriter]]. Methods that return a single answer, e.g. `count()` or + * Returns true if this Dataset contains one or more sources that continuously + * return data as it arrives. A Dataset that reads data from a streaming source + * must be executed as a [[StreamingQuery]] using the `start()` method in + * [[DataStreamWriter]]. Methods that return a single answer, e.g. `count()` or * `collect()`, will throw an [[AnalysisException]] when there is a streaming * source present. * @@ -470,7 +467,7 @@ class Dataset[T] private[sql]( def isStreaming: Boolean = logicalPlan.isStreaming /** - * Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated, + * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) @@ -489,7 +486,7 @@ class Dataset[T] private[sql]( def show(numRows: Int): Unit = show(numRows, truncate = true) /** - * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters + * Displays the top 20 rows of Dataset in a tabular form. Strings more than 20 characters * will be truncated, and all cells will be aligned right. * * @group action @@ -498,7 +495,7 @@ class Dataset[T] private[sql]( def show(): Unit = show(20) /** - * Displays the top 20 rows of [[Dataset]] in a tabular form. + * Displays the top 20 rows of Dataset in a tabular form. * * @param truncate Whether truncate long strings. If true, strings more than 20 characters will * be truncated and all cells will be aligned right @@ -509,7 +506,7 @@ class Dataset[T] private[sql]( def show(truncate: Boolean): Unit = show(20, truncate) /** - * Displays the [[Dataset]] in a tabular form. For example: + * Displays the Dataset in a tabular form. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -563,7 +560,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: DataFrame): DataFrame = withPlan { + def join(right: Dataset[_]): DataFrame = withPlan { Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } @@ -588,7 +585,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: DataFrame, usingColumn: String): DataFrame = { + def join(right: Dataset[_], usingColumn: String): DataFrame = { join(right, Seq(usingColumn)) } @@ -613,7 +610,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: DataFrame, usingColumns: Seq[String]): DataFrame = { + def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame = { join(right, usingColumns, "inner") } @@ -634,10 +631,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: DataFrame, usingColumns: Seq[String], joinType: String): DataFrame = { + def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. - val joined = sparkSession.executePlan( + val joined = sparkSession.sessionState.executePlan( Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] @@ -662,7 +659,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: DataFrame, joinExprs: Column): DataFrame = join(right, joinExprs, "inner") + def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right, joinExprs, "inner") /** * Join with another [[DataFrame]], using the given join expression. The following performs @@ -685,7 +682,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { // Note that in this function, we introduce a hack in the case of self-join to automatically // resolve ambiguous join conditions into ones that might make sense [SPARK-6231]. // Consider this case: df.join(df, df("key") === df("key")) @@ -730,7 +727,7 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to + * Joins this Dataset returning a [[Tuple2]] for each pair where `condition` evaluates to * true. * * This is similar to the relation `join` function with one important difference in the @@ -750,36 +747,67 @@ class Dataset[T] private[sql]( */ @Experimental def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { - val left = this.logicalPlan - val right = other.logicalPlan - - val joined = sparkSession.executePlan(Join(left, right, joinType = - JoinType(joinType), Some(condition.expr))) - val leftOutput = joined.analyzed.output.take(left.output.length) - val rightOutput = joined.analyzed.output.takeRight(right.output.length) + // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, + // etc. + val joined = sparkSession.sessionState.executePlan( + Join( + this.logicalPlan, + other.logicalPlan, + JoinType(joinType), + Some(condition.expr))).analyzed.asInstanceOf[Join] + + // For both join side, combine all outputs into a single column and alias it with "_1" or "_2", + // to match the schema for the encoder of the join result. + // Note that we do this before joining them, to enable the join operator to return null for one + // side, in cases like outer-join. + val left = { + val combined = if (this.exprEnc.flat) { + assert(joined.left.output.length == 1) + Alias(joined.left.output.head, "_1")() + } else { + Alias(CreateStruct(joined.left.output), "_1")() + } + Project(combined :: Nil, joined.left) + } - val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(leftOutput.head, "_1")() - case _ => Alias(CreateStruct(leftOutput), "_1")() + val right = { + val combined = if (other.exprEnc.flat) { + assert(joined.right.output.length == 1) + Alias(joined.right.output.head, "_2")() + } else { + Alias(CreateStruct(joined.right.output), "_2")() + } + Project(combined :: Nil, joined.right) } - val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(rightOutput.head, "_2")() - case _ => Alias(CreateStruct(rightOutput), "_2")() + + // Rewrites the join condition to make the attribute point to correct column/field, after we + // combine the outputs of each join side. + val conditionExpr = joined.condition.get transformUp { + case a: Attribute if joined.left.outputSet.contains(a) => + if (this.exprEnc.flat) { + left.output.head + } else { + val index = joined.left.output.indexWhere(_.exprId == a.exprId) + GetStructField(left.output.head, index) + } + case a: Attribute if joined.right.outputSet.contains(a) => + if (other.exprEnc.flat) { + right.output.head + } else { + val index = joined.right.output.indexWhere(_.exprId == a.exprId) + GetStructField(right.output.head, index) + } } implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) + ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) - withTypedPlan { - Project( - leftData :: rightData :: Nil, - joined.analyzed) - } + withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr))) } /** * :: Experimental :: - * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair + * Using inner equi-join to join this Dataset returning a [[Tuple2]] for each pair * where `condition` evaluates to true. * * @param other Right side of the join. @@ -794,7 +822,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with each partition sorted by the given expressions. + * Returns a new Dataset with each partition sorted by the given expressions. * * This is the same operation as "SORT BY" in SQL (Hive QL). * @@ -807,7 +835,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with each partition sorted by the given expressions. + * Returns a new Dataset with each partition sorted by the given expressions. * * This is the same operation as "SORT BY" in SQL (Hive QL). * @@ -820,7 +848,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] sorted by the specified column, all in ascending order. + * Returns a new Dataset sorted by the specified column, all in ascending order. * {{{ * // The following 3 are equivalent * ds.sort("sortcol") @@ -837,7 +865,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] sorted by the given expressions. For example: + * Returns a new Dataset sorted by the given expressions. For example: * {{{ * ds.sort($"col1", $"col2".desc) * }}} @@ -851,7 +879,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] sorted by the given expressions. + * Returns a new Dataset sorted by the given expressions. * This is an alias of the `sort` function. * * @group typedrel @@ -861,7 +889,7 @@ class Dataset[T] private[sql]( def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*) /** - * Returns a new [[Dataset]] sorted by the given expressions. + * Returns a new Dataset sorted by the given expressions. * This is an alias of the `sort` function. * * @group typedrel @@ -895,7 +923,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with an alias set. + * Returns a new Dataset with an alias set. * * @group typedrel * @since 1.6.0 @@ -905,7 +933,7 @@ class Dataset[T] private[sql]( } /** - * (Scala-specific) Returns a new [[Dataset]] with an alias set. + * (Scala-specific) Returns a new Dataset with an alias set. * * @group typedrel * @since 2.0.0 @@ -913,7 +941,7 @@ class Dataset[T] private[sql]( def as(alias: Symbol): Dataset[T] = as(alias.name) /** - * Returns a new [[Dataset]] with an alias set. Same as `as`. + * Returns a new Dataset with an alias set. Same as `as`. * * @group typedrel * @since 2.0.0 @@ -921,7 +949,7 @@ class Dataset[T] private[sql]( def alias(alias: String): Dataset[T] = as(alias) /** - * (Scala-specific) Returns a new [[Dataset]] with an alias set. Same as `as`. + * (Scala-specific) Returns a new Dataset with an alias set. Same as `as`. * * @group typedrel * @since 2.0.0 @@ -980,7 +1008,7 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. + * Returns a new Dataset by computing the given [[Column]] expression for each element. * * {{{ * val ds = Seq(1, 2, 3).toDS() @@ -996,7 +1024,7 @@ class Dataset[T] private[sql]( sparkSession, Project( c1.withInputType( - unresolvedTEncoder.deserializer, + exprEnc.deserializer, logicalPlan.output).named :: Nil, logicalPlan), implicitly[Encoder[U1]]) @@ -1010,14 +1038,14 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named) + columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } /** * :: Experimental :: - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * Returns a new Dataset by computing the given [[Column]] expressions for each element. * * @group typedrel * @since 1.6.0 @@ -1028,7 +1056,7 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * Returns a new Dataset by computing the given [[Column]] expressions for each element. * * @group typedrel * @since 1.6.0 @@ -1042,7 +1070,7 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * Returns a new Dataset by computing the given [[Column]] expressions for each element. * * @group typedrel * @since 1.6.0 @@ -1057,7 +1085,7 @@ class Dataset[T] private[sql]( /** * :: Experimental :: - * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * Returns a new Dataset by computing the given [[Column]] expressions for each element. * * @group typedrel * @since 1.6.0 @@ -1126,7 +1154,7 @@ class Dataset[T] private[sql]( } /** - * Groups the [[Dataset]] using the specified columns, so we can run aggregation on them. See + * Groups the Dataset using the specified columns, so we can run aggregation on them. See * [[RelationalGroupedDataset]] for all the available aggregate functions. * * {{{ @@ -1149,7 +1177,7 @@ class Dataset[T] private[sql]( } /** - * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, + * Create a multi-dimensional rollup for the current Dataset using the specified columns, * so we can run aggregation on them. * See [[RelationalGroupedDataset]] for all the available aggregate functions. * @@ -1173,7 +1201,7 @@ class Dataset[T] private[sql]( } /** - * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, + * Create a multi-dimensional cube for the current Dataset using the specified columns, * so we can run aggregation on them. * See [[RelationalGroupedDataset]] for all the available aggregate functions. * @@ -1197,7 +1225,7 @@ class Dataset[T] private[sql]( } /** - * Groups the [[Dataset]] using the specified columns, so that we can run aggregation on them. + * Groups the Dataset using the specified columns, so that we can run aggregation on them. * See [[RelationalGroupedDataset]] for all the available aggregate functions. * * This is a variant of groupBy that can only group by existing columns using column names @@ -1226,7 +1254,7 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Scala-specific) - * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func` + * Reduces the elements of this Dataset using the specified binary function. The given `func` * must be commutative and associative or the result may be non-deterministic. * * @group action @@ -1259,7 +1287,7 @@ class Dataset[T] private[sql]( def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val inputPlan = logicalPlan val withGroupingKey = AppendColumns(func, inputPlan) - val executed = sparkSession.executePlan(withGroupingKey) + val executed = sparkSession.sessionState.executePlan(withGroupingKey) new KeyValueGroupedDataset( encoderFor[K], @@ -1282,7 +1310,7 @@ class Dataset[T] private[sql]( groupByKey(func.call(_))(encoder) /** - * Create a multi-dimensional rollup for the current [[Dataset]] using the specified columns, + * Create a multi-dimensional rollup for the current Dataset using the specified columns, * so we can run aggregation on them. * See [[RelationalGroupedDataset]] for all the available aggregate functions. * @@ -1311,7 +1339,7 @@ class Dataset[T] private[sql]( } /** - * Create a multi-dimensional cube for the current [[Dataset]] using the specified columns, + * Create a multi-dimensional cube for the current Dataset using the specified columns, * so we can run aggregation on them. * See [[RelationalGroupedDataset]] for all the available aggregate functions. * @@ -1339,7 +1367,7 @@ class Dataset[T] private[sql]( } /** - * (Scala-specific) Aggregates on the entire [[Dataset]] without groups. + * (Scala-specific) Aggregates on the entire Dataset without groups. * {{{ * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) * ds.agg("age" -> "max", "salary" -> "avg") @@ -1354,7 +1382,7 @@ class Dataset[T] private[sql]( } /** - * (Scala-specific) Aggregates on the entire [[Dataset]] without groups. + * (Scala-specific) Aggregates on the entire Dataset without groups. * {{{ * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) * ds.agg(Map("age" -> "max", "salary" -> "avg")) @@ -1367,7 +1395,7 @@ class Dataset[T] private[sql]( def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) /** - * (Java-specific) Aggregates on the entire [[Dataset]] without groups. + * (Java-specific) Aggregates on the entire Dataset without groups. * {{{ * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) * ds.agg(Map("age" -> "max", "salary" -> "avg")) @@ -1380,7 +1408,7 @@ class Dataset[T] private[sql]( def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs) /** - * Aggregates on the entire [[Dataset]] without groups. + * Aggregates on the entire Dataset without groups. * {{{ * // ds.agg(...) is a shorthand for ds.groupBy().agg(...) * ds.agg(max($"age"), avg($"salary")) @@ -1394,9 +1422,9 @@ class Dataset[T] private[sql]( def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) /** - * Returns a new [[Dataset]] by taking the first `n` rows. The difference between this function + * Returns a new Dataset by taking the first `n` rows. The difference between this function * and `head` is that `head` is an action and returns an array (by triggering query execution) - * while `limit` returns a new [[Dataset]]. + * while `limit` returns a new Dataset. * * @group typedrel * @since 2.0.0 @@ -1406,7 +1434,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset. + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * This is equivalent to `UNION ALL` in SQL. * * To do a SQL-style set union (that does deduplication of elements), use this function followed @@ -1419,7 +1447,7 @@ class Dataset[T] private[sql]( def unionAll(other: Dataset[T]): Dataset[T] = union(other) /** - * Returns a new [[Dataset]] containing union of rows in this Dataset and another Dataset. + * Returns a new Dataset containing union of rows in this Dataset and another Dataset. * This is equivalent to `UNION ALL` in SQL. * * To do a SQL-style set union (that does deduplication of elements), use this function followed @@ -1428,14 +1456,14 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def union(other: Dataset[T]): Dataset[T] = withTypedPlan { + def union(other: Dataset[T]): Dataset[T] = withSetOperator { // This breaks caching, but it's usually ok because it addresses a very specific use case: // using union to union many files or partitions. CombineUnions(Union(logicalPlan, other.logicalPlan)) } /** - * Returns a new [[Dataset]] containing rows only in both this Dataset and another Dataset. + * Returns a new Dataset containing rows only in both this Dataset and another Dataset. * This is equivalent to `INTERSECT` in SQL. * * Note that, equality checking is performed directly on the encoded representation of the data @@ -1444,12 +1472,12 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan { + def intersect(other: Dataset[T]): Dataset[T] = withSetOperator { Intersect(logicalPlan, other.logicalPlan) } /** - * Returns a new [[Dataset]] containing rows in this Dataset but not in another Dataset. + * Returns a new Dataset containing rows in this Dataset but not in another Dataset. * This is equivalent to `EXCEPT` in SQL. * * Note that, equality checking is performed directly on the encoded representation of the data @@ -1458,12 +1486,12 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def except(other: Dataset[T]): Dataset[T] = withTypedPlan { + def except(other: Dataset[T]): Dataset[T] = withSetOperator { Except(logicalPlan, other.logicalPlan) } /** - * Returns a new [[Dataset]] by sampling a fraction of rows. + * Returns a new Dataset by sampling a fraction of rows. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. @@ -1472,12 +1500,17 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan { - Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = { + require(fraction >= 0, + s"Fraction must be nonnegative, but got ${fraction}") + + withTypedPlan { + Sample(0.0, fraction, withReplacement, seed, logicalPlan)() + } } /** - * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. + * Returns a new Dataset by sampling a fraction of rows, using a random seed. * * @param withReplacement Sample with replacement or not. * @param fraction Fraction of rows to generate. @@ -1490,7 +1523,7 @@ class Dataset[T] private[sql]( } /** - * Randomly splits this [[Dataset]] with the provided weights. + * Randomly splits this Dataset with the provided weights. * * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. @@ -1501,6 +1534,11 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = { + require(weights.forall(_ >= 0), + s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}") + require(weights.sum > 0, + s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}") + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the @@ -1517,7 +1555,7 @@ class Dataset[T] private[sql]( } /** - * Returns a Java list that contains randomly split [[Dataset]] with the provided weights. + * Returns a Java list that contains randomly split Dataset with the provided weights. * * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. @@ -1531,7 +1569,7 @@ class Dataset[T] private[sql]( } /** - * Randomly splits this [[Dataset]] with the provided weights. + * Randomly splits this Dataset with the provided weights. * * @param weights weights for splits, will be normalized if they don't sum to 1. * @group typedrel @@ -1542,7 +1580,7 @@ class Dataset[T] private[sql]( } /** - * Randomly splits this [[Dataset]] with the provided weights. Provided for the Python Api. + * Randomly splits this Dataset with the provided weights. Provided for the Python Api. * * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. @@ -1552,41 +1590,41 @@ class Dataset[T] private[sql]( } /** - * :: Experimental :: - * (Scala-specific) Returns a new [[Dataset]] where each row has been expanded to zero or more + * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of * the input row are implicitly joined with each row that is output by the function. * - * The following example uses this function to count the number of books which contain - * a given word: + * Given that this is deprecated, as an alternative, you can explode columns either using + * `functions.explode()` or `flatMap()`. The following example uses these alternatives to count + * the number of books that contain a given word: * * {{{ * case class Book(title: String, words: String) * val ds: Dataset[Book] * - * case class Word(word: String) - * val allWords = ds.explode('words) { - * case Row(words: String) => words.split(" ").map(Word(_)) - * } + * val allWords = ds.select('title, explode(split('words, " ")).as("word")) * * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title")) * }}} * + * Using `flatMap()` this can similarly be exploded as: + * + * {{{ + * ds.flatMap(_.words.split(" ")) + * }}} + * * @group untypedrel * @since 2.0.0 */ - @Experimental + @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { - val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val elementTypes = schema.toAttributes.map { - attr => (attr.dataType, attr.nullable, attr.name) } - val names = schema.toAttributes.map(_.name) - val convert = CatalystTypeConverters.createToCatalystConverter(schema) + val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) val rowFunction = f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) - val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) + val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) withPlan { Generate(generator, join = true, outer = false, @@ -1595,31 +1633,39 @@ class Dataset[T] private[sql]( } /** - * :: Experimental :: - * (Scala-specific) Returns a new [[Dataset]] where a single column has been expanded to zero + * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All * columns of the input row are implicitly joined with each value that is output by the function. * + * Given that this is deprecated, as an alternative, you can explode columns either using + * `functions.explode()`: + * * {{{ - * ds.explode("words", "word") {words: String => words.split(" ")} + * ds.select(explode(split('words, " ")).as("word")) + * }}} + * + * or `flatMap()`: + * + * {{{ + * ds.flatMap(_.words.split(" ")) * }}} * * @group untypedrel * @since 2.0.0 */ - @Experimental + @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B]) : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil // TODO handle the metadata? - val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable, attr.name) } + val elementSchema = attributes.toStructType def rowFunction(row: Row): TraversableOnce[InternalRow] = { val convert = CatalystTypeConverters.createToCatalystConverter(dataType) f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o))) } - val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) + val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil) withPlan { Generate(generator, join = true, outer = false, @@ -1628,7 +1674,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] by adding a column or replacing the existing column that has + * Returns a new Dataset by adding a column or replacing the existing column that has * the same name. * * @group untypedrel @@ -1653,7 +1699,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] by adding a column with metadata. + * Returns a new Dataset by adding a column with metadata. */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { val resolver = sparkSession.sessionState.analyzer.resolver @@ -1674,7 +1720,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with a column renamed. + * Returns a new Dataset with a column renamed. * This is a no-op if schema doesn't contain existingName. * * @group untypedrel @@ -1699,8 +1745,11 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with a column dropped. - * This is a no-op if schema doesn't contain column name. + * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain + * column name. + * + * This method can only be used to drop top level columns. the colName string is treated + * literally without further interpretation. * * @group untypedrel * @since 2.0.0 @@ -1710,18 +1759,23 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with columns dropped. + * Returns a new Dataset with columns dropped. * This is a no-op if schema doesn't contain column name(s). * + * This method can only be used to drop top level columns. the colName string is treated literally + * without further interpretation. + * * @group untypedrel * @since 2.0.0 */ @scala.annotation.varargs def drop(colNames: String*): DataFrame = { val resolver = sparkSession.sessionState.analyzer.resolver - val remainingCols = - schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name)) - if (remainingCols.size == this.schema.size) { + val allColumns = queryExecution.analyzed.output + val remainingCols = allColumns.filter { attribute => + colNames.forall(n => !resolver(attribute.name, n)) + }.map(attribute => Column(attribute)) + if (remainingCols.size == allColumns.size) { toDF() } else { this.select(remainingCols: _*) @@ -1729,8 +1783,8 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with a column dropped. - * This version of drop accepts a Column rather than a name. + * Returns a new Dataset with a column dropped. + * This version of drop accepts a [[Column]] rather than a name. * This is a no-op if the Dataset doesn't have a column * with an equivalent expression. * @@ -1752,7 +1806,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] that contains only the unique rows from this [[Dataset]]. + * Returns a new Dataset that contains only the unique rows from this Dataset. * This is an alias for `distinct`. * * @group typedrel @@ -1761,14 +1815,20 @@ class Dataset[T] private[sql]( def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns) /** - * (Scala-specific) Returns a new [[Dataset]] with duplicate rows removed, considering only + * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. * * @group typedrel * @since 2.0.0 */ def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { - val groupCols = colNames.map(resolve) + val resolver = sparkSession.sessionState.analyzer.resolver + val allColumns = queryExecution.analyzed.output + val groupCols = colNames.map { colName => + allColumns.find(col => resolver(col.name, colName)).getOrElse( + throw new AnalysisException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")) + } val groupColExprIds = groupCols.map(_.exprId) val aggCols = logicalPlan.output.map { attr => if (groupColExprIds.contains(attr.exprId)) { @@ -1781,7 +1841,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] with duplicate rows removed, considering only + * Returns a new Dataset with duplicate rows removed, considering only * the subset of columns. * * @group typedrel @@ -1789,12 +1849,25 @@ class Dataset[T] private[sql]( */ def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq) + /** + * Returns a new [[Dataset]] with duplicate rows removed, considering only + * the subset of columns. + * + * @group typedrel + * @since 2.0.0 + */ + @scala.annotation.varargs + def dropDuplicates(col1: String, cols: String*): Dataset[T] = { + val colNames: Seq[String] = col1 +: cols + dropDuplicates(colNames) + } + /** * Computes statistics for numeric columns, including count, mean, stddev, min, and max. * If no columns are given, this function computes statistics for all numerical columns. * * This function is meant for exploratory data analysis, as we make no guarantee about the - * backward compatibility of the schema of the resulting [[Dataset]]. If you want to + * backward compatibility of the schema of the resulting Dataset. If you want to * programmatically compute summary statistics, use the `agg` function instead. * * {{{ @@ -1845,7 +1918,8 @@ class Dataset[T] private[sql]( // All columns are string type val schema = StructType( StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - LocalRelation.fromExternalRows(schema, ret) + // `toArray` forces materialization to make the seq serializable + LocalRelation.fromExternalRows(schema, ret.toArray.toSeq) } /** @@ -1885,7 +1959,7 @@ class Dataset[T] private[sql]( * .transform(...) * }}} * - * @group func + * @group typedrel * @since 1.6.0 */ def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) @@ -1893,43 +1967,43 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Scala-specific) - * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * Returns a new Dataset that only contains elements where `func` returns `true`. * - * @group func + * @group typedrel * @since 1.6.0 */ @Experimental def filter(func: T => Boolean): Dataset[T] = { - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) val function = Literal.create(func, ObjectType(classOf[T => Boolean])) - val condition = Invoke(function, "apply", BooleanType, deserialized.output) - val filter = Filter(condition, deserialized) - withTypedPlan(CatalystSerde.serialize[T](filter)) + val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil) + val filter = Filter(condition, logicalPlan) + withTypedPlan(filter) } /** * :: Experimental :: * (Java-specific) - * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * Returns a new Dataset that only contains elements where `func` returns `true`. * - * @group func + * @group typedrel * @since 1.6.0 */ @Experimental def filter(func: FilterFunction[T]): Dataset[T] = { - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]])) - val condition = Invoke(function, "call", BooleanType, deserialized.output) - val filter = Filter(condition, deserialized) - withTypedPlan(CatalystSerde.serialize[T](filter)) + val condition = Invoke(function, "call", BooleanType, deserializer :: Nil) + val filter = Filter(condition, logicalPlan) + withTypedPlan(filter) } /** * :: Experimental :: * (Scala-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * Returns a new Dataset that contains the result of applying `func` to each element. * - * @group func + * @group typedrel * @since 1.6.0 */ @Experimental @@ -1940,9 +2014,9 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Java-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * Returns a new Dataset that contains the result of applying `func` to each element. * - * @group func + * @group typedrel * @since 1.6.0 */ @Experimental @@ -1954,9 +2028,9 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Scala-specific) - * Returns a new [[Dataset]] that contains the result of applying `func` to each partition. + * Returns a new Dataset that contains the result of applying `func` to each partition. * - * @group func + * @group typedrel * @since 1.6.0 */ @Experimental @@ -1970,9 +2044,9 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Java-specific) - * Returns a new [[Dataset]] that contains the result of applying `f` to each partition. + * Returns a new Dataset that contains the result of applying `f` to each partition. * - * @group func + * @group typedrel * @since 1.6.0 */ @Experimental @@ -1984,8 +2058,6 @@ class Dataset[T] private[sql]( /** * Returns a new [[DataFrame]] that contains the result of applying a serialized R function * `func` to each partition. - * - * @group func */ private[sql] def mapPartitionsInR( func: Array[Byte], @@ -2001,10 +2073,10 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Scala-specific) - * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * Returns a new Dataset by first applying a function to all elements of this Dataset, * and then flattening the results. * - * @group func + * @group typedrel * @since 1.6.0 */ @Experimental @@ -2014,10 +2086,10 @@ class Dataset[T] private[sql]( /** * :: Experimental :: * (Java-specific) - * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * Returns a new Dataset by first applying a function to all elements of this Dataset, * and then flattening the results. * - * @group func + * @group typedrel * @since 1.6.0 */ @Experimental @@ -2038,7 +2110,7 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Runs `func` on each element of this [[Dataset]]. + * Runs `func` on each element of this Dataset. * * @group action * @since 1.6.0 @@ -2046,7 +2118,7 @@ class Dataset[T] private[sql]( def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) /** - * Applies a function `f` to each partition of this [[Dataset]]. + * Applies a function `f` to each partition of this Dataset. * * @group action * @since 1.6.0 @@ -2057,7 +2129,7 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Runs `func` on each partition of this [[Dataset]]. + * Runs `func` on each partition of this Dataset. * * @group action * @since 1.6.0 @@ -2066,7 +2138,7 @@ class Dataset[T] private[sql]( foreachPartition(it => func.call(it.asJava)) /** - * Returns the first `n` rows in the [[Dataset]]. + * Returns the first `n` rows in the Dataset. * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. @@ -2077,7 +2149,7 @@ class Dataset[T] private[sql]( def take(n: Int): Array[T] = head(n) /** - * Returns the first `n` rows in the [[Dataset]] as a list. + * Returns the first `n` rows in the Dataset as a list. * * Running take requires moving data into the application's driver process, and doing so with * a very large `n` can crash the driver process with OutOfMemoryError. @@ -2088,7 +2160,7 @@ class Dataset[T] private[sql]( def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*) /** - * Returns an array that contains all of [[Row]]s in this [[Dataset]]. + * Returns an array that contains all of [[Row]]s in this Dataset. * * Running collect requires moving all the data into the application's driver process, and * doing so on a very large dataset can crash the driver process with OutOfMemoryError. @@ -2101,7 +2173,7 @@ class Dataset[T] private[sql]( def collect(): Array[T] = collect(needCallback = true) /** - * Returns a Java list that contains all of [[Row]]s in this [[Dataset]]. + * Returns a Java list that contains all of [[Row]]s in this Dataset. * * Running collect requires moving all the data into the application's driver process, and * doing so on a very large dataset can crash the driver process with OutOfMemoryError. @@ -2111,14 +2183,14 @@ class Dataset[T] private[sql]( */ def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => withNewExecutionId { - val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) + val values = queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow) java.util.Arrays.asList(values : _*) } } private def collect(needCallback: Boolean): Array[T] = { def execute(): Array[T] = withNewExecutionId { - queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) + queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow) } if (needCallback) { @@ -2129,9 +2201,9 @@ class Dataset[T] private[sql]( } /** - * Return an iterator that contains all of [[Row]]s in this [[Dataset]]. + * Return an iterator that contains all of [[Row]]s in this Dataset. * - * The iterator will consume as much memory as the largest partition in this [[Dataset]]. + * The iterator will consume as much memory as the largest partition in this Dataset. * * Note: this results in multiple Spark jobs, and if the input Dataset is the result * of a wide transformation (e.g. join with different partitioners), to avoid @@ -2142,12 +2214,12 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ => withNewExecutionId { - queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava + queryExecution.executedPlan.executeToIterator().map(boundEnc.fromRow).asJava } } /** - * Returns the number of rows in the [[Dataset]]. + * Returns the number of rows in the Dataset. * @group action * @since 1.6.0 */ @@ -2156,7 +2228,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * Returns a new Dataset that has exactly `numPartitions` partitions. * * @group typedrel * @since 1.6.0 @@ -2166,7 +2238,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] partitioned by the given partitioning expressions into + * Returns a new Dataset partitioned by the given partitioning expressions into * `numPartitions`. The resulting Dataset is hash partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). @@ -2180,8 +2252,9 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] partitioned by the given partitioning expressions preserving - * the existing number of partitions. The resulting Datasetis hash partitioned. + * Returns a new Dataset partitioned by the given partitioning expressions, using + * `spark.sql.shuffle.partitions` as number of partitions. + * The resulting Dataset is hash partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). * @@ -2194,12 +2267,12 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] that has exactly `numPartitions` partitions. + * Returns a new Dataset that has exactly `numPartitions` partitions. * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of * the 100 new partitions will claim 10 of the current partitions. * - * @group rdd + * @group typedrel * @since 1.6.0 */ def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { @@ -2207,7 +2280,7 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] that contains only the unique rows from this [[Dataset]]. + * Returns a new Dataset that contains only the unique rows from this Dataset. * This is an alias for `dropDuplicates`. * * Note that, equality checking is performed directly on the encoded representation of the data @@ -2219,18 +2292,18 @@ class Dataset[T] private[sql]( def distinct(): Dataset[T] = dropDuplicates() /** - * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). * * @group basic * @since 1.6.0 */ def persist(): this.type = { - sparkSession.cacheManager.cacheQuery(this) + sparkSession.sharedState.cacheManager.cacheQuery(this) this } /** - * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`). + * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). * * @group basic * @since 1.6.0 @@ -2238,7 +2311,7 @@ class Dataset[T] private[sql]( def cache(): this.type = persist() /** - * Persist this [[Dataset]] with the given storage level. + * Persist this Dataset with the given storage level. * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, * `MEMORY_AND_DISK_2`, etc. @@ -2247,12 +2320,12 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def persist(newLevel: StorageLevel): this.type = { - sparkSession.cacheManager.cacheQuery(this, None, newLevel) + sparkSession.sharedState.cacheManager.cacheQuery(this, None, newLevel) this } /** - * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. * * @param blocking Whether to block until all blocks are deleted. * @@ -2260,12 +2333,12 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def unpersist(blocking: Boolean): this.type = { - sparkSession.cacheManager.tryUncacheQuery(this, blocking) + sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking) this } /** - * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk. + * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. * * @group basic * @since 1.6.0 @@ -2273,56 +2346,117 @@ class Dataset[T] private[sql]( def unpersist(): this.type = unpersist(blocking = false) /** - * Represents the content of the [[Dataset]] as an [[RDD]] of [[T]]. + * Represents the content of the Dataset as an [[RDD]] of [[T]]. * - * @group rdd + * @group basic * @since 1.6.0 */ lazy val rdd: RDD[T] = { - val objectType = unresolvedTEncoder.deserializer.dataType + val objectType = exprEnc.deserializer.dataType val deserialized = CatalystSerde.deserialize[T](logicalPlan) - sparkSession.executePlan(deserialized).toRdd.mapPartitions { rows => + sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) } } /** - * Returns the content of the [[Dataset]] as a [[JavaRDD]] of [[Row]]s. - * @group rdd + * Returns the content of the Dataset as a [[JavaRDD]] of [[T]]s. + * @group basic * @since 1.6.0 */ def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD() /** - * Returns the content of the [[Dataset]] as a [[JavaRDD]] of [[Row]]s. - * @group rdd + * Returns the content of the Dataset as a [[JavaRDD]] of [[T]]s. + * @group basic * @since 1.6.0 */ def javaRDD: JavaRDD[T] = toJavaRDD /** - * Registers this [[Dataset]] as a temporary table using the given name. The lifetime of this - * temporary table is tied to the [[SQLContext]] that was used to create this Dataset. + * Registers this Dataset as a temporary table using the given name. The lifetime of this + * temporary table is tied to the [[SparkSession]] that was used to create this Dataset. * * @group basic * @since 1.6.0 */ + @deprecated("Use createOrReplaceTempView(viewName) instead.", "2.0.0") def registerTempTable(tableName: String): Unit = { - sparkSession.registerTable(toDF(), tableName) + createOrReplaceTempView(tableName) + } + + /** + * Creates a temporary view using the given name. The lifetime of this + * temporary view is tied to the [[SparkSession]] that was used to create this Dataset. + * + * @throws AnalysisException if the view name already exists + * + * @group basic + * @since 2.0.0 + */ + @throws[AnalysisException] + def createTempView(viewName: String): Unit = withPlan { + val tableDesc = CatalogTable( + identifier = sparkSession.sessionState.sqlParser.parseTableIdentifier(viewName), + tableType = CatalogTableType.VIEW, + schema = Seq.empty[CatalogColumn], + storage = CatalogStorageFormat.empty) + CreateViewCommand(tableDesc, logicalPlan, allowExisting = false, replace = false, + isTemporary = true) + } + + /** + * Creates a temporary view using the given name. The lifetime of this + * temporary view is tied to the [[SparkSession]] that was used to create this Dataset. + * + * @group basic + * @since 2.0.0 + */ + def createOrReplaceTempView(viewName: String): Unit = withPlan { + val tableDesc = CatalogTable( + identifier = sparkSession.sessionState.sqlParser.parseTableIdentifier(viewName), + tableType = CatalogTableType.VIEW, + schema = Seq.empty[CatalogColumn], + storage = CatalogStorageFormat.empty) + CreateViewCommand(tableDesc, logicalPlan, allowExisting = false, replace = true, + isTemporary = true) } /** * :: Experimental :: - * Interface for saving the content of the [[Dataset]] out into external storage or streams. + * Interface for saving the content of the non-streaming Dataset out into external storage. * - * @group output + * @group basic * @since 1.6.0 */ @Experimental - def write: DataFrameWriter = new DataFrameWriter(toDF()) + def write: DataFrameWriter[T] = { + if (isStreaming) { + logicalPlan.failAnalysis( + "'write' can not be called on streaming Dataset/DataFrame") + } + new DataFrameWriter[T](this) + } /** - * Returns the content of the [[Dataset]] as a Dataset of JSON strings. + * :: Experimental :: + * Interface for saving the content of the streaming Dataset out into external storage. + * + * @group basic + * @since 2.0.0 + */ + @Experimental + def writeStream: DataStreamWriter[T] = { + if (!isStreaming) { + logicalPlan.failAnalysis( + "'writeStream' can be called only on streaming Dataset/DataFrame") + } + new DataStreamWriter[T](this) + } + + + /** + * Returns the content of the Dataset as a Dataset of JSON strings. * @since 2.0.0 */ def toJSON: Dataset[String] = { @@ -2378,19 +2512,23 @@ class Dataset[T] private[sql]( /** * Converts a JavaRDD to a PythonRDD. */ - protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val structType = schema // capture it for closure val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)) EvaluatePython.javaToPython(rdd) } - protected[sql] def collectToPython(): Int = { + private[sql] def collectToPython(): Int = { + EvaluatePython.registerPicklers() withNewExecutionId { - PythonRDD.collectAndServe(javaToPython.rdd) + val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) + val iter = new SerDeUtil.AutoBatchedPickler( + queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + PythonRDD.serveIterator(iter, "serve-DataFrame") } } - protected[sql] def toPythonIterator(): Int = { + private[sql] def toPythonIterator(): Int = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) } @@ -2469,4 +2607,14 @@ class Dataset[T] private[sql]( @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { Dataset(sparkSession, logicalPlan) } + + /** A convenient function to wrap a set based logical plan and produce a Dataset. */ + @inline private def withSetOperator[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { + // Set operators widen types (change the schema), so we cannot reuse the row encoder. + Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] + } else { + Dataset(sparkSession, logicalPlan) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index c5df028485373..a435734b0caef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * regarding binary compatibility and source compatibility of methods here. * * {{{ - * sqlContext.experimental.extraStrategies += ... + * spark.experimental.extraStrategies += ... * }}} * * @since 1.3.0 @@ -42,9 +42,9 @@ class ExperimentalMethods private[sql]() { * @since 1.3.0 */ @Experimental - var extraStrategies: Seq[Strategy] = Nil + @volatile var extraStrategies: Seq[Strategy] = Nil @Experimental - var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil + @volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala new file mode 100644 index 0000000000000..f56b25b5576f1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/ForeachWriter.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.streaming.StreamingQuery + +/** + * :: Experimental :: + * A class to consume data generated by a [[StreamingQuery]]. Typically this is used to send the + * generated data to external systems. Each partition will use a new deserialized instance, so you + * usually should do all the initialization (e.g. opening a connection or initiating a transaction) + * in the `open` method. + * + * Scala example: + * {{{ + * datasetOfString.write.foreach(new ForeachWriter[String] { + * + * def open(partitionId: Long, version: Long): Boolean = { + * // open connection + * } + * + * def process(record: String) = { + * // write string to connection + * } + * + * def close(errorOrNull: Throwable): Unit = { + * // close the connection + * } + * }) + * }}} + * + * Java example: + * {{{ + * datasetOfString.write().foreach(new ForeachWriter() { + * + * @Override + * public boolean open(long partitionId, long version) { + * // open connection + * } + * + * @Override + * public void process(String value) { + * // write string to connection + * } + * + * @Override + * public void close(Throwable errorOrNull) { + * // close the connection + * } + * }); + * }}} + * @since 2.0.0 + */ +@Experimental +abstract class ForeachWriter[T] extends Serializable { + + /** + * Called when starting to process one partition of new data in the executor. The `version` is + * for data deduplication when there are failures. When recovering from a failure, some data may + * be generated multiple times but they will always have the same version. + * + * If this method finds using the `partitionId` and `version` that this partition has already been + * processed, it can return `false` to skip the further data processing. However, `close` still + * will be called for cleaning up resources. + * + * @param partitionId the partition id. + * @param version a unique id for data deduplication. + * @return `true` if the corresponding partition and version id should be processed. `false` + * indicates the partition should be skipped. + */ + def open(partitionId: Long, version: Long): Boolean + + /** + * Called to process the data in the executor side. This method will be called only when `open` + * returns `true`. + */ + def process(value: T): Unit + + /** + * Called when stopping to process one partition of new data in the executor side. This is + * guaranteed to be called either `open` returns `true` or `false`. However, + * `close` won't be called in the following cases: + * - JVM crashes without throwing a `Throwable` + * - `open` throws a `Throwable`. + * + * @param errorOrNull the error thrown during processing data or null if there was no error. + */ + def close(errorOrNull: Throwable): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 3a5ea19b8ad14..8eec42aab4fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -21,16 +21,17 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.expressions.ReduceAggregator /** * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not - * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupBy` on an existing - * [[Dataset]]. + * construct a [[KeyValueGroupedDataset]] directly, but should instead call `groupByKey` on + * an existing [[Dataset]]. * * @since 2.0.0 */ @@ -42,17 +43,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { - // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders - // when constructing new logical plans that will operate on the output of the current - // queryexecution. - - private implicit val unresolvedKEncoder = encoderFor(kEncoder) - private implicit val unresolvedVEncoder = encoderFor(vEncoder) - - private val resolvedKEncoder = - unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) - private val resolvedVEncoder = - unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + // Similar to [[Dataset]], we turn the passed in encoder to `ExpressionEncoder` explicitly. + private implicit val kExprEnc = encoderFor(kEncoder) + private implicit val vExprEnc = encoderFor(vEncoder) private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession @@ -67,13 +60,14 @@ class KeyValueGroupedDataset[K, V] private[sql]( def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = new KeyValueGroupedDataset( encoderFor[L], - unresolvedVEncoder, + vExprEnc, queryExecution, dataAttributes, groupingAttributes) /** - * Returns a [[Dataset]] that contains each unique key. + * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping + * over the Dataset to extract the keys and then running a distinct operation on those. * * @since 1.6.0 */ @@ -184,10 +178,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( * @since 1.6.0 */ def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { - val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - - implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) - flatMapGroups(func) + val vEncoder = encoderFor[V] + val aggregator: TypedColumn[V, V] = new ReduceAggregator[V](f)(vEncoder).toColumn + agg(aggregator) } /** @@ -204,13 +197,12 @@ class KeyValueGroupedDataset[K, V] private[sql]( * Internal helper function for building typed aggregations that return tuples. For simplicity * and code reuse, we do this without the help of the type system and then use helper functions * that cast appropriately for the user facing interface. - * TODO: does not handle aggregations that return nonflat results, */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named) - val keyColumn = if (resolvedKEncoder.flat) { + columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named) + val keyColumn = if (kExprEnc.flat) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { @@ -222,7 +214,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( new Dataset( sparkSession, execution, - ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) + ExpressionEncoder.tuple(kExprEnc +: encoders)) } /** @@ -287,7 +279,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit val uEncoder = other.unresolvedVEncoder + implicit val uEncoder = other.vExprEnc Dataset[R]( sparkSession, CoGroup( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 4f5bf633fab2e..6148ddfe05acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,13 +20,18 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.NumericType +import org.apache.spark.sql.types.StructType /** * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]]. @@ -34,6 +39,8 @@ import org.apache.spark.sql.types.NumericType * The main method is the agg function, which has multiple variants. This class also contains * convenience some first order statistics such as mean, sum for convenience. * + * This class was named `GroupedData` in Spark 1.x. + * * @since 2.0.0 */ class RelationalGroupedDataset protected[sql]( @@ -73,6 +80,8 @@ class RelationalGroupedDataset protected[sql]( private[this] def alias(expr: Expression): NamedExpression = expr match { case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr + case a: AggregateExpression if a.aggregateFunction.isInstanceOf[TypedAggregateExpression] => + UnresolvedAlias(a, Some(Column.generateAlias)) case expr: Expression => Alias(expr, usePrettyExpression(expr).sql)() } @@ -119,7 +128,7 @@ class RelationalGroupedDataset protected[sql]( } /** - * (Scala-specific) Compute aggregates by specifying a map from column name to + * (Scala-specific) Compute aggregates by specifying the column names and * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. @@ -134,7 +143,9 @@ class RelationalGroupedDataset protected[sql]( * @since 1.3.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - agg((aggExpr +: aggExprs).toMap) + toDF((aggExpr +: aggExprs).map { case (colName, expr) => + strToExpr(expr)(df(colName).expr) + }) } /** @@ -210,7 +221,7 @@ class RelationalGroupedDataset protected[sql]( def agg(expr: Column, exprs: Column*): DataFrame = { toDF((expr +: exprs).map { case typed: TypedColumn[_, _] => - typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr + typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr case c => c.expr }) } @@ -376,6 +387,48 @@ class RelationalGroupedDataset protected[sql]( def pivot(pivotColumn: String, values: java.util.List[Any]): RelationalGroupedDataset = { pivot(pivotColumn, values.asScala) } + + /** + * Applies the given serialized R function `func` to each group of data. For each unique group, + * the function will be passed the group key and an iterator that contains all of the elements in + * the group. The function can return an iterator containing elements of an arbitrary type which + * will be returned as a new [[DataFrame]]. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. If an application intends to perform an aggregation over each + * key, it is best to use the reduce function or an + * [[org.apache.spark.sql.expressions#Aggregator Aggregator]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @since 2.0.0 + */ + private[sql] def flatMapGroupsInR( + f: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + outputSchema: StructType): DataFrame = { + val groupingNamedExpressions = groupingExprs.map(alias) + val groupingCols = groupingNamedExpressions.map(Column(_)) + val groupingDataFrame = df.select(groupingCols : _*) + val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) + Dataset.ofRows( + df.sparkSession, + FlatMapGroupsInR( + f, + packageNames, + broadcastVars, + outputSchema, + groupingDataFrame.exprEnc.deserializer, + df.exprEnc.deserializer, + df.exprEnc.schema, + groupingAttributes, + df.logicalPlan.output, + df.logicalPlan)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index 670288b23400e..7e07e0cb84a87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.internal.config.ConfigEntry +import org.apache.spark.internal.config.{ConfigEntry, OptionalConfigEntry} import org.apache.spark.sql.internal.SQLConf @@ -35,9 +35,8 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * * @since 2.0.0 */ - def set(key: String, value: String): RuntimeConfig = { + def set(key: String, value: String): Unit = { sqlConf.setConfString(key, value) - this } /** @@ -45,7 +44,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * * @since 2.0.0 */ - def set(key: String, value: Boolean): RuntimeConfig = { + def set(key: String, value: Boolean): Unit = { set(key, value.toString) } @@ -54,7 +53,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { * * @since 2.0.0 */ - def set(key: String, value: Long): RuntimeConfig = { + def set(key: String, value: Long): Unit = { set(key, value.toString) } @@ -86,6 +85,10 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) { sqlConf.getConf(entry) } + protected[sql] def get[T](entry: OptionalConfigEntry[T]): Option[T] = { + sqlConf.getConf(entry) + } + /** * Returns the value of Spark runtime configuration property for the given key. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 168ac7e04b923..e7627ac2c95ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -19,35 +19,30 @@ package org.apache.spark.sql import java.beans.BeanInfo import java.util.Properties -import java.util.concurrent.atomic.AtomicReference import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.ShowTablesCommand -import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQueryManager} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager /** - * The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x. + * The entry point for working with structured data (rows and columns) in Spark 1.x. * - * As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class here - * for backward compatibility. + * As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class + * here for backward compatibility. * * @groupname basic Basic Operations * @groupname ddl_ops Persistent Catalog DDL @@ -56,61 +51,33 @@ import org.apache.spark.sql.util.ExecutionListenerManager * @groupname specificdata Specific Data Sources * @groupname config Configuration * @groupname dataframes Custom DataFrame Creation - * @groupname dataset Custom DataFrame Creation + * @groupname dataset Custom Dataset Creation * @groupname Ungrouped Support functions for language integrated queries * @since 1.0.0 */ -class SQLContext private[sql]( - val sparkSession: SparkSession, - val isRootContext: Boolean) +class SQLContext private[sql](val sparkSession: SparkSession) extends Logging with Serializable { self => + sparkSession.sparkContext.assertNotStopped() + // Note: Since Spark 2.0 this class has become a wrapper of SparkSession, where the // real functionality resides. This class remains mainly for backward compatibility. - private[sql] def this(sparkSession: SparkSession) = { - this(sparkSession, true) - } - + @deprecated("Use SparkSession.builder instead", "2.0.0") def this(sc: SparkContext) = { - this(new SparkSession(sc)) + this(SparkSession.builder().sparkContext(sc).getOrCreate()) } + @deprecated("Use SparkSession.builder instead", "2.0.0") def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) // TODO: move this logic into SparkSession - // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user - // wants to create a new root SQLContext (a SQLContext that is not created by newSession). - private val allowMultipleContexts = - sparkContext.conf.getBoolean( - SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, - SQLConf.ALLOW_MULTIPLE_CONTEXTS.defaultValue.get) - - // Assert no root SQLContext is running when allowMultipleContexts is false. - { - if (!allowMultipleContexts && isRootContext) { - SQLContext.getInstantiatedContextOption() match { - case Some(rootSQLContext) => - val errMsg = "Only one SQLContext/HiveContext may be running in this JVM. " + - s"It is recommended to use SQLContext.getOrCreate to get the instantiated " + - s"SQLContext/HiveContext. To ignore this error, " + - s"set ${SQLConf.ALLOW_MULTIPLE_CONTEXTS.key} = true in SparkConf." - throw new SparkException(errMsg) - case None => // OK - } - } - } - - protected[sql] def sessionState: SessionState = sparkSession.sessionState - protected[sql] def sharedState: SharedState = sparkSession.sharedState - protected[sql] def conf: SQLConf = sessionState.conf - protected[sql] def runtimeConf: RuntimeConfig = sparkSession.conf - protected[sql] def cacheManager: CacheManager = sparkSession.cacheManager - protected[sql] def listener: SQLListener = sparkSession.listener - protected[sql] def externalCatalog: ExternalCatalog = sparkSession.externalCatalog + private[sql] def sessionState: SessionState = sparkSession.sessionState + private[sql] def sharedState: SharedState = sparkSession.sharedState + private[sql] def conf: SQLConf = sessionState.conf def sparkContext: SparkContext = sparkSession.sparkContext @@ -121,7 +88,7 @@ class SQLContext private[sql]( * * @since 1.6.0 */ - def newSession(): SQLContext = sparkSession.newSession().wrapped + def newSession(): SQLContext = sparkSession.newSession().sqlContext /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -189,14 +156,6 @@ class SQLContext private[sql]( sparkSession.conf.getAll } - protected[sql] def parseSql(sql: String): LogicalPlan = sparkSession.parseSql(sql) - - protected[sql] def executeSql(sql: String): QueryExecution = sparkSession.executeSql(sql) - - protected[sql] def executePlan(plan: LogicalPlan): QueryExecution = { - sparkSession.executePlan(plan) - } - /** * :: Experimental :: * A collection of methods that are considered experimental, but can be used to hook into @@ -210,17 +169,18 @@ class SQLContext private[sql]( def experimental: ExperimentalMethods = sparkSession.experimental /** - * :: Experimental :: * Returns a [[DataFrame]] with no rows or columns. * * @group basic * @since 1.3.0 */ - @Experimental def emptyDataFrame: DataFrame = sparkSession.emptyDataFrame /** * A collection of methods for registering user-defined functions (UDF). + * Note that the user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. * * The following example registers a Scala closure as UDF: * {{{ @@ -259,15 +219,6 @@ class SQLContext private[sql]( sparkSession.catalog.isCached(tableName) } - /** - * Returns true if the [[Dataset]] is currently cached in-memory. - * @group cachemgmt - * @since 1.3.0 - */ - private[sql] def isCached(qName: Dataset[_]): Boolean = { - sparkSession.cacheManager.lookupCachedData(qName).nonEmpty - } - /** * Caches the specified table in-memory. * @group cachemgmt @@ -374,7 +325,7 @@ class SQLContext private[sql]( * // |-- name: string (nullable = false) * // |-- age: integer (nullable = true) * - * dataFrame.registerTempTable("people") + * dataFrame.createOrReplaceTempView("people") * sqlContext.sql("select name from people").collect.foreach(println) * }}} * @@ -395,15 +346,73 @@ class SQLContext private[sql]( sparkSession.createDataFrame(rowRDD, schema, needsConversion) } - + /** + * :: Experimental :: + * Creates a [[Dataset]] from a local Seq of data of a given type. This method requires an + * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) + * that is generally created automatically through implicits from a `SparkSession`, or can be + * created explicitly by calling static methods on [[Encoders]]. + * + * == Example == + * + * {{{ + * + * import spark.implicits._ + * case class Person(name: String, age: Long) + * val data = Seq(Person("Michael", 29), Person("Andy", 30), Person("Justin", 19)) + * val ds = spark.createDataset(data) + * + * ds.show() + * // +-------+---+ + * // | name|age| + * // +-------+---+ + * // |Michael| 29| + * // | Andy| 30| + * // | Justin| 19| + * // +-------+---+ + * }}} + * + * @since 2.0.0 + * @group dataset + */ + @Experimental def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { sparkSession.createDataset(data) } + /** + * :: Experimental :: + * Creates a [[Dataset]] from an RDD of a given type. This method requires an + * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) + * that is generally created automatically through implicits from a `SparkSession`, or can be + * created explicitly by calling static methods on [[Encoders]]. + * + * @since 2.0.0 + * @group dataset + */ + @Experimental def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { sparkSession.createDataset(data) } + /** + * :: Experimental :: + * Creates a [[Dataset]] from a [[java.util.List]] of a given type. This method requires an + * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) + * that is generally created automatically through implicits from a `SparkSession`, or can be + * created explicitly by calling static methods on [[Encoders]]. + * + * == Java Example == + * + * {{{ + * List data = Arrays.asList("hello", "world"); + * Dataset ds = spark.createDataset(data, Encoders.STRING()); + * }}} + * + * @since 2.0.0 + * @group dataset + */ + @Experimental def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { sparkSession.createDataset(data) } @@ -419,7 +428,7 @@ class SQLContext private[sql]( /** * :: DeveloperApi :: - * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema. + * Creates a [[DataFrame]] from a [[JavaRDD]] containing [[Row]]s using the given schema. * It is important to make sure that the structure of every [[Row]] of the provided RDD matches * the provided schema. Otherwise, there will be runtime exception. * @@ -433,7 +442,7 @@ class SQLContext private[sql]( /** * :: DeveloperApi :: - * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema. + * Creates a [[DataFrame]] from a [[java.util.List]] containing [[Row]]s using the given schema. * It is important to make sure that the structure of every [[Row]] of the provided List matches * the provided schema. Otherwise, there will be runtime exception. * @@ -470,7 +479,7 @@ class SQLContext private[sql]( } /** - * Applies a schema to an List of Java Beans. + * Applies a schema to a List of Java Beans. * * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. @@ -482,8 +491,8 @@ class SQLContext private[sql]( } /** - * :: Experimental :: - * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]]. + * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a + * [[DataFrame]]. * {{{ * sqlContext.read.parquet("/path/to/file.parquet") * sqlContext.read.schema(schema).json("/path/to/file.json") @@ -492,31 +501,41 @@ class SQLContext private[sql]( * @group genericdata * @since 1.4.0 */ - @Experimental def read: DataFrameReader = sparkSession.read + /** * :: Experimental :: + * Returns a [[DataStreamReader]] that can be used to read streaming data in as a [[DataFrame]]. + * {{{ + * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") + * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") + * }}} + * + * @since 2.0.0 + */ + @Experimental + def readStream: DataStreamReader = sparkSession.readStream + + + /** * Creates an external table from the given path and returns the corresponding DataFrame. * It will use the default data source configured by spark.sql.sources.default. * * @group ddl_ops * @since 1.3.0 */ - @Experimental def createExternalTable(tableName: String, path: String): DataFrame = { sparkSession.catalog.createExternalTable(tableName, path) } /** - * :: Experimental :: * Creates an external table from the given path based on a data source * and returns the corresponding DataFrame. * * @group ddl_ops * @since 1.3.0 */ - @Experimental def createExternalTable( tableName: String, path: String, @@ -525,14 +544,12 @@ class SQLContext private[sql]( } /** - * :: Experimental :: * Creates an external table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. * * @group ddl_ops * @since 1.3.0 */ - @Experimental def createExternalTable( tableName: String, source: String, @@ -541,7 +558,6 @@ class SQLContext private[sql]( } /** - * :: Experimental :: * (Scala-specific) * Creates an external table from the given path based on a data source and a set of options. * Then, returns the corresponding DataFrame. @@ -549,7 +565,6 @@ class SQLContext private[sql]( * @group ddl_ops * @since 1.3.0 */ - @Experimental def createExternalTable( tableName: String, source: String, @@ -558,14 +573,12 @@ class SQLContext private[sql]( } /** - * :: Experimental :: * Create an external table from the given path based on a data source, a schema and * a set of options. Then, returns the corresponding DataFrame. * * @group ddl_ops * @since 1.3.0 */ - @Experimental def createExternalTable( tableName: String, source: String, @@ -575,7 +588,6 @@ class SQLContext private[sql]( } /** - * :: Experimental :: * (Scala-specific) * Create an external table from the given path based on a data source, a schema and * a set of options. Then, returns the corresponding DataFrame. @@ -583,7 +595,6 @@ class SQLContext private[sql]( * @group ddl_ops * @since 1.3.0 */ - @Experimental def createExternalTable( tableName: String, source: String, @@ -597,7 +608,7 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - sparkSession.registerTable(df, tableName) + df.createOrReplaceTempView(tableName) } /** @@ -609,56 +620,56 @@ class SQLContext private[sql]( * @since 1.3.0 */ def dropTempTable(tableName: String): Unit = { - sparkSession.catalog.dropTempTable(tableName) + sparkSession.catalog.dropTempView(tableName) } /** * :: Experimental :: - * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements - * in an range from 0 to `end` (exclusive) with step value 1. + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in a range from 0 to `end` (exclusive) with step value 1. * - * @since 2.0.0 - * @group dataset + * @since 1.4.1 + * @group dataframe */ @Experimental - def range(end: Long): Dataset[java.lang.Long] = sparkSession.range(end) + def range(end: Long): DataFrame = sparkSession.range(end).toDF() /** * :: Experimental :: - * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end` (exclusive) with step value 1. + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in a range from `start` to `end` (exclusive) with step value 1. * - * @since 2.0.0 - * @group dataset + * @since 1.4.0 + * @group dataframe */ @Experimental - def range(start: Long, end: Long): Dataset[java.lang.Long] = sparkSession.range(start, end) + def range(start: Long, end: Long): DataFrame = sparkSession.range(start, end).toDF() /** * :: Experimental :: - * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end` (exclusive) with an step value. + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in a range from `start` to `end` (exclusive) with a step value. * * @since 2.0.0 - * @group dataset + * @group dataframe */ @Experimental - def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { - sparkSession.range(start, end, step) + def range(start: Long, end: Long, step: Long): DataFrame = { + sparkSession.range(start, end, step).toDF() } /** * :: Experimental :: - * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements * in an range from `start` to `end` (exclusive) with an step value, with partition number * specified. * - * @since 2.0.0 - * @group dataset + * @since 1.4.0 + * @group dataframe */ @Experimental - def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { - sparkSession.range(start, end, step, numPartitions) + def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { + sparkSession.range(start, end, step, numPartitions).toDF() } /** @@ -705,12 +716,12 @@ class SQLContext private[sql]( } /** - * Returns a [[ContinuousQueryManager]] that allows managing all the - * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this` context. + * Returns a [[StreamingQueryManager]] that allows managing all the + * [[org.apache.spark.sql.streaming.StreamingQuery StreamingQueries]] active on `this` context. * * @since 2.0.0 */ - def streams: ContinuousQueryManager = sparkSession.streams + def streams: StreamingQueryManager = sparkSession.streams /** * Returns the names of tables in the current database as an array. @@ -737,42 +748,294 @@ class SQLContext private[sql]( * have the same format as the one generated by `toString` in scala. * It is only used by PySpark. */ - protected[sql] def parseDataType(dataTypeString: String): DataType = { - sparkSession.parseDataType(dataTypeString) + // TODO: Remove this function (would require updating PySpark). + private[sql] def parseDataType(dataTypeString: String): DataType = { + DataType.fromJson(dataTypeString) } + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // Deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + /** - * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. */ - protected[sql] def applySchemaToPythonRDD( - rdd: RDD[Array[Any]], - schemaString: String): DataFrame = { - sparkSession.applySchemaToPythonRDD(rdd, schemaString) + @deprecated("Use createDataFrame instead.", "1.3.0") + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) } /** - * Apply a schema defined by the schema to an RDD. It is only used by PySpark. + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. */ - protected[sql] def applySchemaToPythonRDD( - rdd: RDD[Array[Any]], - schema: StructType): DataFrame = { - sparkSession.applySchemaToPythonRDD(rdd, schema) + @deprecated("Use createDataFrame instead.", "1.3.0") + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) } - // TODO: move this logic into SparkSession + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("Use createDataFrame instead.", "1.3.0") + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("Use createDataFrame instead.", "1.3.0") + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } - // Register a successfully instantiated context to the singleton. This should be at the end of - // the class definition so that the singleton is updated only if there is no exception in the - // construction of the instance. - sparkContext.addSparkListener(new SparkListener { - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - SQLContext.clearInstantiatedContext() - SQLContext.clearSqlListener() + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty + * [[DataFrame]] if no paths are passed in. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().parquet()`. + */ + @deprecated("Use read.parquet() instead.", "1.4.0") + @scala.annotation.varargs + def parquetFile(paths: String*): DataFrame = { + if (paths.isEmpty) { + emptyDataFrame + } else { + read.parquet(paths : _*) } - }) + } + + /** + * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json() instead.", "1.4.0") + def jsonFile(path: String): DataFrame = { + read.json(path) + } + + /** + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json() instead.", "1.4.0") + def jsonFile(path: String, schema: StructType): DataFrame = { + read.schema(schema).json(path) + } + + /** + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json() instead.", "1.4.0") + def jsonFile(path: String, samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(path) + } - sparkSession.setWrappedContext(self) - SQLContext.setInstantiatedContext(self) + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json() instead.", "1.4.0") + def jsonRDD(json: RDD[String]): DataFrame = read.json(json) + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json() instead.", "1.4.0") + def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) + + /** + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json() instead.", "1.4.0") + def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { + read.schema(schema).json(json) + } + + /** + * Loads an JavaRDD storing JSON objects (one object per record) and applies the given + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json() instead.", "1.4.0") + def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { + read.schema(schema).json(json) + } + + /** + * Loads an RDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json() instead.", "1.4.0") + def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(json) + } + + /** + * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json() instead.", "1.4.0") + def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(json) + } + + /** + * Returns the dataset stored at path as a DataFrame, + * using the default data source configured by spark.sql.sources.default. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().load(path)`. + */ + @deprecated("Use read.load(path) instead.", "1.4.0") + def load(path: String): DataFrame = { + read.load(path) + } + + /** + * Returns the dataset stored at path as a DataFrame, using the given data source. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. + */ + @deprecated("Use read.format(source).load(path) instead.", "1.4.0") + def load(path: String, source: String): DataFrame = { + read.format(source).load(path) + } + + /** + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + */ + @deprecated("Use read.format(source).options(options).load() instead.", "1.4.0") + def load(source: String, options: java.util.Map[String, String]): DataFrame = { + read.options(options).format(source).load() + } + + /** + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + */ + @deprecated("Use read.format(source).options(options).load() instead.", "1.4.0") + def load(source: String, options: Map[String, String]): DataFrame = { + read.options(options).format(source).load() + } + + /** + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. + */ + @deprecated("Use read.format(source).schema(schema).options(options).load() instead.", "1.4.0") + def load( + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + read.format(source).schema(schema).options(options).load() + } + + /** + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. + */ + @deprecated("Use read.format(source).schema(schema).options(options).load() instead.", "1.4.0") + def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { + read.format(source).schema(schema).options(options).load() + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + */ + @deprecated("Use read.jdbc() instead.", "1.4.0") + def jdbc(url: String, table: String): DataFrame = { + read.jdbc(url, table, new Properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. + * + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + */ + @deprecated("Use read.jdbc() instead.", "1.4.0") + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): DataFrame = { + read.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. The theParts parameter gives a list expressions + * suitable for inclusion in WHERE clauses; each one defines one partition + * of the [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + */ + @deprecated("Use read.jdbc() instead.", "1.4.0") + def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { + read.jdbc(url, table, theParts, new Properties) + } } /** @@ -785,19 +1048,6 @@ class SQLContext private[sql]( */ object SQLContext { - /** - * The active SQLContext for the current thread. - */ - private val activeContext: InheritableThreadLocal[SQLContext] = - new InheritableThreadLocal[SQLContext] - - /** - * Reference to the created SQLContext. - */ - @transient private val instantiatedContext = new AtomicReference[SQLContext]() - - @transient private val sqlListener = new AtomicReference[SQLListener]() - /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -809,41 +1059,9 @@ object SQLContext { * * @since 1.5.0 */ + @deprecated("Use SparkSession.builder instead", "2.0.0") def getOrCreate(sparkContext: SparkContext): SQLContext = { - val ctx = activeContext.get() - if (ctx != null && !ctx.sparkContext.isStopped) { - return ctx - } - - synchronized { - val ctx = instantiatedContext.get() - if (ctx == null || ctx.sparkContext.isStopped) { - new SQLContext(sparkContext) - } else { - ctx - } - } - } - - private[sql] def clearInstantiatedContext(): Unit = { - instantiatedContext.set(null) - } - - private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { - synchronized { - val ctx = instantiatedContext.get() - if (ctx == null || ctx.sparkContext.isStopped) { - instantiatedContext.set(sqlContext) - } - } - } - - private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { - Option(instantiatedContext.get()) - } - - private[sql] def clearSqlListener(): Unit = { - sqlListener.set(null) + SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext } /** @@ -853,8 +1071,9 @@ object SQLContext { * * @since 1.6.0 */ + @deprecated("Use SparkSession.setActiveSession instead", "2.0.0") def setActive(sqlContext: SQLContext): Unit = { - activeContext.set(sqlContext) + SparkSession.setActiveSession(sqlContext.sparkSession) } /** @@ -863,12 +1082,9 @@ object SQLContext { * * @since 1.6.0 */ + @deprecated("Use SparkSession.clearActiveSession instead", "2.0.0") def clearActive(): Unit = { - activeContext.remove() - } - - private[sql] def getActive(): Option[SQLContext] = { - Option(activeContext.get()) + SparkSession.clearActiveSession() } /** @@ -892,20 +1108,6 @@ object SQLContext { } } - /** - * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. - */ - private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - if (sqlListener.get() == null) { - val listener = new SQLListener(sc.conf) - if (sqlListener.compareAndSet(null, listener)) { - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - } - } - sqlListener.get() - } - /** * Extract `spark.sql.*` properties from the conf and return them as a [[Properties]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index f423e7d6b5765..440952572d8c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -24,7 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder /** - * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. + * A collection of implicit methods for converting common Scala objects into [[Dataset]]s. * * @since 1.6.0 */ @@ -33,7 +33,7 @@ abstract class SQLImplicits { protected def _sqlContext: SQLContext /** - * Converts $"col name" into an [[Column]]. + * Converts $"col name" into a [[Column]]. * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 4c2a7b8ae9060..13d3e75e3d393 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,72 +18,84 @@ package org.apache.spark.sql import java.beans.Introspector -import java.util.Properties +import java.util.concurrent.atomic.AtomicReference -import scala.collection.immutable import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{CATALOG_IMPLEMENTATION, ConfigEntry} +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState, SQLConf} +import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState} import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, LongType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils /** - * The entry point to Spark execution. + * The entry point to programming Spark with the Dataset and DataFrame API. + * + * In environments that this has been created upfront (e.g. REPL, notebooks), use the builder + * to get an existing session: + * + * {{{ + * SparkSession.builder().getOrCreate() + * }}} + * + * The builder can also be used to create a new session: + * + * {{{ + * SparkSession.builder() + * .master("local") + * .appName("Word Count") + * .config("spark.some.config.option", "some-value") + * .getOrCreate() + * }}} */ class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState]) extends Serializable with Logging { self => - def this(sc: SparkContext) { + private[sql] def this(sc: SparkContext) { this(sc, None) } + sparkContext.assertNotStopped() + + /** + * The version of Spark on which this application is running. + * + * @since 2.0.0 + */ + def version: String = SPARK_VERSION /* ----------------------- * | Session-related state | * ----------------------- */ - { - val defaultWarehousePath = - SQLConf.WAREHOUSE_PATH - .defaultValueString - .replace("${system:user.dir}", System.getProperty("user.dir")) - val warehousePath = sparkContext.conf.get( - SQLConf.WAREHOUSE_PATH.key, - defaultWarehousePath) - sparkContext.conf.set(SQLConf.WAREHOUSE_PATH.key, warehousePath) - sparkContext.conf.set("hive.metastore.warehouse.dir", warehousePath) - logInfo(s"Setting warehouse location to $warehousePath") - } - /** * State shared across sessions, including the [[SparkContext]], cached data, listener, * and a catalog that interacts with external systems. */ @transient - protected[sql] lazy val sharedState: SharedState = { + private[sql] lazy val sharedState: SharedState = { existingSharedState.getOrElse( SparkSession.reflect[SharedState, SparkContext]( SparkSession.sharedStateClassName(sparkContext.conf), @@ -95,32 +107,19 @@ class SparkSession private( * functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]]. */ @transient - protected[sql] lazy val sessionState: SessionState = { + private[sql] lazy val sessionState: SessionState = { SparkSession.reflect[SessionState, SparkSession]( SparkSession.sessionStateClassName(sparkContext.conf), self) } /** - * A wrapped version of this session in the form of a [[SQLContext]]. + * A wrapped version of this session in the form of a [[SQLContext]], for backward compatibility. + * + * @since 2.0.0 */ @transient - private var _wrapped: SQLContext = _ - - protected[sql] def wrapped: SQLContext = { - if (_wrapped == null) { - _wrapped = new SQLContext(self, isRootContext = false) - } - _wrapped - } - - protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = { - _wrapped = sqlContext - } - - protected[sql] def cacheManager: CacheManager = sharedState.cacheManager - protected[sql] def listener: SQLListener = sharedState.listener - protected[sql] def externalCatalog: ExternalCatalog = sharedState.externalCatalog + val sqlContext: SQLContext = new SQLContext(this) /** * Runtime configuration interface for Spark. @@ -129,7 +128,6 @@ class SparkSession private( * configurations that are relevant to Spark SQL. When getting the value of a config, * this defaults to the value set in the underlying [[SparkContext]], if any. * - * @group config * @since 2.0.0 */ @transient lazy val conf: RuntimeConfig = new RuntimeConfig(sessionState.conf) @@ -139,7 +137,6 @@ class SparkSession private( * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. * - * @group basic * @since 2.0.0 */ @Experimental @@ -150,7 +147,6 @@ class SparkSession private( * A collection of methods that are considered experimental, but can be used to hook into * the query planner for advanced functionality. * - * @group basic * @since 2.0.0 */ @Experimental @@ -158,6 +154,9 @@ class SparkSession private( /** * A collection of methods for registering user-defined functions (UDF). + * Note that the user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. * * The following example registers a Scala closure as UDF: * {{{ @@ -182,19 +181,19 @@ class SparkSession private( * DataTypes.StringType); * }}} * - * @group basic * @since 2.0.0 */ def udf: UDFRegistration = sessionState.udf /** - * Returns a [[ContinuousQueryManager]] that allows managing all the - * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this`. + * :: Experimental :: + * Returns a [[StreamingQueryManager]] that allows managing all the + * [[StreamingQuery StreamingQueries]] active on `this`. * - * @group basic * @since 2.0.0 */ - def streams: ContinuousQueryManager = sessionState.continuousQueryManager + @Experimental + def streams: StreamingQueryManager = sessionState.streamingQueryManager /** * Start a new session with isolated SQL configurations, temporary tables, registered @@ -205,7 +204,6 @@ class SparkSession private( * and child sessions are set up with the same shared state. If the underlying catalog * implementation is Hive, this will initialize the metastore, which may take some time. * - * @group basic * @since 2.0.0 */ def newSession(): SparkSession = { @@ -218,28 +216,36 @@ class SparkSession private( * --------------------------------- */ /** - * :: Experimental :: * Returns a [[DataFrame]] with no rows or columns. * - * @group dataframes * @since 2.0.0 */ - @Experimental @transient lazy val emptyDataFrame: DataFrame = { createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) } + /** + * :: Experimental :: + * Creates a new [[Dataset]] of type T containing zero elements. + * + * @return 2.0.0 + */ + @Experimental + def emptyDataset[T: Encoder]: Dataset[T] = { + val encoder = implicitly[Encoder[T]] + new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder) + } + /** * :: Experimental :: * Creates a [[DataFrame]] from an RDD of Product (e.g. case classes, tuples). * - * @group dataframes * @since 2.0.0 */ @Experimental def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SQLContext.setActive(wrapped) + SparkSession.setActiveSession(this) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) @@ -250,12 +256,11 @@ class SparkSession private( * :: Experimental :: * Creates a [[DataFrame]] from a local Seq of Product. * - * @group dataframes * @since 2.0.0 */ @Experimental def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SQLContext.setActive(wrapped) + SparkSession.setActiveSession(this) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) @@ -286,11 +291,10 @@ class SparkSession private( * // |-- name: string (nullable = false) * // |-- age: integer (nullable = true) * - * dataFrame.registerTempTable("people") + * dataFrame.createOrReplaceTempView("people") * sparkSession.sql("select name from people").collect.foreach(println) * }}} * - * @group dataframes * @since 2.0.0 */ @DeveloperApi @@ -300,11 +304,10 @@ class SparkSession private( /** * :: DeveloperApi :: - * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema. + * Creates a [[DataFrame]] from a [[JavaRDD]] containing [[Row]]s using the given schema. * It is important to make sure that the structure of every [[Row]] of the provided RDD matches * the provided schema. Otherwise, there will be runtime exception. * - * @group dataframes * @since 2.0.0 */ @DeveloperApi @@ -314,11 +317,10 @@ class SparkSession private( /** * :: DeveloperApi :: - * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema. + * Creates a [[DataFrame]] from a [[java.util.List]] containing [[Row]]s using the given schema. * It is important to make sure that the structure of every [[Row]] of the provided List matches * the provided schema. Otherwise, there will be runtime exception. * - * @group dataframes * @since 2.0.0 */ @DeveloperApi @@ -332,7 +334,6 @@ class SparkSession private( * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. * - * @group dataframes * @since 2.0.0 */ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { @@ -352,7 +353,6 @@ class SparkSession private( * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. * - * @group dataframes * @since 2.0.0 */ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { @@ -360,11 +360,10 @@ class SparkSession private( } /** - * Applies a schema to an List of Java Beans. + * Applies a schema to a List of Java Beans. * * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. - * @group dataframes * @since 1.6.0 */ def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { @@ -377,13 +376,45 @@ class SparkSession private( /** * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]]. * - * @group dataframes * @since 2.0.0 */ def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { Dataset.ofRows(self, LogicalRelation(baseRelation)) } + /* ------------------------------- * + | Methods for creating DataSets | + * ------------------------------- */ + + /** + * :: Experimental :: + * Creates a [[Dataset]] from a local Seq of data of a given type. This method requires an + * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) + * that is generally created automatically through implicits from a `SparkSession`, or can be + * created explicitly by calling static methods on [[Encoders]]. + * + * == Example == + * + * {{{ + * + * import spark.implicits._ + * case class Person(name: String, age: Long) + * val data = Seq(Person("Michael", 29), Person("Andy", 30), Person("Justin", 19)) + * val ds = spark.createDataset(data) + * + * ds.show() + * // +-------+---+ + * // | name|age| + * // +-------+---+ + * // |Michael| 29| + * // | Andy| 30| + * // | Justin| 19| + * // +-------+---+ + * }}} + * + * @since 2.0.0 + */ + @Experimental def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes @@ -392,6 +423,16 @@ class SparkSession private( Dataset[T](self, plan) } + /** + * :: Experimental :: + * Creates a [[Dataset]] from an RDD of a given type. This method requires an + * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) + * that is generally created automatically through implicits from a `SparkSession`, or can be + * created explicitly by calling static methods on [[Encoders]]. + * + * @since 2.0.0 + */ + @Experimental def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { val enc = encoderFor[T] val attributes = enc.schema.toAttributes @@ -400,6 +441,23 @@ class SparkSession private( Dataset[T](self, plan) } + /** + * :: Experimental :: + * Creates a [[Dataset]] from a [[java.util.List]] of a given type. This method requires an + * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) + * that is generally created automatically through implicits from a `SparkSession`, or can be + * created explicitly by calling static methods on [[Encoders]]. + * + * == Java Example == + * + * {{{ + * List data = Arrays.asList("hello", "world"); + * Dataset ds = spark.createDataset(data, Encoders.STRING()); + * }}} + * + * @since 2.0.0 + */ + @Experimental def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { createDataset(data.asScala) } @@ -407,10 +465,9 @@ class SparkSession private( /** * :: Experimental :: * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements - * in an range from 0 to `end` (exclusive) with step value 1. + * in a range from 0 to `end` (exclusive) with step value 1. * * @since 2.0.0 - * @group dataset */ @Experimental def range(end: Long): Dataset[java.lang.Long] = range(0, end) @@ -418,10 +475,9 @@ class SparkSession private( /** * :: Experimental :: * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end` (exclusive) with step value 1. + * in a range from `start` to `end` (exclusive) with step value 1. * * @since 2.0.0 - * @group dataset */ @Experimental def range(start: Long, end: Long): Dataset[java.lang.Long] = { @@ -431,10 +487,9 @@ class SparkSession private( /** * :: Experimental :: * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end` (exclusive) with an step value. + * in a range from `start` to `end` (exclusive) with a step value. * * @since 2.0.0 - * @group dataset */ @Experimental def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { @@ -444,11 +499,10 @@ class SparkSession private( /** * :: Experimental :: * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end` (exclusive) with an step value, with partition number + * in a range from `start` to `end` (exclusive) with a step value, with partition number * specified. * * @since 2.0.0 - * @group dataset */ @Experimental def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { @@ -459,7 +513,7 @@ class SparkSession private( * Creates a [[DataFrame]] from an RDD[Row]. * User can specify whether the input rows should be converted to Catalyst rows. */ - protected[sql] def internalCreateDataFrame( + private[sql] def internalCreateDataFrame( catalystRows: RDD[InternalRow], schema: StructType): DataFrame = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied @@ -472,15 +526,15 @@ class SparkSession private( * Creates a [[DataFrame]] from an RDD[Row]. * User can specify whether the input rows should be converted to Catalyst rows. */ - protected[sql] def createDataFrame( + private[sql] def createDataFrame( rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val catalystRows = if (needsConversion) { - val converter = CatalystTypeConverters.createToCatalystConverter(schema) - rowRDD.map(converter(_).asInstanceOf[InternalRow]) + val encoder = RowEncoder(schema) + rowRDD.map(encoder.toRow) } else { rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} } @@ -489,15 +543,14 @@ class SparkSession private( } - /* ------------------------ * - | Catalog-related methods | - * ----------------- ------ */ + /* ------------------------- * + | Catalog-related methods | + * ------------------------- */ /** * Interface through which the user may create, drop, alter or query underlying * databases, tables, functions etc. * - * @group ddl_ops * @since 2.0.0 */ @transient lazy val catalog: Catalog = new CatalogImpl(self) @@ -505,29 +558,16 @@ class SparkSession private( /** * Returns the specified table as a [[DataFrame]]. * - * @group ddl_ops * @since 2.0.0 */ def table(tableName: String): DataFrame = { table(sessionState.sqlParser.parseTableIdentifier(tableName)) } - protected[sql] def table(tableIdent: TableIdentifier): DataFrame = { + private[sql] def table(tableIdent: TableIdentifier): DataFrame = { Dataset.ofRows(self, sessionState.catalog.lookupRelation(tableIdent)) } - /** - * Registers the given [[DataFrame]] as a temporary table in the catalog. - * Temporary tables exist only during the lifetime of this instance of [[SparkSession]]. - */ - protected[sql] def registerTable(df: DataFrame, tableName: String): Unit = { - sessionState.catalog.createTempTable( - sessionState.sqlParser.parseTableIdentifier(tableName).table, - df.logicalPlan, - overrideIfExists = true) - } - - /* ----------------- * | Everything else | * ----------------- */ @@ -536,27 +576,37 @@ class SparkSession private( * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. * The dialect that is used for SQL parsing can be configured with 'spark.sql.dialect'. * - * @group basic * @since 2.0.0 */ def sql(sqlText: String): DataFrame = { - Dataset.ofRows(self, parseSql(sqlText)) + Dataset.ofRows(self, sessionState.sqlParser.parsePlan(sqlText)) } /** - * :: Experimental :: - * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]]. + * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a + * [[DataFrame]]. * {{{ * sparkSession.read.parquet("/path/to/file.parquet") * sparkSession.read.schema(schema).json("/path/to/file.json") * }}} * - * @group genericdata * @since 2.0.0 */ - @Experimental def read: DataFrameReader = new DataFrameReader(self) + /** + * :: Experimental :: + * Returns a [[DataStreamReader]] that can be used to read streaming data in as a [[DataFrame]]. + * {{{ + * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") + * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") + * }}} + * + * @since 2.0.0 + */ + @Experimental + def readStream: DataStreamReader = new DataStreamReader(self) + // scalastyle:off // Disable style checker so "implicits" object can start with lowercase i @@ -566,29 +616,25 @@ class SparkSession private( * common Scala objects into [[DataFrame]]s. * * {{{ - * val sparkSession = new SparkSession(sc) + * val sparkSession = SparkSession.builder.getOrCreate() * import sparkSession.implicits._ * }}} * - * @group basic * @since 2.0.0 */ @Experimental object implicits extends SQLImplicits with Serializable { - protected override def _sqlContext: SQLContext = wrapped + protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext } // scalastyle:on - protected[sql] def parseSql(sql: String): LogicalPlan = { - sessionState.sqlParser.parsePlan(sql) - } - - protected[sql] def executeSql(sql: String): QueryExecution = { - executePlan(parseSql(sql)) - } - - protected[sql] def executePlan(plan: LogicalPlan): QueryExecution = { - sessionState.executePlan(plan) + /** + * Stop the underlying [[SparkContext]]. + * + * @since 2.0.0 + */ + def stop(): Unit = { + sparkContext.stop() } /** @@ -603,17 +649,17 @@ class SparkSession private( /** * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. */ - protected[sql] def applySchemaToPythonRDD( + private[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schemaString: String): DataFrame = { - val schema = parseDataType(schemaString).asInstanceOf[StructType] + val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] applySchemaToPythonRDD(rdd, schema) } /** * Apply a schema defined by the schema to an RDD. It is only used by PySpark. */ - protected[sql] def applySchemaToPythonRDD( + private[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) @@ -635,6 +681,244 @@ class SparkSession private( object SparkSession { + /** + * Builder for [[SparkSession]]. + */ + class Builder extends Logging { + + private[this] val options = new scala.collection.mutable.HashMap[String, String] + + private[this] var userSuppliedContext: Option[SparkContext] = None + + private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { + userSuppliedContext = Option(sparkContext) + this + } + + /** + * Sets a name for the application, which will be shown in the Spark web UI. + * If no application name is set, a randomly generated name will be used. + * + * @since 2.0.0 + */ + def appName(name: String): Builder = config("spark.app.name", name) + + /** + * Sets a config option. Options set using this method are automatically propagated to + * both [[SparkConf]] and SparkSession's own configuration. + * + * @since 2.0.0 + */ + def config(key: String, value: String): Builder = synchronized { + options += key -> value + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to + * both [[SparkConf]] and SparkSession's own configuration. + * + * @since 2.0.0 + */ + def config(key: String, value: Long): Builder = synchronized { + options += key -> value.toString + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to + * both [[SparkConf]] and SparkSession's own configuration. + * + * @since 2.0.0 + */ + def config(key: String, value: Double): Builder = synchronized { + options += key -> value.toString + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to + * both [[SparkConf]] and SparkSession's own configuration. + * + * @since 2.0.0 + */ + def config(key: String, value: Boolean): Builder = synchronized { + options += key -> value.toString + this + } + + /** + * Sets a list of config options based on the given [[SparkConf]]. + * + * @since 2.0.0 + */ + def config(conf: SparkConf): Builder = synchronized { + conf.getAll.foreach { case (k, v) => options += k -> v } + this + } + + /** + * Sets the Spark master URL to connect to, such as "local" to run locally, "local[4]" to + * run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone cluster. + * + * @since 2.0.0 + */ + def master(master: String): Builder = config("spark.master", master) + + /** + * Enables Hive support, including connectivity to a persistent Hive metastore, support for + * Hive serdes, and Hive user-defined functions. + * + * @since 2.0.0 + */ + def enableHiveSupport(): Builder = synchronized { + if (hiveClassesArePresent) { + config(CATALOG_IMPLEMENTATION.key, "hive") + } else { + throw new IllegalArgumentException( + "Unable to instantiate SparkSession with Hive support because " + + "Hive classes are not found.") + } + } + + /** + * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new + * one based on the options set in this builder. + * + * This method first checks whether there is a valid thread-local SparkSession, + * and if yes, return that one. It then checks whether there is a valid global + * default SparkSession, and if yes, return that one. If no valid global default + * SparkSession exists, the method creates a new SparkSession and assigns the + * newly created SparkSession as the global default. + * + * In case an existing SparkSession is returned, the config options specified in + * this builder will be applied to the existing SparkSession. + * + * @since 2.0.0 + */ + def getOrCreate(): SparkSession = synchronized { + // Get the session from current thread's active session. + var session = activeThreadSession.get() + if ((session ne null) && !session.sparkContext.isStopped) { + options.foreach { case (k, v) => session.conf.set(k, v) } + if (options.nonEmpty) { + logWarning("Using an existing SparkSession; some configuration may not take effect.") + } + return session + } + + // Global synchronization so we will only set the default session once. + SparkSession.synchronized { + // If the current thread does not have an active session, get it from the global session. + session = defaultSession.get() + if ((session ne null) && !session.sparkContext.isStopped) { + options.foreach { case (k, v) => session.conf.set(k, v) } + if (options.nonEmpty) { + logWarning("Using an existing SparkSession; some configuration may not take effect.") + } + return session + } + + // No active nor global default session. Create a new one. + val sparkContext = userSuppliedContext.getOrElse { + // set app name if not given + val randomAppName = java.util.UUID.randomUUID().toString + val sparkConf = new SparkConf() + options.foreach { case (k, v) => sparkConf.set(k, v) } + if (!sparkConf.contains("spark.app.name")) { + sparkConf.setAppName(randomAppName) + } + val sc = SparkContext.getOrCreate(sparkConf) + // maybe this is an existing SparkContext, update its SparkConf which maybe used + // by SparkSession + options.foreach { case (k, v) => sc.conf.set(k, v) } + if (!sc.conf.contains("spark.app.name")) { + sc.conf.setAppName(randomAppName) + } + sc + } + session = new SparkSession(sparkContext) + options.foreach { case (k, v) => session.conf.set(k, v) } + defaultSession.set(session) + + // Register a successfully instantiated context to the singleton. This should be at the + // end of the class definition so that the singleton is updated only if there is no + // exception in the construction of the instance. + sparkContext.addSparkListener(new SparkListener { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { + defaultSession.set(null) + sqlListener.set(null) + } + }) + } + + return session + } + } + + /** + * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. + * + * @since 2.0.0 + */ + def builder(): Builder = new Builder + + /** + * Changes the SparkSession that will be returned in this thread and its children when + * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives + * a SparkSession with an isolated session, instead of the global (first created) context. + * + * @since 2.0.0 + */ + def setActiveSession(session: SparkSession): Unit = { + activeThreadSession.set(session) + } + + /** + * Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will + * return the first created context instead of a thread-local override. + * + * @since 2.0.0 + */ + def clearActiveSession(): Unit = { + activeThreadSession.remove() + } + + /** + * Sets the default SparkSession that is returned by the builder. + * + * @since 2.0.0 + */ + def setDefaultSession(session: SparkSession): Unit = { + defaultSession.set(session) + } + + /** + * Clears the default SparkSession that is returned by the builder. + * + * @since 2.0.0 + */ + def clearDefaultSession(): Unit = { + defaultSession.set(null) + } + + private[sql] def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + + private[sql] def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + + /** A global SQL listener used for the SQL UI. */ + private[sql] val sqlListener = new AtomicReference[SQLListener]() + + //////////////////////////////////////////////////////////////////////////////////////// + // Private methods from now on + //////////////////////////////////////////////////////////////////////////////////////// + + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[SparkSession] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[SparkSession] + private val HIVE_SHARED_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSharedState" private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState" @@ -683,17 +967,4 @@ object SparkSession { } } - /** - * Create a new [[SparkSession]] with a catalog backed by Hive. - */ - def withHiveSupport(sc: SparkContext): SparkSession = { - if (hiveClassesArePresent) { - sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive") - new SparkSession(sc) - } else { - throw new IllegalArgumentException( - "Unable to instantiate SparkSession with Hive support because Hive classes are not found.") - } - } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 3a043dcc6af22..b006236481a29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.types.DataType /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. + * Note that the user-defined functions must be deterministic. * * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 36173a49250b5..7d8ea03a27910 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -18,28 +18,65 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util.{Map => JMap} +import scala.collection.JavaConverters._ import scala.util.matching.Regex +import org.apache.spark.internal.Logging +import org.apache.spark.SparkContext import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema -import org.apache.spark.sql.Encoder +import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.types._ -private[sql] object SQLUtils { +private[sql] object SQLUtils extends Logging { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) - def createSQLContext(jsc: JavaSparkContext): SQLContext = { - SQLContext.getOrCreate(jsc.sc) + private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = { + sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive") + sc } - def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { - new JavaSparkContext(sqlCtx.sparkContext) + def getOrCreateSparkSession( + jsc: JavaSparkContext, + sparkConfigMap: JMap[Object, Object], + enableHiveSupport: Boolean): SparkSession = { + val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport) { + SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() + } else { + if (enableHiveSupport) { + logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + + "Spark is not built with Hive; falling back to without Hive support.") + } + SparkSession.builder().sparkContext(jsc.sc).getOrCreate() + } + setSparkContextSessionConf(spark, sparkConfigMap) + spark + } + + def setSparkContextSessionConf( + spark: SparkSession, + sparkConfigMap: JMap[Object, Object]): Unit = { + for ((name, value) <- sparkConfigMap.asScala) { + spark.conf.set(name.toString, value.toString) + } + for ((name, value) <- sparkConfigMap.asScala) { + spark.sparkContext.conf.set(name.toString, value.toString) + } + } + + def getSessionConf(spark: SparkSession): JMap[String, String] = { + spark.conf.getAll.asJava + } + + def getJavaSparkContext(spark: SparkSession): JavaSparkContext = { + new JavaSparkContext(spark.sparkContext) } def createStructType(fields : Seq[StructField]): StructType = { @@ -75,7 +112,7 @@ private[sql] object SQLUtils { org.apache.spark.sql.types.MapType(getSQLDataType(keyType), getSQLDataType(valueType)) case r"\Astruct<(.+)${fieldsStr}>\Z" => if (fieldsStr(fieldsStr.length - 1) == ',') { - throw new IllegalArgumentException(s"Invaid type $dataType") + throw new IllegalArgumentException(s"Invalid type $dataType") } val fields = fieldsStr.split(",") val structFields = fields.map { field => @@ -83,11 +120,11 @@ private[sql] object SQLUtils { case r"\A(.+)${fieldName}:(.+)${fieldType}\Z" => createStructField(fieldName, fieldType, true) - case _ => throw new IllegalArgumentException(s"Invaid type $dataType") + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") } } createStructType(structFields) - case _ => throw new IllegalArgumentException(s"Invaid type $dataType") + case _ => throw new IllegalArgumentException(s"Invalid type $dataType") } } @@ -96,10 +133,10 @@ private[sql] object SQLUtils { StructField(name, dtObj, nullable) } - def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { + def createDF(rdd: RDD[Array[Byte]], schema: StructType, sparkSession: SparkSession): DataFrame = { val num = schema.fields.length val rowRDD = rdd.map(bytesToRow(_, schema)) - sqlContext.createDataFrame(rowRDD, schema) + sparkSession.createDataFrame(rowRDD, schema) } def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { @@ -110,6 +147,8 @@ private[sql] object SQLUtils { data match { case d: java.lang.Double if dataType == FloatType => new java.lang.Float(d) + // Scala Map is the only allowed external type of map type in Row. + case m: java.util.Map[_, _] => m.asScala case _ => data } } @@ -120,7 +159,7 @@ private[sql] object SQLUtils { val num = SerDe.readInt(dis) Row.fromSeq((0 until num).map { i => doConversion(SerDe.readObject(dis), schema.fields(i).dataType) - }.toSeq) + }) } private[sql] def rowToRBytes(row: Row): Array[Byte] = { @@ -145,16 +184,26 @@ private[sql] object SQLUtils { packageNames: Array[Byte], broadcastVars: Array[Object], schema: StructType): DataFrame = { - val bv = broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]]) - val realSchema = - if (schema == null) { - SERIALIZED_R_DATA_SCHEMA - } else { - schema - } + val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]]) + val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema df.mapPartitionsInR(func, packageNames, bv, realSchema) } + /** + * The helper function for gapply() on R side. + */ + def gapply( + gd: RelationalGroupedDataset, + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Object], + schema: StructType): DataFrame = { + val bv = broadcastVars.map(_.asInstanceOf[Broadcast[Object]]) + val realSchema = if (schema == null) SERIALIZED_R_DATA_SCHEMA else schema + gd.flatMapGroupsInR(func, packageNames, bv, realSchema) + } + + def dfToCols(df: DataFrame): Array[Array[Any]] = { val localDF: Array[Row] = df.collect() val numCols = df.columns.length @@ -180,18 +229,18 @@ private[sql] object SQLUtils { } def loadDF( - sqlContext: SQLContext, + sparkSession: SparkSession, source: String, options: java.util.Map[String, String]): DataFrame = { - sqlContext.read.format(source).options(options).load() + sparkSession.read.format(source).options(options).load() } def loadDF( - sqlContext: SQLContext, + sparkSession: SparkSession, source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { - sqlContext.read.format(source).schema(schema).options(options).load() + sparkSession.read.format(source).schema(schema).options(options).load() } def readSqlObject(dis: DataInputStream, dataType: Char): Object = { @@ -216,4 +265,22 @@ private[sql] object SQLUtils { false } } + + def getTables(sparkSession: SparkSession, databaseName: String): DataFrame = { + databaseName match { + case n: String if n != null && n.trim.nonEmpty => + Dataset.ofRows(sparkSession, ShowTablesCommand(Some(n), None)) + case _ => + Dataset.ofRows(sparkSession, ShowTablesCommand(None, None)) + } + } + + def getTableNames(sparkSession: SparkSession, databaseName: String): Array[String] = { + databaseName match { + case n: String if n != null && n.trim.nonEmpty => + sparkSession.catalog.listTables(n).collect().map(_.name) + case _ => + sparkSession.catalog.listTables().collect().map(_.name) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 7a815c1f99336..1aed245fdd332 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -24,6 +24,8 @@ import org.apache.spark.sql.types.StructType /** * Catalog interface for Spark. To access this, use `SparkSession.catalog`. + * + * @since 2.0.0 */ abstract class Catalog { @@ -83,7 +85,8 @@ abstract class Catalog { def listFunctions(dbName: String): Dataset[Function] /** - * Returns a list of columns for the given table in the current database. + * Returns a list of columns for the given table in the current database or + * the given temporary table. * * @since 2.0.0 */ @@ -175,13 +178,13 @@ abstract class Catalog { options: Map[String, String]): DataFrame /** - * Drops the temporary table with the given table name in the catalog. - * If the table has been cached before, then it will also be uncached. + * Drops the temporary view with the given view name in the catalog. + * If the view has been cached before, then it will also be uncached. * - * @param tableName the name of the table to be dropped. + * @param viewName the name of the view to be dropped. * @since 2.0.0 */ - def dropTempTable(tableName: String): Unit + def dropTempView(viewName: String): Unit /** * Returns true if the table is currently cached in-memory. @@ -211,4 +214,24 @@ abstract class Catalog { */ def clearCache(): Unit + /** + * Invalidate and refresh all the cached metadata of the given table. For performance reasons, + * Spark SQL or the external data source library it uses might cache certain metadata about a + * table, such as the location of blocks. When those change outside of Spark SQL, users should + * call this function to invalidate the cache. + * + * If this table is cached as an InMemoryRelation, drop the original cached version and make the + * new version cached lazily. + * + * @since 2.0.0 + */ + def refreshTable(tableName: String): Unit + + /** + * Invalidate and refresh all the cached data (and the associated metadata) for any dataframe that + * contains the given data source path. + * + * @since 2.0.0 + */ + def refreshByPath(path: String): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index 0f7feb8eee7be..33032f07f7bea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -25,6 +25,14 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams // Note: all classes here are expected to be wrapped in Datasets and so must extend // DefinedByConstructorParams for the catalog to be able to create encoders for them. +/** + * A database in Spark, as returned by the `listDatabases` method defined in [[Catalog]]. + * + * @param name name of the database. + * @param description description of the database. + * @param locationUri path (in the form of a uri) to data files. + * @since 2.0.0 + */ class Database( val name: String, @Nullable val description: String, @@ -41,6 +49,16 @@ class Database( } +/** + * A table in Spark, as returned by the `listTables` method in [[Catalog]]. + * + * @param name name of the table. + * @param database name of the database the table belongs to. + * @param description description of the table. + * @param tableType type of the table (e.g. view, table). + * @param isTemporary whether the table is a temporary table. + * @since 2.0.0 + */ class Table( val name: String, @Nullable val database: String, @@ -61,6 +79,17 @@ class Table( } +/** + * A column in Spark, as returned by `listColumns` method in [[Catalog]]. + * + * @param name name of the column. + * @param description description of the column. + * @param dataType data type of the column. + * @param nullable whether the column is nullable. + * @param isPartition whether the column is a partition column. + * @param isBucket whether the column is a bucket column. + * @since 2.0.0 + */ class Column( val name: String, @Nullable val description: String, @@ -83,9 +112,19 @@ class Column( } -// TODO(andrew): should we include the database here? +/** + * A user-defined function in Spark, as returned by `listFunctions` method in [[Catalog]]. + * + * @param name name of the function. + * @param database name of the database the function belongs to. + * @param description description of the function; description can be null. + * @param className the fully qualified class name of the function. + * @param isTemporary whether the function is a temporary function or not. + * @since 2.0.0 + */ class Function( val name: String, + @Nullable val database: String, @Nullable val description: String, val className: String, val isTemporary: Boolean) @@ -94,6 +133,7 @@ class Function( override def toString: String = { "Function[" + s"name='$name', " + + Option(database).map { d => s"database='$d', " }.getOrElse("") + Option(description).map { d => s"description='$d', " }.getOrElse("") + s"className='$className', " + s"isTemporary='$isTemporary']" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 7f26a7e411f73..5e0263ec5b4c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable.Map import scala.util.control.NonFatal import org.apache.spark.internal.Logging @@ -38,14 +39,23 @@ import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType} * representations (e.g. logical plans that operate on local Scala collections), or are simply not * supported by this builder (yet). */ -class SQLBuilder(logicalPlan: LogicalPlan) extends Logging { +class SQLBuilder private ( + logicalPlan: LogicalPlan, + nextSubqueryId: AtomicLong, + nextGenAttrId: AtomicLong, + exprIdMap: Map[Long, Long]) extends Logging { require(logicalPlan.resolved, "SQLBuilder only supports resolved logical query plans. Current plan:\n" + logicalPlan) + def this(logicalPlan: LogicalPlan) = + this(logicalPlan, new AtomicLong(0), new AtomicLong(0), Map.empty[Long, Long]) + def this(df: Dataset[_]) = this(df.queryExecution.analyzed) - private val nextSubqueryId = new AtomicLong(0) private def newSubqueryName(): String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" + private def normalizedName(n: NamedExpression): String = synchronized { + "gen_attr_" + exprIdMap.getOrElseUpdate(n.exprId.id, nextGenAttrId.getAndIncrement()) + } def toSQL: String = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) @@ -69,8 +79,14 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging { try { val replaced = finalPlan.transformAllExpressions { - case e: SubqueryExpression => - SubqueryHolder(new SQLBuilder(e.query).toSQL) + case s: SubqueryExpression => + val query = new SQLBuilder(s.query, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL + val sql = s match { + case _: ListQuery => query + case _: Exists => s"EXISTS($query)" + case _ => s"($query)" + } + SubqueryHolder(sql) case e: NonSQLExpression => throw new UnsupportedOperationException( s"Expression $e doesn't have a SQL representation" @@ -163,6 +179,11 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging { qualifiedName + " TABLESAMPLE(" + fraction + " PERCENT)" }.getOrElse(qualifiedName) + case relation: CatalogRelation => + val m = relation.catalogTable + val qualifiedName = s"${quoteIdentifier(m.database)}.${quoteIdentifier(m.identifier.table)}" + qualifiedName + case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) if orders.map(_.child) == partitionExprs => build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", ")) @@ -184,6 +205,12 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging { case p: ScriptTransformation => scriptTransformationToSQL(p) + case p: LocalRelation => + p.toSQL(newSubqueryName()) + + case p: Range => + p.toSQL() + case OneRowRelation => "" @@ -268,7 +295,7 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging { // 5. the table alias for output columns of generator. // 6. the AS keyword // 7. the column alias, can be more than one, e.g. AS key, value - // An concrete example: "tbl LATERAL VIEW EXPLODE(map_col) sub_q AS key, value", and the builder + // A concrete example: "tbl LATERAL VIEW EXPLODE(map_col) sub_q AS key, value", and the builder // will put it in FROM clause later. build( childSQL, @@ -370,8 +397,6 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging { ) } - private def normalizedName(n: NamedExpression): String = "gen_attr_" + n.exprId.id - object Canonicalizer extends RuleExecutor[LogicalPlan] { override protected def batches: Seq[Batch] = Seq( Batch("Prepare", FixedPoint(100), @@ -404,7 +429,9 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging { // that table relation, as we can only convert table sample to standard SQL string. ResolveSQLTable, // Insert sub queries on top of operators that need to appear after FROM clause. - AddSubquery + AddSubquery, + // Reconstruct subquery expressions. + ConstructSubqueryExpressions ) ) @@ -484,6 +511,57 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging { } } + object ConstructSubqueryExpressions extends Rule[LogicalPlan] { + def apply(tree: LogicalPlan): LogicalPlan = tree transformAllExpressions { + case ScalarSubquery(query, conditions, exprId) if conditions.nonEmpty => + def rewriteAggregate(a: Aggregate): Aggregate = { + val filter = Filter(conditions.reduce(And), addSubqueryIfNeeded(a.child)) + Aggregate(Nil, a.aggregateExpressions.take(1), filter) + } + val cleaned = query match { + case Project(_, child) => child + case child => child + } + val rewrite = cleaned match { + case a: Aggregate => + rewriteAggregate(a) + case Filter(c, a: Aggregate) => + Filter(c, rewriteAggregate(a)) + } + ScalarSubquery(rewrite, Seq.empty, exprId) + + case PredicateSubquery(query, conditions, false, exprId) => + val subquery = addSubqueryIfNeeded(query) + val plan = if (conditions.isEmpty) { + subquery + } else { + Project(Seq(Alias(Literal(1), "1")()), + Filter(conditions.reduce(And), subquery)) + } + Exists(plan, exprId) + + case PredicateSubquery(query, conditions, true, exprId) => + val (in, correlated) = conditions.partition(_.isInstanceOf[EqualTo]) + val (outer, inner) = in.zipWithIndex.map { + case (EqualTo(l, r), i) if query.outputSet.intersect(r.references).nonEmpty => + (l, Alias(r, s"_c$i")()) + case (EqualTo(r, l), i) => + (l, Alias(r, s"_c$i")()) + }.unzip + val wrapped = addSubqueryIfNeeded(query) + val filtered = if (correlated.nonEmpty) { + Filter(conditions.reduce(And), wrapped) + } else { + wrapped + } + val value = outer match { + case Seq(expr) => expr + case exprs => CreateStruct(exprs) + } + In(value, Seq(ListQuery(Project(inner, filtered), exprId))) + } + } + private def addSubquery(plan: LogicalPlan): SubqueryAlias = { SubqueryAlias(newSubqueryName(), plan) } @@ -526,9 +604,8 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging { /** * A place holder for generated SQL for subquery expression. */ - case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable { + case class SubqueryHolder(override val sql: String) extends LeafExpression with Unevaluable { override def dataType: DataType = NullType override def nullable: Boolean = true - override def sql: String = s"($query)" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f601138a9d72b..83b7c779ab818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -19,15 +19,19 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock +import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.SparkSession import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ -private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) +case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) /** * Provides support in a SQLContext for caching query results and automatically using these cached @@ -37,7 +41,7 @@ private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMe * * Internal to Spark SQL. */ -private[sql] class CacheManager extends Logging { +class CacheManager extends Logging { @transient private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] @@ -64,13 +68,13 @@ private[sql] class CacheManager extends Logging { } /** Clears all cached tables. */ - private[sql] def clearCache(): Unit = writeLock { + def clearCache(): Unit = writeLock { cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } /** Checks if the cache is empty. */ - private[sql] def isEmpty: Boolean = readLock { + def isEmpty: Boolean = readLock { cachedData.isEmpty } @@ -79,7 +83,7 @@ private[sql] class CacheManager extends Logging { * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because * recomputing the in-memory columnar representation of the underlying table is expensive. */ - private[sql] def cacheQuery( + def cacheQuery( query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { @@ -95,27 +99,16 @@ private[sql] class CacheManager extends Logging { sparkSession.sessionState.conf.useCompression, sparkSession.sessionState.conf.columnBatchSize, storageLevel, - sparkSession.executePlan(planToCache).executedPlan, + sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName)) } } - /** Removes the data for the given [[Dataset]] from the cache */ - private[sql] def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { - val planToCache = query.queryExecution.analyzed - val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) - require(dataIndex >= 0, s"Table $query is not cached.") - cachedData(dataIndex).cachedRepresentation.uncache(blocking) - cachedData.remove(dataIndex) - } - /** - * Tries to remove the data for the given [[Dataset]] from the cache - * if it's cached + * Tries to remove the data for the given [[Dataset]] from the cache. + * No operation, if it's already uncached. */ - private[sql] def tryUncacheQuery( - query: Dataset[_], - blocking: Boolean = true): Boolean = writeLock { + def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) val found = dataIndex >= 0 @@ -127,17 +120,17 @@ private[sql] class CacheManager extends Logging { } /** Optionally returns cached data for the given [[Dataset]] */ - private[sql] def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { + def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ - private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { + def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { cachedData.find(cd => plan.sameResult(cd.plan)) } /** Replaces segments of the given logical plan with cached versions where possible. */ - private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = { + def useCachedData(plan: LogicalPlan): LogicalPlan = { plan transformDown { case currentFragment => lookupCachedData(currentFragment) @@ -150,11 +143,56 @@ private[sql] class CacheManager extends Logging { * Invalidates the cache of any data that contains `plan`. Note that it is possible that this * function will over invalidate. */ - private[sql] def invalidateCache(plan: LogicalPlan): Unit = writeLock { + def invalidateCache(plan: LogicalPlan): Unit = writeLock { cachedData.foreach { case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => data.cachedRepresentation.recache() case _ => } } + + /** + * Invalidates the cache of any data that contains `resourcePath` in one or more + * `HadoopFsRelation` node(s) as part of its logical plan. + */ + def invalidateCachedPath( + sparkSession: SparkSession, resourcePath: String): Unit = writeLock { + val (fs, qualifiedPath) = { + val path = new Path(resourcePath) + val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) + (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + } + + cachedData.foreach { + case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => + val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) + if (dataIndex >= 0) { + data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) + cachedData.remove(dataIndex) + } + sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) + case _ => // Do Nothing + } + } + + /** + * Traverses a given `plan` and searches for the occurrences of `qualifiedPath` in the + * [[org.apache.spark.sql.execution.datasources.FileCatalog]] of any [[HadoopFsRelation]] nodes + * in the plan. If found, we refresh the metadata and return true. Otherwise, this method returns + * false. + */ + private def lookupAndRefresh(plan: LogicalPlan, fs: FileSystem, qualifiedPath: Path): Boolean = { + plan match { + case lr: LogicalRelation => lr.relation match { + case hr: HadoopFsRelation => + val invalidate = hr.location.paths + .map(_.makeQualified(fs.getUri, fs.getWorkingDirectory)) + .contains(qualifiedPath) + if (invalidate) hr.location.refresh() + invalidate + case _ => false + } + case _ => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d6516f26a70f3..ba30bed0b450e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -17,21 +17,23 @@ package org.apache.spark.sql.execution +import org.apache.commons.lang3.StringUtils + import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession, SQLContext} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} -import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.datasources.HadoopFsRelation -import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetSource} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.Utils object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -73,7 +75,7 @@ object RDDConversions { } /** Logical plan node for scanning data from an RDD. */ -private[sql] case class LogicalRDD( +case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow])(session: SparkSession) extends LogicalPlan with MultiInstanceRelation { @@ -85,11 +87,15 @@ private[sql] case class LogicalRDD( override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan): Boolean = plan match { - case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id - case _ => false + override def sameResult(plan: LogicalPlan): Boolean = { + plan.canonicalized match { + case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id + case _ => false + } } + override protected def stringArgs: Iterator[Any] = Iterator(output) + override def producedAttributes: AttributeSet = outputSet @transient override lazy val statistics: Statistics = Statistics( @@ -100,12 +106,12 @@ private[sql] case class LogicalRDD( } /** Physical plan node for scanning data from an RDD. */ -private[sql] case class RDDScanExec( +case class RDDScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], override val nodeName: String) extends LeafExecNode { - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { @@ -120,15 +126,18 @@ private[sql] case class RDDScanExec( } override def simpleString: String = { - s"Scan $nodeName${output.mkString("[", ",", "]")}" + s"Scan $nodeName${Utils.truncatedString(output, "[", ",", "]")}" } } -private[sql] trait DataSourceScanExec extends LeafExecNode { +trait DataSourceScanExec extends LeafExecNode { val rdd: RDD[InternalRow] val relation: BaseRelation + val metastoreTableIdentifier: Option[TableIdentifier] - override val nodeName: String = relation.toString + override val nodeName: String = { + s"Scan $relation ${metastoreTableIdentifier.map(_.unquotedString).getOrElse("")}" + } // Ignore rdd when checking results override def sameResult(plan: SparkPlan): Boolean = plan match { @@ -138,20 +147,22 @@ private[sql] trait DataSourceScanExec extends LeafExecNode { } /** Physical plan node for scanning data from a relation. */ -private[sql] case class RowDataSourceScanExec( +case class RowDataSourceScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], @transient relation: BaseRelation, override val outputPartitioning: Partitioning, - override val metadata: Map[String, String] = Map.empty) + override val metadata: Map[String, String], + override val metastoreTableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with CodegenSupport { - private[sql] override lazy val metrics = + override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) val outputUnsafeRows = relation match { case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] => - !SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + !SparkSession.getActiveSession.get.sessionState.conf.getConf( + SQLConf.PARQUET_VECTORIZED_READER_ENABLED) case _: HadoopFsRelation => true case _ => false } @@ -174,8 +185,12 @@ private[sql] case class RowDataSourceScanExec( } override def simpleString: String = { - val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" - s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}" + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield { + key + ": " + StringUtils.abbreviate(value, 100) + } + + s"$nodeName${Utils.truncatedString(output, "[", ",", "]")}" + + s"${Utils.truncatedString(metadataEntries, " ", ", ", "")}" } override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -207,26 +222,32 @@ private[sql] case class RowDataSourceScanExec( } /** Physical plan node for scanning data from a batched relation. */ -private[sql] case class BatchedDataSourceScanExec( +case class BatchedDataSourceScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], @transient relation: BaseRelation, override val outputPartitioning: Partitioning, - override val metadata: Map[String, String] = Map.empty) + override val metadata: Map[String, String], + override val metastoreTableIdentifier: Option[TableIdentifier]) extends DataSourceScanExec with CodegenSupport { - private[sql] override lazy val metrics = + override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) protected override def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException + // in the case of fallback, this batched scan should never fail because of: + // 1) only primitive types are supported + // 2) the number of columns should be smaller than spark.sql.codegen.maxFields + WholeStageCodegenExec(this).execute() } override def simpleString: String = { - val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" - val metadataStr = metadataEntries.mkString(" ", ", ", "") - s"BatchedScan $nodeName${output.mkString("[", ",", "]")}$metadataStr" + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield { + key + ": " + StringUtils.abbreviate(value, 100) + } + val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "") + s"Batched$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr" } override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -240,7 +261,7 @@ private[sql] case class BatchedDataSourceScanExec( val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = s"/* ${toCommentSafeString(str)} */\n" + (if (nullable) { + val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { s""" boolean ${isNullVar} = ${columnVar}.isNullAt($ordinal); $javaType ${valueVar} = ${isNullVar} ? ${ctx.defaultValue(dataType)} : ($value); @@ -316,7 +337,7 @@ private[sql] case class BatchedDataSourceScanExec( } } -private[sql] object DataSourceScanExec { +object DataSourceScanExec { // Metadata keys val INPUT_PATHS = "InputPaths" val PUSHED_FILTERS = "PushedFilters" @@ -325,7 +346,8 @@ private[sql] object DataSourceScanExec { output: Seq[Attribute], rdd: RDD[InternalRow], relation: BaseRelation, - metadata: Map[String, String] = Map.empty): DataSourceScanExec = { + metadata: Map[String, String] = Map.empty, + metastoreTableIdentifier: Option[TableIdentifier] = None): DataSourceScanExec = { val outputPartitioning = { val bucketSpec = relation match { // TODO: this should be closer to bucket planning. @@ -334,15 +356,14 @@ private[sql] object DataSourceScanExec { case _ => None } - def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse { - throw new AnalysisException(s"bucket column $colName not found in existing columns " + - s"(${output.map(_.name).mkString(", ")})") - } - bucketSpec.map { spec => val numBuckets = spec.numBuckets - val bucketColumns = spec.bucketColumnNames.map(toAttribute) - HashPartitioning(bucketColumns, numBuckets) + val bucketColumns = spec.bucketColumnNames.flatMap { n => output.find(_.name == n) } + if (bucketColumns.size == spec.bucketColumnNames.size) { + HashPartitioning(bucketColumns, numBuckets) + } else { + UnknownPartitioning(0) + } }.getOrElse { UnknownPartitioning(0) } @@ -351,9 +372,11 @@ private[sql] object DataSourceScanExec { relation match { case r: HadoopFsRelation if r.fileFormat.supportBatch(r.sparkSession, StructType.fromAttributes(output)) => - BatchedDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata) + BatchedDataSourceScanExec( + output, rdd, relation, outputPartitioning, metadata, metastoreTableIdentifier) case _ => - RowDataSourceScanExec(output, rdd, relation, outputPartitioning, metadata) + RowDataSourceScanExec( + output, rdd, relation, outputPartitioning, metadata, metastoreTableIdentifier) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index c201822d4479a..d5603b3b00914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit import org.apache.spark.sql.execution.metric.SQLMetrics /** - * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. + * Apply all of the GroupExpressions to every input row, hence we will get + * multiple output rows for an input row. * @param projections The group of expressions, all of the group expressions should * output the same schema specified bye the parameter `output` * @param output The output Schema @@ -39,7 +39,7 @@ case class ExpandExec( child: SparkPlan) extends UnaryExecNode with CodegenSupport { - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) // The GroupExpressions can output data with arbitrary partitioning, so set it diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala index 7a2a9eed5807d..a299fed7fd14a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/FileRelation.scala @@ -22,7 +22,7 @@ package org.apache.spark.sql.execution * the list of paths that it returns will be returned to a user who calls `inputPaths` on any * DataFrame that queries this relation. */ -private[sql] trait FileRelation { +trait FileRelation { /** Returns the list of files that will be read when scanning this relation. */ def inputFiles: Array[String] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 934bc38dc47cb..39189a2b0c72c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -55,7 +55,7 @@ case class GenerateExec( child: SparkPlan) extends UnaryExecNode { - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def producedAttributes: AttributeSet = AttributeSet(output) @@ -66,7 +66,7 @@ case class GenerateExec( // boundGenerator.terminate() should be triggered after all of the rows in the partition val rows = if (join) { child.execute().mapPartitionsInternal { iter => - val generatorNullRow = new GenericInternalRow(generator.elementTypes.size) + val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) val joinedRow = new JoinedRow iter.flatMap { row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index c5e78b033359d..c998e04223f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics /** * Physical plan node for scanning data from a local collection. */ -private[sql] case class LocalTableScanExec( +case class LocalTableScanExec( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafExecNode { - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) private val unsafeRows: Array[InternalRow] = { @@ -48,11 +48,22 @@ private[sql] case class LocalTableScanExec( } } + override protected def stringArgs: Iterator[Any] = { + if (rows.isEmpty) { + Iterator("", output) + } else { + Iterator(output) + } + } + override def executeCollect(): Array[InternalRow] = { + longMetric("numOutputRows").add(unsafeRows.size) unsafeRows } override def executeTake(limit: Int): Array[InternalRow] = { - unsafeRows.take(limit) + val taken = unsafeRows.take(limit) + longMetric("numOutputRows").add(taken.size) + taken } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 3e772286e0e55..ba2332398f962 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCom import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} +import org.apache.spark.util.Utils /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -60,20 +61,20 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } lazy val analyzed: LogicalPlan = { - SQLContext.setActive(sparkSession.wrapped) + SparkSession.setActiveSession(sparkSession) sparkSession.sessionState.analyzer.execute(logical) } lazy val withCachedData: LogicalPlan = { assertAnalyzed() assertSupported() - sparkSession.cacheManager.useCachedData(analyzed) + sparkSession.sharedState.cacheManager.useCachedData(analyzed) } lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData) lazy val sparkPlan: SparkPlan = { - SQLContext.setActive(sparkSession.wrapped) + SparkSession.setActiveSession(sparkSession) planner.plan(ReturnAnswer(optimizedPlan)).next() } @@ -110,24 +111,27 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { */ def hiveResultString(): Seq[String] = executedPlan match { case ExecutedCommandExec(desc: DescribeTableCommand) => - // If it is a describe command for a Hive table, we want to have the output format - // be similar with Hive. - desc.run(sparkSession).map { - case Row(name: String, dataType: String, comment) => - Seq(name, dataType, - Option(comment.asInstanceOf[String]).getOrElse("")) - .map(s => String.format(s"%-20s", s)) - .mkString("\t") + SQLExecution.withNewExecutionId(sparkSession, this) { + // If it is a describe command for a Hive table, we want to have the output format + // be similar with Hive. + desc.run(sparkSession).map { + case Row(name: String, dataType: String, comment) => + Seq(name, dataType, + Option(comment.asInstanceOf[String]).getOrElse("")) + .map(s => String.format(s"%-20s", s)) + .mkString("\t") + } } case command: ExecutedCommandExec => command.executeCollect().map(_.getString(0)) - case other => - val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq - // We need the types so we can output struct field names - val types = analyzed.output.map(_.dataType) - // Reformat to match hive tab delimited output. - result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq + SQLExecution.withNewExecutionId(sparkSession, this) { + val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq + // We need the types so we can output struct field names + val types = analyzed.output.map(_.dataType) + // Reformat to match hive tab delimited output. + result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq + } } /** Formats a datum (based on the given data type) and returns the string representation. */ @@ -201,23 +205,26 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { def simpleString: String = { s"""== Physical Plan == - |${stringOrError(executedPlan)} + |${stringOrError(executedPlan.treeString(verbose = false))} """.stripMargin.trim } override def toString: String = { - def output = - analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}").mkString(", ") + def output = Utils.truncatedString( + analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ") + val analyzedPlan = Seq( + stringOrError(output), + stringOrError(analyzed.treeString(verbose = true)) + ).filter(_.nonEmpty).mkString("\n") s"""== Parsed Logical Plan == - |${stringOrError(logical)} + |${stringOrError(logical.treeString(verbose = true))} |== Analyzed Logical Plan == - |${stringOrError(output)} - |${stringOrError(analyzed)} + |$analyzedPlan |== Optimized Logical Plan == - |${stringOrError(optimizedPlan)} + |${stringOrError(optimizedPlan.treeString(verbose = true))} |== Physical Plan == - |${stringOrError(executedPlan)} + |${stringOrError(executedPlan.treeString(verbose = true))} """.stripMargin.trim } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala index 7462dbc4eba3a..717ff93eab5d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow * iterator to consume the next row, whereas RowIterator combines these calls into a single * [[advanceNext()]] method. */ -private[sql] abstract class RowIterator { +abstract class RowIterator { /** * Advance this iterator by a single row. Returns `false` if this iterator has no more rows * and `true` otherwise. If this returns `true`, then the new row can be retrieved by calling diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 397d66b31153f..ec07aab359ac6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -23,9 +23,8 @@ import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} -import org.apache.spark.util.Utils -private[sql] object SQLExecution { +object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" @@ -46,7 +45,11 @@ private[sql] object SQLExecution { val executionId = SQLExecution.nextExecutionId sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) val r = try { - val callSite = Utils.getCallSite() + // sparkContext.getCallSite() would first try to pick up any call site that was previously + // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on + // streaming queries would give us call site like "run at :0" + val callSite = sparkSession.sparkContext.getCallSite() + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index e81cd28ea34d1..5f0c26441692d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -95,7 +95,7 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A * interfaces / internals. * * This RDD takes a [[ShuffleDependency]] (`dependency`), - * and a optional array of partition start indices as input arguments + * and an optional array of partition start indices as input arguments * (`specifiedPartitionStartIndices`). * * The `dependency` has the parent RDD of this RDD, which represents the dataset before shuffle diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 0e4d6d72c6b5b..cde3ed48ffeaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -22,11 +22,9 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.unsafe.sort.RadixSort; /** * Performs (external) sorting. @@ -52,7 +50,7 @@ case class SortExec( private val enableRadixSort = sqlContext.conf.enableRadixSort - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "sortTime" -> SQLMetrics.createTimingMetric(sparkContext, "sort time"), "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) @@ -97,11 +95,8 @@ case class SortExec( // Remember spill data size of this task before execute this operator so that we can // figure out how many bytes we spilled for this operator. val spillSizeBefore = metrics.memoryBytesSpilled - val beforeSort = System.nanoTime() - val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) - - sortTime += (System.nanoTime() - beforeSort) / 1000000 + sortTime += sorter.getSortTimeNanos / 1000000 peakMemory += sorter.getPeakMemoryUsage spillSize += metrics.memoryBytesSpilled - spillSizeBefore metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage) @@ -151,15 +146,13 @@ case class SortExec( val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") val spillSizeBefore = ctx.freshName("spillSizeBefore") - val startTime = ctx.freshName("startTime") val sortTime = metricTerm(ctx, "sortTime") s""" | if ($needToSort) { | long $spillSizeBefore = $metrics.memoryBytesSpilled(); - | long $startTime = System.nanoTime(); | $addToSorter(); | $sortedIterator = $sorterVariable.sort(); - | $sortTime.add((System.nanoTime() - $startTime) / 1000000); + | $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000); | $peakMemory.add($sorterVariable.getPeakMemoryUsage()); | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 08b2d7fcd4882..12a10cba20fe9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate import org.apache.spark.sql.internal.SQLConf class SparkOptimizer( @@ -28,6 +29,7 @@ class SparkOptimizer( experimentalMethods: ExperimentalMethods) extends Optimizer(catalog, conf) { - override def batches: Seq[Batch] = super.batches :+ Batch( - "User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) + override def batches: Seq[Batch] = super.batches :+ + Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ + Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 0bbe970420707..fa40414bcea86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -27,7 +27,7 @@ import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -50,7 +50,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SQLContext.getActive().orNull + final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull protected def sparkContext = sqlContext.sparkContext @@ -65,31 +65,31 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { - SQLContext.setActive(sqlContext) + SparkSession.setActiveSession(sqlContext.sparkSession) super.makeCopy(newArgs) } /** * Return all metadata that describes more details of this SparkPlan. */ - private[sql] def metadata: Map[String, String] = Map.empty + def metadata: Map[String, String] = Map.empty /** * Return all metrics containing metrics of this SparkPlan. */ - private[sql] def metrics: Map[String, SQLMetric] = Map.empty + def metrics: Map[String, SQLMetric] = Map.empty /** * Reset all the metrics. */ - private[sql] def resetMetrics(): Unit = { + def resetMetrics(): Unit = { metrics.valuesIterator.foreach(_.reset()) } /** * Return a LongSQLMetric according to the name. */ - private[sql] def longMetric(name: String): SQLMetric = metrics(name) + def longMetric(name: String): SQLMetric = metrics(name) // TODO: Move to `DistributedPlan` /** Specifies how data is partitioned across different nodes in the cluster. */ @@ -106,16 +106,20 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) /** - * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute after - * preparations. Concrete implementations of SparkPlan should override doExecute. + * Returns the result of this query as an RDD[InternalRow] by delegating to `doExecute` after + * preparations. + * + * Concrete implementations of SparkPlan should override `doExecute`. */ final def execute(): RDD[InternalRow] = executeQuery { doExecute() } /** - * Returns the result of this query as a broadcast variable by delegating to doBroadcast after - * preparations. Concrete implementations of SparkPlan should override doBroadcast. + * Returns the result of this query as a broadcast variable by delegating to `doExecuteBroadcast` + * after preparations. + * + * Concrete implementations of SparkPlan should override `doExecuteBroadcast`. */ final def executeBroadcast[T](): broadcast.Broadcast[T] = executeQuery { doExecuteBroadcast() @@ -162,7 +166,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def waitForSubqueries(): Unit = synchronized { // fill in the result of subqueries subqueryResults.foreach { case (e, futureResult) => - val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf) + val rows = ThreadUtils.awaitResultInForkJoinSafely(futureResult, Duration.Inf) if (rows.length > 1) { sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") } @@ -391,7 +395,7 @@ object SparkPlan { ThreadUtils.newDaemonCachedThreadPool("subquery", 16)) } -private[sql] trait LeafExecNode extends SparkPlan { +trait LeafExecNode extends SparkPlan { override def children: Seq[SparkPlan] = Nil override def producedAttributes: AttributeSet = outputSet } @@ -403,7 +407,7 @@ object UnaryExecNode { } } -private[sql] trait UnaryExecNode extends SparkPlan { +trait UnaryExecNode extends SparkPlan { def child: SparkPlan override def children: Seq[SparkPlan] = child :: Nil @@ -411,7 +415,7 @@ private[sql] trait UnaryExecNode extends SparkPlan { override def outputPartitioning: Partitioning = child.outputPartitioning } -private[sql] trait BinaryExecNode extends SparkPlan { +trait BinaryExecNode extends SparkPlan { def left: SparkPlan def right: SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index f84070a0c4bcb..7aa93126fdabd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -47,7 +47,7 @@ class SparkPlanInfo( } } -private[sql] object SparkPlanInfo { +private[execution] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { val children = plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index de832ec70b4df..319fff1f7c122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} import org.apache.spark.sql.internal.SQLConf @@ -42,6 +43,16 @@ class SparkPlanner( InMemoryScans :: BasicOperators :: Nil) + override def plan(plan: LogicalPlan): Iterator[SparkPlan] = { + super.plan(plan).map { + _.transformUp { + case PlanLater(p) => + // TODO: use the first plan for now, but we will implement plan space exploaration later. + this.plan(p).next() + } + } + } + /** * Used to build table scan operators where complex projection and filtering are done using * separate physical operators. This function returns the given scan operator with Project and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8128a6efe3ccb..f092360a98bae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -18,20 +18,20 @@ package org.apache.spark.sql.execution import scala.collection.JavaConverters._ -import scala.util.Try import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.tree.TerminalNode import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, ScriptInputOutputSchema} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, _} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} - +import org.apache.spark.sql.types.DataType /** * Concrete parser for Spark SQL statements. @@ -75,8 +75,20 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create an [[AnalyzeTable]] command. This currently only implements the NOSCAN option (other - * options are passed on to Hive) e.g.: + * Create a [[ResetCommand]] logical plan. + * Example SQL : + * {{{ + * RESET; + * }}} + */ + override def visitResetConfiguration( + ctx: ResetConfigurationContext): LogicalPlan = withOrigin(ctx) { + ResetCommand + } + + /** + * Create an [[AnalyzeTableCommand]] command. This currently only implements the NOSCAN + * option (other options are passed on to Hive) e.g.: * {{{ * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN; * }}} @@ -85,11 +97,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (ctx.partitionSpec == null && ctx.identifier != null && ctx.identifier.getText.toLowerCase == "noscan") { - AnalyzeTable(visitTableIdentifier(ctx.tableIdentifier).toString) + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString) } else { // Always just run the no scan analyze. We should fix this and implement full analyze // command in the future. - AnalyzeTable(visitTableIdentifier(ctx.tableIdentifier).toString) + AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier).toString) } } @@ -141,7 +153,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * A command for users to list the columm names for a table. + * A command for users to list the column names for a table. * This function creates a [[ShowColumnsCommand]] logical plan. * * The syntax of using this command in SQL is: @@ -154,8 +166,10 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val lookupTable = Option(ctx.db) match { case None => table - case Some(db) if table.database.isDefined => - throw new ParseException("Duplicates the declaration for database", ctx) + case Some(db) if table.database.exists(_ != db) => + operationNotAllowed( + s"SHOW COLUMNS with conflicting databases: '$db' != '${table.database.get}'", + ctx) case Some(db) => TableIdentifier(table.identifier, Some(db.getText)) } ShowColumnsCommand(lookupTable) @@ -178,6 +192,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ShowPartitionsCommand(table, partitionKeys) } + /** + * Creates a [[ShowCreateTableCommand]] + */ + override def visitShowCreateTable(ctx: ShowCreateTableContext): LogicalPlan = withOrigin(ctx) { + val table = visitTableIdentifier(ctx.tableIdentifier()) + ShowCreateTableCommand(table) + } + /** * Create a [[RefreshTable]] logical plan. */ @@ -185,19 +207,33 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { RefreshTable(visitTableIdentifier(ctx.tableIdentifier)) } + /** + * Create a [[RefreshTable]] logical plan. + */ + override def visitRefreshResource(ctx: RefreshResourceContext): LogicalPlan = withOrigin(ctx) { + val resourcePath = remainder(ctx.REFRESH.getSymbol).trim + RefreshResource(resourcePath) + } + /** * Create a [[CacheTableCommand]] logical plan. */ override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { val query = Option(ctx.query).map(plan) - CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null) + val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + if (query.isDefined && tableIdent.database.isDefined) { + val database = tableIdent.database.get + throw new ParseException(s"It is not allowed to add database prefix `$database` to " + + s"the table name in CACHE TABLE AS SELECT", ctx) + } + CacheTableCommand(tableIdent, query, ctx.LAZY != null) } /** * Create an [[UncacheTableCommand]] logical plan. */ override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { - UncacheTableCommand(ctx.identifier.getText) + UncacheTableCommand(visitTableIdentifier(ctx.tableIdentifier)) } /** @@ -209,18 +245,22 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { /** * Create an [[ExplainCommand]] logical plan. + * The syntax of using this command in SQL is: + * {{{ + * EXPLAIN (EXTENDED | CODEGEN) SELECT * FROM ... + * }}} */ override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { - val options = ctx.explainOption.asScala - if (options.exists(_.FORMATTED != null)) { - logWarning("Unsupported operation: EXPLAIN FORMATTED option") + if (ctx.FORMATTED != null) { + operationNotAllowed("EXPLAIN FORMATTED", ctx) + } + if (ctx.LOGICAL != null) { + operationNotAllowed("EXPLAIN LOGICAL", ctx) } - // Create the explain comment. val statement = plan(ctx.statement) if (isExplainableStatement(statement)) { - ExplainCommand(statement, extended = options.exists(_.EXTENDED != null), - codegen = options.exists(_.CODEGEN != null)) + ExplainCommand(statement, extended = ctx.EXTENDED != null, codegen = ctx.CODEGEN != null) } else { ExplainCommand(OneRowRelation) } @@ -238,12 +278,26 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * Create a [[DescribeTableCommand]] logical plan. */ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { - // FORMATTED and columns are not supported. Return null and let the parser decide what to do - // with this (create an exception or pass it on to a different system). - if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) { + // Describe column are not supported yet. Return null and let the parser decide + // what to do with this (create an exception or pass it on to a different system). + if (ctx.describeColName != null) { null } else { - DescribeTableCommand(visitTableIdentifier(ctx.tableIdentifier), ctx.EXTENDED != null) + val partitionSpec = if (ctx.partitionSpec != null) { + // According to the syntax, visitPartitionSpec returns `Map[String, Option[String]]`. + visitPartitionSpec(ctx.partitionSpec).map { + case (key, Some(value)) => key -> value + case (key, _) => + throw new ParseException(s"PARTITION specification is incomplete: `$key`", ctx) + } + } else { + Map.empty[String, String] + } + DescribeTableCommand( + visitTableIdentifier(ctx.tableIdentifier), + partitionSpec, + ctx.EXTENDED != null, + ctx.FORMATTED != null) } } @@ -259,9 +313,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { val temporary = ctx.TEMPORARY != null val ifNotExists = ctx.EXISTS != null - assert(!temporary || !ifNotExists, - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.", - ctx) + if (temporary && ifNotExists) { + operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) + } (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) } @@ -273,40 +327,63 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) if (external) { - throw new ParseException("Unsupported operation: EXTERNAL option", ctx) + operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } - val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText + val partitionColumnNames = + Option(ctx.partitionColumnNames) + .map(visitIdentifierList(_).toArray) + .getOrElse(Array.empty[String]) val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) if (ctx.query != null) { // Get the backing query. val query = plan(ctx.query) + if (temp) { + operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx) + } + // Determine the storage mode. val mode = if (ifNotExists) { SaveMode.Ignore - } else if (temp) { - SaveMode.Overwrite } else { SaveMode.ErrorIfExists } - val partitionColumnNames = - Option(ctx.partitionColumnNames) - .map(visitIdentifierList(_).toArray) - .getOrElse(Array.empty[String]) - CreateTableUsingAsSelect( - table, provider, temp, partitionColumnNames, bucketSpec, mode, options, query) + table, provider, partitionColumnNames, bucketSpec, mode, options, query) } else { val struct = Option(ctx.colTypeList()).map(createStructType) - CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false) + CreateTableUsing( + table, + struct, + provider, + temp, + options, + partitionColumnNames, + bucketSpec, + ifNotExists, + managedIfNoPath = true) } } /** - * Create a [[LoadData]] command. + * Creates a [[CreateTempViewUsing]] logical plan. + */ + override def visitCreateTempViewUsing( + ctx: CreateTempViewUsingContext): LogicalPlan = withOrigin(ctx) { + CreateTempViewUsing( + tableIdent = visitTableIdentifier(ctx.tableIdentifier()), + userSpecifiedSchema = Option(ctx.colTypeList()).map(createStructType), + replace = ctx.REPLACE != null, + provider = ctx.tableProvider.qualifiedName.getText, + options = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) + } + + /** + * Create a [[LoadDataCommand]] command. * * For example: * {{{ @@ -315,7 +392,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitLoadData(ctx: LoadDataContext): LogicalPlan = withOrigin(ctx) { - LoadData( + LoadDataCommand( table = visitTableIdentifier(ctx.tableIdentifier), path = string(ctx.path), isLocal = ctx.LOCAL != null, @@ -324,16 +401,74 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { ) } + /** + * Create a [[TruncateTableCommand]] command. + * + * For example: + * {{{ + * TRUNCATE TABLE tablename [PARTITION (partcol1=val1, partcol2=val2 ...)] + * }}} + */ + override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) { + TruncateTableCommand( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) + } + + /** + * Create a [[AlterTableRecoverPartitionsCommand]] command. + * + * For example: + * {{{ + * MSCK REPAIR TABLE tablename + * }}} + */ + override def visitRepairTable(ctx: RepairTableContext): LogicalPlan = withOrigin(ctx) { + AlterTableRecoverPartitionsCommand( + visitTableIdentifier(ctx.tableIdentifier), + "MSCK REPAIR TABLE") + } + /** * Convert a table property list into a key-value map. + * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. */ override def visitTablePropertyList( ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { - ctx.tableProperty.asScala.map { property => + val properties = ctx.tableProperty.asScala.map { property => val key = visitTablePropertyKey(property.key) val value = Option(property.value).map(string).orNull key -> value - }.toMap + } + // Check for duplicate property names. + checkDuplicateKeys(properties, ctx) + properties.toMap + } + + /** + * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified. + */ + private def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.filter { case (_, v) => v == null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props + } + + /** + * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified. + */ + private def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.filter { case (_, v) => v != null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props.keys.toSeq } /** @@ -350,7 +485,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create a [[CreateDatabase]] command. + * Create a [[CreateDatabaseCommand]] command. * * For example: * {{{ @@ -359,16 +494,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitCreateDatabase(ctx: CreateDatabaseContext): LogicalPlan = withOrigin(ctx) { - CreateDatabase( + CreateDatabaseCommand( ctx.identifier.getText, ctx.EXISTS != null, Option(ctx.locationSpec).map(visitLocationSpec), Option(ctx.comment).map(string), - Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)) + Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) } /** - * Create an [[AlterDatabaseProperties]] command. + * Create an [[AlterDatabasePropertiesCommand]] command. * * For example: * {{{ @@ -377,13 +512,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitSetDatabaseProperties( ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { - AlterDatabaseProperties( + AlterDatabasePropertiesCommand( ctx.identifier.getText, - visitTablePropertyList(ctx.tablePropertyList)) + visitPropertyKeyValues(ctx.tablePropertyList)) } /** - * Create a [[DropDatabase]] command. + * Create a [[DropDatabaseCommand]] command. * * For example: * {{{ @@ -391,11 +526,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitDropDatabase(ctx: DropDatabaseContext): LogicalPlan = withOrigin(ctx) { - DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE != null) + DropDatabaseCommand(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE != null) } /** - * Create a [[DescribeDatabase]] command. + * Create a [[DescribeDatabaseCommand]] command. * * For example: * {{{ @@ -403,11 +538,51 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitDescribeDatabase(ctx: DescribeDatabaseContext): LogicalPlan = withOrigin(ctx) { - DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null) + DescribeDatabaseCommand(ctx.identifier.getText, ctx.EXTENDED != null) } /** - * Create a [[CreateFunction]] command. + * Create a plan for a DESCRIBE FUNCTION command. + */ + override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { + import ctx._ + val functionName = + if (describeFuncName.STRING() != null) { + FunctionIdentifier(string(describeFuncName.STRING()), database = None) + } else if (describeFuncName.qualifiedName() != null) { + visitFunctionName(describeFuncName.qualifiedName) + } else { + FunctionIdentifier(describeFuncName.getText, database = None) + } + DescribeFunctionCommand(functionName, EXTENDED != null) + } + + /** + * Create a plan for a SHOW FUNCTIONS command. + */ + override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { + import ctx._ + val (user, system) = Option(ctx.identifier).map(_.getText.toLowerCase) match { + case None | Some("all") => (true, true) + case Some("system") => (false, true) + case Some("user") => (true, false) + case Some(x) => throw new ParseException(s"SHOW $x FUNCTIONS not supported", ctx) + } + + val (db, pat) = if (qualifiedName != null) { + val name = visitFunctionName(qualifiedName) + (name.database, Some(name.funcName)) + } else if (pattern != null) { + (None, Some(string(pattern))) + } else { + (None, None) + } + + ShowFunctionsCommand(db, pat, user, system) + } + + /** + * Create a [[CreateFunctionCommand]] command. * * For example: * {{{ @@ -420,15 +595,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val resourceType = resource.identifier.getText.toLowerCase resourceType match { case "jar" | "file" | "archive" => - resourceType -> string(resource.STRING) + FunctionResource(FunctionResourceType.fromString(resourceType), string(resource.STRING)) case other => - throw new ParseException(s"Resource Type '$resourceType' is not supported.", ctx) + operationNotAllowed(s"CREATE FUNCTION with resource type '$resourceType'", ctx) } } // Extract database, name & alias. val functionIdentifier = visitFunctionName(ctx.qualifiedName) - CreateFunction( + CreateFunctionCommand( functionIdentifier.database, functionIdentifier.funcName, string(ctx.className), @@ -437,7 +612,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create a [[DropFunction]] command. + * Create a [[DropFunctionCommand]] command. * * For example: * {{{ @@ -446,7 +621,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitDropFunction(ctx: DropFunctionContext): LogicalPlan = withOrigin(ctx) { val functionIdentifier = visitFunctionName(ctx.qualifiedName) - DropFunction( + DropFunctionCommand( functionIdentifier.database, functionIdentifier.funcName, ctx.EXISTS != null, @@ -454,23 +629,20 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { } /** - * Create a [[DropTable]] command. + * Create a [[DropTableCommand]] command. */ override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { if (ctx.PURGE != null) { - throw new ParseException("Unsupported operation: PURGE option", ctx) - } - if (ctx.REPLICATION != null) { - throw new ParseException("Unsupported operation: REPLICATION clause", ctx) + operationNotAllowed("DROP TABLE ... PURGE", ctx) } - DropTable( + DropTableCommand( visitTableIdentifier(ctx.tableIdentifier), ctx.EXISTS != null, ctx.VIEW != null) } /** - * Create a [[AlterTableRename]] command. + * Create a [[AlterTableRenameCommand]] command. * * For example: * {{{ @@ -479,14 +651,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableRename( + AlterTableRenameCommand( visitTableIdentifier(ctx.from), visitTableIdentifier(ctx.to), ctx.VIEW != null) } /** - * Create an [[AlterTableSetProperties]] command. + * Create an [[AlterTableSetPropertiesCommand]] command. * * For example: * {{{ @@ -496,14 +668,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitSetTableProperties( ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { - AlterTableSetProperties( + AlterTableSetPropertiesCommand( visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList), + visitPropertyKeyValues(ctx.tablePropertyList), ctx.VIEW != null) } /** - * Create an [[AlterTableUnsetProperties]] command. + * Create an [[AlterTableUnsetPropertiesCommand]] command. * * For example: * {{{ @@ -513,15 +685,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitUnsetTableProperties( ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { - AlterTableUnsetProperties( + AlterTableUnsetPropertiesCommand( visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList).keys.toSeq, + visitPropertyKeys(ctx.tablePropertyList), ctx.EXISTS != null, ctx.VIEW != null) } /** - * Create an [[AlterTableSerDeProperties]] command. + * Create an [[AlterTableSerDePropertiesCommand]] command. * * For example: * {{{ @@ -530,16 +702,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitSetTableSerDe(ctx: SetTableSerDeContext): LogicalPlan = withOrigin(ctx) { - AlterTableSerDeProperties( + AlterTableSerDePropertiesCommand( visitTableIdentifier(ctx.tableIdentifier), Option(ctx.STRING).map(string), - Option(ctx.tablePropertyList).map(visitTablePropertyList), + Option(ctx.tablePropertyList).map(visitPropertyKeyValues), // TODO a partition spec is allowed to have optional values. This is currently violated. Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) } /** - * Create an [[AlterTableAddPartition]] command. + * Create an [[AlterTableAddPartitionCommand]] command. * * For example: * {{{ @@ -553,7 +725,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { override def visitAddTablePartition( ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { if (ctx.VIEW != null) { - throw new ParseException(s"Operation not allowed: partitioned views", ctx) + operationNotAllowed("ALTER VIEW ... ADD PARTITION", ctx) } // Create partition spec to location mapping. val specsAndLocs = if (ctx.partitionSpec.isEmpty) { @@ -567,14 +739,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { // Alter View: the location clauses are not allowed. ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec(_) -> None) } - AlterTableAddPartition( + AlterTableAddPartitionCommand( visitTableIdentifier(ctx.tableIdentifier), specsAndLocs, ctx.EXISTS != null) } /** - * Create an [[AlterTableRenamePartition]] command + * Create an [[AlterTableRenamePartitionCommand]] command * * For example: * {{{ @@ -583,14 +755,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitRenameTablePartition( ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableRenamePartition( + AlterTableRenamePartitionCommand( visitTableIdentifier(ctx.tableIdentifier), visitNonOptionalPartitionSpec(ctx.from), visitNonOptionalPartitionSpec(ctx.to)) } /** - * Create an [[AlterTableDropPartition]] command + * Create an [[AlterTableDropPartitionCommand]] command * * For example: * {{{ @@ -604,48 +776,32 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { override def visitDropTablePartitions( ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { if (ctx.VIEW != null) { - throw new ParseException(s"Operation not allowed: partitioned views", ctx) + operationNotAllowed("ALTER VIEW ... DROP PARTITION", ctx) } if (ctx.PURGE != null) { - throw new ParseException(s"Operation not allowed: PURGE", ctx) + operationNotAllowed("ALTER TABLE ... DROP PARTITION ... PURGE", ctx) } - AlterTableDropPartition( + AlterTableDropPartitionCommand( visitTableIdentifier(ctx.tableIdentifier), ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), ctx.EXISTS != null) } /** - * Create an [[AlterTableSetFileFormat]] command + * Create an [[AlterTableRecoverPartitionsCommand]] command * * For example: * {{{ - * ALTER TABLE table [PARTITION spec] SET FILEFORMAT file_format; + * ALTER TABLE table RECOVER PARTITIONS; * }}} */ - override def visitSetTableFileFormat( - ctx: SetTableFileFormatContext): LogicalPlan = withOrigin(ctx) { - // AlterTableSetFileFormat currently takes both a GenericFileFormat and a - // TableFileFormatContext. This is a bit weird because it should only take one. It also should - // use a CatalogFileFormat instead of either a String or a Sequence of Strings. We will address - // this in a follow-up PR. - val (fileFormat, genericFormat) = ctx.fileFormat match { - case s: GenericFileFormatContext => - (Seq.empty[String], Option(s.identifier.getText)) - case s: TableFileFormatContext => - val elements = Seq(s.inFmt, s.outFmt) ++ Option(s.serdeCls).toSeq - (elements.map(string), None) - } - AlterTableSetFileFormat( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - fileFormat, - genericFormat)( - parseException("ALTER TABLE SET FILEFORMAT", ctx)) + override def visitRecoverPartitions( + ctx: RecoverPartitionsContext): LogicalPlan = withOrigin(ctx) { + AlterTableRecoverPartitionsCommand(visitTableIdentifier(ctx.tableIdentifier)) } /** - * Create an [[AlterTableSetLocation]] command + * Create an [[AlterTableSetLocationCommand]] command * * For example: * {{{ @@ -653,85 +809,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * }}} */ override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) { - AlterTableSetLocation( + AlterTableSetLocationCommand( visitTableIdentifier(ctx.tableIdentifier), Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), visitLocationSpec(ctx.locationSpec)) } - /** - * Create an [[AlterTableChangeCol]] command - * - * For example: - * {{{ - * ALTER TABLE tableIdentifier [PARTITION spec] - * CHANGE [COLUMN] col_old_name col_new_name column_type [COMMENT col_comment] - * [FIRST|AFTER column_name] [CASCADE|RESTRICT]; - * }}} - */ - override def visitChangeColumn(ctx: ChangeColumnContext): LogicalPlan = withOrigin(ctx) { - val col = visitColType(ctx.colType()) - val comment = if (col.metadata.contains("comment")) { - Option(col.metadata.getString("comment")) - } else { - None - } - - AlterTableChangeCol( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - ctx.oldName.getText, - // We could also pass in a struct field - seems easier. - col.name, - col.dataType, - comment, - Option(ctx.after).map(_.getText), - // Note that Restrict and Cascade are mutually exclusive. - ctx.RESTRICT != null, - ctx.CASCADE != null)( - parseException("ALTER TABLE CHANGE COLUMN", ctx)) - } - - /** - * Create an [[AlterTableAddCol]] command - * - * For example: - * {{{ - * ALTER TABLE tableIdentifier [PARTITION spec] - * ADD COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT] - * }}} - */ - override def visitAddColumns(ctx: AddColumnsContext): LogicalPlan = withOrigin(ctx) { - AlterTableAddCol( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - createStructType(ctx.colTypeList), - // Note that Restrict and Cascade are mutually exclusive. - ctx.RESTRICT != null, - ctx.CASCADE != null)( - parseException("ALTER TABLE ADD COLUMNS", ctx)) - } - - /** - * Create an [[AlterTableReplaceCol]] command - * - * For example: - * {{{ - * ALTER TABLE tableIdentifier [PARTITION spec] - * REPLACE COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT] - * }}} - */ - override def visitReplaceColumns(ctx: ReplaceColumnsContext): LogicalPlan = withOrigin(ctx) { - AlterTableReplaceCol( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - createStructType(ctx.colTypeList), - // Note that Restrict and Cascade are mutually exclusive. - ctx.RESTRICT != null, - ctx.CASCADE != null)( - parseException("ALTER TABLE REPLACE COLUMNS", ctx)) - } - /** * Create location string. */ @@ -752,7 +835,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { .map { orderedIdCtx => Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => if (dir.toLowerCase != "asc") { - throw parseException("Only ASC ordering is supported for sorting columns", ctx) + operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) } } @@ -780,29 +863,58 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitFailNativeCommand( ctx: FailNativeCommandContext): LogicalPlan = withOrigin(ctx) { - val keywords = if (ctx.kws != null) { - Seq(ctx.kws.kw1, ctx.kws.kw2, ctx.kws.kw3, ctx.kws.kw4, ctx.kws.kw5, ctx.kws.kw6) - .filter(_ != null).map(_.getText).mkString(" ") + val keywords = if (ctx.unsupportedHiveNativeCommands != null) { + ctx.unsupportedHiveNativeCommands.children.asScala.collect { + case n: TerminalNode => n.getText + }.mkString(" ") } else { // SET ROLE is the exception to the rule, because we handle this before other SET commands. "SET ROLE" } - throw parseException(keywords, ctx) + operationNotAllowed(keywords, ctx) } /** - * Create an [[AddJar]] or [[AddFile]] command depending on the requested resource. + * Create a [[AddFileCommand]], [[AddJarCommand]], [[ListFilesCommand]] or [[ListJarsCommand]] + * command depending on the requested operation on resources. + * Expected format: + * {{{ + * ADD (FILE[S] | JAR[S] ) + * LIST (FILE[S] [filepath ...] | JAR[S] [jarpath ...]) + * }}} */ - override def visitAddResource(ctx: AddResourceContext): LogicalPlan = withOrigin(ctx) { - ctx.identifier.getText.toLowerCase match { - case "file" => AddFile(remainder(ctx.identifier).trim) - case "jar" => AddJar(remainder(ctx.identifier).trim) - case other => throw new ParseException(s"Unsupported resource type '$other'.", ctx) + override def visitManageResource(ctx: ManageResourceContext): LogicalPlan = withOrigin(ctx) { + val mayebePaths = remainder(ctx.identifier).trim + ctx.op.getType match { + case SqlBaseParser.ADD => + ctx.identifier.getText.toLowerCase match { + case "file" => AddFileCommand(mayebePaths) + case "jar" => AddJarCommand(mayebePaths) + case other => operationNotAllowed(s"ADD with resource type '$other'", ctx) + } + case SqlBaseParser.LIST => + ctx.identifier.getText.toLowerCase match { + case "files" | "file" => + if (mayebePaths.length > 0) { + ListFilesCommand(mayebePaths.split("\\s+")) + } else { + ListFilesCommand() + } + case "jars" | "jar" => + if (mayebePaths.length > 0) { + ListJarsCommand(mayebePaths.split("\\s+")) + } else { + ListJarsCommand() + } + case other => operationNotAllowed(s"LIST with resource type '$other'", ctx) + } + case _ => operationNotAllowed(s"Other types of operation on resources", ctx) } } /** - * Create a table, returning either a [[CreateTable]] or a [[CreateTableAsSelectLogicalPlan]]. + * Create a table, returning either a [[CreateTableCommand]] or a + * [[CreateHiveTableAsSelectLogicalPlan]]. * * This is not used to create datasource tables, which is handled through * "CREATE TABLE ... USING ...". @@ -812,14 +924,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * * Expected format: * {{{ - * CREATE [TEMPORARY] [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name - * [(col1 data_type [COMMENT col_comment], ...)] + * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * [(col1[:] data_type [COMMENT col_comment], ...)] * [COMMENT table_comment] - * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)] - * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS] - * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) [STORED AS DIRECTORIES]] + * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] * [ROW FORMAT row_format] - * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]] + * [STORED AS file_format] * [LOCATION path] * [TBLPROPERTIES (property_name=property_value, ...)] * [AS select_statement]; @@ -831,25 +941,37 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (temp) { throw new ParseException( "CREATE TEMPORARY TABLE is not supported yet. " + - "Please use registerTempTable as an alternative.", ctx) + "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) } if (ctx.skewSpec != null) { - throw new ParseException("Operation not allowed: CREATE TABLE ... SKEWED BY ...", ctx) + operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } if (ctx.bucketSpec != null) { - throw new ParseException("Operation not allowed: CREATE TABLE ... CLUSTERED BY ...", ctx) - } - val tableType = if (external) { - CatalogTableType.EXTERNAL - } else { - CatalogTableType.MANAGED + operationNotAllowed("CREATE TABLE ... CLUSTERED BY", ctx) } val comment = Option(ctx.STRING).map(string) val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns) val cols = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns) - val properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) + // Ensuring whether no duplicate name is used in table definition + val colNames = cols.map(_.name) + if (colNames.length != colNames.distinct.length) { + val duplicateColumns = colNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + } + operationNotAllowed(s"Duplicated column names found in table definition of $name: " + + duplicateColumns.mkString("[", ",", "]"), ctx) + } + + // For Hive tables, partition columns must not be part of the schema + val badPartCols = partitionCols.map(_.name).toSet.intersect(colNames.toSet) + if (badPartCols.nonEmpty) { + operationNotAllowed(s"Partition columns may not be specified in the schema: " + + badPartCols.map("\"" + _ + "\"").mkString("[", ",", "]"), ctx) + } + // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly val schema = cols ++ partitionCols @@ -867,18 +989,33 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { // Note: Keep this unspecified because we use the presence of the serde to decide // whether to convert a table created by CTAS to a datasource table. serde = None, + compressed = false, serdeProperties = Map()) } + validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) - .getOrElse(EmptyStorageFormat) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat).getOrElse(EmptyStorageFormat) + .getOrElse(CatalogStorageFormat.empty) + val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) + .getOrElse(CatalogStorageFormat.empty) val location = Option(ctx.locationSpec).map(visitLocationSpec) + // If we are creating an EXTERNAL table, then the LOCATION field is required + if (external && location.isEmpty) { + operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) + } val storage = CatalogStorageFormat( locationUri = location, inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), + compressed = false, serdeProperties = rowStorage.serdeProperties ++ fileStorage.serdeProperties) + // If location is defined, we'll assume this is an external table. + // Otherwise, we may accidentally delete existing data. + val tableType = if (external || location.isDefined) { + CatalogTableType.EXTERNAL + } else { + CatalogTableType.MANAGED + } // TODO support the sql text - have a proper location for this! val tableDesc = CatalogTable( @@ -891,13 +1028,51 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { comment = comment) selectQuery match { - case Some(q) => CreateTableAsSelectLogicalPlan(tableDesc, q, ifNotExists) - case None => CreateTable(tableDesc, ifNotExists) + case Some(q) => + // Hive does not allow to use a CTAS statement to create a partitioned table. + if (tableDesc.partitionColumnNames.nonEmpty) { + val errorMessage = "A Create Table As Select (CTAS) statement is not allowed to " + + "create a partitioned table using Hive's file formats. " + + "Please use the syntax of \"CREATE TABLE tableName USING dataSource " + + "OPTIONS (...) PARTITIONED BY ...\" to create a partitioned table through a " + + "CTAS statement." + operationNotAllowed(errorMessage, ctx) + } + // Just use whatever is projected in the select statement as our schema + if (schema.nonEmpty) { + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + } + + val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) + if (conf.convertCTAS && !hasStorageProperties) { + val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists + // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties + // are empty Maps. + val optionsWithPath = if (location.isDefined) { + Map("path" -> location.get) + } else { + Map.empty[String, String] + } + CreateTableUsingAsSelect( + tableIdent = tableDesc.identifier, + provider = conf.defaultDataSourceName, + partitionColumns = tableDesc.partitionColumnNames.toArray, + bucketSpec = None, + mode = mode, + options = optionsWithPath, + q + ) + } else { + CreateHiveTableAsSelectLogicalPlan(tableDesc, q, ifNotExists) + } + case None => CreateTableCommand(tableDesc, ifNotExists) } } /** - * Create a [[CreateTableLike]] command. + * Create a [[CreateTableLikeCommand]] command. * * For example: * {{{ @@ -908,11 +1083,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { override def visitCreateTableLike(ctx: CreateTableLikeContext): LogicalPlan = withOrigin(ctx) { val targetTable = visitTableIdentifier(ctx.target) val sourceTable = visitTableIdentifier(ctx.source) - CreateTableLike(targetTable, sourceTable, ctx.EXISTS != null) + CreateTableLikeCommand(targetTable, sourceTable, ctx.EXISTS != null) } /** * Create a [[CatalogStorageFormat]] for creating tables. + * + * Format: STORED AS ... */ override def visitCreateFileFormat( ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { @@ -924,25 +1101,20 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { case (c: GenericFileFormatContext, null) => visitGenericFileFormat(c) case (null, storageHandler) => - throw new ParseException("Operation not allowed: ... STORED BY storage_handler ...", ctx) + operationNotAllowed("STORED BY", ctx) case _ => - throw new ParseException("expected either STORED AS or STORED BY, not both", ctx) + throw new ParseException("Expected either STORED AS or STORED BY, not both", ctx) } } - /** Empty storage format for default values and copies. */ - private val EmptyStorageFormat = CatalogStorageFormat(None, None, None, None, Map.empty) - /** * Create a [[CatalogStorageFormat]]. */ override def visitTableFileFormat( ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - EmptyStorageFormat.copy( + CatalogStorageFormat.empty.copy( inputFormat = Option(string(ctx.inFmt)), - outputFormat = Option(string(ctx.outFmt)), - serde = Option(ctx.serdeCls).map(string) - ) + outputFormat = Option(string(ctx.outFmt))) } /** @@ -953,12 +1125,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val source = ctx.identifier.getText HiveSerDe.sourceToSerDe(source, conf) match { case Some(s) => - EmptyStorageFormat.copy( + CatalogStorageFormat.empty.copy( inputFormat = s.inputFormat, outputFormat = s.outputFormat, serde = s.serde) case None => - throw new ParseException(s"Unrecognized file format in STORED AS clause: $source", ctx) + operationNotAllowed(s"STORED AS with file format '$source'", ctx) } } @@ -993,9 +1165,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { override def visitRowFormatSerde( ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) { import ctx._ - EmptyStorageFormat.copy( + CatalogStorageFormat.empty.copy( serde = Option(string(name)), - serdeProperties = Option(tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)) + serdeProperties = Option(tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) } /** @@ -1012,17 +1184,61 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { entry("field.delim", ctx.fieldsTerminatedBy) ++ entry("serialization.format", ctx.fieldsTerminatedBy) ++ entry("escape.delim", ctx.escapedBy) ++ + // The following typo is inherited from Hive... entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++ entry("mapkey.delim", ctx.keysTerminatedBy) ++ Option(ctx.linesSeparatedBy).toSeq.map { token => val value = string(token) - assert( + validate( value == "\n", s"LINES TERMINATED BY only supports newline '\\n' right now: $value", ctx) "line.delim" -> value } - EmptyStorageFormat.copy(serdeProperties = entries.toMap) + CatalogStorageFormat.empty.copy(serdeProperties = entries.toMap) + } + + /** + * Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT + * and STORED AS. + * + * The following are allowed. Anything else is not: + * ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE] + * ROW FORMAT DELIMITED ... STORED AS TEXTFILE + * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ... + */ + private def validateRowFormatFileFormat( + rowFormatCtx: RowFormatContext, + createFileFormatCtx: CreateFileFormatContext, + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx == null || createFileFormatCtx == null) { + return + } + (rowFormatCtx, createFileFormatCtx.fileFormat) match { + case (_, ffTable: TableFileFormatContext) => // OK + case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase match { + case ("sequencefile" | "textfile" | "rcfile") => // OK + case fmt => + operationNotAllowed( + s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde", + parentCtx) + } + case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase match { + case "textfile" => // OK + case fmt => operationNotAllowed( + s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) + } + case _ => + // should never happen + def str(ctx: ParserRuleContext): String = { + (0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ") + } + operationNotAllowed( + s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}", + parentCtx) + } } /** @@ -1030,7 +1246,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { * * For example: * {{{ - * CREATE VIEW [IF NOT EXISTS] [db_name.]view_name + * CREATE [OR REPLACE] [TEMPORARY] VIEW [IF NOT EXISTS] [db_name.]view_name * [(column_name [COMMENT column_comment], ...) ] * [COMMENT view_comment] * [TBLPROPERTIES (property_name = property_value, ...)] @@ -1039,63 +1255,50 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { */ override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) { if (ctx.identifierList != null) { - throw new ParseException(s"Operation not allowed: partitioned views", ctx) + operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { val identifiers = Option(ctx.identifierCommentList).toSeq.flatMap(_.identifierComment.asScala) val schema = identifiers.map { ic => CatalogColumn(ic.identifier.getText, null, nullable = true, Option(ic.STRING).map(string)) } - createView( - ctx, - ctx.tableIdentifier, - comment = Option(ctx.STRING).map(string), - schema, - ctx.query, - Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty), - ctx.EXISTS != null, - ctx.REPLACE != null - ) + + val sql = Option(source(ctx.query)) + val tableDesc = CatalogTable( + identifier = visitTableIdentifier(ctx.tableIdentifier), + tableType = CatalogTableType.VIEW, + schema = schema, + storage = CatalogStorageFormat.empty, + properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty), + viewOriginalText = sql, + viewText = sql, + comment = Option(ctx.STRING).map(string)) + + CreateViewCommand( + tableDesc, + plan(ctx.query), + allowExisting = ctx.EXISTS != null, + replace = ctx.REPLACE != null, + isTemporary = ctx.TEMPORARY != null) } } /** - * Alter the query of a view. This creates a [[CreateViewCommand]] command. + * Alter the query of a view. This creates a [[AlterViewAsCommand]] command. + * + * For example: + * {{{ + * ALTER VIEW [db_name.]view_name AS SELECT ...; + * }}} */ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { - createView( - ctx, - ctx.tableIdentifier, - comment = None, - Seq.empty, - ctx.query, - Map.empty, - allowExist = false, - replace = true) - } - - /** - * Create a [[CreateViewCommand]] command. - */ - private def createView( - ctx: ParserRuleContext, - name: TableIdentifierContext, - comment: Option[String], - schema: Seq[CatalogColumn], - query: QueryContext, - properties: Map[String, String], - allowExist: Boolean, - replace: Boolean): LogicalPlan = { - val sql = Option(source(query)) val tableDesc = CatalogTable( - identifier = visitTableIdentifier(name), + identifier = visitTableIdentifier(ctx.tableIdentifier), tableType = CatalogTableType.VIEW, - schema = schema, - storage = EmptyStorageFormat, - properties = properties, - viewOriginalText = sql, - viewText = sql, - comment = comment) - CreateViewCommand(tableDesc, plan(query), allowExist, replace, command(ctx)) + storage = CatalogStorageFormat.empty, + schema = Nil, + viewOriginalText = Option(source(ctx.query))) + + AlterViewAsCommand(tableDesc, plan(ctx.query)) } /** @@ -1109,7 +1312,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { // just convert the whole type string to lower case, otherwise the struct field names // will no longer be case sensitive. Instead, we rely on our parser to get the proper // case before passing it to Hive. - CatalystSqlParser.parseDataType(col.dataType.getText).catalogString, + typedVisit[DataType](col.dataType).catalogString, nullable = true, Option(col.STRING).map(string)) } @@ -1126,13 +1329,17 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { recordReader: Token, schemaLess: Boolean): ScriptInputOutputSchema = { if (recordWriter != null || recordReader != null) { + // TODO: what does this message mean? throw new ParseException( "Unsupported operation: Used defined record reader/writer classes.", ctx) } // Decode and input/output format. type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) - def format(fmt: RowFormatContext, configKey: String): Format = fmt match { + def format( + fmt: RowFormatContext, + configKey: String, + defaultConfigValue: String): Format = fmt match { case c: RowFormatDelimitedContext => // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema // expects a seq of pairs in which the old parsers' token names are used as keys. @@ -1151,11 +1358,11 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { case c: RowFormatSerdeContext => // Use a serde format. - val CatalogStorageFormat(None, None, None, Some(name), props) = visitRowFormatSerde(c) + val CatalogStorageFormat(None, None, None, Some(name), _, props) = visitRowFormatSerde(c) // SPARK-10310: Special cases LazySimpleSerDe val recordHandler = if (name == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { - Try(conf.getConfString(configKey)).toOption + Option(conf.getConfString(configKey, defaultConfigValue)) } else { None } @@ -1166,15 +1373,18 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { val name = conf.getConfString("hive.script.serde", "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") val props = Seq("field.delim" -> "\t") - val recordHandler = Try(conf.getConfString(configKey)).toOption + val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) (Nil, Option(name), props, recordHandler) } val (inFormat, inSerdeClass, inSerdeProps, reader) = - format(inRowFormat, "hive.script.recordreader") + format( + inRowFormat, "hive.script.recordreader", "org.apache.hadoop.hive.ql.exec.TextRecordReader") val (outFormat, outSerdeClass, outSerdeProps, writer) = - format(outRowFormat, "hive.script.recordwriter") + format( + outRowFormat, "hive.script.recordwriter", + "org.apache.hadoop.hive.ql.exec.TextRecordWriter") ScriptInputOutputSchema( inFormat, outFormat, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala deleted file mode 100644 index c590f7c6c3e8b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.nio.ByteBuffer -import java.util.{HashMap => JavaHashMap} - -import scala.reflect.ClassTag - -import com.esotericsoftware.kryo.{Kryo, Serializer} -import com.esotericsoftware.kryo.io.{Input, Output} -import com.twitter.chill.ResourcePool - -import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} -import org.apache.spark.sql.types.Decimal -import org.apache.spark.util.MutablePair - -private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { - override def newKryo(): Kryo = { - val kryo = super.newKryo() - kryo.setRegistrationRequired(false) - kryo.register(classOf[MutablePair[_, _]]) - kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) - kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow]) - kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) - kryo.register(classOf[java.math.BigDecimal], new JavaBigDecimalSerializer) - kryo.register(classOf[BigDecimal], new ScalaBigDecimalSerializer) - - kryo.register(classOf[Decimal]) - kryo.register(classOf[JavaHashMap[_, _]]) - - kryo.setReferences(false) - kryo - } -} - -private[execution] class KryoResourcePool(size: Int) - extends ResourcePool[SerializerInstance](size) { - - val ser: SparkSqlSerializer = { - val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - new SparkSqlSerializer(sparkConf) - } - - def newInstance(): SerializerInstance = ser.newInstance() -} - -private[sql] object SparkSqlSerializer { - @transient lazy val resourcePool = new KryoResourcePool(30) - - private[this] def acquireRelease[O](fn: SerializerInstance => O): O = { - val kryo = resourcePool.borrow - try { - fn(kryo) - } finally { - resourcePool.release(kryo) - } - } - - def serialize[T: ClassTag](o: T): Array[Byte] = - acquireRelease { k => - JavaUtils.bufferToArray(k.serialize(o)) - } - - def deserialize[T: ClassTag](bytes: Array[Byte]): T = - acquireRelease { k => - k.deserialize[T](ByteBuffer.wrap(bytes)) - } -} - -private[sql] class JavaBigDecimalSerializer extends Serializer[java.math.BigDecimal] { - def write(kryo: Kryo, output: Output, bd: java.math.BigDecimal) { - // TODO: There are probably more efficient representations than strings... - output.writeString(bd.toString) - } - - def read(kryo: Kryo, input: Input, tpe: Class[java.math.BigDecimal]): java.math.BigDecimal = { - new java.math.BigDecimal(input.readString()) - } -} - -private[sql] class ScalaBigDecimalSerializer extends Serializer[BigDecimal] { - def write(kryo: Kryo, output: Output, bd: BigDecimal) { - // TODO: There are probably more efficient representations than strings... - output.writeString(bd.toString) - } - - def read(kryo: Kryo, input: Input, tpe: Class[BigDecimal]): BigDecimal = { - new java.math.BigDecimal(input.readString()) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 238334e26b45c..ccfb95454b08c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder @@ -31,10 +32,31 @@ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} -import org.apache.spark.sql.execution.streaming.MemoryPlan +import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.StreamingQuery -private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { +/** + * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting + * with the query planner and is not designed to be stable across spark releases. Developers + * writing libraries should instead consider using the stable APIs provided in + * [[org.apache.spark.sql.sources]] + */ +abstract class SparkStrategy extends GenericStrategy[SparkPlan] { + + override protected def planLater(plan: LogicalPlan): SparkPlan = PlanLater(plan) +} + +case class PlanLater(plan: LogicalPlan) extends LeafExecNode { + + override def output: Seq[Attribute] = plan.output + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException() + } +} + +abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => /** @@ -44,22 +66,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ReturnAnswer(rootPlan) => rootPlan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil + execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil case logical.Limit( IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( - limit, order, Some(projectList), planLater(child)) :: Nil + limit, order, projectList, planLater(child)) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.CollectLimitExec(limit, planLater(child)) :: Nil case other => planLater(other) :: Nil } case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil + execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil case logical.Limit( IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) => execution.TakeOrderedAndProjectExec( - limit, order, Some(projectList), planLater(child)) :: Nil + limit, order, projectList, planLater(child)) :: Nil case _ => Nil } } @@ -92,7 +114,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold + plan.statistics.isBroadcastable || + plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold } /** @@ -118,6 +141,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { private def canBuildRight(joinType: JoinType): Boolean = joinType match { case Inner | LeftOuter | LeftSemi | LeftAnti => true + case j: ExistenceJoin => true case _ => false } @@ -188,7 +212,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } // This join could be very slow or OOM joins.BroadcastNestedLoopJoinExec( - planLater(left), planLater(right), buildSide, joinType, condition) :: Nil + planLater(left), planLater(right), buildSide, joinType, condition, + withinBroadcastThreshold = false) :: Nil // --- Cases where this strategy does not apply --------------------------------------------- @@ -198,7 +223,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { /** * Used to plan aggregation queries that are computed incrementally as part of a - * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner + * [[StreamingQuery]]. Currently this rule is injected into the planner * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] */ object StatefulAggregationStrategy extends Strategy { @@ -206,7 +231,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => - aggregate.Utils.planStreamingAggregation( + aggregate.AggUtils.planStreamingAggregation( namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, @@ -239,20 +264,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sys.error("Distinct columns cannot exist in Aggregate operator containing " + "aggregate functions which don't support partial aggregation.") } else { - aggregate.Utils.planAggregateWithoutPartial( + aggregate.AggUtils.planAggregateWithoutPartial( groupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } } else if (functionsWithDistinct.isEmpty) { - aggregate.Utils.planAggregateWithoutDistinct( + aggregate.AggUtils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } else { - aggregate.Utils.planAggregateWithOneDistinct( + aggregate.AggUtils.planAggregateWithOneDistinct( groupingExpressions, functionsWithDistinct, functionsWithoutDistinct, @@ -280,6 +305,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * This strategy is just for explaining `Dataset/DataFrame` created by `spark.readStream`. + * It won't affect the execution, because `StreamingRelation` will be replaced with + * `StreamingExecutionRelation` in `StreamingQueryManager` and `StreamingExecutionRelation` will + * be replaced with the real relation using the `Source` in `StreamExecution`. + */ + object StreamingRelationStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case s: StreamingRelation => + StreamingRelationExec(s.sourceName, s.output) :: Nil + case s: StreamingExecutionRelation => + StreamingRelationExec(s.toString, s.output) :: Nil + case _ => Nil + } + } + // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions @@ -302,7 +343,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "logical except operator should have been replaced by anti-join in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => - execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil + execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil case logical.SerializeFromObject(serializer, child) => execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil case logical.MapPartitions(f, objAttr, child) => @@ -310,6 +351,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitionsInR(f, p, b, is, os, objAttr, child) => execution.MapPartitionsExec( execution.r.MapPartitionsRWrapper(f, p, b, is, os), objAttr, planLater(child)) :: Nil + case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => + execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, + data, objAttr, planLater(child)) :: Nil case logical.MapElements(f, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => @@ -358,8 +402,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil - case r @ logical.Range(start, end, step, numSlices, output) => - execution.RangeExec(start, step, numSlices, r.numElements, output) :: Nil + case r : logical.Range => + execution.RangeExec(r) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil @@ -371,10 +415,13 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object DDLStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case CreateTableUsing(tableIdent, userSpecifiedSchema, provider, true, opts, false, _) => + case c: CreateTableUsing if c.temporary && !c.allowExisting => + logWarning( + s"CREATE TEMPORARY TABLE ${c.tableIdent.identifier} USING... is deprecated, " + + s"please use CREATE TEMPORARY VIEW viewName USING... instead") ExecutedCommandExec( - CreateTempTableUsing( - tableIdent, userSpecifiedSchema, provider, opts)) :: Nil + CreateTempViewUsing( + c.tableIdent, c.userSpecifiedSchema, replace = true, c.provider, c.options)) :: Nil case c: CreateTableUsing if !c.temporary => val cmd = @@ -383,6 +430,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { c.userSpecifiedSchema, c.provider, c.options, + c.partitionColumns, + c.bucketSpec, c.allowExisting, c.managedIfNoPath) ExecutedCommandExec(cmd) :: Nil @@ -391,15 +440,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new AnalysisException( "allowExisting should be set to false when creating a temporary table.") - case c: CreateTableUsingAsSelect if c.temporary && c.partitionColumns.nonEmpty => - sys.error("Cannot create temporary partitioned table.") - - case c: CreateTableUsingAsSelect if c.temporary => - val cmd = CreateTempTableUsingAsSelect( - c.tableIdent, c.provider, Array.empty[String], c.mode, c.options, c.child) - ExecutedCommandExec(cmd) :: Nil - - case c: CreateTableUsingAsSelect if !c.temporary => + case c: CreateTableUsingAsSelect => val cmd = CreateDataSourceTableAsSelectCommand( c.tableIdent, @@ -408,15 +449,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { c.bucketSpec, c.mode, c.options, - c.child) + c.query) ExecutedCommandExec(cmd) :: Nil - case logical.ShowFunctions(db, pattern) => - ExecutedCommandExec(ShowFunctions(db, pattern)) :: Nil - - case logical.DescribeFunction(function, extended) => - ExecutedCommandExec(DescribeFunction(function, extended)) :: Nil - + case c: CreateTempViewUsing => + ExecutedCommandExec(c) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 484923428f4ad..8ab553369de6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -40,12 +40,12 @@ import org.apache.spark.unsafe.Platform * * @param numFields the number of fields in the row being serialized. */ -private[sql] class UnsafeRowSerializer( +class UnsafeRowSerializer( numFields: Int, dataSize: SQLMetric = null) extends Serializer with Serializable { override def newInstance(): SerializerInstance = new UnsafeRowSerializerInstance(numFields, dataSize) - override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true + override def supportsRelocationOfSerializedObjects: Boolean = true } private class UnsafeRowSerializerInstance( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 15b4abe806678..fb57ed7692de4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -24,12 +24,12 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.toCommentSafeString -import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * An interface for those physical operators that support codegen. @@ -38,7 +38,7 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { - case _: TungstenAggregate => "agg" + case _: HashAggregateExec => "agg" case _: BroadcastHashJoinExec => "bhj" case _: SortMergeJoinExec => "smj" case _: RDDScanExec => "rdd" @@ -79,7 +79,7 @@ trait CodegenSupport extends SparkPlan { this.parent = parent ctx.freshNamePrefix = variablePrefix s""" - |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */ + |${ctx.registerComment(s"PRODUCE: ${this.simpleString}")} |${doProduce(ctx)} """.stripMargin } @@ -105,7 +105,7 @@ trait CodegenSupport extends SparkPlan { protected def doProduce(ctx: CodegenContext): String /** - * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume(). + * Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`. */ final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { val inputVars = @@ -131,6 +131,7 @@ trait CodegenSupport extends SparkPlan { } val evaluateInputs = evaluateVariables(outputVars) // generate the code to create a UnsafeRow + ctx.INPUT_ROW = row ctx.currentVars = outputVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) val code = s""" @@ -147,8 +148,7 @@ trait CodegenSupport extends SparkPlan { ctx.freshNamePrefix = parent.variablePrefix val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) s""" - | - |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */ + |${ctx.registerComment(s"CONSUME: ${parent.simpleString}")} |$evaluated |${parent.doConsume(ctx, inputVars, rowVar)} """.stripMargin @@ -212,8 +212,8 @@ trait CodegenSupport extends SparkPlan { /** * InputAdapter is used to hide a SparkPlan from a subtree that support codegen. * - * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes - * an RDD iterator of InternalRow. + * This is the leaf node of a tree with WholeStageCodegen that is used to generate code + * that consumes an RDD iterator of InternalRow. */ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupport { @@ -247,9 +247,14 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp """.stripMargin } - override def simpleString: String = "INPUT" - - override def treeChildren: Seq[SparkPlan] = Nil + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + builder: StringBuilder, + verbose: Boolean, + prefix: String = ""): StringBuilder = { + child.generateTreeString(depth, lastChildren, builder, verbose, "") + } } object WholeStageCodegenExec { @@ -290,7 +295,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, WholeStageCodegenExec.PIPELINE_DURATION_METRIC)) @@ -299,7 +304,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co * * @return the tuple of the codegen context and the actual generated source. */ - def doCodeGen(): (CodegenContext, String) = { + def doCodeGen(): (CodegenContext, CodeAndComment) = { val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) val source = s""" @@ -307,9 +312,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co return new GeneratedIterator(references); } - /** Codegened pipeline for: - * ${toCommentSafeString(child.treeString.trim)} - */ + ${ctx.registerComment(s"""Codegend pipeline for\n${child.treeString.trim}""")} final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { private Object[] references; @@ -333,14 +336,24 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co """.trim // try to compile, helpful for debug - val cleanedSource = CodeFormatter.stripExtraNewLines(source) + val cleanedSource = CodeFormatter.stripOverlappingComments( + new CodeAndComment(CodeFormatter.stripExtraNewLines(source), ctx.getPlaceHolderToComments())) + logDebug(s"\n${CodeFormatter.format(cleanedSource)}") - CodeGenerator.compile(cleanedSource) (ctx, cleanedSource) } override def doExecute(): RDD[InternalRow] = { val (ctx, cleanedSource) = doCodeGen() + // try to compile and fallback if it failed + try { + CodeGenerator.compile(cleanedSource) + } catch { + case e: Exception if !Utils.isTesting && sqlContext.conf.wholeStageFallback => + // We should already saw the error message + logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString") + return child.execute() + } val references = ctx.references.toArray val durationMs = longMetric("pipelineTime") @@ -400,20 +413,14 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co """.stripMargin.trim } - override def innerChildren: Seq[SparkPlan] = { - child :: Nil + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + builder: StringBuilder, + verbose: Boolean, + prefix: String = ""): StringBuilder = { + child.generateTreeString(depth, lastChildren, builder, verbose, "*") } - - private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match { - case InputAdapter(c) => c :: Nil - case other => other.children.flatMap(collectInputs) - } - - override def treeChildren: Seq[SparkPlan] = { - collectInputs(child) - } - - override def simpleString: String = "WholeStageCodegen" } @@ -450,7 +457,7 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { } /** - * Inserts a InputAdapter on top of those that do not support codegen. + * Inserts an InputAdapter on top of those that do not support codegen. */ private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match { case j @ SortMergeJoinExec(_, _, _, _, left, right) if j.supportCodegen => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala index 97bbab65af1de..3927a5011ba4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala @@ -345,6 +345,8 @@ case class WindowExec( null, 1024, SparkEnv.get.memoryManager.pageSizeBytes, + SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), false) rows.foreach { r => sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0) @@ -580,25 +582,43 @@ private[execution] final class OffsetWindowFunctionFrame( /** Row used to combine the offset and the current row. */ private[this] val join = new JoinedRow - /** Create the projection. */ + /** + * Create the projection used when the offset row exists. + * Please note that this project always respect null input values (like PostgreSQL). + */ private[this] val projection = { // Collect the expressions and bind them. val inputAttrs = inputSchema.map(_.withNullability(true)) - val numInputAttributes = inputAttrs.size val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { case e: OffsetWindowFunction => val input = BindReferences.bindReference(e.input, inputAttrs) + input + case e => + BindReferences.bindReference(e, inputAttrs) + } + + // Create the projection. + newMutableProjection(boundExpressions, Nil).target(target) + } + + /** Create the projection used when the offset row DOES NOT exists. */ + private[this] val fillDefaultValue = { + // Collect the expressions and bind them. + val inputAttrs = inputSchema.map(_.withNullability(true)) + val numInputAttributes = inputAttrs.size + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { + case e: OffsetWindowFunction => if (e.default == null || e.default.foldable && e.default.eval() == null) { - // Without default value. - input + // The default value is null. + Literal.create(null, e.dataType) } else { - // With default value. + // The default value is an expression. val default = BindReferences.bindReference(e.default, inputAttrs).transform { // Shift the input reference to its default version. case BoundReference(o, dataType, nullable) => BoundReference(o + numInputAttributes, dataType, nullable) } - org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) + default } case e => BindReferences.bindReference(e, inputAttrs) @@ -623,10 +643,12 @@ private[execution] final class OffsetWindowFunctionFrame( if (inputIndex >= 0 && inputIndex < input.size) { val r = input.next() join(r, current) + projection(join) } else { join(emptyRow, current) + // Use default values since the offset row does not exist. + fillDefaultValue(join) } - projection(join) inputIndex += 1 } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index f93c446007422..4fbb9d554c9bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateSto /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ -object Utils { +object AggUtils { def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], @@ -35,7 +35,7 @@ object Utils { val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) - SortBasedAggregateExec( + SortAggregateExec( requiredChildDistributionExpressions = Some(groupingExpressions), groupingExpressions = groupingExpressions, aggregateExpressions = completeAggregateExpressions, @@ -54,10 +54,10 @@ object Utils { initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { - val usesTungstenAggregate = TungstenAggregate.supportsAggregate( + val useHash = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - if (usesTungstenAggregate) { - TungstenAggregate( + if (useHash) { + HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, @@ -66,7 +66,7 @@ object Utils { resultExpressions = resultExpressions, child = child) } else { - SortBasedAggregateExec( + SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, @@ -82,7 +82,7 @@ object Utils { aggregateExpressions: Seq[AggregateExpression], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - // Check if we can use TungstenAggregate. + // Check if we can use HashAggregate. // 1. Create an Aggregate Operator for partial aggregations. @@ -311,8 +311,10 @@ object Utils { aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), child = restored) } - - val saved = StateStoreSaveExec(groupingAttributes, None, partialMerged2) + // Note: stateId and returnAllStates are filled in later with preparation rules + // in IncrementalExecution. + val saved = StateStoreSaveExec( + groupingAttributes, stateId = None, returnAllStates = None, partialMerged2) val finalAndCompleteAggregate: SparkPlan = { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 81aacb437ba54..6ca36e4acb448 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -52,7 +52,7 @@ abstract class AggregationIterator( * - PartialMerge (for single distinct) * - Partial and PartialMerge (for single distinct) * - Final - * - Complete (for SortBasedAggregate with functions that does not support Partial) + * - Complete (for SortAggregate with functions that does not support Partial) * - Final and Complete (currently not used) * * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression @@ -73,9 +73,10 @@ abstract class AggregationIterator( startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction](expressions.length) + val expressionsLength = expressions.length + val functions = new Array[AggregateFunction](expressionsLength) var i = 0 - while (i < expressions.length) { + while (i < expressionsLength) { val func = expressions(i).aggregateFunction val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => @@ -171,7 +172,7 @@ abstract class AggregationIterator( case PartialMerge | Final => (buffer: MutableRow, row: InternalRow) => ae.merge(buffer, row) } - } + }.toArray // This projection is used to merge buffer values for all expression-based aggregates. val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) val updateProjection = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index d0ba37ee1338b..bd7efa606e0ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -29,8 +29,12 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator +import org.apache.spark.util.Utils -case class TungstenAggregate( +/** + * Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size. + */ +case class HashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -44,13 +48,13 @@ case class TungstenAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) } - require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes)) + require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) - override lazy val allAttributes: Seq[Attribute] = + override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"), @@ -244,8 +248,12 @@ case class TungstenAggregate( } } ctx.currentVars = bufVars ++ input - // TODO: support subexpression elimination - val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).genCode(ctx)) + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } // aggregate buffer should be updated atomic val updates = aggVals.zipWithIndex.map { case (ev, i) => s""" @@ -255,6 +263,9 @@ case class TungstenAggregate( } s""" | // do aggregate + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function | ${evaluateVariables(aggVals)} | // update aggregation buffer | ${updates.mkString("\n").trim} @@ -448,7 +459,7 @@ case class TungstenAggregate( } /** - * Using the vectorized hash map in TungstenAggregate is currently supported for all primitive + * Using the vectorized hash map in HashAggregate is currently supported for all primitive * data types during partial aggregation. However, we currently only enable the hash map for a * subset of cases that've been verified to show performance improvements on our benchmarks * subject to an internal conf that sets an upper limit on the maximum length of the aggregate @@ -650,8 +661,12 @@ case class TungstenAggregate( val updateRowInVectorizedHashMap: Option[String] = { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = vectorizedRowBuffer - val vectorizedRowEvals = - updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val vectorizedRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable, @@ -659,6 +674,8 @@ case class TungstenAggregate( } Option( s""" + |// common sub-expressions + |$effectiveCodes |// evaluate aggregate function |${evaluateVariables(vectorizedRowEvals)} |// update vectorized row @@ -701,13 +718,19 @@ case class TungstenAggregate( val updateRowInUnsafeRowMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val unsafeRowBufferEvals = - updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) } s""" + |// common sub-expressions + |$effectiveCodes |// evaluate aggregate function |${evaluateVariables(unsafeRowBufferEvals)} |// update unsafe row buffer @@ -740,23 +763,31 @@ case class TungstenAggregate( """ } - override def simpleString: String = { + override def verboseString: String = toString(verbose = true) + + override def simpleString: String = toString(verbose = false) + + private def toString(verbose: Boolean): String = { val allAggregateExpressions = aggregateExpressions testFallbackStartsAt match { case None => - val keyString = groupingExpressions.mkString("[", ",", "]") - val functionString = allAggregateExpressions.mkString("[", ",", "]") - val outputString = output.mkString("[", ",", "]") - s"TungstenAggregate(key=$keyString, functions=$functionString, output=$outputString)" + val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = Utils.truncatedString(output, "[", ", ", "]") + if (verbose) { + s"HashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" + } else { + s"HashAggregate(keys=$keyString, functions=$functionString)" + } case Some(fallbackStartsAt) => - s"TungstenAggregateWithControlledFallback $groupingExpressions " + + s"HashAggregateWithControlledFallback $groupingExpressions " + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } } -object TungstenAggregate { +object HashAggregateExec { def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala similarity index 81% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2e74d59c5f5b6..7c41e5e4c28aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -25,8 +25,12 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.Utils -case class SortBasedAggregateExec( +/** + * Sort-based aggregate operator. + */ +case class SortAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -45,15 +49,15 @@ case class SortBasedAggregateExec( AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ AttributeSet(aggregateBufferAttributes) - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } @@ -100,12 +104,20 @@ case class SortBasedAggregateExec( } } - override def simpleString: String = { + override def simpleString: String = toString(verbose = false) + + override def verboseString: String = toString(verbose = true) + + private def toString(verbose: Boolean): String = { val allAggregateExpressions = aggregateExpressions - val keyString = groupingExpressions.mkString("[", ",", "]") - val functionString = allAggregateExpressions.mkString("[", ",", "]") - val outputString = output.mkString("[", ",", "]") - s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)" + val keyString = Utils.truncatedString(groupingExpressions, "[", ",", "]") + val functionString = Utils.truncatedString(allAggregateExpressions, "[", ",", "]") + val outputString = Utils.truncatedString(output, "[", ",", "]") + if (verbose) { + s"SortAggregate(key=$keyString, functions=$functionString, output=$outputString)" + } else { + s"SortAggregate(key=$keyString, functions=$functionString)" + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index f392b135ce787..3f7f84988594a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -86,7 +86,7 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer - // An SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be + // A SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be // compared to MutableRow (aggregation buffer) directly. private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 243aa15deba3d..4b8adf5230717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -41,7 +41,7 @@ import org.apache.spark.unsafe.KVIterator * - Step 0: Do hash-based aggregation. * - Step 1: Sort all entries of the hash map based on values of grouping expressions and * spill them to disk. - * - Step 2: Create a external sorter based on the spilled sorted map entries and reset the map. + * - Step 2: Create an external sorter based on the spilled sorted map entries and reset the map. * - Step 3: Get a sorted [[KVIterator]] from the external sorter. * - Step 4: Repeat step 0 until no more input. * - Step 5: Initialize sort-based aggregation on the sorted iterator. @@ -434,12 +434,12 @@ class TungstenAggregationIterator( /////////////////////////////////////////////////////////////////////////// /** - * Generate a output row when there is no input and there is no grouping expression. + * Generate an output row when there is no input and there is no grouping expression. */ def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { if (groupingExpressions.isEmpty) { sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer) - // We create a output row and copy it. So, we can free the map. + // We create an output row and copy it. So, we can free the map. val resultCopy = generateOutput(UnsafeRow.createFromByteArray(0, 0), sortBasedAggregationBuffer).copy() hashMap.free() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 535e64cb34442..8bdfa48a30c9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDe import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ @@ -31,32 +32,10 @@ object TypedAggregateExpression { def apply[BUF : Encoder, OUT : Encoder]( aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] - // We will insert the deserializer and function call expression at the bottom of each serializer - // expression while executing `TypedAggregateExpression`, which means multiply serializer - // expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating, - // here we always use one single serializer expression to serialize the buffer object into a - // single-field row, no matter whether the encoder is flat or not. We also need to update the - // deserializer to read in all fields from that single-field row. - // TODO: remove this trick after we have better integration of subexpression elimination and - // whole stage codegen. - val bufferSerializer = if (bufferEncoder.flat) { - bufferEncoder.namedExpressions.head - } else { - Alias(CreateStruct(bufferEncoder.serializer), "buffer")() - } - - val bufferDeserializer = if (bufferEncoder.flat) { - bufferEncoder.deserializer transformUp { - case b: BoundReference => bufferSerializer.toAttribute - } - } else { - bufferEncoder.deserializer transformUp { - case UnresolvedAttribute(nameParts) => - assert(nameParts.length == 1) - UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal) - } - } + val bufferSerializer = bufferEncoder.namedExpressions + val bufferDeserializer = UnresolvedDeserializer( + bufferEncoder.deserializer, + bufferSerializer.map(_.toAttribute)) val outputEncoder = encoderFor[OUT] val outputType = if (outputEncoder.flat) { @@ -82,7 +61,7 @@ object TypedAggregateExpression { case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], inputDeserializer: Option[Expression], - bufferSerializer: NamedExpression, + bufferSerializer: Seq[NamedExpression], bufferDeserializer: Expression, outputSerializer: Seq[Expression], outputExternalType: DataType, @@ -106,11 +85,11 @@ case class TypedAggregateExpression( private def bufferExternalType = bufferDeserializer.dataType override lazy val aggBufferAttributes: Seq[AttributeReference] = - bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil + bufferSerializer.map(_.toAttribute.asInstanceOf[AttributeReference]) override lazy val initialValues: Seq[Expression] = { val zero = Literal.fromObject(aggregator.zero, bufferExternalType) - ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil + bufferSerializer.map(ReferenceToExpressions(_, zero :: Nil)) } override lazy val updateExpressions: Seq[Expression] = { @@ -120,7 +99,7 @@ case class TypedAggregateExpression( bufferExternalType, bufferDeserializer :: inputDeserializer.get :: Nil) - ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil + bufferSerializer.map(ReferenceToExpressions(_, reduced :: Nil)) } override lazy val mergeExpressions: Seq[Expression] = { @@ -136,7 +115,7 @@ case class TypedAggregateExpression( bufferExternalType, leftBuffer :: rightBuffer :: Nil) - ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil + bufferSerializer.map(ReferenceToExpressions(_, merged :: Nil)) } override lazy val evaluateExpression: Expression = { @@ -148,7 +127,12 @@ case class TypedAggregateExpression( dataType match { case s: StructType => - ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil) + val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get + val struct = If( + IsNull(objRef), + Literal.create(null, dataType), + CreateStruct(outputSerializer)) + ReferenceToExpressions(struct, resultObj :: Nil) case _ => assert(outputSerializer.length == 1) outputSerializer.head transform { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 61bd6eb3cde66..b4a9059299539 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types._ /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' * for extremely fast key-value lookups while evaluating aggregates (and fall back to the - * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in TungstenAggregate to speed + * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in HashAggregate to speed * up aggregates w/ key. * * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the @@ -127,7 +127,7 @@ class VectorizedHashMapGenerator( | public $generatedClassName() { | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); - | // TODO: Possibly generate this projection in TungstenAggregate directly + | // TODO: Possibly generate this projection in HashAggregate directly | aggregateBufferBatch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate( | aggregateBufferSchema, org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); | for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) { @@ -313,10 +313,12 @@ class VectorizedHashMapGenerator( def hashLong(l: String): String = s"long $result = $l;" def hashBytes(b: String): String = { val hash = ctx.freshName("hash") + val bytes = ctx.freshName("bytes") s""" |int $result = 0; - |for (int i = 0; i < $b.length; i++) { - | ${genComputeHash(ctx, s"$b[i]", ByteType, hash)} + |byte[] $bytes = $b; + |for (int i = 0; i < $bytes.length; i++) { + | ${genComputeHash(ctx, s"$bytes[i]", ByteType, hash)} | $result = ($result ^ (0x9e3779b9)) + $hash + ($result << 6) + ($result >>> 2); |} """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 4ceb710f4b2b4..586e1456ac69e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -202,9 +202,9 @@ sealed trait BufferSetterGetterUtils { } /** - * A Mutable [[Row]] representing an mutable aggregation buffer. + * A Mutable [[Row]] representing a mutable aggregation buffer. */ -private[sql] class MutableAggregationBufferImpl ( +private[aggregate] class MutableAggregationBufferImpl( schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], @@ -266,7 +266,7 @@ private[sql] class MutableAggregationBufferImpl ( /** * A [[Row]] representing an immutable aggregation buffer. */ -private[sql] class InputAggregationBuffer private[sql] ( +private[aggregate] class InputAggregationBuffer( schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], @@ -319,7 +319,7 @@ private[sql] class InputAggregationBuffer private[sql] ( * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the * internal aggregation code path. */ -private[sql] case class ScalaUDAF( +case class ScalaUDAF( children: Seq[Expression], udaf: UserDefinedAggregateFunction, mutableAggBufferOffset: Int = 0, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index d492fa7c412df..a544371ffee7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} /** Physical plan for Project. */ @@ -102,7 +102,7 @@ case class FilterExec(condition: Expression, child: SparkPlan) } } - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def inputRDDs(): Seq[RDD[InternalRow]] = { @@ -142,9 +142,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) // To generate the predicates we will follow this algorithm. // For each predicate that is not IsNotNull, we will generate them one by one loading attributes - // as necessary. For each of both attributes, if there is a IsNotNull predicate we will generate - // that check *before* the predicate. After all of these predicates, we will generate the - // remaining IsNotNull checks that were not part of other predicates. + // as necessary. For each of both attributes, if there is an IsNotNull predicate we will + // generate that check *before* the predicate. After all of these predicates, we will generate + // the remaining IsNotNull checks that were not part of other predicates. // This has the property of not doing redundant IsNotNull checks and taking better advantage of // short-circuiting, not loading attributes until they are needed. // This is very perf sensitive. @@ -228,7 +228,7 @@ case class SampleExec( child: SparkPlan) extends UnaryExecNode with CodegenSupport { override def output: Seq[Attribute] = child.output - private[sql] override lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) protected override def doExecute(): RDD[InternalRow] = { @@ -260,6 +260,7 @@ case class SampleExec( if (withReplacement) { val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName val initSampler = ctx.freshName("initSampler") + ctx.copyResult = true ctx.addMutableState(s"$samplerClass", sampler, s"$initSampler();") @@ -305,23 +306,19 @@ case class SampleExec( /** - * Physical plan for range (generating a range of 64 bit numbers. - * - * @param start first number in the range, inclusive. - * @param step size of the step increment. - * @param numSlices number of partitions. - * @param numElements total number of elements to output. - * @param output output attributes. + * Physical plan for range (generating a range of 64 bit numbers). */ -case class RangeExec( - start: Long, - step: Long, - numSlices: Int, - numElements: BigInt, - output: Seq[Attribute]) +case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) extends LeafExecNode with CodegenSupport { - private[sql] override lazy val metrics = Map( + def start: Long = range.start + def step: Long = range.step + def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) + def numElements: BigInt = range.numElements + + override val output: Seq[Attribute] = range.output + + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) // output attributes should not affect the results @@ -458,6 +455,8 @@ case class RangeExec( } } } + + override def simpleString: String = range.simpleString } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 5d4476989a369..470307bd940ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -33,9 +33,9 @@ private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializabl } private[columnar] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { - val (forAttribute, schema) = { + val (forAttribute: AttributeMap[ColumnStatisticsSchema], schema: Seq[AttributeReference]) = { val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) - (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) + (AttributeMap(allStats), allStats.flatMap(_._2.schema)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index e2e33e32463fc..96bd338f092e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator, UnsafeRowWriter} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator, UnsafeRowWriter} import org.apache.spark.sql.types._ /** @@ -58,7 +58,7 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(nu } /** - * Generates bytecode for an [[ColumnarIterator]] for columnar cache. + * Generates bytecode for a [[ColumnarIterator]] for columnar cache. */ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging { @@ -127,7 +127,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) var groupedAccessorsLength = 0 - groupedAccessorsItr.zipWithIndex.map { case (body, i) => + groupedAccessorsItr.zipWithIndex.foreach { case (body, i) => groupedAccessorsLength += 1 val funcName = s"accessors$i" val funcCode = s""" @@ -137,7 +137,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera """.stripMargin ctx.addNewFunction(funcName, funcCode) } - groupedExtractorsItr.zipWithIndex.map { case (body, i) => + groupedExtractorsItr.zipWithIndex.foreach { case (body, i) => val funcName = s"extractors$i" val funcCode = s""" |private void $funcName() { @@ -150,7 +150,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n")) } - val code = s""" + val codeBody = s""" import java.nio.ByteBuffer; import java.nio.ByteOrder; import scala.collection.Iterator; @@ -224,7 +224,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } }""" - logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}") + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + logDebug(s"Generated ColumnarIterator:\n${CodeFormatter.format(code)}") CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala new file mode 100644 index 0000000000000..56bd5c1891e8d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.LongAccumulator + + +object InMemoryRelation { + def apply( + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + child: SparkPlan, + tableName: Option[String]): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() +} + + +/** + * CachedBatch is a cached batch of rows. + * + * @param numRows The total number of rows in this batch + * @param buffers The buffers for serialized columns + * @param stats The stat of columns + */ +private[columnar] +case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) + +case class InMemoryRelation( + output: Seq[Attribute], + useCompression: Boolean, + batchSize: Int, + storageLevel: StorageLevel, + @transient child: SparkPlan, + tableName: Option[String])( + @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, + val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) + extends logical.LeafNode with MultiInstanceRelation { + + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) + + override def producedAttributes: AttributeSet = outputSet + + @transient val partitionStatistics = new PartitionStatistics(output) + + override lazy val statistics: Statistics = { + if (batchStats.value == 0L) { + // Underlying columnar RDD hasn't been materialized, no useful statistics information + // available, return the default statistics. + Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) + } else { + Statistics(sizeInBytes = batchStats.value.longValue) + } + } + + // If the cached column buffers were not passed in, we calculate them in the constructor. + // As in Spark, the actual work of caching is lazy. + if (_cachedColumnBuffers == null) { + buildBuffers() + } + + def recache(): Unit = { + _cachedColumnBuffers.unpersist() + _cachedColumnBuffers = null + buildBuffers() + } + + private def buildBuffers(): Unit = { + val output = child.output + val cached = child.execute().mapPartitionsInternal { rowIterator => + new Iterator[CachedBatch] { + def next(): CachedBatch = { + val columnBuilders = output.map { attribute => + ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression) + }.toArray + + var rowCount = 0 + var totalSize = 0L + while (rowIterator.hasNext && rowCount < batchSize + && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) { + val row = rowIterator.next() + + // Added for SPARK-6082. This assertion can be useful for scenarios when something + // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM + // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat + // hard to decipher. + assert( + row.numFields == columnBuilders.length, + s"Row column number mismatch, expected ${output.size} columns, " + + s"but got ${row.numFields}." + + s"\nRow content: $row") + + var i = 0 + totalSize = 0 + while (i < row.numFields) { + columnBuilders(i).appendFrom(row, i) + totalSize += columnBuilders(i).columnStats.sizeInBytes + i += 1 + } + rowCount += 1 + } + + batchStats.add(totalSize) + + val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) + .flatMap(_.values)) + CachedBatch(rowCount, columnBuilders.map { builder => + JavaUtils.bufferToArray(builder.build()) + }, stats) + } + + def hasNext: Boolean = rowIterator.hasNext + } + }.persist(storageLevel) + + cached.setName( + tableName.map(n => s"In-memory table $n") + .getOrElse(StringUtils.abbreviate(child.toString, 1024))) + _cachedColumnBuffers = cached + } + + def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { + InMemoryRelation( + newOutput, useCompression, batchSize, storageLevel, child, tableName)( + _cachedColumnBuffers, batchStats) + } + + override def newInstance(): this.type = { + new InMemoryRelation( + output.map(_.newInstance()), + useCompression, + batchSize, + storageLevel, + child, + tableName)( + _cachedColumnBuffers, + batchStats).asInstanceOf[this.type] + } + + def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers + + override protected def otherCopyArgs: Seq[AnyRef] = + Seq(_cachedColumnBuffers, batchStats) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 94b87a5812e4c..e63b313cb1d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -17,206 +17,26 @@ package org.apache.spark.sql.execution.columnar -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{Accumulable, Accumulator, AccumulatorContext} -import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.UserDefinedType -import org.apache.spark.storage.StorageLevel - -private[sql] object InMemoryRelation { - def apply( - useCompression: Boolean, - batchSize: Int, - storageLevel: StorageLevel, - child: SparkPlan, - tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() -} - -/** - * CachedBatch is a cached batch of rows. - * - * @param numRows The total number of rows in this batch - * @param buffers The buffers for serialized columns - * @param stats The stat of columns - */ -private[columnar] -case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow) - -private[sql] case class InMemoryRelation( - output: Seq[Attribute], - useCompression: Boolean, - batchSize: Int, - storageLevel: StorageLevel, - @transient child: SparkPlan, - tableName: Option[String])( - @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, - @transient private[sql] var _statistics: Statistics = null, - private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) - extends logical.LeafNode with MultiInstanceRelation { - - override def producedAttributes: AttributeSet = outputSet - - private[sql] val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = - if (_batchStats == null) { - child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow]) - } else { - _batchStats - } - - @transient val partitionStatistics = new PartitionStatistics(output) - - private def computeSizeInBytes = { - val sizeOfRow: Expression = - BindReferences.bindReference( - output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add), - partitionStatistics.schema) - - batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum - } - - // Statistics propagation contracts: - // 1. Non-null `_statistics` must reflect the actual statistics of the underlying data - // 2. Only propagate statistics when `_statistics` is non-null - private def statisticsToBePropagated = if (_statistics == null) { - val updatedStats = statistics - if (_statistics == null) null else updatedStats - } else { - _statistics - } - - override def statistics: Statistics = { - if (_statistics == null) { - if (batchStats.value.isEmpty) { - // Underlying columnar RDD hasn't been materialized, no useful statistics information - // available, return the default statistics. - Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes) - } else { - // Underlying columnar RDD has been materialized, required information has also been - // collected via the `batchStats` accumulator, compute the final statistics, - // and update `_statistics`. - _statistics = Statistics(sizeInBytes = computeSizeInBytes) - _statistics - } - } else { - // Pre-computed statistics - _statistics - } - } - - // If the cached column buffers were not passed in, we calculate them in the constructor. - // As in Spark, the actual work of caching is lazy. - if (_cachedColumnBuffers == null) { - buildBuffers() - } - - def recache(): Unit = { - _cachedColumnBuffers.unpersist() - _cachedColumnBuffers = null - buildBuffers() - } - private def buildBuffers(): Unit = { - val output = child.output - val cached = child.execute().mapPartitionsInternal { rowIterator => - new Iterator[CachedBatch] { - def next(): CachedBatch = { - val columnBuilders = output.map { attribute => - ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression) - }.toArray - - var rowCount = 0 - var totalSize = 0L - while (rowIterator.hasNext && rowCount < batchSize - && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) { - val row = rowIterator.next() - - // Added for SPARK-6082. This assertion can be useful for scenarios when something - // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM - // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat - // hard to decipher. - assert( - row.numFields == columnBuilders.length, - s"Row column number mismatch, expected ${output.size} columns, " + - s"but got ${row.numFields}." + - s"\nRow content: $row") - - var i = 0 - totalSize = 0 - while (i < row.numFields) { - columnBuilders(i).appendFrom(row, i) - totalSize += columnBuilders(i).columnStats.sizeInBytes - i += 1 - } - rowCount += 1 - } - - val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.values)) - - batchStats += stats - CachedBatch(rowCount, columnBuilders.map { builder => - JavaUtils.bufferToArray(builder.build()) - }, stats) - } - - def hasNext: Boolean = rowIterator.hasNext - } - }.persist(storageLevel) - - cached.setName(tableName.map(n => s"In-memory table $n").getOrElse(child.toString)) - _cachedColumnBuffers = cached - } - - def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { - InMemoryRelation( - newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, statisticsToBePropagated, batchStats) - } - - override def newInstance(): this.type = { - new InMemoryRelation( - output.map(_.newInstance()), - useCompression, - batchSize, - storageLevel, - child, - tableName)( - _cachedColumnBuffers, - statisticsToBePropagated, - batchStats).asInstanceOf[this.type] - } - - def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers - - override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats) - private[sql] def uncache(blocking: Boolean): Unit = { - AccumulatorContext.remove(batchStats.id) - cachedColumnBuffers.unpersist(blocking) - _cachedColumnBuffers = null - } -} - -private[sql] case class InMemoryTableScanExec( +case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) extends LeafExecNode { - private[sql] override lazy val metrics = Map( + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = attributes @@ -283,8 +103,8 @@ private[sql] case class InMemoryTableScanExec( sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean // Accumulators used for testing purposes - lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) - lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readPartitions = sparkContext.longAccumulator + lazy val readBatches = sparkContext.longAccumulator private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning @@ -300,7 +120,7 @@ private[sql] case class InMemoryTableScanExec( // within the map Partitions closure. val schema = relation.partitionStatistics.schema val schemaIndex = schema.zipWithIndex - val relOutput = relation.output + val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers buffers.mapPartitionsInternal { cachedBatchIterator => @@ -311,7 +131,7 @@ private[sql] case class InMemoryTableScanExec( // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = attributes.map { a => - relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType + relOutput.indexOf(a.exprId) -> a.dataType }.unzip // Do partition batch pruning if enabled @@ -327,9 +147,6 @@ private[sql] case class InMemoryTableScanExec( logInfo(s"Skipping partition based on stats $statsString") false } else { - if (enableAccumulators) { - readBatches += 1 - } true } } @@ -339,6 +156,9 @@ private[sql] case class InMemoryTableScanExec( // update SQL metrics val withMetrics = cachedBatchesToScan.map { batch => + if (enableAccumulators) { + readBatches.add(1) + } numOutputRows += batch.numRows batch } @@ -350,7 +170,7 @@ private[sql] case class InMemoryTableScanExec( val columnarIterator = GenerateColumnAccessor.generate(columnTypes) columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray) if (enableAccumulators && columnarIterator.hasNext) { - readPartitions += 1 + readPartitions.add(1) } columnarIterator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala index 63eae1b8685ac..0f4680e502781 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnBuilder.scala @@ -66,11 +66,7 @@ private[columnar] trait CompressibleColumnBuilder[T <: AtomicType] } private def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { - var i = 0 - while (i < compressionEncoders.length) { - compressionEncoders(i).gatherCompressibilityStats(row, ordinal) - i += 1 - } + compressionEncoders.foreach(_.gatherCompressibilityStats(row, ordinal)) } abstract override def appendFrom(row: InternalRow, ordinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala similarity index 83% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 54ff5ae7d9d0e..07127533b0349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -21,9 +21,10 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable, SimpleCatalogRelation} /** @@ -33,15 +34,17 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable} * Right now, it only supports Hive tables and it only updates the size of a Hive table * in the Hive metastore. */ -case class AnalyzeTable(tableName: String) extends RunnableCommand { +case class AnalyzeTableCommand(tableName: String) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val sessionState = sparkSession.sessionState val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentwithDB = TableIdentifier(tableIdent.table, Some(db)) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentwithDB)) relation match { - case relation: CatalogRelation => + case relation: CatalogRelation if !relation.isInstanceOf[SimpleCatalogRelation] => val catalogTable: CatalogTable = relation.catalogTable // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) @@ -95,17 +98,17 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand { sessionState.catalog.alterTable( catalogTable.copy( properties = relation.catalogTable.properties + - (AnalyzeTable.TOTAL_SIZE_FIELD -> newTotalSize.toString))) + (AnalyzeTableCommand.TOTAL_SIZE_FIELD -> newTotalSize.toString))) } case otherRelation => - throw new UnsupportedOperationException( - s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") + throw new AnalysisException(s"ANALYZE TABLE is only supported for Hive tables, " + + s"but '${tableIdent.unquotedString}' is a ${otherRelation.nodeName}.") } Seq.empty[Row] } } -object AnalyzeTable { +object AnalyzeTableCommand { val TOTAL_SIZE_FIELD = "totalSize" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 2409b5d203f40..af6def52d07d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.command -import java.util.NoSuchElementException - import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute @@ -88,7 +86,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } val schema = StructType( StructField("key", StringType, nullable = false) :: - StructField("default", StringType, nullable = false) :: + StructField("value", StringType, nullable = false) :: StructField("meaning", StringType, nullable = false) :: Nil) (schema.toAttributes, runFunc) @@ -118,3 +116,17 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm override def run(sparkSession: SparkSession): Seq[Row] = runFunc(sparkSession) } + +/** + * This command is for resetting SQLConf to the default values. Command that runs + * {{{ + * reset; + * }}} + */ +case object ResetCommand extends RunnableCommand with Logging { + + override def run(sparkSession: SparkSession): Seq[Row] = { + sparkSession.sessionState.conf.clear() + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index f05401b02b16c..605e49c5a87c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -18,42 +18,44 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{Dataset, Row, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan - case class CacheTableCommand( - tableName: String, - plan: Option[LogicalPlan], - isLazy: Boolean) - extends RunnableCommand { + tableIdent: TableIdentifier, + plan: Option[LogicalPlan], + isLazy: Boolean) extends RunnableCommand { + require(plan.isEmpty || tableIdent.database.isEmpty, + "Database name is not allowed in CACHE TABLE AS SELECT") + + override protected def innerChildren: Seq[QueryPlan[_]] = { + plan.toSeq + } override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => - sparkSession.registerTable(Dataset.ofRows(sparkSession, logicalPlan), tableName) + Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) } - sparkSession.catalog.cacheTable(tableName) + sparkSession.catalog.cacheTable(tableIdent.quotedString) if (!isLazy) { // Performs eager caching - sparkSession.table(tableName).count() + sparkSession.table(tableIdent).count() } Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } -case class UncacheTableCommand(tableName: String) extends RunnableCommand { +case class UncacheTableCommand(tableIdent: TableIdentifier) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.table(tableName).unpersist(blocking = false) + sparkSession.catalog.uncacheTable(tableIdent.quotedString) Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } /** @@ -65,6 +67,4 @@ case object ClearCacheCommand extends RunnableCommand { sparkSession.catalog.clearCache() Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 5f9287b3b55a9..d82e54e57564c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -17,29 +17,25 @@ package org.apache.spark.sql.execution.command -import java.io.File - import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.CatalogTableType -import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.debug._ +import org.apache.spark.sql.execution.streaming.IncrementalExecution +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types._ /** * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty +trait RunnableCommand extends logical.Command { def run(sparkSession: SparkSession): Seq[Row] } @@ -47,7 +43,7 @@ private[sql] trait RunnableCommand extends LogicalPlan with logical.Command { * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. */ -private[sql] case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { +case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field @@ -62,22 +58,23 @@ private[sql] case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkP cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow]) } + override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil + override def output: Seq[Attribute] = cmd.output override def children: Seq[SparkPlan] = Nil override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray + override def executeToIterator: Iterator[InternalRow] = sideEffectResult.toIterator + override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray protected override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.parallelize(sideEffectResult, 1) } - - override def argString: String = cmd.toString } - /** * An explain command for users to see how a command will be executed. * @@ -85,7 +82,7 @@ private[sql] case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkP * (but do NOT actually execute it). * * {{{ - * EXPLAIN (EXTENDED|CODEGEN) SELECT * FROM ... + * EXPLAIN (EXTENDED | CODEGEN) SELECT * FROM ... * }}} * * @param logicalPlan plan to explain @@ -103,7 +100,14 @@ case class ExplainCommand( // Run through the optimizer to generate the physical plan. override def run(sparkSession: SparkSession): Seq[Row] = try { - val queryExecution = sparkSession.executePlan(logicalPlan) + val queryExecution = + if (logicalPlan.isStreaming) { + // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the + // output mode does not matter since there is no `Sink`. + new IncrementalExecution(sparkSession, logicalPlan, OutputMode.Append(), "", 0) + } else { + sparkSession.sessionState.executePlan(logicalPlan) + } val outputString = if (codegen) { codegenString(queryExecution.executedPlan) @@ -117,101 +121,3 @@ case class ExplainCommand( ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) } } - -/** - * A command to list the column names for a table. This function creates a - * [[ShowColumnsCommand]] logical plan. - * - * The syntax of using this command in SQL is: - * {{{ - * SHOW COLUMNS (FROM | IN) table_identifier [(FROM | IN) database]; - * }}} - */ -case class ShowColumnsCommand(table: TableIdentifier) extends RunnableCommand { - // The result of SHOW COLUMNS has one column called 'result' - override val output: Seq[Attribute] = { - AttributeReference("result", StringType, nullable = false)() :: Nil - } - - override def run(sparkSession: SparkSession): Seq[Row] = { - sparkSession.sessionState.catalog.getTableMetadata(table).schema.map { c => - Row(c.name) - } - } -} - -/** - * A command to list the partition names of a table. If the partition spec is specified, - * partitions that match the spec are returned. [[AnalysisException]] exception is thrown under - * the following conditions: - * - * 1. If the command is called for a non partitioned table. - * 2. If the partition spec refers to the columns that are not defined as partitioning columns. - * - * This function creates a [[ShowPartitionsCommand]] logical plan - * - * The syntax of using this command in SQL is: - * {{{ - * SHOW PARTITIONS [db_name.]table_name [PARTITION(partition_spec)] - * }}} - */ -case class ShowPartitionsCommand( - table: TableIdentifier, - spec: Option[TablePartitionSpec]) extends RunnableCommand { - // The result of SHOW PARTITIONS has one column called 'result' - override val output: Seq[Attribute] = { - AttributeReference("result", StringType, nullable = false)() :: Nil - } - - private def getPartName(spec: TablePartitionSpec, partColNames: Seq[String]): String = { - partColNames.map { name => - PartitioningUtils.escapePathName(name) + "=" + PartitioningUtils.escapePathName(spec(name)) - }.mkString(File.separator) - } - - override def run(sparkSession: SparkSession): Seq[Row] = { - val catalog = sparkSession.sessionState.catalog - val db = table.database.getOrElse(catalog.getCurrentDatabase) - if (catalog.isTemporaryTable(table)) { - throw new AnalysisException("SHOW PARTITIONS is not allowed on a temporary table: " + - s"${table.unquotedString}") - } else { - val tab = catalog.getTableMetadata(table) - /** - * Validate and throws an [[AnalysisException]] exception under the following conditions: - * 1. If the table is not partitioned. - * 2. If it is a datasource table. - * 3. If it is a view or index table. - */ - if (tab.tableType == CatalogTableType.VIEW || - tab.tableType == CatalogTableType.INDEX) { - throw new AnalysisException("SHOW PARTITIONS is not allowed on a view or index table: " + - s"${tab.qualifiedName}") - } - if (!DDLUtils.isTablePartitioned(tab)) { - throw new AnalysisException("SHOW PARTITIONS is not allowed on a table that is not " + - s"partitioned: ${tab.qualifiedName}") - } - if (DDLUtils.isDatasourceTable(tab)) { - throw new AnalysisException("SHOW PARTITIONS is not allowed on a datasource table: " + - s"${tab.qualifiedName}") - } - /** - * Validate the partitioning spec by making sure all the referenced columns are - * defined as partitioning columns in table definition. An AnalysisException exception is - * thrown if the partitioning spec is invalid. - */ - if (spec.isDefined) { - val badColumns = spec.get.keySet.filterNot(tab.partitionColumns.map(_.name).contains) - if (badColumns.nonEmpty) { - throw new AnalysisException( - s"Non-partitioning column(s) [${badColumns.mkString(", ")}] are " + - s"specified for SHOW PARTITIONS") - } - } - val partNames = - catalog.listPartitions(table, spec).map(p => getPartName(p.spec, tab.partitionColumnNames)) - partNames.map { p => Row(p) } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index f670f63472bf7..de7d1fa0afe69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -26,9 +26,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases -import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSource, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.HiveSerDe import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types._ @@ -36,7 +36,7 @@ import org.apache.spark.sql.types._ /** * A command used to create a data source table. * - * Note: This is different from [[CreateTable]]. Please check the syntax for difference. + * Note: This is different from [[CreateTableCommand]]. Please check the syntax for difference. * This is not intended for temporary tables. * * The syntax of using this command in SQL is: @@ -51,6 +51,8 @@ case class CreateDataSourceTableCommand( userSpecifiedSchema: Option[StructType], provider: String, options: Map[String, String], + partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], ignoreIfExists: Boolean, managedIfNoPath: Boolean) extends RunnableCommand { @@ -71,20 +73,18 @@ case class CreateDataSourceTableCommand( s"characters, numbers and _.") } - val tableName = tableIdent.unquotedString val sessionState = sparkSession.sessionState - if (sessionState.catalog.tableExists(tableIdent)) { if (ignoreIfExists) { return Seq.empty[Row] } else { - throw new AnalysisException(s"Table $tableName already exists.") + throw new AnalysisException(s"Table ${tableIdent.unquotedString} already exists.") } } var isExternal = true val optionsWithPath = - if (!options.contains("path") && managedIfNoPath) { + if (!new CaseInsensitiveMap(options).contains("path") && managedIfNoPath) { isExternal = false options + ("path" -> sessionState.catalog.defaultTablePath(tableIdent)) } else { @@ -97,14 +97,14 @@ case class CreateDataSourceTableCommand( userSpecifiedSchema = userSpecifiedSchema, className = provider, bucketSpec = None, - options = optionsWithPath).resolveRelation() + options = optionsWithPath).resolveRelation(checkPathExist = false) CreateDataSourceTableUtils.createDataSourceTable( sparkSession = sparkSession, tableIdent = tableIdent, userSpecifiedSchema = userSpecifiedSchema, - partitionColumns = Array.empty[String], - bucketSpec = None, + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, provider = provider, options = optionsWithPath, isExternal = isExternal) @@ -116,8 +116,8 @@ case class CreateDataSourceTableCommand( /** * A command used to create a data source table using the result of a query. * - * Note: This is different from [[CreateTableAsSelect]]. Please check the syntax for difference. - * This is not intended for temporary tables. + * Note: This is different from [[CreateTableAsSelectLogicalPlan]]. Please check the syntax for + * difference. This is not intended for temporary tables. * * The syntax of using this command in SQL is: * {{{ @@ -136,6 +136,8 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { + override protected def innerChildren: Seq[LogicalPlan] = Seq(query) + override def run(sparkSession: SparkSession): Seq[Row] = { // Since we are saving metadata to metastore, we need to check if metastore supports // the table name and database name we have for this query. MetaStoreUtils.validateName @@ -152,20 +154,25 @@ case class CreateDataSourceTableAsSelectCommand( s"characters, numbers and _.") } - val tableName = tableIdent.unquotedString val sessionState = sparkSession.sessionState + val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) + val tableIdentWithDB = tableIdent.copy(database = Some(db)) + val tableName = tableIdentWithDB.unquotedString + var createMetastoreTable = false var isExternal = true val optionsWithPath = - if (!options.contains("path")) { + if (!new CaseInsensitiveMap(options).contains("path")) { isExternal = false options + ("path" -> sessionState.catalog.defaultTablePath(tableIdent)) } else { options } - var existingSchema = None: Option[StructType] - if (sparkSession.sessionState.catalog.tableExists(tableIdent)) { + var existingSchema = Option.empty[StructType] + // Pass a table identifier with database part, so that `tableExists` won't check temp views + // unexpectedly. + if (sparkSession.sessionState.catalog.tableExists(tableIdentWithDB)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -189,15 +196,32 @@ case class CreateDataSourceTableAsSelectCommand( // TODO: Check that options from the resolved relation match the relation that we are // inserting into (i.e. using the same compression). - EliminateSubqueryAliases( - sessionState.catalog.lookupRelation(tableIdent)) match { + // Pass a table identifier with database part, so that `tableExists` won't check temp + // views unexpectedly. + EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => + // check if the file formats match + l.relation match { + case r: HadoopFsRelation if r.fileFormat.getClass != dataSource.providingClass => + throw new AnalysisException( + s"The file format of the existing table $tableName is " + + s"`${r.fileFormat.getClass.getName}`. It doesn't match the specified " + + s"format `$provider`") + case _ => + } + if (query.schema.size != l.schema.size) { + throw new AnalysisException( + s"The column number of the existing schema[${l.schema}] " + + s"doesn't match the data schema[${query.schema}]'s") + } existingSchema = Some(l.schema) + case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) => + existingSchema = DDLUtils.getSchemaFromTableProperties(s.metadata) case o => throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") } case SaveMode.Overwrite => - sparkSession.sql(s"DROP TABLE IF EXISTS $tableName") + sessionState.catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true) // Need to create the table again. createMetastoreTable = true } @@ -221,8 +245,13 @@ case class CreateDataSourceTableAsSelectCommand( bucketSpec = bucketSpec, options = optionsWithPath) - val result = dataSource.write(mode, df) - + val result = try { + dataSource.write(mode, df) + } catch { + case ex: AnalysisException => + logError(s"Failed to write to table $tableName in $mode mode", ex) + throw ex + } if (createMetastoreTable) { // We will use the schema of resolved.relation as the schema of the table (instead of // the schema of df). It is important since the nullability may be changed by the relation @@ -239,12 +268,30 @@ case class CreateDataSourceTableAsSelectCommand( } // Refresh the cache of the table in the catalog. - sessionState.catalog.refreshTable(tableIdent) + sessionState.catalog.refreshTable(tableIdentWithDB) Seq.empty[Row] } } + object CreateDataSourceTableUtils extends Logging { + + val DATASOURCE_PREFIX = "spark.sql.sources." + val DATASOURCE_PROVIDER = DATASOURCE_PREFIX + "provider" + val DATASOURCE_WRITEJOBUUID = DATASOURCE_PREFIX + "writeJobUUID" + val DATASOURCE_OUTPUTPATH = DATASOURCE_PREFIX + "output.path" + val DATASOURCE_SCHEMA = DATASOURCE_PREFIX + "schema" + val DATASOURCE_SCHEMA_PREFIX = DATASOURCE_SCHEMA + "." + val DATASOURCE_SCHEMA_NUMPARTS = DATASOURCE_SCHEMA_PREFIX + "numParts" + val DATASOURCE_SCHEMA_NUMPARTCOLS = DATASOURCE_SCHEMA_PREFIX + "numPartCols" + val DATASOURCE_SCHEMA_NUMSORTCOLS = DATASOURCE_SCHEMA_PREFIX + "numSortCols" + val DATASOURCE_SCHEMA_NUMBUCKETS = DATASOURCE_SCHEMA_PREFIX + "numBuckets" + val DATASOURCE_SCHEMA_NUMBUCKETCOLS = DATASOURCE_SCHEMA_PREFIX + "numBucketCols" + val DATASOURCE_SCHEMA_PART_PREFIX = DATASOURCE_SCHEMA_PREFIX + "part." + val DATASOURCE_SCHEMA_PARTCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "partCol." + val DATASOURCE_SCHEMA_BUCKETCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "bucketCol." + val DATASOURCE_SCHEMA_SORTCOL_PREFIX = DATASOURCE_SCHEMA_PREFIX + "sortCol." + /** * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), * i.e. if this name only contains characters, numbers, and _. @@ -269,7 +316,7 @@ object CreateDataSourceTableUtils extends Logging { options: Map[String, String], isExternal: Boolean): Unit = { val tableProperties = new mutable.HashMap[String, String] - tableProperties.put("spark.sql.sources.provider", provider) + tableProperties.put(DATASOURCE_PROVIDER, provider) // Saves optional user specified schema. Serialized JSON schema string may be too long to be // stored into a single metastore SerDe property. In this case, we split the JSON string and @@ -279,34 +326,32 @@ object CreateDataSourceTableUtils extends Logging { val schemaJsonString = schema.json // Split the JSON string. val parts = schemaJsonString.grouped(threshold).toSeq - tableProperties.put("spark.sql.sources.schema.numParts", parts.size.toString) + tableProperties.put(DATASOURCE_SCHEMA_NUMPARTS, parts.size.toString) parts.zipWithIndex.foreach { case (part, index) => - tableProperties.put(s"spark.sql.sources.schema.part.$index", part) + tableProperties.put(s"$DATASOURCE_SCHEMA_PART_PREFIX$index", part) } } if (userSpecifiedSchema.isDefined && partitionColumns.length > 0) { - tableProperties.put("spark.sql.sources.schema.numPartCols", partitionColumns.length.toString) + tableProperties.put(DATASOURCE_SCHEMA_NUMPARTCOLS, partitionColumns.length.toString) partitionColumns.zipWithIndex.foreach { case (partCol, index) => - tableProperties.put(s"spark.sql.sources.schema.partCol.$index", partCol) + tableProperties.put(s"$DATASOURCE_SCHEMA_PARTCOL_PREFIX$index", partCol) } } if (userSpecifiedSchema.isDefined && bucketSpec.isDefined) { val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get - tableProperties.put("spark.sql.sources.schema.numBuckets", numBuckets.toString) - tableProperties.put("spark.sql.sources.schema.numBucketCols", - bucketColumnNames.length.toString) + tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETS, numBuckets.toString) + tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETCOLS, bucketColumnNames.length.toString) bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) => - tableProperties.put(s"spark.sql.sources.schema.bucketCol.$index", bucketCol) + tableProperties.put(s"$DATASOURCE_SCHEMA_BUCKETCOL_PREFIX$index", bucketCol) } if (sortColumnNames.nonEmpty) { - tableProperties.put("spark.sql.sources.schema.numSortCols", - sortColumnNames.length.toString) + tableProperties.put(DATASOURCE_SCHEMA_NUMSORTCOLS, sortColumnNames.length.toString) sortColumnNames.zipWithIndex.foreach { case (sortCol, index) => - tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol) + tableProperties.put(s"$DATASOURCE_SCHEMA_SORTCOL_PREFIX$index", sortCol) } } } @@ -349,6 +394,7 @@ object CreateDataSourceTableUtils extends Logging { inputFormat = None, outputFormat = None, serde = None, + compressed = false, serdeProperties = options ), properties = tableProperties.toMap) @@ -368,6 +414,7 @@ object CreateDataSourceTableUtils extends Logging { inputFormat = serde.inputFormat, outputFormat = serde.outputFormat, serde = serde.serde, + compressed = false, serdeProperties = options ), schema = relation.schema.map { f => @@ -380,15 +427,16 @@ object CreateDataSourceTableUtils extends Logging { // TODO: Support persisting partitioned data source relations in Hive compatible format val qualifiedTableName = tableIdent.quotedString val skipHiveMetadata = options.getOrElse("skipHiveMetadata", "false").toBoolean - val (hiveCompatibleTable, logMessage) = (maybeSerDe, dataSource.resolveRelation()) match { + val resolvedRelation = dataSource.resolveRelation(checkPathExist = false) + val (hiveCompatibleTable, logMessage) = (maybeSerDe, resolvedRelation) match { case _ if skipHiveMetadata => val message = s"Persisting partitioned data source relation $qualifiedTableName into " + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive." (None, message) - case (Some(serde), relation: HadoopFsRelation) - if relation.location.paths.length == 1 && relation.partitionSchema.isEmpty => + case (Some(serde), relation: HadoopFsRelation) if relation.location.paths.length == 1 && + relation.partitionSchema.isEmpty && relation.bucketSpec.isEmpty => val hiveTable = newHiveCompatibleMetastoreTable(relation, serde) val message = s"Persisting data source relation $qualifiedTableName with a single input path " + @@ -403,6 +451,13 @@ object CreateDataSourceTableUtils extends Logging { "Input path(s): " + relation.location.paths.mkString("\n", "\n", "") (None, message) + case (Some(serde), relation: HadoopFsRelation) if relation.bucketSpec.nonEmpty => + val message = + s"Persisting bucketed data source relation $qualifiedTableName into " + + "Hive metastore in Spark SQL specific format, which is NOT compatible with Hive. " + + "Input path(s): " + relation.location.paths.mkString("\n", "\n", "") + (None, message) + case (Some(serde), relation: HadoopFsRelation) => val message = s"Persisting data source relation $qualifiedTableName with multiple input paths into " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala index cefe0f6e629b0..e5a6a5f60b8a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala @@ -33,9 +33,9 @@ import org.apache.spark.sql.types.StringType */ case class ShowDatabasesCommand(databasePattern: Option[String]) extends RunnableCommand { - // The result of SHOW DATABASES has one column called 'result' + // The result of SHOW DATABASES has one column called 'databaseName' override val output: Seq[Attribute] = { - AttributeReference("result", StringType, nullable = false)() :: Nil + AttributeReference("databaseName", StringType, nullable = false)() :: Nil } override def run(sparkSession: SparkSession): Seq[Row] = { @@ -59,6 +59,4 @@ case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { sparkSession.sessionState.catalog.setCurrentDatabase(databaseName) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 44647116b4889..d63d29defdd67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -17,43 +17,29 @@ package org.apache.spark.sql.execution.command +import scala.collection.{GenMap, GenSeq} +import scala.collection.parallel.ForkJoinTaskSupport +import scala.concurrent.forkjoin.ForkJoinPool import scala.util.control.NonFatal -import org.apache.spark.internal.Logging +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} + import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} -import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTablePartition, CatalogTableType, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ +import org.apache.spark.sql.execution.datasources.BucketSpec +import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.types._ - - +import org.apache.spark.util.SerializableConfiguration // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL -/** - * A DDL command that is not supported right now. Since we have already implemented - * the parsing rules for some commands that are not allowed, we use this as the base class - * of those commands. - */ -abstract class UnsupportedCommand(exception: ParseException) extends RunnableCommand { - - // Throws the ParseException when we create this command. - throw exception - - override def run(sparkSession: SparkSession): Seq[Row] = { - Seq.empty[Row] - } - - override val output: Seq[Attribute] = { - Seq(AttributeReference("result", StringType, nullable = false)()) - } - -} - /** * A command for users to create a new database. * @@ -61,10 +47,13 @@ abstract class UnsupportedCommand(exception: ParseException) extends RunnableCom * unless 'ifNotExists' is true. * The syntax of using this command in SQL is: * {{{ - * CREATE DATABASE|SCHEMA [IF NOT EXISTS] database_name + * CREATE (DATABASE|SCHEMA) [IF NOT EXISTS] database_name + * [COMMENT database_comment] + * [LOCATION database_directory] + * [WITH DBPROPERTIES (property_name=property_value, ...)]; * }}} */ -case class CreateDatabase( +case class CreateDatabaseCommand( databaseName: String, ifNotExists: Boolean, path: Option[String], @@ -83,8 +72,6 @@ case class CreateDatabase( ifNotExists) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } @@ -104,7 +91,7 @@ case class CreateDatabase( * DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]; * }}} */ -case class DropDatabase( +case class DropDatabaseCommand( databaseName: String, ifExists: Boolean, cascade: Boolean) @@ -114,8 +101,6 @@ case class DropDatabase( sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } /** @@ -127,7 +112,7 @@ case class DropDatabase( * ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) * }}} */ -case class AlterDatabaseProperties( +case class AlterDatabasePropertiesCommand( databaseName: String, props: Map[String, String]) extends RunnableCommand { @@ -139,8 +124,6 @@ case class AlterDatabaseProperties( Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } /** @@ -153,7 +136,7 @@ case class AlterDatabaseProperties( * DESCRIBE DATABASE [EXTENDED] db_name * }}} */ -case class DescribeDatabase( +case class DescribeDatabaseCommand( databaseName: String, extended: Boolean) extends RunnableCommand { @@ -194,38 +177,32 @@ case class DescribeDatabase( * DROP VIEW [IF EXISTS] [db_name.]view_name; * }}} */ -case class DropTable( +case class DropTableCommand( tableName: TableIdentifier, ifExists: Boolean, isView: Boolean) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (!catalog.tableExists(tableName)) { - if (!ifExists) { - val objectName = if (isView) "View" else "Table" - logError(s"$objectName '${tableName.quotedString}' does not exist") - } - } else { - // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view - // issue an exception. - catalog.getTableMetadataOption(tableName).map(_.tableType match { - case CatalogTableType.VIEW if !isView => - throw new AnalysisException( - "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") - case o if o != CatalogTableType.VIEW && isView => - throw new AnalysisException( - s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") - case _ => - }) - try { - sparkSession.cacheManager.tryUncacheQuery(sparkSession.table(tableName.quotedString)) - } catch { - case NonFatal(e) => log.warn(s"${e.getMessage}", e) - } - catalog.invalidateTable(tableName) - catalog.dropTable(tableName, ifExists) + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadataOption(tableName).map(_.tableType match { + case CatalogTableType.VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + }) + try { + sparkSession.sharedState.cacheManager.uncacheQuery( + sparkSession.table(tableName.quotedString)) + } catch { + case NonFatal(e) => log.warn(e.toString, e) } + catalog.refreshTable(tableName) + catalog.dropTable(tableName, ifExists) Seq.empty[Row] } } @@ -239,22 +216,20 @@ case class DropTable( * ALTER VIEW view1 SET TBLPROPERTIES ('key1' = 'val1', 'key2' = 'val2', ...); * }}} */ -case class AlterTableSetProperties( +case class AlterTableSetPropertiesCommand( tableName: TableIdentifier, properties: Map[String, String], isView: Boolean) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { + val ident = if (isView) "VIEW" else "TABLE" val catalog = sparkSession.sessionState.catalog DDLUtils.verifyAlterTableType(catalog, tableName, isView) + DDLUtils.verifyTableProperties(properties.keys.toSeq, s"ALTER $ident") val table = catalog.getTableMetadata(tableName) - val newProperties = table.properties ++ properties - if (DDLUtils.isDatasourceTable(newProperties)) { - throw new AnalysisException( - "alter table properties is not supported for tables defined using the datasource API") - } - val newTable = table.copy(properties = newProperties) + // This overrides old properties + val newTable = table.copy(properties = table.properties ++ properties) catalog.alterTable(newTable) Seq.empty[Row] } @@ -270,7 +245,7 @@ case class AlterTableSetProperties( * ALTER VIEW view1 UNSET TBLPROPERTIES [IF EXISTS] ('key1', 'key2', ...); * }}} */ -case class AlterTableUnsetProperties( +case class AlterTableUnsetPropertiesCommand( tableName: TableIdentifier, propKeys: Seq[String], ifExists: Boolean, @@ -278,18 +253,16 @@ case class AlterTableUnsetProperties( extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { + val ident = if (isView) "VIEW" else "TABLE" val catalog = sparkSession.sessionState.catalog DDLUtils.verifyAlterTableType(catalog, tableName, isView) + DDLUtils.verifyTableProperties(propKeys, s"ALTER $ident") val table = catalog.getTableMetadata(tableName) - if (DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "alter table properties is not supported for datasource tables") - } if (!ifExists) { propKeys.foreach { k => if (!table.properties.contains(k)) { throw new AnalysisException( - s"attempted to unset non-existent property '$k' in table '$tableName'") + s"Attempted to unset non-existent property '$k' in table '${table.identifier}'") } } } @@ -310,29 +283,46 @@ case class AlterTableUnsetProperties( * ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties; * }}} */ -case class AlterTableSerDeProperties( +case class AlterTableSerDePropertiesCommand( tableName: TableIdentifier, serdeClassName: Option[String], serdeProperties: Option[Map[String, String]], - partition: Option[Map[String, String]]) + partSpec: Option[TablePartitionSpec]) extends RunnableCommand { // should never happen if we parsed things correctly require(serdeClassName.isDefined || serdeProperties.isDefined, - "alter table attempted to set neither serde class name nor serde properties") + "ALTER TABLE attempted to set neither serde class name nor serde properties") override def run(sparkSession: SparkSession): Seq[Row] = { + DDLUtils.verifyTableProperties( + serdeProperties.toSeq.flatMap(_.keys.toSeq), + "ALTER TABLE SERDEPROPERTIES") val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) - // Do not support setting serde for datasource tables + // For datasource tables, disallow setting serde or specifying partition + if (partSpec.isDefined && DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException("Operation not allowed: ALTER TABLE SET " + + "[SERDE | SERDEPROPERTIES] for a specific partition is not supported " + + "for tables created with the datasource API") + } if (serdeClassName.isDefined && DDLUtils.isDatasourceTable(table)) { - throw new AnalysisException( - "alter table serde is not supported for datasource tables") + throw new AnalysisException("Operation not allowed: ALTER TABLE SET SERDE is " + + "not supported for tables created with the datasource API") + } + if (partSpec.isEmpty) { + val newTable = table.withNewStorage( + serde = serdeClassName.orElse(table.storage.serde), + serdeProperties = table.storage.serdeProperties ++ serdeProperties.getOrElse(Map())) + catalog.alterTable(newTable) + } else { + val spec = partSpec.get + val part = catalog.getPartition(table.identifier, spec) + val newPart = part.copy(storage = part.storage.copy( + serde = serdeClassName.orElse(part.storage.serde), + serdeProperties = part.storage.serdeProperties ++ serdeProperties.getOrElse(Map()))) + catalog.alterPartitions(table.identifier, Seq(newPart)) } - val newTable = table.withNewStorage( - serde = serdeClassName.orElse(table.storage.serde), - serdeProperties = table.storage.serdeProperties ++ serdeProperties.getOrElse(Map())) - catalog.alterTable(newTable) Seq.empty[Row] } @@ -350,7 +340,7 @@ case class AlterTableSerDeProperties( * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] * }}} */ -case class AlterTableAddPartition( +case class AlterTableAddPartitionCommand( tableName: TableIdentifier, partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])], ifNotExists: Boolean) @@ -361,13 +351,13 @@ case class AlterTableAddPartition( val table = catalog.getTableMetadata(tableName) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( - "alter table add partition is not allowed for tables defined using the datasource API") + "ALTER TABLE ADD PARTITION is not allowed for tables defined using the datasource API") } val parts = partitionSpecsAndLocs.map { case (spec, location) => // inherit table storage format (possibly except for location) CatalogTablePartition(spec, table.storage.copy(locationUri = location)) } - catalog.createPartitions(tableName, parts, ignoreIfExists = ifNotExists) + catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) Seq.empty[Row] } @@ -381,7 +371,7 @@ case class AlterTableAddPartition( * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2; * }}} */ -case class AlterTableRenamePartition( +case class AlterTableRenamePartitionCommand( tableName: TableIdentifier, oldPartition: TablePartitionSpec, newPartition: TablePartitionSpec) @@ -409,7 +399,7 @@ case class AlterTableRenamePartition( * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; * }}} */ -case class AlterTableDropPartition( +case class AlterTableDropPartitionCommand( tableName: TableIdentifier, specs: Seq[TablePartitionSpec], ifExists: Boolean) @@ -420,20 +410,213 @@ case class AlterTableDropPartition( val table = catalog.getTableMetadata(tableName) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( - "alter table drop partition is not allowed for tables defined using the datasource API") + "ALTER TABLE DROP PARTITIONS is not allowed for tables defined using the datasource API") } - catalog.dropPartitions(tableName, specs, ignoreIfNotExists = ifExists) + catalog.dropPartitions(table.identifier, specs, ignoreIfNotExists = ifExists) Seq.empty[Row] } } -case class AlterTableSetFileFormat( + +case class PartitionStatistics(numFiles: Int, totalSize: Long) + +/** + * Recover Partitions in ALTER TABLE: recover all the partition in the directory of a table and + * update the catalog. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table RECOVER PARTITIONS; + * MSCK REPAIR TABLE table; + * }}} + */ +case class AlterTableRecoverPartitionsCommand( tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec], - fileFormat: Seq[String], - genericFormat: Option[String])(exception: ParseException) - extends UnsupportedCommand(exception) with Logging + cmd: String = "ALTER TABLE RECOVER PARTITIONS") extends RunnableCommand { + + // These are list of statistics that can be collected quickly without requiring a scan of the data + // see https://github.com/apache/hive/blob/master/ + // common/src/java/org/apache/hadoop/hive/common/StatsSetupConst.java + val NUM_FILES = "numFiles" + val TOTAL_SIZE = "totalSize" + val DDL_TIME = "transient_lastDdlTime" + + private def getPathFilter(hadoopConf: Configuration): PathFilter = { + // Dummy jobconf to get to the pathFilter defined in configuration + // It's very expensive to create a JobConf(ClassUtil.findContainingJar() is slow) + val jobConf = new JobConf(hadoopConf, this.getClass) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + new PathFilter { + override def accept(path: Path): Boolean = { + val name = path.getName + if (name != "_SUCCESS" && name != "_temporary" && !name.startsWith(".")) { + pathFilter == null || pathFilter.accept(path) + } else { + false + } + } + } + } + + override def run(spark: SparkSession): Seq[Row] = { + val catalog = spark.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + val tableIdentWithDB = table.identifier.quotedString + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + s"Operation not allowed: $cmd on datasource tables: $tableIdentWithDB") + } + if (!DDLUtils.isTablePartitioned(table)) { + throw new AnalysisException( + s"Operation not allowed: $cmd only works on partitioned tables: $tableIdentWithDB") + } + if (table.storage.locationUri.isEmpty) { + throw new AnalysisException(s"Operation not allowed: $cmd only works on table with " + + s"location provided: $tableIdentWithDB") + } + + val root = new Path(table.storage.locationUri.get) + logInfo(s"Recover all the partitions in $root") + val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + + val threshold = spark.conf.get("spark.rdd.parallelListingThreshold", "10").toInt + val hadoopConf = spark.sparkContext.hadoopConfiguration + val pathFilter = getPathFilter(hadoopConf) + val partitionSpecsAndLocs = scanPartitions( + spark, fs, pathFilter, root, Map(), table.partitionColumnNames.map(_.toLowerCase), threshold) + val total = partitionSpecsAndLocs.length + logInfo(s"Found $total partitions in $root") + + val partitionStats = if (spark.sqlContext.conf.gatherFastStats) { + gatherPartitionStats(spark, partitionSpecsAndLocs, fs, pathFilter, threshold) + } else { + GenMap.empty[String, PartitionStatistics] + } + logInfo(s"Finished to gather the fast stats for all $total partitions.") + + addPartitions(spark, table, partitionSpecsAndLocs, partitionStats) + logInfo(s"Recovered all partitions ($total).") + Seq.empty[Row] + } + + @transient private lazy val evalTaskSupport = new ForkJoinTaskSupport(new ForkJoinPool(8)) + + private def scanPartitions( + spark: SparkSession, + fs: FileSystem, + filter: PathFilter, + path: Path, + spec: TablePartitionSpec, + partitionNames: Seq[String], + threshold: Int): GenSeq[(TablePartitionSpec, Path)] = { + if (partitionNames.isEmpty) { + return Seq(spec -> path) + } + + val statuses = fs.listStatus(path, filter) + val statusPar: GenSeq[FileStatus] = + if (partitionNames.length > 1 && statuses.length > threshold || partitionNames.length > 2) { + // parallelize the list of partitions here, then we can have better parallelism later. + val parArray = statuses.par + parArray.tasksupport = evalTaskSupport + parArray + } else { + statuses + } + statusPar.flatMap { st => + val name = st.getPath.getName + if (st.isDirectory && name.contains("=")) { + val ps = name.split("=", 2) + val columnName = PartitioningUtils.unescapePathName(ps(0)).toLowerCase + // TODO: Validate the value + val value = PartitioningUtils.unescapePathName(ps(1)) + // comparing with case-insensitive, but preserve the case + if (columnName == partitionNames.head) { + scanPartitions(spark, fs, filter, st.getPath, spec ++ Map(columnName -> value), + partitionNames.drop(1), threshold) + } else { + logWarning(s"expect partition column ${partitionNames.head}, but got ${ps(0)}, ignore it") + Seq() + } + } else { + logWarning(s"ignore ${new Path(path, name)}") + Seq() + } + } + } + + private def gatherPartitionStats( + spark: SparkSession, + partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)], + fs: FileSystem, + pathFilter: PathFilter, + threshold: Int): GenMap[String, PartitionStatistics] = { + if (partitionSpecsAndLocs.length > threshold) { + val hadoopConf = spark.sparkContext.hadoopConfiguration + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val serializedPaths = partitionSpecsAndLocs.map(_._2.toString).toArray + + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(serializedPaths.length, + Math.min(spark.sparkContext.defaultParallelism, 10000)) + // gather the fast stats for all the partitions otherwise Hive metastore will list all the + // files for all the new partitions in sequential way, which is super slow. + logInfo(s"Gather the fast stats in parallel using $numParallelism tasks.") + spark.sparkContext.parallelize(serializedPaths, numParallelism) + .mapPartitions { paths => + val pathFilter = getPathFilter(serializableConfiguration.value) + paths.map(new Path(_)).map{ path => + val fs = path.getFileSystem(serializableConfiguration.value) + val statuses = fs.listStatus(path, pathFilter) + (path.toString, PartitionStatistics(statuses.length, statuses.map(_.getLen).sum)) + } + }.collectAsMap() + } else { + partitionSpecsAndLocs.map { case (_, location) => + val statuses = fs.listStatus(location, pathFilter) + (location.toString, PartitionStatistics(statuses.length, statuses.map(_.getLen).sum)) + }.toMap + } + } + + private def addPartitions( + spark: SparkSession, + table: CatalogTable, + partitionSpecsAndLocs: GenSeq[(TablePartitionSpec, Path)], + partitionStats: GenMap[String, PartitionStatistics]): Unit = { + val total = partitionSpecsAndLocs.length + var done = 0L + // Hive metastore may not have enough memory to handle millions of partitions in single RPC, + // we should split them into smaller batches. Since Hive client is not thread safe, we cannot + // do this in parallel. + val batchSize = 100 + partitionSpecsAndLocs.toIterator.grouped(batchSize).foreach { batch => + val now = System.currentTimeMillis() / 1000 + val parts = batch.map { case (spec, location) => + val params = partitionStats.get(location.toString).map { + case PartitionStatistics(numFiles, totalSize) => + // This two fast stat could prevent Hive metastore to list the files again. + Map(NUM_FILES -> numFiles.toString, + TOTAL_SIZE -> totalSize.toString, + // Workaround a bug in HiveMetastore that try to mutate a read-only parameters. + // see metastore/src/java/org/apache/hadoop/hive/metastore/HiveMetaStore.java + DDL_TIME -> now.toString) + }.getOrElse(Map.empty) + // inherit table storage format (possibly except for location) + CatalogTablePartition( + spec, + table.storage.copy(locationUri = Some(location.toUri.toString)), + params) + } + spark.sessionState.catalog.createPartitions(tableName, parts, ignoreIfExists = true) + done += parts.length + logDebug(s"Recovered ${parts.length} partitions ($done/$total so far)") + } + } +} + /** * A command that sets the location of a table or a partition. @@ -446,7 +629,7 @@ case class AlterTableSetFileFormat( * ALTER TABLE table_name [PARTITION partition_spec] SET LOCATION "loc"; * }}} */ -case class AlterTableSetLocation( +case class AlterTableSetLocationCommand( tableName: TableIdentifier, partitionSpec: Option[TablePartitionSpec], location: String) @@ -458,16 +641,16 @@ case class AlterTableSetLocation( partitionSpec match { case Some(spec) => // Partition spec is specified, so we set the location only for this partition - val part = catalog.getPartition(tableName, spec) + val part = catalog.getPartition(table.identifier, spec) val newPart = if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( - "alter table set location for partition is not allowed for tables defined " + + "ALTER TABLE SET LOCATION for partition is not allowed for tables defined " + "using the datasource API") } else { part.copy(storage = part.storage.copy(locationUri = Some(location))) } - catalog.alterPartitions(tableName, Seq(newPart)) + catalog.alterPartitions(table.identifier, Seq(newPart)) case None => // No partition spec is specified, so we set the location for the table itself val newTable = @@ -482,42 +665,13 @@ case class AlterTableSetLocation( } Seq.empty[Row] } - } -case class AlterTableChangeCol( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec], - oldColName: String, - newColName: String, - dataType: DataType, - comment: Option[String], - afterColName: Option[String], - restrict: Boolean, - cascade: Boolean)(exception: ParseException) - extends UnsupportedCommand(exception) with Logging - -case class AlterTableAddCol( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec], - columns: StructType, - restrict: Boolean, - cascade: Boolean)(exception: ParseException) - extends UnsupportedCommand(exception) with Logging - -case class AlterTableReplaceCol( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec], - columns: StructType, - restrict: Boolean, - cascade: Boolean)(exception: ParseException) - extends UnsupportedCommand(exception) with Logging - -private[sql] object DDLUtils { +object DDLUtils { def isDatasourceTable(props: Map[String, String]): Boolean = { - props.contains("spark.sql.sources.provider") + props.contains(DATASOURCE_PROVIDER) } def isDatasourceTable(table: CatalogTable): Boolean = { @@ -542,9 +696,80 @@ private[sql] object DDLUtils { case _ => }) } + + /** + * If the given table properties (or SerDe properties) contains datasource properties, + * throw an exception. + */ + def verifyTableProperties(propKeys: Seq[String], operation: String): Unit = { + val datasourceKeys = propKeys.filter(_.startsWith(DATASOURCE_PREFIX)) + if (datasourceKeys.nonEmpty) { + throw new AnalysisException(s"Operation not allowed: $operation property keys may not " + + s"start with '$DATASOURCE_PREFIX': ${datasourceKeys.mkString("[", ", ", "]")}") + } + } + def isTablePartitioned(table: CatalogTable): Boolean = { - table.partitionColumns.size > 0 || - table.properties.contains("spark.sql.sources.schema.numPartCols") + table.partitionColumns.nonEmpty || table.properties.contains(DATASOURCE_SCHEMA_NUMPARTCOLS) } -} + // A persisted data source table may not store its schema in the catalog. In this case, its schema + // will be inferred at runtime when the table is referenced. + def getSchemaFromTableProperties(metadata: CatalogTable): Option[StructType] = { + require(isDatasourceTable(metadata)) + val props = metadata.properties + if (props.isDefinedAt(DATASOURCE_SCHEMA)) { + // Originally, we used spark.sql.sources.schema to store the schema of a data source table. + // After SPARK-6024, we removed this flag. + // Although we are not using spark.sql.sources.schema any more, we need to still support. + props.get(DATASOURCE_SCHEMA).map(DataType.fromJson(_).asInstanceOf[StructType]) + } else { + metadata.properties.get(DATASOURCE_SCHEMA_NUMPARTS).map { numParts => + val parts = (0 until numParts.toInt).map { index => + val part = metadata.properties.get(s"$DATASOURCE_SCHEMA_PART_PREFIX$index").orNull + if (part == null) { + throw new AnalysisException( + "Could not read schema from the metastore because it is corrupted " + + s"(missing part $index of the schema, $numParts parts are expected).") + } + + part + } + // Stick all parts back to a single schema string. + DataType.fromJson(parts.mkString).asInstanceOf[StructType] + } + } + } + + private def getColumnNamesByType( + props: Map[String, String], colType: String, typeName: String): Seq[String] = { + require(isDatasourceTable(props)) + + for { + numCols <- props.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").toSeq + index <- 0 until numCols.toInt + } yield props.getOrElse( + s"$DATASOURCE_SCHEMA_PREFIX${colType}Col.$index", + throw new AnalysisException( + s"Corrupted $typeName in catalog: $numCols parts expected, but part $index is missing." + ) + ) + } + + def getPartitionColumnsFromTableProperties(metadata: CatalogTable): Seq[String] = { + getColumnNamesByType(metadata.properties, "part", "partitioning columns") + } + + def getBucketSpecFromTableProperties(metadata: CatalogTable): Option[BucketSpec] = { + if (isDatasourceTable(metadata)) { + metadata.properties.get(DATASOURCE_SCHEMA_NUMBUCKETS).map { numBuckets => + BucketSpec( + numBuckets.toInt, + getColumnNamesByType(metadata.properties, "bucket", "bucketing columns"), + getColumnNamesByType(metadata.properties, "sort", "sorting columns")) + } + } else { + None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 5aa779ddeb6c7..26593d2918a6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource} import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionInfo} import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -38,12 +39,11 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] * }}} */ -// TODO: Use Seq[FunctionResource] instead of Seq[(String, String)] for resources. -case class CreateFunction( +case class CreateFunctionCommand( databaseName: Option[String], functionName: String, className: String, - resources: Seq[(String, String)], + resources: Seq[FunctionResource], isTemp: Boolean) extends RunnableCommand { @@ -51,9 +51,8 @@ case class CreateFunction( val catalog = sparkSession.sessionState.catalog if (isTemp) { if (databaseName.isDefined) { - throw new AnalysisException( - s"It is not allowed to provide database name when defining a temporary function. " + - s"However, database name ${databaseName.get} is provided.") + throw new AnalysisException(s"Specifying a database in CREATE TEMPORARY FUNCTION " + + s"is not allowed: '${databaseName.get}'") } // We first load resources and then put the builder in the function registry. // Please note that it is allowed to overwrite an existing temp function. @@ -82,8 +81,8 @@ case class CreateFunction( * DESCRIBE FUNCTION [EXTENDED] upper; * }}} */ -case class DescribeFunction( - functionName: String, +case class DescribeFunctionCommand( + functionName: FunctionIdentifier, isExtended: Boolean) extends RunnableCommand { override val output: Seq[Attribute] = { @@ -93,7 +92,7 @@ case class DescribeFunction( private def replaceFunctionName(usage: String, functionName: String): String = { if (usage == null) { - "To be added." + "N/A." } else { usage.replaceAll("_FUNC_", functionName) } @@ -101,7 +100,7 @@ case class DescribeFunction( override def run(sparkSession: SparkSession): Seq[Row] = { // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. - functionName.toLowerCase match { + functionName.funcName.toLowerCase match { case "<>" => Row(s"Function: $functionName") :: Row(s"Usage: a <> b - Returns TRUE if a is not equal to b") :: Nil @@ -116,12 +115,13 @@ case class DescribeFunction( Row(s"Function: case") :: Row(s"Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " + s"When a = b, returns c; when a = d, return e; else return f") :: Nil - case _ => sparkSession.sessionState.functionRegistry.lookupFunction(functionName) match { - case Some(info) => + case _ => + try { + val info = sparkSession.sessionState.catalog.lookupFunctionInfo(functionName) val result = Row(s"Function: ${info.getName}") :: Row(s"Class: ${info.getClassName}") :: - Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil + Row(s"Usage: ${replaceFunctionName(info.getUsage, info.getName)}") :: Nil if (isExtended) { result :+ @@ -129,9 +129,9 @@ case class DescribeFunction( } else { result } - - case None => Seq(Row(s"Function: $functionName not found.")) - } + } catch { + case _: NoSuchFunctionException => Seq(Row(s"Function: $functionName not found.")) + } } } } @@ -142,7 +142,7 @@ case class DescribeFunction( * ifExists: returns an error if the function doesn't exist, unless this is true. * isTemp: indicates if it is a temporary function. */ -case class DropFunction( +case class DropFunctionCommand( databaseName: Option[String], functionName: String, ifExists: Boolean, @@ -153,9 +153,11 @@ case class DropFunction( val catalog = sparkSession.sessionState.catalog if (isTemp) { if (databaseName.isDefined) { - throw new AnalysisException( - s"It is not allowed to provide database name when dropping a temporary function. " + - s"However, database name ${databaseName.get} is provided.") + throw new AnalysisException(s"Specifying a database in DROP TEMPORARY FUNCTION " + + s"is not allowed: '${databaseName.get}'") + } + if (FunctionRegistry.builtin.functionExists(functionName)) { + throw new AnalysisException(s"Cannot drop native function '$functionName'") } catalog.dropTempFunction(functionName, ifExists) } else { @@ -178,10 +180,13 @@ case class DropFunction( * For the pattern, '*' matches any sequence of characters (including no characters) and * '|' is for alternation. * For example, "show functions like 'yea*|windo*'" will return "window" and "year". - * - * TODO currently we are simply ignore the db */ -case class ShowFunctions(db: Option[String], pattern: Option[String]) extends RunnableCommand { +case class ShowFunctionsCommand( + db: Option[String], + pattern: Option[String], + showUserFunctions: Boolean, + showSystemFunctions: Boolean) extends RunnableCommand { + override val output: Seq[Attribute] = { val schema = StructType(StructField("function", StringType, nullable = false) :: Nil) schema.toAttributes @@ -194,7 +199,10 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru val functionNames = sparkSession.sessionState.catalog .listFunctions(dbName, pattern.getOrElse("*")) - .map(_.unquotedString) + .collect { + case (f, "USER") if showUserFunctions => f.unquotedString + case (f, "SYSTEM") if showSystemFunctions => f.unquotedString + } // The session catalog caches some persistent functions in the FunctionRegistry // so there can be duplicates. functionNames.distinct.sorted.map(Row(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala index 29bcb30592516..20b08946675d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.execution.command +import java.io.File +import java.net.URI + +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} /** * Adds a jar to the current session so it can be used (for UDFs or serdes). */ -case class AddJar(path: String) extends RunnableCommand { +case class AddJarCommand(path: String) extends RunnableCommand { override val output: Seq[Attribute] = { val schema = StructType( StructField("result", IntegerType, nullable = false) :: Nil) @@ -40,9 +45,57 @@ case class AddJar(path: String) extends RunnableCommand { /** * Adds a file to the current session so it can be used. */ -case class AddFile(path: String) extends RunnableCommand { +case class AddFileCommand(path: String) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { sparkSession.sparkContext.addFile(path) Seq.empty[Row] } } + +/** + * Returns a list of file paths that are added to resources. + * If file paths are provided, return the ones that are added to resources. + */ +case class ListFilesCommand(files: Seq[String] = Seq.empty[String]) extends RunnableCommand { + override val output: Seq[Attribute] = { + AttributeReference("Results", StringType, nullable = false)() :: Nil + } + override def run(sparkSession: SparkSession): Seq[Row] = { + val fileList = sparkSession.sparkContext.listFiles() + if (files.size > 0) { + files.map { f => + val uri = new URI(f) + val schemeCorrectedPath = uri.getScheme match { + case null | "local" => new File(f).getCanonicalFile.toURI.toString + case _ => f + } + new Path(schemeCorrectedPath).toUri.toString + }.collect { + case f if fileList.contains(f) => f + }.map(Row(_)) + } else { + fileList.map(Row(_)) + } + } +} + +/** + * Returns a list of jar files that are added to resources. + * If jar files are provided, return the ones that are added to resources. + */ +case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand { + override val output: Seq[Attribute] = { + AttributeReference("Results", StringType, nullable = false)() :: Nil + } + override def run(sparkSession: SparkSession): Seq[Row] = { + val jarList = sparkSession.sparkContext.listJars() + if (jars.nonEmpty) { + for { + jarName <- jars.map(f => new Path(f).getName) + jarPath <- jarList if jarPath.contains(jarName) + } yield Row(jarPath) + } else { + jarList.map(Row(_)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 6078918316d9e..ad0c779ff29ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -19,24 +19,35 @@ package org.apache.spark.sql.execution.command import java.io.File import java.net.URI +import java.util.Date import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal +import scala.util.Try + +import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, UnaryNode} -import org.apache.spark.sql.types.{BooleanType, MetadataBuilder, StringType} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ +import org.apache.spark.sql.execution.datasources.PartitioningUtils +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -case class CreateTableAsSelectLogicalPlan( - tableDesc: CatalogTable, - child: LogicalPlan, - allowExisting: Boolean) extends UnaryNode with Command { +case class CreateHiveTableAsSelectLogicalPlan( + tableDesc: CatalogTable, + query: LogicalPlan, + allowExisting: Boolean) extends Command { - override def output: Seq[Attribute] = Seq.empty[Attribute] + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override lazy val resolved: Boolean = tableDesc.identifier.database.isDefined && @@ -44,11 +55,16 @@ case class CreateTableAsSelectLogicalPlan( tableDesc.storage.serde.isDefined && tableDesc.storage.inputFormat.isDefined && tableDesc.storage.outputFormat.isDefined && - childrenResolved + query.resolved } /** - * A command to create a table with the same definition of the given existing table. + * A command to create a MANAGED table with the same definition of the given existing table. + * In the target table definition, the table comment is always empty but the column comments + * are identical to the ones defined in the source table. + * + * The CatalogTable attributes copied from the source table are storage(inputFormat, outputFormat, + * serde, compressed, properties), schema, provider, partitionColumnNames, bucketSpec. * * The syntax of using this command in SQL is: * {{{ @@ -56,29 +72,61 @@ case class CreateTableAsSelectLogicalPlan( * LIKE [other_db_name.]existing_table_name * }}} */ -case class CreateTableLike( +case class CreateTableLikeCommand( targetTable: TableIdentifier, sourceTable: TableIdentifier, ifNotExists: Boolean) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (!catalog.tableExists(sourceTable)) { - throw new AnalysisException( - s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'") - } - if (catalog.isTemporaryTable(sourceTable)) { - throw new AnalysisException( - s"Source table in CREATE TABLE LIKE cannot be temporary: '$sourceTable'") + val sourceTableDesc = catalog.getTempViewOrPermanentTableMetadata(sourceTable) + + if (DDLUtils.isDatasourceTable(sourceTableDesc) || + sourceTableDesc.tableType == CatalogTableType.VIEW) { + val outputSchema = + StructType(sourceTableDesc.schema.map { c => + val builder = new MetadataBuilder + c.comment.map(comment => builder.putString("comment", comment)) + StructField( + c.name, + CatalystSqlParser.parseDataType(c.dataType), + c.nullable, + metadata = builder.build()) + }) + val (schema, provider) = if (DDLUtils.isDatasourceTable(sourceTableDesc)) { + (DDLUtils.getSchemaFromTableProperties(sourceTableDesc).getOrElse(outputSchema), + sourceTableDesc.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER)) + } else { // VIEW + (outputSchema, sparkSession.sessionState.conf.defaultDataSourceName) + } + createDataSourceTable( + sparkSession = sparkSession, + tableIdent = targetTable, + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + bucketSpec = None, + provider = provider, + options = Map("path" -> catalog.defaultTablePath(targetTable)), + isExternal = false) + } else { + val newStorage = + sourceTableDesc.storage.copy( + locationUri = None, + serdeProperties = sourceTableDesc.storage.serdeProperties) + val newTableDesc = + CatalogTable( + identifier = targetTable, + tableType = CatalogTableType.MANAGED, + storage = newStorage, + schema = sourceTableDesc.schema, + partitionColumnNames = sourceTableDesc.partitionColumnNames, + sortColumnNames = sourceTableDesc.sortColumnNames, + bucketColumnNames = sourceTableDesc.bucketColumnNames, + numBuckets = sourceTableDesc.numBuckets) + + catalog.createTable(newTableDesc, ifNotExists) } - val tableToCreate = catalog.getTableMetadata(sourceTable).copy( - identifier = targetTable, - tableType = CatalogTableType.MANAGED, - createTime = System.currentTimeMillis, - lastAccessTime = -1).withNewStorage(locationUri = None) - - catalog.createTable(tableToCreate, ifNotExists) Seq.empty[Row] } } @@ -108,9 +156,11 @@ case class CreateTableLike( * [AS select_statement]; * }}} */ -case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand { +case class CreateTableCommand(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { + DDLUtils.verifyTableProperties(table.properties.keys.toSeq, "CREATE TABLE") + DDLUtils.verifyTableProperties(table.storage.serdeProperties.keys.toSeq, "CREATE TABLE") sparkSession.sessionState.catalog.createTable(table, ifNotExists) Seq.empty[Row] } @@ -127,7 +177,7 @@ case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends Runnab * ALTER VIEW view1 RENAME TO view2; * }}} */ -case class AlterTableRename( +case class AlterTableRenameCommand( oldName: TableIdentifier, newName: TableIdentifier, isView: Boolean) @@ -135,9 +185,38 @@ case class AlterTableRename( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - DDLUtils.verifyAlterTableType(catalog, oldName, isView) - catalog.invalidateTable(oldName) - catalog.renameTable(oldName, newName) + // If this is a temp view, just rename the view. + // Otherwise, if this is a real table, we also need to uncache and invalidate the table. + if (catalog.isTemporaryTable(oldName)) { + catalog.renameTable(oldName, newName) + } else { + val table = catalog.getTableMetadata(oldName) + DDLUtils.verifyAlterTableType(catalog, table.identifier, isView) + // If an exception is thrown here we can just assume the table is uncached; + // this can happen with Hive tables when the underlying catalog is in-memory. + val wasCached = Try(sparkSession.catalog.isCached(oldName.unquotedString)).getOrElse(false) + if (wasCached) { + try { + sparkSession.catalog.uncacheTable(oldName.unquotedString) + } catch { + case NonFatal(e) => log.warn(e.toString, e) + } + } + // For datasource tables, we also need to update the "path" serde property + if (DDLUtils.isDatasourceTable(table) && table.tableType == CatalogTableType.MANAGED) { + val newPath = catalog.defaultTablePath(newName) + val newTable = table.withNewStorage( + serdeProperties = table.storage.serdeProperties ++ Map("path" -> newPath)) + catalog.alterTable(newTable) + } + // Invalidate the table last, otherwise uncaching the table would load the logical plan + // back into the hive metastore cache + catalog.refreshTable(oldName) + catalog.renameTable(oldName, newName) + if (wasCached) { + sparkSession.catalog.cacheTable(newName.unquotedString) + } + } Seq.empty[Row] } @@ -152,7 +231,7 @@ case class AlterTableRename( * [PARTITION (partcol1=val1, partcol2=val2 ...)] * }}} */ -case class LoadData( +case class LoadDataCommand( table: TableIdentifier, path: String, isLocal: Boolean, @@ -161,38 +240,34 @@ case class LoadData( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog - if (!catalog.tableExists(table)) { - throw new AnalysisException( - s"Table in LOAD DATA does not exist: '$table'") - } - - val targetTable = catalog.getTableMetadataOption(table).getOrElse { - throw new AnalysisException( - s"Table in LOAD DATA cannot be temporary: '$table'") - } - + val targetTable = catalog.getTableMetadata(table) + val tableIdentwithDB = targetTable.identifier.quotedString if (DDLUtils.isDatasourceTable(targetTable)) { throw new AnalysisException( - "LOAD DATA is not supported for datasource tables") + s"LOAD DATA is not supported for datasource tables: '$tableIdentwithDB'") } - if (targetTable.partitionColumnNames.nonEmpty) { - if (partition.isEmpty || targetTable.partitionColumnNames.size != partition.get.size) { - throw new AnalysisException( - "LOAD DATA to partitioned table must specify a specific partition of " + - "the table by specifying values for all of the partitioning columns.") + if (partition.isEmpty) { + throw new AnalysisException(s"LOAD DATA target table '$tableIdentwithDB' is partitioned, " + + s"but no partition spec is provided") + } + if (targetTable.partitionColumnNames.size != partition.get.size) { + throw new AnalysisException(s"LOAD DATA target table '$tableIdentwithDB' is partitioned, " + + s"but number of columns in provided partition spec (${partition.get.size}) " + + s"do not match number of partitioned columns in table " + + s"(s${targetTable.partitionColumnNames.size})") } - partition.get.keys.foreach { colName => if (!targetTable.partitionColumnNames.contains(colName)) { - throw new AnalysisException( - s"LOAD DATA to partitioned table specifies a non-existing partition column: '$colName'") + throw new AnalysisException(s"LOAD DATA target table '$tableIdentwithDB' is " + + s"partitioned, but the specified partition spec refers to a column that is " + + s"not partitioned: '$colName'") } } } else { if (partition.nonEmpty) { - throw new AnalysisException( - "LOAD DATA to non-partitioned table cannot specify partition.") + throw new AnalysisException(s"LOAD DATA target table '$tableIdentwithDB' is not " + + s"partitioned, but a partition spec was provided.") } } @@ -200,7 +275,7 @@ case class LoadData( if (isLocal) { val uri = Utils.resolveURI(path) if (!new File(uri.getPath()).exists()) { - throw new AnalysisException(s"LOAD DATA with non-existing path: $path") + throw new AnalysisException(s"LOAD DATA input path does not exist: $path") } uri } else { @@ -231,7 +306,7 @@ case class LoadData( if (scheme == null) { throw new AnalysisException( - "LOAD DATA with non-local path must specify URI Scheme.") + s"LOAD DATA: URI scheme is required for non-local input paths: '$path'") } // Follow Hive's behavior: @@ -267,13 +342,90 @@ case class LoadData( } } +/** + * A command to truncate table. + * + * The syntax of this command is: + * {{{ + * TRUNCATE TABLE tablename [PARTITION (partcol1=val1, partcol2=val2 ...)] + * }}} + */ +case class TruncateTableCommand( + tableName: TableIdentifier, + partitionSpec: Option[TablePartitionSpec]) extends RunnableCommand { + + override def run(spark: SparkSession): Seq[Row] = { + val catalog = spark.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + val tableIdentwithDB = table.identifier.quotedString + if (table.tableType == CatalogTableType.EXTERNAL) { + throw new AnalysisException( + s"Operation not allowed: TRUNCATE TABLE on external tables: '$tableIdentwithDB'") + } + if (table.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"Operation not allowed: TRUNCATE TABLE on views: '$tableIdentwithDB'") + } + val isDatasourceTable = DDLUtils.isDatasourceTable(table) + if (isDatasourceTable && partitionSpec.isDefined) { + throw new AnalysisException( + s"Operation not allowed: TRUNCATE TABLE ... PARTITION is not supported " + + s"for tables created using the data sources API: '$tableIdentwithDB'") + } + if (table.partitionColumnNames.isEmpty && partitionSpec.isDefined) { + throw new AnalysisException( + s"Operation not allowed: TRUNCATE TABLE ... PARTITION is not supported " + + s"for tables that are not partitioned: '$tableIdentwithDB'") + } + val locations = + if (isDatasourceTable) { + Seq(table.storage.serdeProperties.get("path")) + } else if (table.partitionColumnNames.isEmpty) { + Seq(table.storage.locationUri) + } else { + catalog.listPartitions(table.identifier, partitionSpec).map(_.storage.locationUri) + } + val hadoopConf = spark.sessionState.newHadoopConf() + locations.foreach { location => + if (location.isDefined) { + val path = new Path(location.get) + try { + val fs = path.getFileSystem(hadoopConf) + fs.delete(path, true) + fs.mkdirs(path) + } catch { + case NonFatal(e) => + throw new AnalysisException( + s"Failed to truncate table '$tableIdentwithDB' when removing data of the path: " + + s"$path because of ${e.toString}") + } + } + } + // After deleting the data, invalidate the table to make sure we don't keep around a stale + // file relation in the metastore cache. + spark.sessionState.refreshTable(tableName.unquotedString) + // Also try to drop the contents of the table from the columnar cache + try { + spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier)) + } catch { + case NonFatal(e) => + log.warn(s"Exception when attempting to uncache table '$tableIdentwithDB'", e) + } + Seq.empty[Row] + } +} + /** * Command that looks like * {{{ - * DESCRIBE (EXTENDED) table_name; + * DESCRIBE [EXTENDED|FORMATTED] table_name partitionSpec?; * }}} */ -case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean) +case class DescribeTableCommand( + table: TableIdentifier, + partitionSpec: Map[String, String], + isExtended: Boolean, + isFormatted: Boolean) extends RunnableCommand { override val output: Seq[Attribute] = Seq( @@ -288,31 +440,188 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean) override def run(sparkSession: SparkSession): Seq[Row] = { val result = new ArrayBuffer[Row] - sparkSession.sessionState.catalog.lookupRelation(table) match { - case catalogRelation: CatalogRelation => - catalogRelation.catalogTable.schema.foreach { column => - result += Row(column.name, column.dataType, column.comment.orNull) - } + val catalog = sparkSession.sessionState.catalog - if (catalogRelation.catalogTable.partitionColumns.nonEmpty) { - result += Row("# Partition Information", "", "") - result += Row(s"# ${output(0).name}", output(1).name, output(2).name) + if (catalog.isTemporaryTable(table)) { + if (partitionSpec.nonEmpty) { + throw new AnalysisException( + s"DESC PARTITION is not allowed on a temporary view: ${table.identifier}") + } + describeSchema(catalog.lookupRelation(table).schema, result) + } else { + val metadata = catalog.getTableMetadata(table) - catalogRelation.catalogTable.partitionColumns.foreach { col => - result += Row(col.name, col.dataType, col.comment.orNull) - } + if (DDLUtils.isDatasourceTable(metadata)) { + DDLUtils.getSchemaFromTableProperties(metadata) match { + case Some(userSpecifiedSchema) => describeSchema(userSpecifiedSchema, result) + case None => describeSchema(catalog.lookupRelation(table).schema, result) } + } else { + describeSchema(metadata.schema, result) + } + + describePartitionInfo(metadata, result) - case relation => - relation.schema.fields.foreach { field => - val comment = - if (field.metadata.contains("comment")) field.metadata.getString("comment") else "" - result += Row(field.name, field.dataType.simpleString, comment) + if (partitionSpec.isEmpty) { + if (isExtended) { + describeExtendedTableInfo(metadata, result) + } else if (isFormatted) { + describeFormattedTableInfo(metadata, result) } + } else { + describeDetailedPartitionInfo(catalog, metadata, result) + } } result } + + private def describePartitionInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + if (DDLUtils.isDatasourceTable(table)) { + val userSpecifiedSchema = DDLUtils.getSchemaFromTableProperties(table) + val partColNames = DDLUtils.getPartitionColumnsFromTableProperties(table) + for (schema <- userSpecifiedSchema if partColNames.nonEmpty) { + append(buffer, "# Partition Information", "", "") + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + describeSchema(StructType(partColNames.map(schema(_))), buffer) + } + } else { + if (table.partitionColumns.nonEmpty) { + append(buffer, "# Partition Information", "", "") + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + describeSchema(table.partitionColumns, buffer) + } + } + } + + private def describeExtendedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "# Detailed Table Information", table.toString, "") + } + + private def describeFormattedTableInfo(table: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "# Detailed Table Information", "", "") + append(buffer, "Database:", table.database, "") + append(buffer, "Owner:", table.owner, "") + append(buffer, "Create Time:", new Date(table.createTime).toString, "") + append(buffer, "Last Access Time:", new Date(table.lastAccessTime).toString, "") + append(buffer, "Location:", table.storage.locationUri.getOrElse(""), "") + append(buffer, "Table Type:", table.tableType.name, "") + + append(buffer, "Table Parameters:", "", "") + table.properties.filterNot { + // Hides schema properties that hold user-defined schema, partition columns, and bucketing + // information since they are already extracted and shown in other parts. + case (key, _) => key.startsWith(CreateDataSourceTableUtils.DATASOURCE_SCHEMA) + }.foreach { case (key, value) => + append(buffer, s" $key", value, "") + } + + describeStorageInfo(table, buffer) + } + + private def describeStorageInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "# Storage Information", "", "") + metadata.storage.serde.foreach(serdeLib => append(buffer, "SerDe Library:", serdeLib, "")) + metadata.storage.inputFormat.foreach(format => append(buffer, "InputFormat:", format, "")) + metadata.storage.outputFormat.foreach(format => append(buffer, "OutputFormat:", format, "")) + append(buffer, "Compressed:", if (metadata.storage.compressed) "Yes" else "No", "") + describeBucketingInfo(metadata, buffer) + + append(buffer, "Storage Desc Parameters:", "", "") + metadata.storage.serdeProperties.foreach { case (key, value) => + append(buffer, s" $key", value, "") + } + } + + private def describeBucketingInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { + def appendBucketInfo(numBuckets: Int, bucketColumns: Seq[String], sortColumns: Seq[String]) = { + append(buffer, "Num Buckets:", numBuckets.toString, "") + append(buffer, "Bucket Columns:", bucketColumns.mkString("[", ", ", "]"), "") + append(buffer, "Sort Columns:", sortColumns.mkString("[", ", ", "]"), "") + } + + DDLUtils.getBucketSpecFromTableProperties(metadata) match { + case Some(bucketSpec) => + appendBucketInfo( + bucketSpec.numBuckets, + bucketSpec.bucketColumnNames, + bucketSpec.sortColumnNames) + case None => + appendBucketInfo( + metadata.numBuckets, + metadata.bucketColumnNames, + metadata.sortColumnNames) + } + } + + private def describeDetailedPartitionInfo( + catalog: SessionCatalog, + metadata: CatalogTable, + result: ArrayBuffer[Row]): Unit = { + if (metadata.tableType == CatalogTableType.VIEW) { + throw new AnalysisException( + s"DESC PARTITION is not allowed on a view: ${table.identifier}") + } + if (DDLUtils.isDatasourceTable(metadata)) { + throw new AnalysisException( + s"DESC PARTITION is not allowed on a datasource table: ${table.identifier}") + } + val partition = catalog.getPartition(table, partitionSpec) + if (isExtended) { + describeExtendedDetailedPartitionInfo(table, metadata, partition, result) + } else if (isFormatted) { + describeFormattedDetailedPartitionInfo(table, metadata, partition, result) + describeStorageInfo(metadata, result) + } + } + + private def describeExtendedDetailedPartitionInfo( + tableIdentifier: TableIdentifier, + table: CatalogTable, + partition: CatalogTablePartition, + buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "Detailed Partition Information " + partition.toString, "", "") + } + + private def describeFormattedDetailedPartitionInfo( + tableIdentifier: TableIdentifier, + table: CatalogTable, + partition: CatalogTablePartition, + buffer: ArrayBuffer[Row]): Unit = { + append(buffer, "", "", "") + append(buffer, "# Detailed Partition Information", "", "") + append(buffer, "Partition Value:", s"[${partition.spec.values.mkString(", ")}]", "") + append(buffer, "Database:", table.database, "") + append(buffer, "Table:", tableIdentifier.table, "") + append(buffer, "Location:", partition.storage.locationUri.getOrElse(""), "") + append(buffer, "Partition Parameters:", "", "") + partition.parameters.foreach { case (key, value) => + append(buffer, s" $key", value, "") + } + } + + private def describeSchema(schema: Seq[CatalogColumn], buffer: ArrayBuffer[Row]): Unit = { + schema.foreach { column => + append(buffer, column.name, column.dataType.toLowerCase, column.comment.orNull) + } + } + + private def describeSchema(schema: StructType, buffer: ArrayBuffer[Row]): Unit = { + schema.foreach { column => + val comment = + if (column.metadata.contains("comment")) column.metadata.getString("comment") else null + append(buffer, column.name, column.dataType.simpleString, comment) + } + } + + private def append( + buffer: ArrayBuffer[Row], column: String, dataType: String, comment: String): Unit = { + buffer += Row(column, dataType, comment) + } } @@ -389,3 +698,315 @@ case class ShowTablePropertiesCommand(table: TableIdentifier, propertyKey: Optio } } } + +/** + * A command to list the column names for a table. This function creates a + * [[ShowColumnsCommand]] logical plan. + * + * The syntax of using this command in SQL is: + * {{{ + * SHOW COLUMNS (FROM | IN) table_identifier [(FROM | IN) database]; + * }}} + */ +case class ShowColumnsCommand(tableName: TableIdentifier) extends RunnableCommand { + // The result of SHOW COLUMNS has one column called 'result' + override val output: Seq[Attribute] = { + AttributeReference("result", StringType, nullable = false)() :: Nil + } + + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val table = catalog.getTempViewOrPermanentTableMetadata(tableName) + table.schema.map { c => + Row(c.name) + } + } +} + +/** + * A command to list the partition names of a table. If the partition spec is specified, + * partitions that match the spec are returned. [[AnalysisException]] exception is thrown under + * the following conditions: + * + * 1. If the command is called for a non partitioned table. + * 2. If the partition spec refers to the columns that are not defined as partitioning columns. + * + * This function creates a [[ShowPartitionsCommand]] logical plan + * + * The syntax of using this command in SQL is: + * {{{ + * SHOW PARTITIONS [db_name.]table_name [PARTITION(partition_spec)] + * }}} + */ +case class ShowPartitionsCommand( + tableName: TableIdentifier, + spec: Option[TablePartitionSpec]) extends RunnableCommand { + // The result of SHOW PARTITIONS has one column called 'result' + override val output: Seq[Attribute] = { + AttributeReference("result", StringType, nullable = false)() :: Nil + } + + private def getPartName(spec: TablePartitionSpec, partColNames: Seq[String]): String = { + partColNames.map { name => + PartitioningUtils.escapePathName(name) + "=" + PartitioningUtils.escapePathName(spec(name)) + }.mkString(File.separator) + } + + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + val tableIdentWithDB = table.identifier.quotedString + + /** + * Validate and throws an [[AnalysisException]] exception under the following conditions: + * 1. If the table is not partitioned. + * 2. If it is a datasource table. + * 3. If it is a view. + */ + if (table.tableType == VIEW) { + throw new AnalysisException(s"SHOW PARTITIONS is not allowed on a view: $tableIdentWithDB") + } + + if (!DDLUtils.isTablePartitioned(table)) { + throw new AnalysisException( + s"SHOW PARTITIONS is not allowed on a table that is not partitioned: $tableIdentWithDB") + } + + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + s"SHOW PARTITIONS is not allowed on a datasource table: $tableIdentWithDB") + } + + /** + * Validate the partitioning spec by making sure all the referenced columns are + * defined as partitioning columns in table definition. An AnalysisException exception is + * thrown if the partitioning spec is invalid. + */ + if (spec.isDefined) { + val badColumns = spec.get.keySet.filterNot(table.partitionColumns.map(_.name).contains) + if (badColumns.nonEmpty) { + val badCols = badColumns.mkString("[", ", ", "]") + throw new AnalysisException( + s"Non-partitioning column(s) $badCols are specified for SHOW PARTITIONS") + } + } + + val partNames = catalog.listPartitions(tableName, spec).map { p => + getPartName(p.spec, table.partitionColumnNames) + } + + partNames.map(Row(_)) + } +} + +case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableCommand { + override val output: Seq[Attribute] = Seq( + AttributeReference("createtab_stmt", StringType, nullable = false)() + ) + + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalog = sparkSession.sessionState.catalog + val tableMetadata = catalog.getTableMetadata(table) + + val stmt = if (DDLUtils.isDatasourceTable(tableMetadata)) { + showCreateDataSourceTable(tableMetadata) + } else { + showCreateHiveTable(tableMetadata) + } + + Seq(Row(stmt)) + } + + private def showCreateHiveTable(metadata: CatalogTable): String = { + def reportUnsupportedError(features: Seq[String]): Unit = { + throw new AnalysisException( + s"Failed to execute SHOW CREATE TABLE against table ${metadata.identifier.quotedString}, " + + "which is created by Hive and uses the following unsupported feature(s)\n" + + features.map(" - " + _).mkString("\n") + ) + } + + if (metadata.unsupportedFeatures.nonEmpty) { + reportUnsupportedError(metadata.unsupportedFeatures) + } + + val builder = StringBuilder.newBuilder + + val tableTypeString = metadata.tableType match { + case EXTERNAL => " EXTERNAL TABLE" + case VIEW => " VIEW" + case MANAGED => " TABLE" + } + + builder ++= s"CREATE$tableTypeString ${table.quotedString}" + + if (metadata.tableType == VIEW) { + if (metadata.schema.nonEmpty) { + builder ++= metadata.schema.map(_.name).mkString("(", ", ", ")") + } + builder ++= metadata.viewText.mkString(" AS\n", "", "\n") + } else { + showHiveTableHeader(metadata, builder) + showHiveTableNonDataColumns(metadata, builder) + showHiveTableStorageInfo(metadata, builder) + showHiveTableProperties(metadata, builder) + } + + builder.toString() + } + + private def showHiveTableHeader(metadata: CatalogTable, builder: StringBuilder): Unit = { + val columns = metadata.schema.filterNot { column => + metadata.partitionColumnNames.contains(column.name) + }.map(columnToDDLFragment) + + if (columns.nonEmpty) { + builder ++= columns.mkString("(", ", ", ")\n") + } + + metadata + .comment + .map("COMMENT '" + escapeSingleQuotedString(_) + "'\n") + .foreach(builder.append) + } + + private def columnToDDLFragment(column: CatalogColumn): String = { + val comment = column.comment.map(escapeSingleQuotedString).map(" COMMENT '" + _ + "'") + s"${quoteIdentifier(column.name)} ${column.dataType}${comment.getOrElse("")}" + } + + private def showHiveTableNonDataColumns(metadata: CatalogTable, builder: StringBuilder): Unit = { + if (metadata.partitionColumns.nonEmpty) { + val partCols = metadata.partitionColumns.map(columnToDDLFragment) + builder ++= partCols.mkString("PARTITIONED BY (", ", ", ")\n") + } + + if (metadata.bucketColumnNames.nonEmpty) { + throw new UnsupportedOperationException( + "Creating Hive table with bucket spec is not supported yet.") + } + } + + private def showHiveTableStorageInfo(metadata: CatalogTable, builder: StringBuilder): Unit = { + val storage = metadata.storage + + storage.serde.foreach { serde => + builder ++= s"ROW FORMAT SERDE '$serde'\n" + + val serdeProps = metadata.storage.serdeProperties.map { + case (key, value) => + s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" + } + + builder ++= serdeProps.mkString("WITH SERDEPROPERTIES (\n ", ",\n ", "\n)\n") + } + + if (storage.inputFormat.isDefined || storage.outputFormat.isDefined) { + builder ++= "STORED AS\n" + + storage.inputFormat.foreach { format => + builder ++= s" INPUTFORMAT '${escapeSingleQuotedString(format)}'\n" + } + + storage.outputFormat.foreach { format => + builder ++= s" OUTPUTFORMAT '${escapeSingleQuotedString(format)}'\n" + } + } + + if (metadata.tableType == EXTERNAL) { + storage.locationUri.foreach { uri => + builder ++= s"LOCATION '$uri'\n" + } + } + } + + private def showHiveTableProperties(metadata: CatalogTable, builder: StringBuilder): Unit = { + if (metadata.properties.nonEmpty) { + val filteredProps = metadata.properties.filterNot { + // Skips "EXTERNAL" property for external tables + case (key, _) => key == "EXTERNAL" && metadata.tableType == EXTERNAL + } + + val props = filteredProps.map { case (key, value) => + s"'${escapeSingleQuotedString(key)}' = '${escapeSingleQuotedString(value)}'" + } + + if (props.nonEmpty) { + builder ++= props.mkString("TBLPROPERTIES (\n ", ",\n ", "\n)\n") + } + } + } + + private def showCreateDataSourceTable(metadata: CatalogTable): String = { + val builder = StringBuilder.newBuilder + + builder ++= s"CREATE TABLE ${table.quotedString} " + showDataSourceTableDataColumns(metadata, builder) + showDataSourceTableOptions(metadata, builder) + showDataSourceTableNonDataColumns(metadata, builder) + + builder.toString() + } + + private def showDataSourceTableDataColumns( + metadata: CatalogTable, builder: StringBuilder): Unit = { + DDLUtils.getSchemaFromTableProperties(metadata).foreach { schema => + val columns = schema.fields.map(f => s"${quoteIdentifier(f.name)} ${f.dataType.sql}") + builder ++= columns.mkString("(", ", ", ")") + } + + builder ++= "\n" + } + + private def showDataSourceTableOptions(metadata: CatalogTable, builder: StringBuilder): Unit = { + val props = metadata.properties + + builder ++= s"USING ${props(CreateDataSourceTableUtils.DATASOURCE_PROVIDER)}\n" + + val dataSourceOptions = metadata.storage.serdeProperties.filterNot { + case (key, value) => + // If it's a managed table, omit PATH option. Spark SQL always creates external table + // when the table creation DDL contains the PATH option. + key.toLowerCase == "path" && metadata.tableType == MANAGED + }.map { + case (key, value) => s"${quoteIdentifier(key)} '${escapeSingleQuotedString(value)}'" + } + + if (dataSourceOptions.nonEmpty) { + builder ++= "OPTIONS (\n" + builder ++= dataSourceOptions.mkString(" ", ",\n ", "\n") + builder ++= ")\n" + } + } + + private def showDataSourceTableNonDataColumns( + metadata: CatalogTable, builder: StringBuilder): Unit = { + val partCols = DDLUtils.getPartitionColumnsFromTableProperties(metadata) + if (partCols.nonEmpty) { + builder ++= s"PARTITIONED BY ${partCols.mkString("(", ", ", ")")}\n" + } + + DDLUtils.getBucketSpecFromTableProperties(metadata).foreach { spec => + if (spec.bucketColumnNames.nonEmpty) { + builder ++= s"CLUSTERED BY ${spec.bucketColumnNames.mkString("(", ", ", ")")}\n" + + if (spec.sortColumnNames.nonEmpty) { + builder ++= s"SORTED BY ${spec.sortColumnNames.mkString("(", ", ", ")")}\n" + } + + builder ++= s"INTO ${spec.numBuckets} BUCKETS\n" + } + } + } + + private def escapeSingleQuotedString(str: String): String = { + val builder = StringBuilder.newBuilder + + str.foreach { + case '\'' => builder ++= s"\\\'" + case ch => builder += ch + } + + builder.toString() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 1641780db8bc2..125b3d1b05878 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.command import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SparkSession} -import org.apache.spark.sql.catalyst.SQLBuilder +import org.apache.spark.sql.catalyst.{SQLBuilder, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} @@ -37,63 +38,114 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} * already exists, throws analysis exception. * @param replace if true, and if the view already exists, updates it; if false, and if the view * already exists, throws analysis exception. - * @param sql the original sql + * @param isTemporary if true, the view is created as a temporary view. Temporary views are dropped + * at the end of current Spark session. Existing permanent relations with the same + * name are not visible to the current session while the temporary view exists, + * unless they are specified with full qualified table name with database prefix. */ case class CreateViewCommand( tableDesc: CatalogTable, child: LogicalPlan, allowExisting: Boolean, replace: Boolean, - sql: String) + isTemporary: Boolean) extends RunnableCommand { + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) + // TODO: Note that this class can NOT canonicalize the view SQL string entirely, which is // different from Hive and may not work for some cases like create view on self join. - override def output: Seq[Attribute] = Seq.empty[Attribute] + require(tableDesc.tableType == CatalogTableType.VIEW, + "The type of the table to created with CREATE VIEW must be 'CatalogTableType.VIEW'.") + if (!isTemporary) { + require(tableDesc.viewText.isDefined, + "The table to created with CREATE VIEW must have 'viewText'.") + } - require(tableDesc.tableType == CatalogTableType.VIEW) - require(tableDesc.viewText.isDefined) + if (allowExisting && replace) { + throw new AnalysisException("CREATE VIEW with both IF NOT EXISTS and REPLACE is not allowed.") + } - private val tableIdentifier = tableDesc.identifier + // Disallows 'CREATE TEMPORARY VIEW IF NOT EXISTS' to be consistent with 'CREATE TEMPORARY TABLE' + if (allowExisting && isTemporary) { + throw new AnalysisException( + "It is not allowed to define a TEMPORARY view with IF NOT EXISTS.") + } - if (allowExisting && replace) { + // Temporary view names should NOT contain database prefix like "database.table" + if (isTemporary && tableDesc.identifier.database.isDefined) { + val database = tableDesc.identifier.database.get throw new AnalysisException( - "It is not allowed to define a view with both IF NOT EXISTS and OR REPLACE.") + s"It is not allowed to add database prefix `$database` for the TEMPORARY view name.") } override def run(sparkSession: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. - val qe = sparkSession.executePlan(child) + val qe = sparkSession.sessionState.executePlan(child) qe.assertAnalyzed() val analyzedPlan = qe.analyzed - require(tableDesc.schema == Nil || tableDesc.schema.length == analyzedPlan.output.length) + if (tableDesc.schema != Nil && tableDesc.schema.length != analyzedPlan.output.length) { + throw new AnalysisException(s"The number of columns produced by the SELECT clause " + + s"(num: `${analyzedPlan.output.length}`) does not match the number of column names " + + s"specified by CREATE VIEW (num: `${tableDesc.schema.length}`).") + } val sessionState = sparkSession.sessionState - if (sessionState.catalog.tableExists(tableIdentifier)) { - if (allowExisting) { - // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view - // already exists. - } else if (replace) { - // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - sessionState.catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) + if (isTemporary) { + createTemporaryView(tableDesc.identifier, sparkSession, analyzedPlan) + } else { + // Adds default database for permanent table if it doesn't exist, so that tableExists() + // only check permanent tables. + val database = tableDesc.identifier.database.getOrElse( + sessionState.catalog.getCurrentDatabase) + val tableIdentifier = tableDesc.identifier.copy(database = Option(database)) + + if (sessionState.catalog.tableExists(tableIdentifier)) { + if (allowExisting) { + // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view + // already exists. + } else if (replace) { + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` + sessionState.catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) + } else { + // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already + // exists. + throw new AnalysisException( + s"View $tableIdentifier already exists. If you want to update the view definition, " + + "please use ALTER VIEW AS or CREATE OR REPLACE VIEW AS") + } } else { - // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already - // exists. - throw new AnalysisException(s"View $tableIdentifier already exists. " + - "If you want to update the view definition, please use ALTER VIEW AS or " + - "CREATE OR REPLACE VIEW AS") + // Create the view if it doesn't exist. + sessionState.catalog.createTable( + prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false) } - } else { - // Create the view if it doesn't exist. - sessionState.catalog.createTable( - prepareTable(sparkSession, analyzedPlan), ignoreIfExists = false) } - Seq.empty[Row] } + private def createTemporaryView( + table: TableIdentifier, sparkSession: SparkSession, analyzedPlan: LogicalPlan): Unit = { + + val sessionState = sparkSession.sessionState + val catalog = sessionState.catalog + + // Projects column names to alias names + val logicalPlan = { + if (tableDesc.schema.isEmpty) { + analyzedPlan + } else { + val projectList = analyzedPlan.output.zip(tableDesc.schema).map { + case (attr, col) => Alias(attr, col.name)() + } + sparkSession.sessionState.executePlan(Project(projectList, analyzedPlan)).analyzed + } + } + + catalog.createTempView(table.table, logicalPlan, replace) + } + /** * Returns a [[CatalogTable]] that can be used to save in the catalog. This comment canonicalize * SQL based on the analyzed plan, and also creates the proper schema for the view. @@ -108,7 +160,7 @@ case class CreateViewCommand( val projectList = analyzedPlan.output.zip(tableDesc.schema).map { case (attr, col) => Alias(attr, col.name)() } - sparkSession.executePlan(Project(projectList, analyzedPlan)).analyzed + sparkSession.sessionState.executePlan(Project(projectList, analyzedPlan)).analyzed } new SQLBuilder(logicalPlan).toSQL } else { @@ -137,8 +189,7 @@ case class CreateViewCommand( sparkSession.sql(viewSQL).queryExecution.assertAnalyzed() } catch { case NonFatal(e) => - throw new RuntimeException( - "Failed to analyze the canonicalized SQL. It is possible there is a bug in Spark.", e) + throw new RuntimeException(s"Failed to analyze the canonicalized SQL: $viewSQL", e) } val viewSchema: Seq[CatalogColumn] = { @@ -159,3 +210,68 @@ case class CreateViewCommand( /** Escape backtick with double-backtick in column name and wrap it with backtick. */ private def quote(name: String) = s"`${name.replaceAll("`", "``")}`" } + +/** + * Alter a view with given query plan. If the view name contains database prefix, this command will + * alter a permanent view matching the given name, or throw an exception if view not exist. Else, + * this command will try to alter a temporary view first, if view not exist, try permanent view + * next, if still not exist, throw an exception. + * + * @param tableDesc the catalog table + * @param query the logical plan that represents the view; this is used to generate a canonicalized + * version of the SQL that can be saved in the catalog. + */ +case class AlterViewAsCommand( + tableDesc: CatalogTable, + query: LogicalPlan) extends RunnableCommand { + + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + + override def run(session: SparkSession): Seq[Row] = { + // If the plan cannot be analyzed, throw an exception and don't proceed. + val qe = session.sessionState.executePlan(query) + qe.assertAnalyzed() + val analyzedPlan = qe.analyzed + + if (session.sessionState.catalog.isTemporaryTable(tableDesc.identifier)) { + session.sessionState.catalog.createTempView( + tableDesc.identifier.table, + analyzedPlan, + overrideIfExists = true) + } else { + alterPermanentView(session, analyzedPlan) + } + + Seq.empty[Row] + } + + private def alterPermanentView(session: SparkSession, analyzedPlan: LogicalPlan): Unit = { + val viewMeta = session.sessionState.catalog.getTableMetadata(tableDesc.identifier) + if (viewMeta.tableType != CatalogTableType.VIEW) { + throw new AnalysisException(s"${viewMeta.identifier} is not a view.") + } + + val viewSQL: String = new SQLBuilder(analyzedPlan).toSQL + // Validate the view SQL - make sure we can parse it and analyze it. + // If we cannot analyze the generated query, there is probably a bug in SQL generation. + try { + session.sql(viewSQL).queryExecution.assertAnalyzed() + } catch { + case NonFatal(e) => + throw new RuntimeException(s"Failed to analyze the canonicalized SQL: $viewSQL", e) + } + + val viewSchema: Seq[CatalogColumn] = { + analyzedPlan.output.map { a => + CatalogColumn(a.name, a.dataType.catalogString) + } + } + + val updatedViewMeta = viewMeta.copy( + schema = viewSchema, + viewOriginalText = tableDesc.viewOriginalText, + viewText = Some(viewSQL)) + + session.sessionState.catalog.alterTable(updatedViewMeta) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 63dc1fd71e6d8..9d2fee8b75976 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -1,23 +1,23 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.execution.datasources -import java.util.ServiceLoader +import java.util.{ServiceConfigurationError, ServiceLoader} import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} @@ -30,8 +30,15 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{CalendarIntervalType, StructType} import org.apache.spark.util.Utils @@ -67,21 +74,48 @@ case class DataSource( bucketSpec: Option[BucketSpec] = None, options: Map[String, String] = Map.empty) extends Logging { - case class SourceInfo(name: String, schema: StructType) + case class SourceInfo(name: String, schema: StructType, partitionColumns: Seq[String]) lazy val providingClass: Class[_] = lookupDataSource(className) lazy val sourceInfo = sourceSchema() /** A map to maintain backward compatibility in case we move data sources around. */ - private val backwardCompatibilityMap = Map( - "org.apache.spark.sql.jdbc" -> classOf[jdbc.DefaultSource].getCanonicalName, - "org.apache.spark.sql.jdbc.DefaultSource" -> classOf[jdbc.DefaultSource].getCanonicalName, - "org.apache.spark.sql.json" -> classOf[json.DefaultSource].getCanonicalName, - "org.apache.spark.sql.json.DefaultSource" -> classOf[json.DefaultSource].getCanonicalName, - "org.apache.spark.sql.parquet" -> classOf[parquet.DefaultSource].getCanonicalName, - "org.apache.spark.sql.parquet.DefaultSource" -> classOf[parquet.DefaultSource].getCanonicalName, - "com.databricks.spark.csv" -> classOf[csv.DefaultSource].getCanonicalName - ) + private val backwardCompatibilityMap: Map[String, String] = { + val jdbc = classOf[JdbcRelationProvider].getCanonicalName + val json = classOf[JsonFileFormat].getCanonicalName + val parquet = classOf[ParquetFileFormat].getCanonicalName + val csv = classOf[CSVFileFormat].getCanonicalName + val libsvm = "org.apache.spark.ml.source.libsvm.LibSVMFileFormat" + val orc = "org.apache.spark.sql.hive.orc.OrcFileFormat" + + Map( + "org.apache.spark.sql.jdbc" -> jdbc, + "org.apache.spark.sql.jdbc.DefaultSource" -> jdbc, + "org.apache.spark.sql.execution.datasources.jdbc.DefaultSource" -> jdbc, + "org.apache.spark.sql.execution.datasources.jdbc" -> jdbc, + "org.apache.spark.sql.json" -> json, + "org.apache.spark.sql.json.DefaultSource" -> json, + "org.apache.spark.sql.execution.datasources.json" -> json, + "org.apache.spark.sql.execution.datasources.json.DefaultSource" -> json, + "org.apache.spark.sql.parquet" -> parquet, + "org.apache.spark.sql.parquet.DefaultSource" -> parquet, + "org.apache.spark.sql.execution.datasources.parquet" -> parquet, + "org.apache.spark.sql.execution.datasources.parquet.DefaultSource" -> parquet, + "org.apache.spark.sql.hive.orc.DefaultSource" -> orc, + "org.apache.spark.sql.hive.orc" -> orc, + "org.apache.spark.ml.source.libsvm.DefaultSource" -> libsvm, + "org.apache.spark.ml.source.libsvm" -> libsvm, + "com.databricks.spark.csv" -> csv + ) + } + + /** + * Class that were removed in Spark 2.0. Used to detect incompatibility libraries for Spark 2.0. + */ + private val spark2RemovedClasses = Set( + "org.apache.spark.sql.DataFrame", + "org.apache.spark.sql.sources.HadoopFsRelationProvider", + "org.apache.spark.Logging") /** Given a provider name, look up the data source class definition. */ private def lookupDataSource(provider0: String): Class[_] = { @@ -90,44 +124,72 @@ case class DataSource( val loader = Utils.getContextOrSparkClassLoader val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) - serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { - // the provider format did not match any given registered aliases - case Nil => - Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => - // Found the data source using fully qualified path - dataSource - case Failure(error) => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new ClassNotFoundException( - "The ORC data source must be used with Hive support enabled.", error) - } else { - if (provider == "avro" || provider == "com.databricks.spark.avro") { - throw new ClassNotFoundException( - s"Failed to find data source: $provider. Please use Spark package " + - "http://spark-packages.org/package/databricks/spark-avro", - error) + try { + serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { + // the provider format did not match any given registered aliases + case Nil => + try { + Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => + // Found the data source using fully qualified path + dataSource + case Failure(error) => + if (provider.toLowerCase == "orc" || + provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new AnalysisException( + "The ORC data source must be used with Hive support enabled") + } else if (provider.toLowerCase == "avro" || + provider == "com.databricks.spark.avro") { + throw new AnalysisException( + s"Failed to find data source: ${provider.toLowerCase}. Please find an Avro " + + "package at " + + "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects") + } else { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please find packages at " + + "https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects", + error) + } + } + } catch { + case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal + // NoClassDefFoundError's class name uses "/" rather than "." for packages + val className = e.getMessage.replaceAll("/", ".") + if (spark2RemovedClasses.contains(className)) { + throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " + + "Please check if your library is compatible with Spark 2.0", e) } else { - throw new ClassNotFoundException( - s"Failed to find data source: $provider. Please find packages at " + - "http://spark-packages.org", - error) + throw e } - } + } + case head :: Nil => + // there is exactly one registered alias + head.getClass + case sources => + // There are multiple registered aliases for the input + sys.error(s"Multiple sources found for $provider " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") + } + } catch { + case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] => + // NoClassDefFoundError's class name uses "/" rather than "." for packages + val className = e.getCause.getMessage.replaceAll("/", ".") + if (spark2RemovedClasses.contains(className)) { + throw new ClassNotFoundException(s"Detected an incompatible DataSourceRegister. " + + "Please remove the incompatible library from classpath or upgrade it. " + + s"Error: ${e.getMessage}", e) + } else { + throw e } - case head :: Nil => - // there is exactly one registered alias - head.getClass - case sources => - // There are multiple registered aliases for the input - sys.error(s"Multiple sources found for $provider " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name.") } } - private def inferFileFormatSchema(format: FileFormat): StructType = { - userSpecifiedSchema.orElse { + /** + * Infer the schema of the given FileFormat, returns a pair of schema and partition column names. + */ + private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = { + userSpecifiedSchema.map(_ -> partitionColumns).orElse { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val allPaths = caseInsensitiveOptions.get("path") val globbedPaths = allPaths.toSeq.flatMap { path => @@ -136,11 +198,16 @@ case class DataSource( val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog: FileCatalog = new HDFSFileCatalog(sparkSession, options, globbedPaths, None) - format.inferSchema( + val fileCatalog = new ListingFileCatalog(sparkSession, globbedPaths, options, None) + val partitionSchema = fileCatalog.partitionSpec().partitionColumns + val inferred = format.inferSchema( sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) + + inferred.map { inferredSchema => + StructType(inferredSchema ++ partitionSchema) -> partitionSchema.map(_.name) + } }.getOrElse { throw new AnalysisException("Unable to infer schema. It must be specified manually.") } @@ -151,15 +218,38 @@ case class DataSource( providingClass.newInstance() match { case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( - sparkSession.wrapped, userSpecifiedSchema, className, options) - SourceInfo(name, schema) + sparkSession.sqlContext, userSpecifiedSchema, className, options) + SourceInfo(name, schema, Nil) case format: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) - SourceInfo(s"FileSource[$path]", inferFileFormatSchema(format)) + + // Check whether the path exists if it is not a glob pattern. + // For glob pattern, we do not check it because the glob pattern might only make sense + // once the streaming job starts and some upstream source starts dropping data. + val hdfsPath = new Path(path) + if (!SparkHadoopUtil.get.isGlobPath(hdfsPath)) { + val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + if (!fs.exists(hdfsPath)) { + throw new AnalysisException(s"Path does not exist: $path") + } + } + + val isSchemaInferenceEnabled = sparkSession.conf.get(SQLConf.STREAMING_SCHEMA_INFERENCE) + val isTextSource = providingClass == classOf[text.TextFileFormat] + // If the schema inference is disabled, only text sources require schema to be specified + if (!isSchemaInferenceEnabled && !isTextSource && userSpecifiedSchema.isEmpty) { + throw new IllegalArgumentException( + "Schema must be specified when creating a streaming source DataFrame. " + + "If some files already exist in the directory, then depending on the file format " + + "you may be able to create a static DataFrame on that directory with " + + "'spark.read.load(directory)' and infer schema from it.") + } + val (schema, partCols) = inferFileFormatSchema(format) + SourceInfo(s"FileSource[$path]", schema, partCols) case _ => throw new UnsupportedOperationException( @@ -171,28 +261,21 @@ case class DataSource( def createSource(metadataPath: String): Source = { providingClass.newInstance() match { case s: StreamSourceProvider => - s.createSource(sparkSession.wrapped, metadataPath, userSpecifiedSchema, className, options) + s.createSource( + sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options) case format: FileFormat => - val caseInsensitiveOptions = new CaseInsensitiveMap(options) - val path = caseInsensitiveOptions.getOrElse("path", { + val path = new CaseInsensitiveMap(options).getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) - - def dataFrameBuilder(files: Array[String]): DataFrame = { - val newOptions = options.filterKeys(_ != "path") + ("basePath" -> path) - val newDataSource = - DataSource( - sparkSession, - paths = files, - userSpecifiedSchema = Some(sourceInfo.schema), - className = className, - options = new CaseInsensitiveMap(newOptions)) - Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation())) - } - new FileStreamSource( - sparkSession, metadataPath, path, sourceInfo.schema, dataFrameBuilder) + sparkSession = sparkSession, + path = path, + fileFormatClassName = className, + schema = sourceInfo.schema, + partitionColumns = sourceInfo.partitionColumns, + metadataPath = metadataPath, + options = options) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed reading") @@ -200,16 +283,22 @@ case class DataSource( } /** Returns a sink that can be used to continually write data. */ - def createSink(): Sink = { + def createSink(outputMode: OutputMode): Sink = { providingClass.newInstance() match { - case s: StreamSinkProvider => s.createSink(sparkSession.wrapped, options, partitionColumns) - case format: FileFormat => + case s: StreamSinkProvider => + s.createSink(sparkSession.sqlContext, options, partitionColumns, outputMode) + + case parquet: parquet.ParquetFileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) + if (outputMode != OutputMode.Append) { + throw new IllegalArgumentException( + s"Data source $className does not support $outputMode output mode") + } + new FileStreamSink(sparkSession, path, parquet, partitionColumns, options) - new FileStreamSink(sparkSession, path, format) case _ => throw new UnsupportedOperationException( s"Data source $className does not support streamed writing") @@ -238,15 +327,22 @@ case class DataSource( } } - /** Create a resolved [[BaseRelation]] that can be used to read data from this [[DataSource]] */ - def resolveRelation(): BaseRelation = { + /** + * Create a resolved [[BaseRelation]] that can be used to read data from or write data into this + * [[DataSource]] + * + * @param checkPathExist A flag to indicate whether to check the existence of path or not. + * This flag will be set to false when we create an empty table (the + * path of the table does not exist). + */ + def resolveRelation(checkPathExist: Boolean = true): BaseRelation = { val caseInsensitiveOptions = new CaseInsensitiveMap(options) val relation = (providingClass.newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => - dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions, schema) + dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema) case (dataSource: RelationProvider, None) => - dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions) + dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) case (_: SchemaRelationProvider, None) => throw new AnalysisException(s"A schema needs to be specified when using $className.") case (_: RelationProvider, Some(_)) => @@ -257,7 +353,7 @@ case class DataSource( case (format: FileFormat, _) if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val fileCatalog = new StreamFileCatalog(sparkSession, basePath) + val fileCatalog = new MetadataLogFileCatalog(sparkSession, basePath) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, @@ -270,13 +366,12 @@ case class DataSource( } HadoopFsRelation( - sparkSession, fileCatalog, partitionSchema = fileCatalog.partitionSpec().partitionColumns, dataSchema = dataSchema, bucketSpec = None, format, - options) + options)(sparkSession) // This is a non-streaming file based datasource. case (format: FileFormat, _) => @@ -287,11 +382,11 @@ case class DataSource( val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified) - if (globPath.isEmpty) { + if (checkPathExist && globPath.isEmpty) { throw new AnalysisException(s"Path does not exist: $qualified") } // Sufficient to check head of the globPath seq for non-glob scenario - if (!fs.exists(globPath.head)) { + if (checkPathExist && !fs.exists(globPath.head)) { throw new AnalysisException(s"Path does not exist: ${globPath.head}") } globPath @@ -309,8 +404,9 @@ case class DataSource( }) } - val fileCatalog: FileCatalog = - new HDFSFileCatalog(sparkSession, options, globbedPaths, partitionSchema) + val fileCatalog = + new ListingFileCatalog( + sparkSession, globbedPaths, options, partitionSchema, !checkPathExist) val dataSchema = userSpecifiedSchema.map { schema => val equality = @@ -332,17 +428,13 @@ case class DataSource( "It must be specified manually") } - val enrichedOptions = - format.prepareRead(sparkSession, caseInsensitiveOptions, fileCatalog.allFiles()) - HadoopFsRelation( - sparkSession, fileCatalog, partitionSchema = fileCatalog.partitionSpec().partitionColumns, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, - enrichedOptions) + caseInsensitiveOptions)(sparkSession) case _ => throw new AnalysisException( @@ -362,7 +454,7 @@ case class DataSource( providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.wrapped, mode, options, data) + dataSource.createRelation(sparkSession.sqlContext, mode, options, data) case format: FileFormat => // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; @@ -378,55 +470,64 @@ case class DataSource( } val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - PartitioningUtils.validatePartitionColumnDataTypes( + PartitioningUtils.validatePartitionColumn( data.schema, partitionColumns, caseSensitive) // If we are appending to a table that already exists, make sure the partitioning matches // up. If we fail to load the table for whatever reason, ignore the check. if (mode == SaveMode.Append) { - val existingPartitionColumnSet = try { - Some( - resolveRelation() - .asInstanceOf[HadoopFsRelation] - .location - .partitionSpec() - .partitionColumns - .fieldNames - .toSet) - } catch { - case e: Exception => - None - } - - existingPartitionColumnSet.foreach { ex => - if (ex.map(_.toLowerCase) != partitionColumns.map(_.toLowerCase()).toSet) { - throw new AnalysisException( - s"Requested partitioning does not equal existing partitioning: " + - s"$ex != ${partitionColumns.toSet}.") - } + val existingPartitionColumns = Try { + resolveRelation() + .asInstanceOf[HadoopFsRelation] + .location + .partitionSpec() + .partitionColumns + .fieldNames + .toSeq + }.getOrElse(Seq.empty[String]) + // TODO: Case sensitivity. + val sameColumns = + existingPartitionColumns.map(_.toLowerCase()) == partitionColumns.map(_.toLowerCase()) + if (existingPartitionColumns.size > 0 && !sameColumns) { + throw new AnalysisException( + s"""Requested partitioning does not match existing partitioning. + |Existing partitioning columns: + | ${existingPartitionColumns.mkString(", ")} + |Requested partitioning columns: + | ${partitionColumns.mkString(", ")} + |""".stripMargin) } } + // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does + // not need to have the query as child, to avoid to analyze an optimized query, + // because InsertIntoHadoopFsRelationCommand will be optimized first. + val columns = partitionColumns.map { name => + val plan = data.logicalPlan + plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { + throw new AnalysisException( + s"Unable to resolve ${name} given [${plan.output.map(_.name).mkString(", ")}]") + }.asInstanceOf[Attribute] + } // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. val plan = - InsertIntoHadoopFsRelation( + InsertIntoHadoopFsRelationCommand( outputPath, - partitionColumns.map(UnresolvedAttribute.quoted), + columns, bucketSpec, format, () => Unit, // No existing table needs to be refreshed. options, data.logicalPlan, mode) - sparkSession.executePlan(plan).toRdd + sparkSession.sessionState.executePlan(plan).toRdd + // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. + copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } - - // We replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. - copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 9bebd74b4b3a2..277969465e424 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -22,18 +22,19 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.DataSourceScanExec.{INPUT_PATHS, PUSHED_FILTERS} -import org.apache.spark.sql.execution.command.ExecutedCommandExec +import org.apache.spark.sql.execution.DataSourceScanExec.PUSHED_FILTERS +import org.apache.spark.sql.execution.command.{CreateDataSourceTableUtils, DDLUtils, ExecutedCommandExec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +43,127 @@ import org.apache.spark.unsafe.types.UTF8String * Replaces generic operations with specific variants that are designed to work with Spark * SQL Data Sources. */ -private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { + + def resolver: Resolver = { + if (conf.caseSensitiveAnalysis) { + caseSensitiveResolution + } else { + caseInsensitiveResolution + } + } + + // The access modifier is used to expose this method to tests. + def convertStaticPartitions( + sourceAttributes: Seq[Attribute], + providedPartitions: Map[String, Option[String]], + targetAttributes: Seq[Attribute], + targetPartitionSchema: StructType): Seq[NamedExpression] = { + + assert(providedPartitions.exists(_._2.isDefined)) + + val staticPartitions = providedPartitions.flatMap { + case (partKey, Some(partValue)) => (partKey, partValue) :: Nil + case (_, None) => Nil + } + + // The sum of the number of static partition columns and columns provided in the SELECT + // clause needs to match the number of columns of the target table. + if (staticPartitions.size + sourceAttributes.size != targetAttributes.size) { + throw new AnalysisException( + s"The data to be inserted needs to have the same number of " + + s"columns as the target table: target table has ${targetAttributes.size} " + + s"column(s) but the inserted data has ${sourceAttributes.size + staticPartitions.size} " + + s"column(s), which contain ${staticPartitions.size} partition column(s) having " + + s"assigned constant values.") + } + + if (providedPartitions.size != targetPartitionSchema.fields.size) { + throw new AnalysisException( + s"The data to be inserted needs to have the same number of " + + s"partition columns as the target table: target table " + + s"has ${targetPartitionSchema.fields.size} partition column(s) but the inserted " + + s"data has ${providedPartitions.size} partition columns specified.") + } + + staticPartitions.foreach { + case (partKey, partValue) => + if (!targetPartitionSchema.fields.exists(field => resolver(field.name, partKey))) { + throw new AnalysisException( + s"$partKey is not a partition column. Partition columns are " + + s"${targetPartitionSchema.fields.map(_.name).mkString("[", ",", "]")}") + } + } + + val partitionList = targetPartitionSchema.fields.map { field => + val potentialSpecs = staticPartitions.filter { + case (partKey, partValue) => resolver(field.name, partKey) + } + if (potentialSpecs.size == 0) { + None + } else if (potentialSpecs.size == 1) { + val partValue = potentialSpecs.head._2 + Some(Alias(Cast(Literal(partValue), field.dataType), "_staticPart")()) + } else { + throw new AnalysisException( + s"Partition column ${field.name} have multiple values specified, " + + s"${potentialSpecs.mkString("[", ", ", "]")}. Please only specify a single value.") + } + } + + // We first drop all leading static partitions using dropWhile and check if there is + // any static partition appear after dynamic partitions. + partitionList.dropWhile(_.isDefined).collectFirst { + case Some(_) => + throw new AnalysisException( + s"The ordering of partition columns is " + + s"${targetPartitionSchema.fields.map(_.name).mkString("[", ",", "]")}. " + + "All partition columns having constant values need to appear before other " + + "partition columns that do not have an assigned constant value.") + } + + assert(partitionList.take(staticPartitions.size).forall(_.isDefined)) + val projectList = + sourceAttributes.take(targetAttributes.size - targetPartitionSchema.fields.size) ++ + partitionList.take(staticPartitions.size).map(_.get) ++ + sourceAttributes.takeRight(targetPartitionSchema.fields.size - staticPartitions.size) + + projectList + } + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // If the InsertIntoTable command is for a partitioned HadoopFsRelation and + // the user has specified static partitions, we add a Project operator on top of the query + // to include those constant column values in the query result. + // + // Example: + // Let's say that we have a table "t", which is created by + // CREATE TABLE t (a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c) + // The statement of "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 3" + // will be converted to "INSERT INTO TABLE t PARTITION (b, c) SELECT 1, 2, 3". + // + // Basically, we will put those partition columns having a assigned value back + // to the SELECT clause. The output of the SELECT clause is organized as + // normal_columns static_partitioning_columns dynamic_partitioning_columns. + // static_partitioning_columns are partitioning columns having assigned + // values in the PARTITION clause (e.g. b in the above example). + // dynamic_partitioning_columns are partitioning columns that do not assigned + // values in the PARTITION clause (e.g. c in the above example). + case insert @ logical.InsertIntoTable( + relation @ LogicalRelation(t: HadoopFsRelation, _, _), parts, query, overwrite, false) + if query.resolved && parts.exists(_._2.isDefined) => + + val projectList = convertStaticPartitions( + sourceAttributes = query.output, + providedPartitions = parts, + targetAttributes = relation.output, + targetPartitionSchema = t.partitionSchema) + + // We will remove all assigned values to static partitions because they have been + // moved to the projectList. + insert.copy(partition = parts.map(p => (p._1, None)), child = Project(projectList, query)) + + case i @ logical.InsertIntoTable( l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false) if query.resolved && t.schema.asNullable == query.schema.asNullable => @@ -65,9 +185,9 @@ private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } - InsertIntoHadoopFsRelation( + InsertIntoHadoopFsRelationCommand( outputPath, - t.partitionSchema.fields.map(_.name).map(UnresolvedAttribute(_)), + query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), t.bucketSpec, t.fileFormat, () => t.refresh(), @@ -77,10 +197,52 @@ private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] { } } + +/** + * Replaces [[SimpleCatalogRelation]] with data source table if its table property contains data + * source information. + */ +class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] { + private def readDataSourceTable(sparkSession: SparkSession, table: CatalogTable): LogicalPlan = { + val userSpecifiedSchema = DDLUtils.getSchemaFromTableProperties(table) + + // We only need names at here since userSpecifiedSchema we loaded from the metastore + // contains partition columns. We can always get datatypes of partitioning columns + // from userSpecifiedSchema. + val partitionColumns = DDLUtils.getPartitionColumnsFromTableProperties(table) + + val bucketSpec = DDLUtils.getBucketSpecFromTableProperties(table) + + val options = table.storage.serdeProperties + val dataSource = + DataSource( + sparkSession, + userSpecifiedSchema = userSpecifiedSchema, + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, + className = table.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER), + options = options) + + LogicalRelation( + dataSource.resolveRelation(), + metastoreTableIdentifier = Some(table.identifier)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) + if DDLUtils.isDatasourceTable(s.metadata) => + i.copy(table = readDataSourceTable(sparkSession, s.metadata)) + + case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) => + readDataSourceTable(sparkSession, s.metadata) + } +} + + /** * A Strategy for planning scans over data sources defined using the sources API. */ -private[sql] object DataSourceStrategy extends Strategy with Logging { +object DataSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( @@ -110,7 +272,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), part, query, overwrite, false) if part.isEmpty => - ExecutedCommandExec(InsertIntoDataSource(l, query, overwrite)) :: Nil + ExecutedCommandExec(InsertIntoDataSourceCommand(l, query, overwrite)) :: Nil case _ => Nil } @@ -185,18 +347,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // `Filter`s or cannot be handled by `relation`. val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And) + // These metadata values make scan plans uniquely identifiable for equality checking. + // TODO(SPARK-17701) using strings for equality checking is brittle val metadata: Map[String, String] = { val pairs = ArrayBuffer.empty[(String, String)] if (pushedFilters.nonEmpty) { pairs += (PUSHED_FILTERS -> pushedFilters.mkString("[", ", ", "]")) } - - relation.relation match { - case r: HadoopFsRelation => pairs += INPUT_PATHS -> r.location.paths.mkString(", ") - case _ => - } - + pairs += ("ReadSchema" -> + StructType.fromAttributes(projects.map(_.toAttribute)).catalogString) pairs.toMap } @@ -217,7 +377,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.DataSourceScanExec.create( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, metadata) + relation.relation, metadata, relation.metastoreTableIdentifier) filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan) } else { // Don't request columns that are only referenced by pushed filters. @@ -227,7 +387,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val scan = execution.DataSourceScanExec.create( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), - relation.relation, metadata) + relation.relation, metadata, relation.metastoreTableIdentifier) execution.ProjectExec( projects, filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index f7f68b1eb90da..1314c94d42cf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -111,7 +111,20 @@ class FileScanRDD( currentFile = files.next() logInfo(s"Reading File $currentFile") InputFileNameHolder.setInputFileName(currentFile.filePath) - currentIterator = readFunction(currentFile) + + try { + currentIterator = readFunction(currentFile) + } catch { + case e: java.io.FileNotFoundException => + throw new java.io.FileNotFoundException( + e.getMessage + "\n" + + "It is possible the underlying files have been updated. " + + "You can explicitly invalidate the cache in Spark by " + + "running 'REFRESH TABLE tableName' command in SQL or " + + "by recreating the Dataset/DataFrame involved." + ) + } + hasNext } else { currentFile = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 615906a52e8de..74510f9c08b6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -27,7 +27,9 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.DataSourceScanExec +import org.apache.spark.sql.execution.DataSourceScanExec.{INPUT_PATHS, PUSHED_FILTERS} +import org.apache.spark.sql.execution.SparkPlan /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -39,7 +41,7 @@ import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} * is only done on top level columns, but formats should support pruning of nested columns as * well. * - Construct a reader function by passing filters and the schema into the FileFormat. - * - Using an partition pruning predicates, enumerate the list of files that should be read. + * - Using a partition pruning predicates, enumerate the list of files that should be read. * - Split the files into tasks and construct a FileScanRDD. * - Add any projection or filters that must be evaluated after the scan. * @@ -52,9 +54,10 @@ import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} * is under the threshold with the addition of the next file, add it. If not, open a new bucket * and add it. Proceed to the next file. */ -private[sql] object FileSourceStrategy extends Strategy with Logging { +object FileSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _)) => + case PhysicalOperation(projects, filters, + l @ LogicalRelation(files: HadoopFsRelation, _, table)) => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: // - partition keys only - used to prune directories to read @@ -106,7 +109,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") - val readFile = files.fileFormat.buildReader( + val readFile = files.fileFormat.buildReaderWithPartitionValues( sparkSession = files.sparkSession, dataSchema = files.dataSchema, partitionSchema = files.partitionSchema, @@ -148,11 +151,18 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { val splitFiles = selectedPartitions.flatMap { partition => partition.files.flatMap { file => val blockLocations = getBlockLocations(file) - (0L until file.getLen by maxSplitBytes).map { offset => - val remaining = file.getLen - offset - val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining - val hosts = getBlockHosts(blockLocations, offset, size) - PartitionedFile(partition.values, file.getPath.toUri.toString, offset, size, hosts) + if (files.fileFormat.isSplitable(files.sparkSession, files.options, file.getPath)) { + (0L until file.getLen by maxSplitBytes).map { offset => + val remaining = file.getLen - offset + val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining + val hosts = getBlockHosts(blockLocations, offset, size) + PartitionedFile( + partition.values, file.getPath.toUri.toString, offset, size, hosts) + } + } else { + val hosts = getBlockHosts(blockLocations, 0, file.getLen) + Seq(PartitionedFile( + partition.values, file.getPath.toUri.toString, 0, file.getLen, hosts)) } } }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse) @@ -192,6 +202,14 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { partitions } + // These metadata values make scan plans uniquely identifiable for equality checking. + val meta = Map( + "PartitionFilters" -> partitionKeyFilters.mkString("[", ", ", "]"), + "Format" -> files.fileFormat.toString, + "ReadSchema" -> prunedDataSchema.simpleString, + PUSHED_FILTERS -> pushedDownFilters.mkString("[", ", ", "]"), + INPUT_PATHS -> files.location.paths.mkString(", ")) + val scan = DataSourceScanExec.create( readDataColumns ++ partitionColumns, @@ -200,10 +218,8 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { readFile, plannedPartitions), files, - Map( - "Format" -> files.fileFormat.toString, - "PushedFilters" -> pushedDownFilters.mkString("[", ", ", "]"), - "ReadSchema" -> prunedDataSchema.simpleString)) + meta, + table) val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) val withFilter = afterScanFilter.map(execution.FilterExec(_, scan)).getOrElse(scan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala index 18f9b55895a64..83cf26c63a175 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources +import java.io.Closeable import java.net.URI import org.apache.hadoop.conf.Configuration @@ -30,7 +31,8 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines * in that file. */ -class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends Iterator[Text] { +class HadoopFileLinesReader( + file: PartitionedFile, conf: Configuration) extends Iterator[Text] with Closeable { private val iterator = { val fileSplit = new FileSplit( new Path(new URI(file.filePath)), @@ -48,4 +50,6 @@ class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends override def hasNext: Boolean = iterator.hasNext override def next(): Text = iterator.next() + + override def close(): Unit = iterator.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala similarity index 87% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 7b15e496414e3..b2ff68a833fea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.sources.InsertableRelation @@ -26,12 +27,14 @@ import org.apache.spark.sql.sources.InsertableRelation /** * Inserts the results of `query` in to a relation that extends [[InsertableRelation]]. */ -private[sql] case class InsertIntoDataSource( +case class InsertIntoDataSourceCommand( logicalRelation: LogicalRelation, query: LogicalPlan, overwrite: Boolean) extends RunnableCommand { + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = Dataset.ofRows(sparkSession, query) @@ -40,7 +43,7 @@ private[sql] case class InsertIntoDataSource( relation.insert(df, overwrite) // Invalidate the cache. - sparkSession.cacheManager.invalidateCache(logicalRelation) + sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 4921e4ca6bb7d..518b02b718757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -35,11 +35,11 @@ import org.apache.spark.sql.internal.SQLConf /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. - * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a - * single write job, and owns a UUID that identifies this job. Each concrete implementation of - * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for - * each task output file. This UUID is passed to executor side via a property named - * `spark.sql.sources.writeJobUUID`. + * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelationCommand]] + * issues a single write job, and owns a UUID that identifies this job. Each concrete + * implementation of [[HadoopFsRelation]] should use this UUID together with task id to generate + * unique file path for each task output file. This UUID is passed to executor side via a + * property named `spark.sql.sources.writeJobUUID`. * * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] * are used to write to normal tables and tables with dynamic partitions. @@ -55,7 +55,7 @@ import org.apache.spark.sql.internal.SQLConf * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is * thrown during job commitment, also aborts the job. */ -private[sql] case class InsertIntoHadoopFsRelation( +case class InsertIntoHadoopFsRelationCommand( outputPath: Path, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], @@ -66,7 +66,7 @@ private[sql] case class InsertIntoHadoopFsRelation( mode: SaveMode) extends RunnableCommand { - override def children: Seq[LogicalPlan] = query :: Nil + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala new file mode 100644 index 0000000000000..706ec6b9b36c7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.FileNotFoundException + +import scala.collection.mutable +import scala.util.Try + +import org.apache.hadoop.fs.{FileStatus, LocatedFileStatus, Path} +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType + + +/** + * A [[FileCatalog]] that generates the list of files to process by recursively listing all the + * files present in `paths`. + * + * @param parameters as set of options to control discovery + * @param paths a list of paths to scan + * @param partitionSchema an optional partition schema that will be use to provide types for the + * discovered partitions + * @param ignoreFileNotFound if true, return empty file list when encountering a + * [[FileNotFoundException]] in file listing. Note that this is a hack + * for SPARK-16313. We should get rid of this flag in the future. + */ +class ListingFileCatalog( + sparkSession: SparkSession, + override val paths: Seq[Path], + parameters: Map[String, String], + partitionSchema: Option[StructType], + ignoreFileNotFound: Boolean = false) + extends PartitioningAwareFileCatalog(sparkSession, parameters, partitionSchema) { + + @volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _ + @volatile private var cachedLeafDirToChildrenFiles: Map[Path, Array[FileStatus]] = _ + @volatile private var cachedPartitionSpec: PartitionSpec = _ + + refresh() + + override def partitionSpec(): PartitionSpec = { + if (cachedPartitionSpec == null) { + cachedPartitionSpec = inferPartitioning() + } + logTrace(s"Partition spec: $cachedPartitionSpec") + cachedPartitionSpec + } + + override protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] = { + cachedLeafFiles + } + + override protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { + cachedLeafDirToChildrenFiles + } + + override def refresh(): Unit = { + val files = listLeafFiles(paths) + cachedLeafFiles = + new mutable.LinkedHashMap[Path, FileStatus]() ++= files.map(f => f.getPath -> f) + cachedLeafDirToChildrenFiles = files.toArray.groupBy(_.getPath.getParent) + cachedPartitionSpec = null + } + + /** + * List leaf files of given paths. This method will submit a Spark job to do parallel + * listing whenever there is a path having more files than the parallel partition discovery + * discovery threshold. + * + * This is publicly visible for testing. + */ + def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { + if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sparkSession, ignoreFileNotFound) + } else { + // Right now, the number of paths is less than the value of + // parallelPartitionDiscoveryThreshold. So, we will list file statues at the driver. + // If there is any child that has more files than the threshold, we will use parallel + // listing. + + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + + val statuses: Seq[FileStatus] = paths.flatMap { path => + val fs = path.getFileSystem(hadoopConf) + logTrace(s"Listing $path on driver") + + val childStatuses = { + val stats = + try { + fs.listStatus(path) + } catch { + case e: FileNotFoundException if ignoreFileNotFound => Array.empty[FileStatus] + } + if (pathFilter != null) stats.filter(f => pathFilter.accept(f.getPath)) else stats + } + + childStatuses.map { + case f: LocatedFileStatus => f + + // NOTE: + // + // - Although S3/S3A/S3N file system can be quite slow for remote file metadata + // operations, calling `getFileBlockLocations` does no harm here since these file system + // implementations don't actually issue RPC for this method. + // + // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not + // be a big deal since we always use to `listLeafFilesInParallel` when the number of + // paths exceeds threshold. + case f => + if (f.isDirectory ) { + // If f is a directory, we do not need to call getFileBlockLocations (SPARK-14959). + f + } else { + HadoopFsRelation.createLocatedFileStatus(f, fs.getFileBlockLocations(f, 0, f.getLen)) + } + } + }.filterNot { status => + val name = status.getPath.getName + HadoopFsRelation.shouldFilterOut(name) + } + + val (dirs, files) = statuses.partition(_.isDirectory) + + // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) + if (dirs.isEmpty) { + mutable.LinkedHashSet(files: _*) + } else { + mutable.LinkedHashSet(files: _*) ++ listLeafFiles(dirs.map(_.getPath)) + } + } + } + + override def equals(other: Any): Boolean = other match { + case hdfs: ListingFileCatalog => paths.toSet == hdfs.paths.toSet + case _ => false + } + + override def hashCode(): Int = paths.toSet.hashCode() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 0e0748ff32df3..2a8e147011f55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.util.Utils /** * Used to link a [[BaseRelation]] in to a logical query plan. @@ -59,9 +60,11 @@ case class LogicalRelation( com.google.common.base.Objects.hashCode(relation, output) } - override def sameResult(otherPlan: LogicalPlan): Boolean = otherPlan match { - case LogicalRelation(otherRelation, _, _) => relation == otherRelation - case _ => false + override def sameResult(otherPlan: LogicalPlan): Boolean = { + otherPlan.canonicalized match { + case LogicalRelation(otherRelation, _, _) => relation == otherRelation + case _ => false + } } // When comparing two LogicalRelations from within LogicalPlan.sameResult, we only need @@ -76,11 +79,23 @@ case class LogicalRelation( /** Used to lookup original attribute capitalization */ val attributeMap: AttributeMap[AttributeReference] = AttributeMap(output.map(o => (o, o))) - def newInstance(): this.type = + /** + * Returns a new instance of this LogicalRelation. According to the semantics of + * MultiInstanceRelation, this method returns a copy of this object with + * unique expression ids. We respect the `expectedOutputAttributes` and create + * new instances of attributes in it. + */ + override def newInstance(): this.type = { LogicalRelation( relation, - expectedOutputAttributes, + expectedOutputAttributes.map(_.map(_.newInstance())), metastoreTableIdentifier).asInstanceOf[this.type] + } + + override def refresh(): Unit = relation match { + case fs: HadoopFsRelation => fs.refresh() + case _ => // Do nothing. + } - override def simpleString: String = s"Relation[${output.mkString(",")}] $relation" + override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala new file mode 100644 index 0000000000000..2130c27ebd8de --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileCatalog.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable + +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.{expressions, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StringType, StructType} + + +/** + * An abstract class that represents [[FileCatalog]]s that are aware of partitioned tables. + * It provides the necessary methods to parse partition data based on a set of files. + * + * @param parameters as set of options to control partition discovery + * @param partitionSchema an optional partition schema that will be use to provide types for the + * discovered partitions +*/ +abstract class PartitioningAwareFileCatalog( + sparkSession: SparkSession, + parameters: Map[String, String], + partitionSchema: Option[StructType]) + extends FileCatalog with Logging { + + protected val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) + + protected def leafFiles: mutable.LinkedHashMap[Path, FileStatus] + + protected def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] + + override def listFiles(filters: Seq[Expression]): Seq[Partition] = { + val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { + Partition(InternalRow.empty, allFiles().filter(f => isDataPath(f.getPath))) :: Nil + } else { + prunePartitions(filters, partitionSpec()).map { + case PartitionDirectory(values, path) => + val files: Seq[FileStatus] = leafDirToChildrenFiles.get(path) match { + case Some(existingDir) => + // Directory has children files in it, return them + existingDir.filter(f => isDataPath(f.getPath)) + + case None => + // Directory does not exist, or has no children files + Nil + } + Partition(values, files) + } + } + logTrace("Selected files after partition pruning:\n\t" + selectedPartitions.mkString("\n\t")) + selectedPartitions + } + + override def allFiles(): Seq[FileStatus] = { + if (partitionSpec().partitionColumns.isEmpty) { + // For each of the input paths, get the list of files inside them + paths.flatMap { path => + // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). + val fs = path.getFileSystem(hadoopConf) + val qualifiedPathPre = fs.makeQualified(path) + val qualifiedPath: Path = if (qualifiedPathPre.isRoot && !qualifiedPathPre.isAbsolute) { + // SPARK-17613: Always append `Path.SEPARATOR` to the end of parent directories, + // because the `leafFile.getParent` would have returned an absolute path with the + // separator at the end. + new Path(qualifiedPathPre, Path.SEPARATOR) + } else { + qualifiedPathPre + } + + // There are three cases possible with each path + // 1. The path is a directory and has children files in it. Then it must be present in + // leafDirToChildrenFiles as those children files will have been found as leaf files. + // Find its children files from leafDirToChildrenFiles and include them. + // 2. The path is a file, then it will be present in leafFiles. Include this path. + // 3. The path is a directory, but has no children files. Do not include this path. + + leafDirToChildrenFiles.get(qualifiedPath) + .orElse { leafFiles.get(qualifiedPath).map(Array(_)) } + .getOrElse(Array.empty) + } + } else { + leafFiles.values.toSeq + } + } + + protected def inferPartitioning(): PartitionSpec = { + // We use leaf dirs containing data files to discover the schema. + val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => + // SPARK-15895: Metadata files (e.g. Parquet summary files) and temporary files should not be + // counted as data files, so that they shouldn't participate partition discovery. + files.exists(f => isDataPath(f.getPath)) + }.keys.toSeq + partitionSchema match { + case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => + val spec = PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = false, + basePaths = basePaths) + + // Without auto inference, all of value in the `row` should be null or in StringType, + // we need to cast into the data type that user specified. + def castPartitionValuesToUserSchema(row: InternalRow) = { + InternalRow((0 until row.numFields).map { i => + Cast( + Literal.create(row.getUTF8String(i), StringType), + userProvidedSchema.fields(i).dataType).eval() + }: _*) + } + + PartitionSpec(userProvidedSchema, spec.partitions.map { part => + part.copy(values = castPartitionValuesToUserSchema(part.values)) + }) + case _ => + PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled(), + basePaths = basePaths) + } + } + + private def prunePartitions( + predicates: Seq[Expression], + partitionSpec: PartitionSpec): Seq[PartitionDirectory] = { + val PartitionSpec(partitionColumns, partitions) = partitionSpec + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + if (partitionPruningPredicates.nonEmpty) { + val predicate = partitionPruningPredicates.reduce(expressions.And) + + val boundPredicate = InterpretedPredicate.create(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + val selected = partitions.filter { + case PartitionDirectory(values, _) => boundPredicate(values) + } + logInfo { + val total = partitions.length + val selectedSize = selected.length + val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 + s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + } + + selected + } else { + partitions + } + } + + /** + * Contains a set of paths that are considered as the base dirs of the input datasets. + * The partitioning discovery logic will make sure it will stop when it reaches any + * base path. + * + * By default, the paths of the dataset provided by users will be base paths. + * Below are three typical examples, + * Case 1) `spark.read.parquet("/path/something=true/")`: the base path will be + * `/path/something=true/`, and the returned DataFrame will not contain a column of `something`. + * Case 2) `spark.read.parquet("/path/something=true/a.parquet")`: the base path will be + * still `/path/something=true/`, and the returned DataFrame will also not contain a column of + * `something`. + * Case 3) `spark.read.parquet("/path/")`: the base path will be `/path/`, and the returned + * DataFrame will have the column of `something`. + * + * Users also can override the basePath by setting `basePath` in the options to pass the new base + * path to the data source. + * For example, `spark.read.option("basePath", "/path/").parquet("/path/something=true/")`, + * and the returned DataFrame will have the column of `something`. + */ + private def basePaths: Set[Path] = { + parameters.get("basePath").map(new Path(_)) match { + case Some(userDefinedBasePath) => + val fs = userDefinedBasePath.getFileSystem(hadoopConf) + if (!fs.isDirectory(userDefinedBasePath)) { + throw new IllegalArgumentException("Option 'basePath' must be a directory") + } + Set(fs.makeQualified(userDefinedBasePath)) + + case None => + paths.map { path => + // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). + val qualifiedPath = path.getFileSystem(hadoopConf).makeQualified(path) + if (leafFiles.contains(qualifiedPath)) qualifiedPath.getParent else qualifiedPath }.toSet + } + } + + private def isDataPath(path: Path): Boolean = { + val name = path.getName + !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 74f2993754f8f..504464216e5a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ +// TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. object PartitionDirectory { def apply(values: InternalRow, path: String): PartitionDirectory = @@ -41,22 +42,23 @@ object PartitionDirectory { * Holds a directory in a partitioned collection of files as well as as the partition values * in the form of a Row. Before scanning, the files at `path` need to be enumerated. */ -private[sql] case class PartitionDirectory(values: InternalRow, path: Path) +case class PartitionDirectory(values: InternalRow, path: Path) -private[sql] case class PartitionSpec( +case class PartitionSpec( partitionColumns: StructType, partitions: Seq[PartitionDirectory]) -private[sql] object PartitionSpec { +object PartitionSpec { val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionDirectory]) } -private[sql] object PartitioningUtils { +object PartitioningUtils { // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since sql/core doesn't // depend on Hive. - private[sql] val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" + val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" - private[sql] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { + private[datasources] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) + { require(columnNames.size == literals.size) } @@ -83,7 +85,7 @@ private[sql] object PartitioningUtils { * path = "hdfs://:/path/to/partition/a=2/b=world/c=6.28"))) * }}} */ - private[sql] def parsePartitions( + private[datasources] def parsePartitions( paths: Seq[Path], defaultPartitionName: String, typeInference: Boolean, @@ -159,14 +161,14 @@ private[sql] object PartitioningUtils { * Seq( * Literal.create(42, IntegerType), * Literal.create("hello", StringType), - * Literal.create(3.14, FloatType))) + * Literal.create(3.14, DoubleType))) * }}} * and the path when we stop the discovery is: * {{{ * hdfs://:/path/to/partition * }}} */ - private[sql] def parsePartition( + private[datasources] def parsePartition( path: Path, defaultPartitionName: String, typeInference: Boolean, @@ -249,7 +251,7 @@ private[sql] object PartitioningUtils { * DoubleType -> StringType * }}} */ - private[sql] def resolvePartitions( + def resolvePartitions( pathsWithPartitionValues: Seq[(Path, PartitionValues)]): Seq[PartitionValues] = { if (pathsWithPartitionValues.isEmpty) { Seq.empty @@ -275,7 +277,7 @@ private[sql] object PartitioningUtils { } } - private[sql] def listConflictingPartitionColumns( + private[datasources] def listConflictingPartitionColumns( pathWithPartitionValues: Seq[(Path, PartitionValues)]): String = { val distinctPartColNames = pathWithPartitionValues.map(_._2.columnNames).distinct @@ -308,7 +310,7 @@ private[sql] object PartitioningUtils { * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and * [[StringType]]. */ - private[sql] def inferPartitionColumnValue( + private[datasources] def inferPartitionColumnValue( raw: String, defaultPartitionName: String, typeInference: Boolean): Literal = { @@ -339,7 +341,7 @@ private[sql] object PartitioningUtils { private val upCastingOrder: Seq[DataType] = Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) - def validatePartitionColumnDataTypes( + def validatePartitionColumn( schema: StructType, partitionColumns: Seq[String], caseSensitive: Boolean): Unit = { @@ -350,6 +352,10 @@ private[sql] object PartitioningUtils { case _ => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column") } } + + if (partitionColumns.nonEmpty && partitionColumns.size == schema.fields.length) { + throw new AnalysisException(s"Cannot use all columns for partition columns") + } } def partitionColumnsSchema( @@ -359,7 +365,7 @@ private[sql] object PartitioningUtils { val equality = columnNameEquality(caseSensitive) StructType(partitionColumns.map { col => schema.find(f => equality(f.name, col)).getOrElse { - throw new RuntimeException(s"Partition column $col not found in schema $schema") + throw new AnalysisException(s"Partition column $col not found in schema $schema") } }).asNullable } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala index f03ae94d55838..938af25a96844 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RecordReaderIterator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.io.Closeable + import org.apache.hadoop.mapreduce.RecordReader import org.apache.spark.sql.catalyst.InternalRow @@ -27,7 +29,8 @@ import org.apache.spark.sql.catalyst.InternalRow * Note that this returns [[Object]]s instead of [[InternalRow]] because we rely on erasure to pass * column batches by pretending they are rows. */ -class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] { +class RecordReaderIterator[T]( + private[this] var rowReader: RecordReader[_, T]) extends Iterator[T] with Closeable { private[this] var havePair = false private[this] var finished = false @@ -38,7 +41,7 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] // Close and release the reader here; close() will also be called when the task // completes, but for tasks that read from many files, it helps to release the // resources early. - rowReader.close() + close() } havePair = !finished } @@ -52,4 +55,18 @@ class RecordReaderIterator[T](rowReader: RecordReader[_, T]) extends Iterator[T] havePair = false rowReader.getCurrentValue } + + override def close(): Unit = { + if (rowReader != null) { + try { + // Close the reader and release it. Note: it's very important that we don't close the + // reader more than once, since that exposes us to MAPREDUCE-5918 when running against + // older Hadoop 2.x releases. That bug can lead to non-deterministic corruption issues + // when reading compressed input. + rowReader.close() + } finally { + rowReader = null + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 3b064a5bc489f..e25924b1ba1ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -32,20 +32,22 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A container for all the details required when writing to a table. */ -case class WriteRelation( +private[datasources] case class WriteRelation( sparkSession: SparkSession, dataSchema: StructType, path: String, prepareJobForWrite: Job => OutputWriterFactory, bucketSpec: Option[BucketSpec]) -private[sql] abstract class BaseWriterContainer( +private[datasources] abstract class BaseWriterContainer( @transient val relation: WriteRelation, @transient private val job: Job, isAppend: Boolean) @@ -91,7 +93,7 @@ private[sql] abstract class BaseWriterContainer( // This UUID is sent to executor side together with the serialized `Configuration` object within // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate // unique task output files. - job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + job.getConfiguration.set(DATASOURCE_WRITEJOBUUID, uniqueWriteJobId.toString) // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, @@ -176,7 +178,7 @@ private[sql] abstract class BaseWriterContainer( val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) ctor.newInstance(new Path(outputPath), context) } else { - // The specified output committer is just a OutputCommitter. + // The specified output committer is just an OutputCommitter. // So, we will use the no-argument constructor. val ctor = clazz.getDeclaredConstructor() ctor.newInstance() @@ -232,7 +234,7 @@ private[sql] abstract class BaseWriterContainer( /** * A writer that writes all of the rows in a partition to a single file. */ -private[sql] class DefaultWriterContainer( +private[datasources] class DefaultWriterContainer( relation: WriteRelation, job: Job, isAppend: Boolean) @@ -241,7 +243,7 @@ private[sql] class DefaultWriterContainer( def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { executorSideSetup(taskContext) val configuration = taskAttemptContext.getConfiguration - configuration.set("spark.sql.sources.output.path", outputPath) + configuration.set(DATASOURCE_OUTPUTPATH, outputPath) var writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) @@ -291,7 +293,7 @@ private[sql] class DefaultWriterContainer( * done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the * writer externally sorts the remaining rows and then writes out them out one file at a time. */ -private[sql] class DynamicPartitionWriterContainer( +private[datasources] class DynamicPartitionWriterContainer( relation: WriteRelation, job: Job, partitionColumns: Seq[Attribute], @@ -349,11 +351,10 @@ private[sql] class DynamicPartitionWriterContainer( val configuration = taskAttemptContext.getConfiguration val path = if (partitionColumns.nonEmpty) { val partitionPath = getPartitionString(key).getString(0) - configuration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + configuration.set(DATASOURCE_OUTPUTPATH, new Path(outputPath, partitionPath).toString) new Path(getWorkPath, partitionPath).toString } else { - configuration.set("spark.sql.sources.output.path", outputPath) + configuration.set(DATASOURCE_OUTPUTPATH, outputPath) getWorkPath } val bucketId = getBucketIdFromKey(key) @@ -389,7 +390,9 @@ private[sql] class DynamicPartitionWriterContainer( StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) + TaskContext.get().taskMemoryManager().pageSizeBytes, + SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) while (iterator.hasNext) { val currentRow = iterator.next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala index 6008d73717f77..2bafe967993b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -31,7 +31,7 @@ private[sql] case class BucketSpec( bucketColumnNames: Seq[String], sortColumnNames: Seq[String]) -private[sql] object BucketingUtils { +object BucketingUtils { // The file name of bucketed data should have 3 parts: // 1. some other information in the head of file name // 2. bucket id part, some numbers, starts with "_" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala similarity index 77% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 75143e609aaf7..107b6007ce467 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -25,20 +25,19 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.JoinedRow -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration /** * Provides access to CSV data from pure SQL statements. */ -class DefaultSource extends FileFormat with DataSourceRegister { +class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "csv" @@ -46,7 +45,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] + override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] override def inferSchema( sparkSession: SparkSession, @@ -61,9 +60,11 @@ class DefaultSource extends FileFormat with DataSourceRegister { val firstRow = new LineCsvReader(csvOptions).parseLine(firstLine) val header = if (csvOptions.headerFlag) { - firstRow + firstRow.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == csvOptions.nullValue) s"_c$index" else value + } } else { - firstRow.zipWithIndex.map { case (value, index) => s"C$index" } + firstRow.zipWithIndex.map { case (value, index) => s"_c$index" } } val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths) @@ -84,6 +85,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + verifySchema(dataSchema) val conf = job.getConfiguration val csvOptions = new CSVOptions(options) csvOptions.compressionCodec.foreach { codec => @@ -110,26 +112,24 @@ class DefaultSource extends FileFormat with DataSourceRegister { (file: PartitionedFile) => { val lineIterator = { val conf = broadcastedHadoopConf.value.value - new HadoopFileLinesReader(file, conf).map { line => + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.map { line => new String(line.getBytes, 0, line.getLength, csvOptions.charset) } } CSVRelation.dropHeaderLine(file, lineIterator, csvOptions) - val unsafeRowIterator = { - val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers) - val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions) - tokenizedIterator.flatMap(parser(_).toSeq) - } - - // Appends partition values - val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes - val joinedRow = new JoinedRow() - val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) - - unsafeRowIterator.map { dataRow => - appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) + val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers) + val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions) + var numMalformedRecords = 0 + tokenizedIterator.flatMap { recordTokens => + val row = parser(recordTokens, numMalformedRecords) + if (row.isEmpty) { + numMalformedRecords += 1 + } + row } } } @@ -181,4 +181,20 @@ class DefaultSource extends FileFormat with DataSourceRegister { .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset))) } } + + private def verifySchema(schema: StructType): Unit = { + def verifyType(dataType: DataType): Unit = dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | BooleanType | _: DecimalType | TimestampType | + DateType | StringType => + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"CSV data source does not support ${dataType.simpleString} data type.") + } + + schema.foreach(field => verifyType(field.dataType)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index cfd66af18892b..3ab775c909238 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.text.{NumberFormat, SimpleDateFormat} +import java.text.NumberFormat import java.util.Locale import scala.util.control.Exception._ import scala.util.Try import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion +import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,6 +85,7 @@ private[csv] object CSVInferSchema { case NullType => tryParseInteger(field, options) case IntegerType => tryParseInteger(field, options) case LongType => tryParseLong(field, options) + case _: DecimalType => tryParseDecimal(field, options) case DoubleType => tryParseDouble(field, options) case TimestampType => tryParseTimestamp(field, options) case BooleanType => tryParseBoolean(field, options) @@ -107,10 +108,28 @@ private[csv] object CSVInferSchema { if ((allCatch opt field.toLong).isDefined) { LongType } else { - tryParseDouble(field, options) + tryParseDecimal(field, options) } } + private def tryParseDecimal(field: String, options: CSVOptions): DataType = { + val decimalTry = allCatch opt { + // `BigDecimal` conversion can fail when the `field` is not a form of number. + val bigDecimal = new BigDecimal(field) + // Because many other formats do not support decimal, it reduces the cases for + // decimals by disallowing values having scale (eg. `1.1`). + if (bigDecimal.scale <= 0) { + // `DecimalType` conversion can fail when + // 1. The precision is bigger than 38. + // 2. scale is bigger than precision. + DecimalType(bigDecimal.precision, bigDecimal.scale) + } else { + tryParseDouble(field, options) + } + } + decimalTry.getOrElse(tryParseDouble(field, options)) + } + private def tryParseDouble(field: String, options: CSVOptions): DataType = { if ((allCatch opt field.toDouble).isDefined) { DoubleType @@ -120,20 +139,14 @@ private[csv] object CSVInferSchema { } private def tryParseTimestamp(field: String, options: CSVOptions): DataType = { - if (options.dateFormat != null) { - // This case infers a custom `dataFormat` is set. - if ((allCatch opt options.dateFormat.parse(field)).isDefined) { - TimestampType - } else { - tryParseBoolean(field, options) - } - } else { + // This case infers a custom `dataFormat` is set. + if ((allCatch opt options.timestampFormat.parse(field)).isDefined) { + TimestampType + } else if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { // We keep this for backwords competibility. - if ((allCatch opt DateTimeUtils.stringToTime(field)).isDefined) { - TimestampType - } else { - tryParseBoolean(field, options) - } + TimestampType + } else { + tryParseBoolean(field, options) } } @@ -152,11 +165,11 @@ private[csv] object CSVInferSchema { StringType } - private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence + private val numericPrecedence: IndexedSeq[DataType] = TypeCoercion.numericPrecedence /** * Copied from internal Spark api - * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion]] */ val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) @@ -170,6 +183,33 @@ private[csv] object CSVInferSchema { val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) Some(numericPrecedence(index)) + // These two cases below deal with when `DecimalType` is larger than `IntegralType`. + case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => + Some(t2) + case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => + Some(t1) + + // These two cases below deal with when `IntegralType` is larger than `DecimalType`. + case (t1: IntegralType, t2: DecimalType) => + findTightestCommonType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + findTightestCommonType(t1, DecimalType.forType(t2)) + + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => + Some(DoubleType) + + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + Some(DoubleType) + } else { + Some(DecimalType(range + scale, scale)) + } + case _ => None } } @@ -192,59 +232,58 @@ private[csv] object CSVTypeCast { nullable: Boolean = true, options: CSVOptions = CSVOptions()): Any = { - castType match { - case _: ByteType => if (datum == options.nullValue && nullable) null else datum.toByte - case _: ShortType => if (datum == options.nullValue && nullable) null else datum.toShort - case _: IntegerType => if (datum == options.nullValue && nullable) null else datum.toInt - case _: LongType => if (datum == options.nullValue && nullable) null else datum.toLong - case _: FloatType => - if (datum == options.nullValue && nullable) { - null - } else if (datum == options.nanValue) { - Float.NaN - } else if (datum == options.negativeInf) { - Float.NegativeInfinity - } else if (datum == options.positiveInf) { - Float.PositiveInfinity - } else { - Try(datum.toFloat) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) - } - case _: DoubleType => - if (datum == options.nullValue && nullable) { - null - } else if (datum == options.nanValue) { - Double.NaN - } else if (datum == options.negativeInf) { - Double.NegativeInfinity - } else if (datum == options.positiveInf) { - Double.PositiveInfinity - } else { - Try(datum.toDouble) - .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) - } - case _: BooleanType => datum.toBoolean - case dt: DecimalType => - if (datum == options.nullValue && nullable) { - null - } else { + if (nullable && datum == options.nullValue) { + null + } else { + castType match { + case _: ByteType => datum.toByte + case _: ShortType => datum.toShort + case _: IntegerType => datum.toInt + case _: LongType => datum.toLong + case _: FloatType => + datum match { + case options.nanValue => Float.NaN + case options.negativeInf => Float.NegativeInfinity + case options.positiveInf => Float.PositiveInfinity + case _ => + Try(datum.toFloat) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue()) + } + case _: DoubleType => + datum match { + case options.nanValue => Double.NaN + case options.negativeInf => Double.NegativeInfinity + case options.positiveInf => Double.PositiveInfinity + case _ => + Try(datum.toDouble) + .getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue()) + } + case _: BooleanType => datum.toBoolean + case dt: DecimalType => val value = new BigDecimal(datum.replaceAll(",", "")) Decimal(value, dt.precision, dt.scale) - } - case _: TimestampType if options.dateFormat != null => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - options.dateFormat.parse(datum).getTime * 1000L - case _: TimestampType => - // This one will lose microseconds parts. - // See https://issues.apache.org/jira/browse/SPARK-10681. - DateTimeUtils.stringToTime(datum).getTime * 1000L - case _: DateType if options.dateFormat != null => - DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime) - case _: DateType => - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) - case _: StringType => UTF8String.fromString(datum) - case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + case _: TimestampType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681. + Try(options.timestampFormat.parse(datum).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(datum).getTime * 1000L + } + case _: DateType => + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime) + } + case _: StringType => UTF8String.fromString(datum) + case udt: UserDefinedType[_] => castTo(datum, udt.sqlType, nullable, options) + case _ => throw new RuntimeException(s"Unsupported type: ${castType.typeName}") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index e4fd09462fabd..364d7c831eb44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.StandardCharsets -import java.text.SimpleDateFormat + +import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} -private[sql] class CSVOptions(@transient private val parameters: Map[String, String]) +private[csv] class CSVOptions(@transient private val parameters: Map[String, String]) extends Logging with Serializable { private def getChar(paramName: String, default: Char): Char = { @@ -101,16 +102,24 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str name.map(CompressionCodecs.getCodecClassName) } - // Share date format object as it is expensive to parse date pattern. - val dateFormat: SimpleDateFormat = { - val dateFormat = parameters.get("dateFormat") - dateFormat.map(new SimpleDateFormat(_)).orNull - } + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. + val dateFormat: FastDateFormat = + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + + val timestampFormat: FastDateFormat = + FastDateFormat.getInstance( + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) val maxColumns = getInt("maxColumns", 20480) val maxCharsPerColumn = getInt("maxCharsPerColumn", 1000000) + val escapeQuotes = getBool("escapeQuotes", true) + + val maxMalformedLogPerPartition = getInt("maxMalformedLogPerPartition", 10) + + val quoteAll = getBool("quoteAll", false) + val inputBufferSize = 128 val isCommentSet = this.comment != '\u0000' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index c3d863f547dab..0a996547d2536 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv import java.io.{ByteArrayOutputStream, OutputStreamWriter, StringReader} import java.nio.charset.StandardCharsets -import com.univocity.parsers.csv.{CsvParser, CsvParserSettings, CsvWriter, CsvWriterSettings} +import com.univocity.parsers.csv._ import org.apache.spark.internal.Logging @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging * @param params Parameters object * @param headers headers for the columns */ -private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { +private[csv] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { protected lazy val parser: CsvParser = { val settings = new CsvParserSettings() @@ -47,7 +47,7 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) settings.setMaxColumns(params.maxColumns) settings.setNullValue(params.nullValue) settings.setMaxCharsPerColumn(params.maxCharsPerColumn) - settings.setParseUnescapedQuotesUntilDelimiter(true) + settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) if (headers != null) settings.setHeaders(headers: _*) new CsvParser(settings) @@ -60,7 +60,7 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) * @param params Parameters object for configuration * @param headers headers for columns */ -private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { +private[csv] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat @@ -73,20 +73,30 @@ private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten writerSettings.setNullValue(params.nullValue) writerSettings.setEmptyValue(params.nullValue) writerSettings.setSkipEmptyLines(true) - writerSettings.setQuoteAllFields(false) + writerSettings.setQuoteAllFields(params.quoteAll) writerSettings.setHeaders(headers: _*) + writerSettings.setQuoteEscapingEnabled(params.escapeQuotes) - def writeRow(row: Seq[String], includeHeader: Boolean): String = { - val buffer = new ByteArrayOutputStream() - val outputWriter = new OutputStreamWriter(buffer, StandardCharsets.UTF_8) - val writer = new CsvWriter(outputWriter, writerSettings) + private var buffer = new ByteArrayOutputStream() + private var writer = new CsvWriter( + new OutputStreamWriter(buffer, StandardCharsets.UTF_8), + writerSettings) + def writeRow(row: Seq[String], includeHeader: Boolean): Unit = { if (includeHeader) { writer.writeHeaders() } writer.writeRow(row.toArray: _*) + } + + def flush(): String = { writer.close() - buffer.toString.stripLineEnd + val lines = buffer.toString.stripLineEnd + buffer = new ByteArrayOutputStream() + writer = new CsvWriter( + new OutputStreamWriter(buffer, StandardCharsets.UTF_8), + writerSettings) + lines } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 4f2d4387b180f..d0d5ce06cf8b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -30,6 +30,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.types._ @@ -49,10 +51,19 @@ object CSVRelation extends Logging { } } + /** + * Returns a function that parses a single CSV record (in the form of an array of strings in which + * each element represents a column) and turns it into either one resulting row or no row (if the + * the record is malformed). + * + * The 2nd argument in the returned function represents the total number of malformed rows + * observed so far. + */ + // This is pretty convoluted and we should probably rewrite the entire CSV parsing soon. def csvParser( schema: StructType, requiredColumns: Array[String], - params: CSVOptions): Array[String] => Option[InternalRow] = { + params: CSVOptions): (Array[String], Int) => Option[InternalRow] = { val schemaFields = schema.fields val requiredFields = StructType(requiredColumns.map(schema(_))).fields val safeRequiredFields = if (params.dropMalformed) { @@ -71,9 +82,16 @@ object CSVRelation extends Logging { val requiredSize = requiredFields.length val row = new GenericMutableRow(requiredSize) - (tokens: Array[String]) => { + (tokens: Array[String], numMalformedRows) => { if (params.dropMalformed && schemaFields.length != tokens.length) { - logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + if (numMalformedRows < params.maxMalformedLogPerPartition) { + logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + } + if (numMalformedRows == params.maxMalformedLogPerPartition - 1) { + logWarning( + s"More than ${params.maxMalformedLogPerPartition} malformed records have been " + + "found on this partition. Malformed records from now on will not be logged.") + } None } else if (params.failFast && schemaFields.length != tokens.length) { throw new RuntimeException(s"Malformed line in FAILFAST mode: " + @@ -108,23 +126,21 @@ object CSVRelation extends Logging { Some(row) } catch { case NonFatal(e) if params.dropMalformed => - logWarning("Parse exception. " + - s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + if (numMalformedRows < params.maxMalformedLogPerPartition) { + logWarning("Parse exception. " + + s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") + } + if (numMalformedRows == params.maxMalformedLogPerPartition - 1) { + logWarning( + s"More than ${params.maxMalformedLogPerPartition} malformed records have been " + + "found on this partition. Malformed records from now on will not be logged.") + } None } } } } - def parseCsv( - tokenizedRDD: RDD[Array[String]], - schema: StructType, - requiredColumns: Array[String], - options: CSVOptions): RDD[InternalRow] = { - val parser = csvParser(schema, requiredColumns, options) - tokenizedRDD.flatMap(parser(_).toSeq) - } - // Skips the header line of each file if the `header` option is set to true. def dropHeaderLine( file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = { @@ -144,7 +160,7 @@ object CSVRelation extends Logging { } } -private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { +private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { override def newInstance( path: String, bucketId: Option[Int], @@ -155,7 +171,7 @@ private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWrit } } -private[sql] class CsvOutputWriter( +private[csv] class CsvOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext, @@ -164,11 +180,19 @@ private[sql] class CsvOutputWriter( // create the Generator without separator inserted between 2 records private[this] val text = new Text() + // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. + // When the value is null, this converter should not be called. + private type ValueConverter = (InternalRow, Int) => String + + // `ValueConverter`s for all values in the fields of the schema + private val valueConverters: Array[ValueConverter] = + dataSchema.map(_.dataType).map(makeConverter).toArray + private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val uniqueWriteJobId = configuration.get(CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.csv$extension") @@ -176,31 +200,60 @@ private[sql] class CsvOutputWriter( }.getRecordWriter(context) } - private var firstRow: Boolean = params.headerFlag - + private val FLUSH_BATCH_SIZE = 1024L + private var records: Long = 0L private val csvWriter = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) - private def rowToString(row: Seq[Any]): Seq[String] = row.map { field => - if (field != null) { - field.toString - } else { - params.nullValue + private def rowToString(row: InternalRow): Seq[String] = { + var i = 0 + val values = new Array[String](row.numFields) + while (i < row.numFields) { + if (!row.isNullAt(i)) { + values(i) = valueConverters(i).apply(row, i) + } else { + values(i) = params.nullValue + } + i += 1 } + values + } + + private def makeConverter(dataType: DataType): ValueConverter = dataType match { + case DateType => + (row: InternalRow, ordinal: Int) => + params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + params.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + + case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + + case dt: DataType => + (row: InternalRow, ordinal: Int) => + row.get(ordinal, dt).toString } override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { - // TODO: Instead of converting and writing every row, we should use the univocity buffer - val resultString = csvWriter.writeRow(rowToString(row.toSeq(dataSchema)), firstRow) - if (firstRow) { - firstRow = false + csvWriter.writeRow(rowToString(row), records == 0L && params.headerFlag) + records += 1 + if (records % FLUSH_BATCH_SIZE == 0) { + flush() + } + } + + private def flush(): Unit = { + val lines = csvWriter.flush() + if (lines.nonEmpty) { + text.set(lines) + recordWriter.write(NullWritable.get(), text) } - text.set(resultString) - recordWriter.write(NullWritable.get(), text) } override def close(): Unit = { + flush() recordWriter.close(context) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 7d0a3d9756e90..857047f820ce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand @@ -38,34 +38,31 @@ case class CreateTableUsing( provider: String, temporary: Boolean, options: Map[String, String], + partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], allowExisting: Boolean, - managedIfNoPath: Boolean) extends LogicalPlan with logical.Command { - - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty -} + managedIfNoPath: Boolean) extends logical.Command /** * A node used to support CTAS statements and saveAsTable for the data source API. - * This node is a [[logical.UnaryNode]] instead of a [[logical.Command]] because we want the - * analyzer can analyze the logical plan that will be used to populate the table. - * So, [[PreWriteCheck]] can detect cases that are not allowed. */ case class CreateTableUsingAsSelect( tableIdent: TableIdentifier, provider: String, - temporary: Boolean, partitionColumns: Array[String], bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], - child: LogicalPlan) extends logical.UnaryNode { - override def output: Seq[Attribute] = Seq.empty[Attribute] + query: LogicalPlan) extends logical.Command { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override lazy val resolved: Boolean = query.resolved } -case class CreateTempTableUsing( +case class CreateTempViewUsing( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], + replace: Boolean, provider: String, options: Map[String, String]) extends RunnableCommand { @@ -80,68 +77,31 @@ case class CreateTempTableUsing( userSpecifiedSchema = userSpecifiedSchema, className = provider, options = options) - sparkSession.sessionState.catalog.createTempTable( + sparkSession.sessionState.catalog.createTempView( tableIdent.table, Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation())).logicalPlan, - overrideIfExists = true) + replace) Seq.empty[Row] } } -case class CreateTempTableUsingAsSelect( - tableIdent: TableIdentifier, - provider: String, - partitionColumns: Array[String], - mode: SaveMode, - options: Map[String, String], - query: LogicalPlan) extends RunnableCommand { - - if (tableIdent.database.isDefined) { - throw new AnalysisException( - s"Temporary table '$tableIdent' should not have specified a database") - } +case class RefreshTable(tableIdent: TableIdentifier) + extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { - val df = Dataset.ofRows(sparkSession, query) - val dataSource = DataSource( - sparkSession, - className = provider, - partitionColumns = partitionColumns, - bucketSpec = None, - options = options) - val result = dataSource.write(mode, df) - sparkSession.sessionState.catalog.createTempTable( - tableIdent.table, - Dataset.ofRows(sparkSession, LogicalRelation(result)).logicalPlan, - overrideIfExists = true) - + // Refresh the given table's metadata. If this table is cached as an InMemoryRelation, + // drop the original cached version and make the new version cached lazily. + sparkSession.catalog.refreshTable(tableIdent.quotedString) Seq.empty[Row] } } -case class RefreshTable(tableIdent: TableIdentifier) +case class RefreshResource(path: String) extends RunnableCommand { override def run(sparkSession: SparkSession): Seq[Row] = { - // Refresh the given table's metadata first. - sparkSession.sessionState.catalog.refreshTable(tableIdent) - - // If this table is cached as a InMemoryColumnarRelation, drop the original - // cached version and make the new version cached lazily. - val logicalPlan = sparkSession.sessionState.catalog.lookupRelation(tableIdent) - // Use lookupCachedData directly since RefreshTable also takes databaseName. - val isCached = sparkSession.cacheManager.lookupCachedData(logicalPlan).nonEmpty - if (isCached) { - // Create a data frame to represent the table. - // TODO: Use uncacheTable once it supports database name. - val df = Dataset.ofRows(sparkSession, logicalPlan) - // Uncache the logicalPlan. - sparkSession.cacheManager.tryUncacheQuery(df, blocking = true) - // Cache it again. - sparkSession.cacheManager.cacheQuery(df, Some(tableIdent.table)) - } - + sparkSession.catalog.refreshByPath(path) Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala index 25f88d9c39487..ea614e55b540a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -22,18 +22,19 @@ import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ +import org.apache.hadoop.io.compress.{CompressionCodecFactory, SplittableCompressionCodec} import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.FileRelation -import org.apache.spark.sql.sources.{BaseRelation, Filter} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, Filter} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration /** @@ -64,6 +65,20 @@ abstract class OutputWriterFactory extends Serializable { bucketId: Option[Int], // TODO: This doesn't belong here... dataSchema: StructType, context: TaskAttemptContext): OutputWriter + + /** + * Returns a new instance of [[OutputWriter]] that will write data to the given path. + * This method gets called by each task on executor to write [[InternalRow]]s to + * format-specific files. Compared to the other `newInstance()`, this is a newer API that + * passes only the path that the writer must write to. The writer must write to the exact path + * and not modify it (do not add subdirectories, extensions, etc.). All other + * file-format-specific information needed to create the writer must be passed + * through the [[OutputWriterFactory]] implementation. + * @since 2.0.0 + */ + def newWriter(path: String): OutputWriter = { + throw new UnsupportedOperationException("newInstance with just path not supported") + } } /** @@ -119,15 +134,15 @@ abstract class OutputWriter { * @param options Configuration used when reading / writing data. */ case class HadoopFsRelation( - sparkSession: SparkSession, location: FileCatalog, partitionSchema: StructType, dataSchema: StructType, bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - options: Map[String, String]) extends BaseRelation with FileRelation { + options: Map[String, String])(val sparkSession: SparkSession) + extends BaseRelation with FileRelation { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext val schema: StructType = { val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet @@ -142,8 +157,12 @@ case class HadoopFsRelation( def refresh(): Unit = location.refresh() - override def toString: String = - s"HadoopFiles" + override def toString: String = { + fileFormat match { + case source: DataSourceRegister => source.shortName() + case _ => "HadoopFiles" + } + } /** Returns the list of files that will be read when scanning this relation. */ override def inputFiles: Array[String] = @@ -166,15 +185,6 @@ trait FileFormat { options: Map[String, String], files: Seq[FileStatus]): Option[StructType] - /** - * Prepares a read job and returns a potentially updated data source option [[Map]]. This method - * can be useful for collecting necessary global information for scanning input data. - */ - def prepareRead( - sparkSession: SparkSession, - options: Map[String, String], - files: Seq[FileStatus]): Map[String, String] = options - /** * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can * be put here. For example, user defined output committer can be configured here @@ -195,6 +205,16 @@ trait FileFormat { false } + /** + * Returns whether a file with `path` could be splitted or not. + */ + def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + false + } + /** * Returns a function that can be used to read a single file in as an Iterator of InternalRow. * @@ -223,6 +243,77 @@ trait FileFormat { // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. throw new UnsupportedOperationException(s"buildReader is not supported for $this") } + + /** + * Exactly the same as [[buildReader]] except that the reader function returned by this method + * appends partition values to [[InternalRow]]s produced by the reader function [[buildReader]] + * returns. + */ + def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + val dataReader = buildReader( + sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) + + new (PartitionedFile => Iterator[InternalRow]) with Serializable { + private val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + + private val joinedRow = new JoinedRow() + + // Using lazy val to avoid serialization + private lazy val appendPartitionColumns = + GenerateUnsafeProjection.generate(fullSchema, fullSchema) + + override def apply(file: PartitionedFile): Iterator[InternalRow] = { + // Using local val to avoid per-row lazy val check (pre-mature optimization?...) + val converter = appendPartitionColumns + + // Note that we have to apply the converter even though `file.partitionValues` is empty. + // This is because the converter is also responsible for converting safe `InternalRow`s into + // `UnsafeRow`s. + dataReader(file).map { dataRow => + converter(joinedRow(dataRow, file.partitionValues)) + } + } + } + } + + /** + * Returns a [[OutputWriterFactory]] for generating output writers that can write data. + * This method is current used only by FileStreamSinkWriter to generate output writers that + * does not use output committers to write data. The OutputWriter generated by the returned + * [[OutputWriterFactory]] must implement the method `newWriter(path)`.. + */ + def buildWriter( + sqlContext: SQLContext, + dataSchema: StructType, + options: Map[String, String]): OutputWriterFactory = { + // TODO: Remove this default implementation when the other formats have been ported + throw new UnsupportedOperationException(s"buildWriter is not supported for $this") + } +} + +/** + * The base class file format that is based on text file. + */ +abstract class TextBasedFileFormat extends FileFormat { + private var codecFactory: CompressionCodecFactory = null + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + if (codecFactory == null) { + codecFactory = new CompressionCodecFactory( + sparkSession.sessionState.newHadoopConfWithOptions(options)) + } + val codec = codecFactory.getCodec(path) + codec == null || codec.isInstanceOf[SplittableCompressionCodec] + } } /** @@ -236,245 +327,59 @@ case class Partition(values: InternalRow, files: Seq[FileStatus]) * as the partitioning characteristics of those files. */ trait FileCatalog { + + /** Returns the list of input paths from which the catalog will get files. */ def paths: Seq[Path] + /** Returns the specification of the partitions inferred from the data. */ def partitionSpec(): PartitionSpec /** * Returns all valid files grouped into partitions when the data is partitioned. If the data is - * unpartitioned, this will return a single partition with not partition values. + * unpartitioned, this will return a single partition with no partition values. * - * @param filters the filters used to prune which partitions are returned. These filters must + * @param filters The filters used to prune which partitions are returned. These filters must * only refer to partition columns and this method will only return files * where these predicates are guaranteed to evaluate to `true`. Thus, these * filters will not need to be evaluated again on the returned data. */ def listFiles(filters: Seq[Expression]): Seq[Partition] + /** Returns all the valid files. */ def allFiles(): Seq[FileStatus] - def getStatus(path: Path): Array[FileStatus] - + /** Refresh the file listing */ def refresh(): Unit } + /** - * A file catalog that caches metadata gathered by scanning all the files present in `paths` - * recursively. - * - * @param parameters as set of options to control discovery - * @param paths a list of paths to scan - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions + * Helper methods for gathering metadata from HDFS. */ -class HDFSFileCatalog( - sparkSession: SparkSession, - parameters: Map[String, String], - override val paths: Seq[Path], - partitionSchema: Option[StructType]) - extends FileCatalog with Logging { - - private val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(parameters) - - var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] - var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] - var cachedPartitionSpec: PartitionSpec = _ - - def partitionSpec(): PartitionSpec = { - if (cachedPartitionSpec == null) { - cachedPartitionSpec = inferPartitioning(partitionSchema) - } - - cachedPartitionSpec - } - - refresh() +object HadoopFsRelation extends Logging { - override def listFiles(filters: Seq[Expression]): Seq[Partition] = { - if (partitionSpec().partitionColumns.isEmpty) { - Partition(InternalRow.empty, allFiles().filterNot(_.getPath.getName startsWith "_")) :: Nil - } else { - prunePartitions(filters, partitionSpec()).map { - case PartitionDirectory(values, path) => - Partition( - values, - getStatus(path).filterNot(_.getPath.getName startsWith "_")) - } - } - } - - protected def prunePartitions( - predicates: Seq[Expression], - partitionSpec: PartitionSpec): Seq[PartitionDirectory] = { - val PartitionSpec(partitionColumns, partitions) = partitionSpec - val partitionColumnNames = partitionColumns.map(_.name).toSet - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - - if (partitionPruningPredicates.nonEmpty) { - val predicate = partitionPruningPredicates.reduce(expressions.And) - - val boundPredicate = InterpretedPredicate.create(predicate.transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }) - - val selected = partitions.filter { - case PartitionDirectory(values, _) => boundPredicate(values) - } - logInfo { - val total = partitions.length - val selectedSize = selected.length - val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 - s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." - } - - selected - } else { - partitions - } - } - - def allFiles(): Seq[FileStatus] = leafFiles.values.toSeq - - def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path) - - private def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { - if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { - HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sparkSession.sparkContext) - } else { - val statuses: Seq[FileStatus] = paths.flatMap { path => - val fs = path.getFileSystem(hadoopConf) - logInfo(s"Listing $path on driver") - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - - val statuses = { - val stats = Try(fs.listStatus(path)).getOrElse(Array.empty[FileStatus]) - if (pathFilter != null) stats.filter(f => pathFilter.accept(f.getPath)) else stats - } - - statuses.map { - case f: LocatedFileStatus => f - - // NOTE: - // - // - Although S3/S3A/S3N file system can be quite slow for remote file metadata - // operations, calling `getFileBlockLocations` does no harm here since these file system - // implementations don't actually issue RPC for this method. - // - // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should a - // a big deal since we always use to `listLeafFilesInParallel` when the number of paths - // exceeds threshold. - case f => new LocatedFileStatus(f, fs.getFileBlockLocations(f, 0, f.getLen)) - } - }.filterNot { status => - val name = status.getPath.getName - HadoopFsRelation.shouldFilterOut(name) - } - - val (dirs, files) = statuses.partition(_.isDirectory) - - // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) - if (dirs.isEmpty) { - mutable.LinkedHashSet(files: _*) - } else { - mutable.LinkedHashSet(files: _*) ++ listLeafFiles(dirs.map(_.getPath)) - } - } - } - - def inferPartitioning(schema: Option[StructType]): PartitionSpec = { - // We use leaf dirs containing data files to discover the schema. - val leafDirs = leafDirToChildrenFiles.keys.toSeq - schema match { - case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val spec = PartitioningUtils.parsePartitions( - leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = false, - basePaths = basePaths) - - // Without auto inference, all of value in the `row` should be null or in StringType, - // we need to cast into the data type that user specified. - def castPartitionValuesToUserSchema(row: InternalRow) = { - InternalRow((0 until row.numFields).map { i => - Cast( - Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType).eval() - }: _*) - } - - PartitionSpec(userProvidedSchema, spec.partitions.map { part => - part.copy(values = castPartitionValuesToUserSchema(part.values)) - }) - case _ => - PartitioningUtils.parsePartitions( - leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled(), - basePaths = basePaths) - } + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // We filter everything that starts with _ and ., except _common_metadata and _metadata + // because Parquet needs to find those metadata files from leaf files returned by this method. + // We should refactor this logic to not mix metadata files with data files. + ((pathName.startsWith("_") && !pathName.contains("=")) || pathName.startsWith(".")) && + !pathName.startsWith("_common_metadata") && !pathName.startsWith("_metadata") } /** - * Contains a set of paths that are considered as the base dirs of the input datasets. - * The partitioning discovery logic will make sure it will stop when it reaches any - * base path. By default, the paths of the dataset provided by users will be base paths. - * For example, if a user uses `sqlContext.read.parquet("/path/something=true/")`, the base path - * will be `/path/something=true/`, and the returned DataFrame will not contain a column of - * `something`. If users want to override the basePath. They can set `basePath` in the options - * to pass the new base path to the data source. - * For the above example, if the user-provided base path is `/path/`, the returned - * DataFrame will have the column of `something`. + * Create a LocatedFileStatus using FileStatus and block locations. */ - private def basePaths: Set[Path] = { - val userDefinedBasePath = parameters.get("basePath").map(basePath => Set(new Path(basePath))) - userDefinedBasePath.getOrElse { - // If the user does not provide basePath, we will just use paths. - paths.toSet - }.map { hdfsPath => - // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). - val fs = hdfsPath.getFileSystem(hadoopConf) - hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + def createLocatedFileStatus(f: FileStatus, locations: Array[BlockLocation]): LocatedFileStatus = { + // The other constructor of LocatedFileStatus will call FileStatus.getPermission(), which is + // very slow on some file system (RawLocalFileSystem, which is launch a subprocess and parse the + // stdout). + val lfs = new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, + f.getModificationTime, 0, null, null, null, null, f.getPath, locations) + if (f.isSymlink) { + lfs.setSymlink(f.getSymlink) } - } - - def refresh(): Unit = { - val files = listLeafFiles(paths) - - leafFiles.clear() - leafDirToChildrenFiles.clear() - - leafFiles ++= files.map(f => f.getPath -> f) - leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) - - cachedPartitionSpec = null - } - - override def equals(other: Any): Boolean = other match { - case hdfs: HDFSFileCatalog => paths.toSet == hdfs.paths.toSet - case _ => false - } - - override def hashCode(): Int = paths.toSet.hashCode() -} - -/** - * Helper methods for gathering metadata from HDFS. - */ -private[sql] object HadoopFsRelation extends Logging { - - /** Checks if we should filter out this path name. */ - def shouldFilterOut(pathName: String): Boolean = { - // TODO: We should try to filter out all files/dirs starting with "." or "_". - // The only reason that we are not doing it now is that Parquet needs to find those - // metadata files from leaf files returned by this methods. We should refactor - // this logic to not mix metadata files with data files. - pathName == "_SUCCESS" || pathName == "_temporary" || pathName.startsWith(".") + lfs } // We don't filter files/directories whose name start with "_" except "_temporary" here, as @@ -482,23 +387,31 @@ private[sql] object HadoopFsRelation extends Logging { // _common_metadata files). "_temporary" directories are explicitly ignored since failed // tasks/jobs may leave partial/corrupted data files there. Files and directories whose name // start with "." are also ignored. - def listLeafFiles(fs: FileSystem, status: FileStatus): Array[FileStatus] = { - logInfo(s"Listing ${status.getPath}") + def listLeafFiles(fs: FileSystem, status: FileStatus, filter: PathFilter): Array[FileStatus] = { + logTrace(s"Listing ${status.getPath}") val name = status.getPath.getName.toLowerCase if (shouldFilterOut(name)) { - Array.empty + Array.empty[FileStatus] } else { - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(fs.getConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) val statuses = { val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) - val stats = files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - if (pathFilter != null) stats.filter(f => pathFilter.accept(f.getPath)) else stats + val stats = files ++ dirs.flatMap(dir => listLeafFiles(fs, dir, filter)) + if (filter != null) stats.filter(f => filter.accept(f.getPath)) else stats } + // statuses do not have any dirs. statuses.filterNot(status => shouldFilterOut(status.getPath.getName)).map { case f: LocatedFileStatus => f - case f => new LocatedFileStatus(f, fs.getFileBlockLocations(f, 0, f.getLen)) + + // NOTE: + // + // - Although S3/S3A/S3N file system can be quite slow for remote file metadata + // operations, calling `getFileBlockLocations` does no harm here since these file system + // implementations don't actually issue RPC for this method. + // + // - Here we are calling `getFileBlockLocations` in a sequential manner, but it should not + // be a big deal since we always use to `listLeafFilesInParallel` when the number of + // paths exceeds threshold. + case f => createLocatedFileStatus(f, fs.getFileBlockLocations(f, 0, f.getLen)) } } } @@ -526,15 +439,34 @@ private[sql] object HadoopFsRelation extends Logging { def listLeafFilesInParallel( paths: Seq[Path], hadoopConf: Configuration, - sparkContext: SparkContext): mutable.LinkedHashSet[FileStatus] = { + sparkSession: SparkSession, + ignoreFileNotFound: Boolean): mutable.LinkedHashSet[FileStatus] = { + assert(paths.size >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + val sparkContext = sparkSession.sparkContext val serializableConfiguration = new SerializableConfiguration(hadoopConf) val serializedPaths = paths.map(_.toString) - val fakeStatuses = sparkContext.parallelize(serializedPaths).map(new Path(_)).flatMap { path => - val fs = path.getFileSystem(serializableConfiguration.value) - Try(listLeafFiles(fs, fs.getFileStatus(path))).getOrElse(Array.empty) + // Set the number of parallelism to prevent following file listing from generating many tasks + // in case of large #defaultParallelism. + val numParallelism = Math.min(paths.size, 10000) + + val fakeStatuses = sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { paths => + // Dummy jobconf to get to the pathFilter defined in configuration + // It's very expensive to create a JobConf(ClassUtil.findContainingJar() is slow) + val jobConf = new JobConf(serializableConfiguration.value, this.getClass) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + paths.map(new Path(_)).flatMap { path => + val fs = path.getFileSystem(serializableConfiguration.value) + try { + listLeafFiles(fs, fs.getFileStatus(path), pathFilter) + } catch { + case e: java.io.FileNotFoundException if ignoreFileNotFound => Array.empty[FileStatus] + } + } }.map { status => val blockLocations = status match { case f: LocatedFileStatus => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala new file mode 100644 index 0000000000000..6c6ec89746ee1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.jdbc + +/** + * Options for the JDBC data source. + */ +private[jdbc] class JDBCOptions( + @transient private val parameters: Map[String, String]) + extends Serializable { + + // a JDBC URL + val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + // name of table + val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + // the column used to partition + val partitionColumn = parameters.getOrElse("partitionColumn", null) + // the lower bound of partition column + val lowerBound = parameters.getOrElse("lowerBound", null) + // the upper bound of the partition column + val upperBound = parameters.getOrElse("upperBound", null) + // the number of partitions + val numPartitions = parameters.getOrElse("numPartitions", null) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 6a5564addf48c..82bc9d75beff6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -38,11 +38,11 @@ import org.apache.spark.unsafe.types.UTF8String /** * Data corresponding to one partition of a JDBCRDD. */ -private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { +case class JDBCPartition(whereClause: String, idx: Int) extends Partition { override def index: Int = idx } -private[sql] object JDBCRDD extends Logging { +object JDBCRDD extends Logging { /** * Maps a JDBC type to a Catalyst type. This function is called only when @@ -136,7 +136,16 @@ private[sql] object JDBCRDD extends Logging { val typeName = rsmd.getColumnTypeName(i + 1) val fieldSize = rsmd.getPrecision(i + 1) val fieldScale = rsmd.getScale(i + 1) - val isSigned = rsmd.isSigned(i + 1) + val isSigned = { + try { + rsmd.isSigned(i + 1) + } catch { + // Workaround for HIVE-14684: + case e: SQLException if + e.getMessage == "Method not supported" && + rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true + } + } val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder() .putString("name", columnName) @@ -192,7 +201,7 @@ private[sql] object JDBCRDD extends Logging { * Turns a single Filter into a String representing a SQL expression. * Returns None for an unhandled filter. */ - private[jdbc] def compileFilter(f: Filter): Option[String] = { + def compileFilter(f: Filter): Option[String] = { Option(f match { case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" case EqualNullSafe(attr, value) => @@ -275,7 +284,7 @@ private[sql] object JDBCRDD extends Logging { * driver code and the workers must be able to access the database; the driver * needs to fetch the schema while the workers need to fetch the data. */ -private[sql] class JDBCRDD( +private[jdbc] class JDBCRDD( sc: SparkContext, getConnection: () => Connection, schema: StructType, @@ -305,14 +314,14 @@ private[sql] class JDBCRDD( * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ private val filterWhereClause: String = - filters.flatMap(JDBCRDD.compileFilter).mkString(" AND ") + filters.flatMap(JDBCRDD.compileFilter).map(p => s"($p)").mkString(" AND ") /** * A WHERE clause representing both `filters`, if any, and the current partition. */ private def getWhereClause(part: JDBCPartition): String = { if (part.whereClause != null && filterWhereClause.length > 0) { - "WHERE " + filterWhereClause + " AND " + part.whereClause + "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" } else if (part.whereClause != null) { "WHERE " + part.whereClause } else if (filterWhereClause.length > 0) { @@ -374,6 +383,7 @@ private[sql] class JDBCRDD( var nextValue: InternalRow = null context.addTaskCompletionListener{ context => close() } + val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] val conn = getConnection() val dialect = JdbcDialects.get(url) @@ -389,7 +399,11 @@ private[sql] class JDBCRDD( val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" val stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - val fetchSize = properties.getProperty("fetchsize", "0").toInt + val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt + require(fetchSize >= 0, + s"Invalid value `${fetchSize.toString}` for parameter " + + s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " + + "the JDBC driver ignores the value and does the estimates.") stmt.setFetchSize(fetchSize) val rs = stmt.executeQuery() @@ -398,6 +412,7 @@ private[sql] class JDBCRDD( def getNext(): InternalRow = { if (rs.next()) { + inputMetrics.incRecordsRead(1) var i = 0 while (i < conversions.length) { val pos = i + 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index bcf70fdc4a497..11613dd912eca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -21,6 +21,7 @@ import java.util.Properties import scala.collection.mutable.ArrayBuffer +import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} @@ -36,7 +37,7 @@ private[sql] case class JDBCPartitioningInfo( upperBound: Long, numPartitions: Int) -private[sql] object JDBCRelation { +private[sql] object JDBCRelation extends Logging { /** * Given a partitioning schematic (a column of integral type, a number of * partitions, and upper and lower bounds on the column's value), generate @@ -52,29 +53,46 @@ private[sql] object JDBCRelation { * @return an array of partitions with where clause for each partition */ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { - if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) + if (partitioning == null || partitioning.numPartitions <= 1 || + partitioning.lowerBound == partitioning.upperBound) { + return Array[Partition](JDBCPartition(null, 0)) + } - val numPartitions = partitioning.numPartitions - val column = partitioning.column - if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) + val lowerBound = partitioning.lowerBound + val upperBound = partitioning.upperBound + require (lowerBound <= upperBound, + "Operation not allowed: the lower bound of partitioning column is larger than the upper " + + s"bound. Lower bound: $lowerBound; Upper bound: $upperBound") + + val numPartitions = + if ((upperBound - lowerBound) >= partitioning.numPartitions) { + partitioning.numPartitions + } else { + logWarning("The number of partitions is reduced because the specified number of " + + "partitions is less than the difference between upper bound and lower bound. " + + s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " + + s"partitions: ${partitioning.numPartitions}; Lower bound: $lowerBound; " + + s"Upper bound: $upperBound.") + upperBound - lowerBound + } // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. - val stride: Long = (partitioning.upperBound / numPartitions - - partitioning.lowerBound / numPartitions) + val stride: Long = upperBound / numPartitions - lowerBound / numPartitions + val column = partitioning.column var i: Int = 0 - var currentValue: Long = partitioning.lowerBound + var currentValue: Long = lowerBound var ans = new ArrayBuffer[Partition]() while (i < numPartitions) { - val lowerBound = if (i != 0) s"$column >= $currentValue" else null + val lBound = if (i != 0) s"$column >= $currentValue" else null currentValue += stride - val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null + val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null val whereClause = - if (upperBound == null) { - lowerBound - } else if (lowerBound == null) { - s"$upperBound or $column is null" + if (uBound == null) { + lBound + } else if (lBound == null) { + s"$uBound or $column is null" } else { - s"$lowerBound AND $upperBound" + s"$lBound AND $uBound" } ans += JDBCPartition(whereClause, i) i = i + 1 @@ -92,7 +110,7 @@ private[sql] case class JDBCRelation( with PrunedFilteredScan with InsertableRelation { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override val needConversion: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala similarity index 64% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 6ff50a3c61223..106ed1d440102 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -22,7 +22,7 @@ import java.util.Properties import org.apache.spark.sql.SQLContext import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} -class DefaultSource extends RelationProvider with DataSourceRegister { +class JdbcRelationProvider extends RelationProvider with DataSourceRegister { override def shortName(): String = "jdbc" @@ -30,30 +30,26 @@ class DefaultSource extends RelationProvider with DataSourceRegister { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) - val partitionColumn = parameters.getOrElse("partitionColumn", null) - val lowerBound = parameters.getOrElse("lowerBound", null) - val upperBound = parameters.getOrElse("upperBound", null) - val numPartitions = parameters.getOrElse("numPartitions", null) - - if (partitionColumn != null - && (lowerBound == null || upperBound == null || numPartitions == null)) { + val jdbcOptions = new JDBCOptions(parameters) + if (jdbcOptions.partitionColumn != null + && (jdbcOptions.lowerBound == null + || jdbcOptions.upperBound == null + || jdbcOptions.numPartitions == null)) { sys.error("Partitioning incompletely specified") } - val partitionInfo = if (partitionColumn == null) { + val partitionInfo = if (jdbcOptions.partitionColumn == null) { null } else { JDBCPartitioningInfo( - partitionColumn, - lowerBound.toLong, - upperBound.toLong, - numPartitions.toInt) + jdbcOptions.partitionColumn, + jdbcOptions.lowerBound.toLong, + jdbcOptions.upperBound.toLong, + jdbcOptions.numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) val properties = new Properties() // Additional properties that we will pass to getConnection parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) - JDBCRelation(url, table, parts, properties)(sqlContext.sparkSession) + JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 065c8572b06a2..2869e801deb96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, DriverManager, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, PreparedStatement, SQLException} import java.util.Properties import scala.collection.JavaConverters._ @@ -34,6 +34,10 @@ import org.apache.spark.sql.types._ */ object JdbcUtils extends Logging { + // the property names are case sensitive + val JDBC_BATCH_FETCH_SIZE = "fetchsize" + val JDBC_BATCH_INSERT_SIZE = "batchsize" + /** * Returns a factory for creating connections to the given JDBC URL. * @@ -50,7 +54,7 @@ object JdbcUtils extends Logging { DriverManager.getDriver(url).getClass.getCanonicalName } () => { - userSpecifiedDriverClass.foreach(DriverRegistry.register) + DriverRegistry.register(driverClass) val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d case d if d.getClass.getCanonicalName == driverClass => d @@ -96,8 +100,9 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val columns = rddSchema.fields.map(_.name).mkString(",") + def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect) + : PreparedStatement = { + val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",") val placeholders = rddSchema.fields.map(_ => "?").mkString(",") val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)" conn.prepareStatement(sql) @@ -154,6 +159,10 @@ object JdbcUtils extends Logging { nullTypes: Array[Int], batchSize: Int, dialect: JdbcDialect): Iterator[Byte] = { + require(batchSize >= 1, + s"Invalid value `${batchSize.toString}` for parameter " + + s"`${JdbcUtils.JDBC_BATCH_INSERT_SIZE}`. The minimum value is 1.") + val conn = getConnection() var committed = false val supportsTransactions = try { @@ -169,7 +178,7 @@ object JdbcUtils extends Logging { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. } - val stmt = insertStatement(conn, table, rddSchema) + val stmt = insertStatement(conn, table, rddSchema, dialect) try { var rowCount = 0 while (iterator.hasNext) { @@ -224,6 +233,17 @@ object JdbcUtils extends Logging { conn.commit() } committed = true + } catch { + case e: SQLException => + val cause = e.getNextException + if (cause != null && e.getCause != cause) { + if (e.getCause == null) { + e.initCause(cause) + } else { + e.addSuppressed(cause) + } + } + throw e } finally { if (!committed) { // The stage must fail. We got here through an exception path, so @@ -252,7 +272,7 @@ object JdbcUtils extends Logging { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => - val name = field.name + val name = dialect.quoteIdentifier(field.name) val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") @@ -275,7 +295,7 @@ object JdbcUtils extends Logging { val rddSchema = df.schema val getConnection: () => Connection = createConnectionFactory(url, properties) - val batchSize = properties.getProperty("batchsize", "1000").toInt + val batchSize = properties.getProperty(JDBC_BATCH_INSERT_SIZE, "1000").toInt df.foreachPartition { iterator => savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 8e8238a594a03..579b036417d24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution.datasources.json +import java.util.Comparator + import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion +import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -63,9 +65,7 @@ private[sql] object InferSchema { None } } - }.treeAggregate[DataType]( - StructType(Seq()))( - compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord), + }.fold(StructType(Seq()))( compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)) canonicalizeType(rootType) match { @@ -76,6 +76,23 @@ private[sql] object InferSchema { } } + private[this] val structFieldComparator = new Comparator[StructField] { + override def compare(o1: StructField, o2: StructField): Int = { + o1.name.compare(o2.name) + } + } + + private def isSorted(arr: Array[StructField]): Boolean = { + var i: Int = 0 + while (i < arr.length - 1) { + if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) { + return false + } + i += 1 + } + true + } + /** * Infer the type of a json document from the parser's token stream */ @@ -99,15 +116,17 @@ private[sql] object InferSchema { case VALUE_STRING => StringType case START_OBJECT => - val builder = Seq.newBuilder[StructField] + val builder = Array.newBuilder[StructField] while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, inferField(parser, configOptions), nullable = true) } - - StructType(builder.result().sortBy(_.name)) + val fields: Array[StructField] = builder.result() + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(fields, structFieldComparator) + StructType(fields) case START_ARRAY => // If this JSON array is empty, we use NullType as a placeholder. @@ -191,7 +210,11 @@ private[sql] object InferSchema { if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { // If this given struct does not have a column used for corrupt records, // add this field. - struct.add(columnNameOfCorruptRecords, StringType, nullable = true) + val newFields: Array[StructField] = + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(newFields, structFieldComparator) + StructType(newFields) } else { // Otherwise, just return this struct. struct @@ -223,11 +246,13 @@ private[sql] object InferSchema { case (ty1, ty2) => compatibleType(ty1, ty2) } + private[this] val emptyStructFieldArray = Array.empty[StructField] + /** * Returns the most general data type for two given data types. */ def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { + TypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough @@ -246,12 +271,43 @@ private[sql] object InferSchema { } case (StructType(fields1), StructType(fields2)) => - val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { - case (name, fieldTypes) => - val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType) - StructField(name, dataType, nullable = true) + // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. + // Therefore, we can take advantage of the fact that we're merging sorted lists and skip + // building a hash map or performing additional sorting. + assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") + + val newFields = new java.util.ArrayList[StructField]() + + var f1Idx = 0 + var f2Idx = 0 + + while (f1Idx < fields1.length && f2Idx < fields2.length) { + val f1Name = fields1(f1Idx).name + val f2Name = fields2(f2Idx).name + val comp = f1Name.compareTo(f2Name) + if (comp == 0) { + val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + newFields.add(StructField(f1Name, dataType, nullable = true)) + f1Idx += 1 + f2Idx += 1 + } else if (comp < 0) { // f1Name < f2Name + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } else { // f1Name > f2Name + newFields.add(fields2(f2Idx)) + f2Idx += 1 + } + } + while (f1Idx < fields1.length) { + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } + while (f2Idx < fields2.length) { + newFields.add(fields2(f2Idx)) + f2Idx += 1 } - StructType(newFields.toSeq.sortBy(_.name)) + StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index 66f1126fb9ae6..02d211d04265e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.json import com.fasterxml.jackson.core.{JsonFactory, JsonParser} +import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} @@ -53,6 +54,14 @@ private[sql] class JSONOptions( private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord") + // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe. + val dateFormat: FastDateFormat = + FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd")) + + val timestampFormat: FastDateFormat = + FastDateFormat.getInstance( + parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ")) + // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 8b920ecafaeed..800d43f3039c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -32,11 +32,17 @@ private[sql] object JacksonGenerator { * @param gen a JsonGenerator object * @param row The row to convert */ - def apply(rowSchema: StructType, gen: JsonGenerator)(row: InternalRow): Unit = { + def apply( + rowSchema: StructType, + gen: JsonGenerator, + options: JSONOptions = new JSONOptions(Map.empty[String, String])) + (row: InternalRow): Unit = { def valWriter: (DataType, Any) => Unit = { case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v) => gen.writeString(v.toString) - case (TimestampType, v: Long) => gen.writeString(DateTimeUtils.toJavaTimestamp(v).toString) + case (TimestampType, v: Long) => + val timestampString = options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(v)) + gen.writeString(timestampString) case (IntegerType, v: Int) => gen.writeNumber(v) case (ShortType, v: Short) => gen.writeNumber(v) case (FloatType, v: Float) => gen.writeNumber(v) @@ -46,7 +52,9 @@ private[sql] object JacksonGenerator { case (ByteType, v: Byte) => gen.writeNumber(v.toInt) case (BinaryType, v: Array[Byte]) => gen.writeBinary(v) case (BooleanType, v: Boolean) => gen.writeBoolean(v) - case (DateType, v: Int) => gen.writeString(DateTimeUtils.toJavaDate(v).toString) + case (DateType, v: Int) => + val dateString = options.dateFormat.format(DateTimeUtils.toJavaDate(v)) + gen.writeString(dateString) // For UDT values, they should be in the SQL type's corresponding value type. // We should not see values in the user-defined class at here. // For example, VectorUDT's SQL type is an array of double. So, we should expect that v is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index aeee2600a19ee..a5417dc4a0e19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream import scala.collection.mutable.ArrayBuffer +import scala.util.Try import com.fasterxml.jackson.core._ @@ -50,34 +51,36 @@ object JacksonParser extends Logging { /** * Parse the current token (and related children) according to a desired schema - * This is an wrapper for the method `convertField()` to handle a row wrapped + * This is a wrapper for the method `convertField()` to handle a row wrapped * with an array. */ def convertRootField( factory: JsonFactory, parser: JsonParser, - schema: DataType): Any = { + schema: DataType, + configOptions: JSONOptions): Any = { import com.fasterxml.jackson.core.JsonToken._ (parser.getCurrentToken, schema) match { case (START_ARRAY, st: StructType) => // SPARK-3308: support reading top level JSON arrays and take every element // in such an array as a row - convertArray(factory, parser, st) + convertArray(factory, parser, st, configOptions) case (START_OBJECT, ArrayType(st, _)) => // the business end of SPARK-3308: // when an object is found but an array is requested just wrap it in a list - convertField(factory, parser, st) :: Nil + convertField(factory, parser, st, configOptions) :: Nil case _ => - convertField(factory, parser, schema) + convertField(factory, parser, schema, configOptions) } } private def convertField( factory: JsonFactory, parser: JsonParser, - schema: DataType): Any = { + schema: DataType, + configOptions: JSONOptions): Any = { import com.fasterxml.jackson.core.JsonToken._ (parser.getCurrentToken, schema) match { case (null | VALUE_NULL, _) => @@ -85,7 +88,7 @@ object JacksonParser extends Logging { case (FIELD_NAME, _) => parser.nextToken() - convertField(factory, parser, schema) + convertField(factory, parser, schema, configOptions) case (VALUE_STRING, StringType) => UTF8String.fromString(parser.getText) @@ -99,19 +102,29 @@ object JacksonParser extends Logging { case (VALUE_STRING, DateType) => val stringValue = parser.getText - if (stringValue.contains("-")) { - // The format of this string will probably be "yyyy-mm-dd". - DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime) - } else { - // In Spark 1.5.0, we store the data as number of days since epoch in string. - // So, we just convert it to Int. - stringValue.toInt - } + // This one will lose microseconds parts. + // See https://issues.apache.org/jira/browse/SPARK-10681.x + Try(DateTimeUtils.millisToDays(configOptions.dateFormat.parse(parser.getText).getTime)) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime)) + .getOrElse { + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } + } case (VALUE_STRING, TimestampType) => // This one will lose microseconds parts. // See https://issues.apache.org/jira/browse/SPARK-10681. - DateTimeUtils.stringToTime(parser.getText).getTime * 1000L + Try(configOptions.timestampFormat.parse(parser.getText).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(parser.getText).getTime * 1000L + } case (VALUE_NUMBER_INT, TimestampType) => parser.getLongValue * 1000000L @@ -179,16 +192,16 @@ object JacksonParser extends Logging { false case (START_OBJECT, st: StructType) => - convertObject(factory, parser, st) + convertObject(factory, parser, st, configOptions) case (START_ARRAY, ArrayType(st, _)) => - convertArray(factory, parser, st) + convertArray(factory, parser, st, configOptions) case (START_OBJECT, MapType(StringType, kt, _)) => - convertMap(factory, parser, kt) + convertMap(factory, parser, kt, configOptions) case (_, udt: UserDefinedType[_]) => - convertField(factory, parser, udt.sqlType) + convertField(factory, parser, udt.sqlType, configOptions) case (token, dataType) => // We cannot parse this token based on the given data type. So, we throw a @@ -207,12 +220,13 @@ object JacksonParser extends Logging { private def convertObject( factory: JsonFactory, parser: JsonParser, - schema: StructType): InternalRow = { + schema: StructType, + configOptions: JSONOptions): InternalRow = { val row = new GenericMutableRow(schema.length) while (nextUntil(parser, JsonToken.END_OBJECT)) { schema.getFieldIndex(parser.getCurrentName) match { case Some(index) => - row.update(index, convertField(factory, parser, schema(index).dataType)) + row.update(index, convertField(factory, parser, schema(index).dataType, configOptions)) case None => parser.skipChildren() @@ -228,12 +242,13 @@ object JacksonParser extends Logging { private def convertMap( factory: JsonFactory, parser: JsonParser, - valueType: DataType): MapData = { + valueType: DataType, + configOptions: JSONOptions): MapData = { val keys = ArrayBuffer.empty[UTF8String] val values = ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_OBJECT)) { keys += UTF8String.fromString(parser.getCurrentName) - values += convertField(factory, parser, valueType) + values += convertField(factory, parser, valueType, configOptions) } ArrayBasedMapData(keys.toArray, values.toArray) } @@ -241,10 +256,11 @@ object JacksonParser extends Logging { private def convertArray( factory: JsonFactory, parser: JsonParser, - elementType: DataType): ArrayData = { + elementType: DataType, + configOptions: JSONOptions): ArrayData = { val values = ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_ARRAY)) { - values += convertField(factory, parser, elementType) + values += convertField(factory, parser, elementType, configOptions) } new GenericArrayData(values.toArray) @@ -285,7 +301,7 @@ object JacksonParser extends Logging { Utils.tryWithResource(factory.createParser(record)) { parser => parser.nextToken() - convertRootField(factory, parser, schema) match { + convertRootField(factory, parser, schema, configOptions) match { case null => failedRecord(record) case row: InternalRow => row :: Nil case array: ArrayData => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala similarity index 87% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 62446583a55b7..cba3255d86f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -28,18 +28,18 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.JoinedRow -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -class DefaultSource extends FileFormat with DataSourceRegister { +class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "json" @@ -56,7 +56,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) val jsonFiles = files.filterNot { status => val name = status.getPath.getName - name.startsWith("_") || name.startsWith(".") + (name.startsWith("_") && !name.contains("=")) || name.startsWith(".") }.toArray val jsonSchema = InferSchema.infer( @@ -86,7 +86,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, bucketId, dataSchema, context) + new JsonOutputWriter(path, parsedOptions, bucketId, dataSchema, context) } } } @@ -106,22 +106,16 @@ class DefaultSource extends FileFormat with DataSourceRegister { val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes - val joinedRow = new JoinedRow() - (file: PartitionedFile) => { - val lines = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map(_.toString) + val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + val lines = linesReader.map(_.toString) - val rows = JacksonParser.parseJson( + JacksonParser.parseJson( lines, requiredSchema, columnNameOfCorruptRecord, parsedOptions) - - val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) - rows.map { row => - appendPartitionColumns(joinedRow(row, file.partitionValues)) - } } } @@ -159,11 +153,12 @@ class DefaultSource extends FileFormat with DataSourceRegister { override def hashCode(): Int = getClass.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] + override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] } private[json] class JsonOutputWriter( path: String, + options: JSONOptions, bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) @@ -178,7 +173,7 @@ private[json] class JsonOutputWriter( new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val uniqueWriteJobId = configuration.get(CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") @@ -190,7 +185,7 @@ private[json] class JsonOutputWriter( override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") override protected[sql] def writeInternal(row: InternalRow): Unit = { - JacksonGenerator(dataSchema, gen)(row) + JacksonGenerator(dataSchema, gen, options)(row) gen.flush() result.set(writer.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala similarity index 78% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index b1513bbe945bb..aef0f1b4157f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -37,20 +37,22 @@ import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{AtomicType, DataType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource +class ParquetFileFormat extends FileFormat with DataSourceRegister with Logging @@ -62,7 +64,7 @@ private[sql] class DefaultSource override def hashCode(): Int = getClass.hashCode() - override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] + override def equals(other: Any): Boolean = other.isInstanceOf[ParquetFileFormat] override def prepareWrite( sparkSession: SparkSession, @@ -98,13 +100,13 @@ private[sql] class DefaultSource // bundled with `ParquetOutputFormat[Row]`. job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) - ParquetOutputFormat.setWriteSupportClass(job, classOf[CatalystWriteSupport]) + ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) // We want to clear this temporary metadata from saving into Parquet file. // This metadata is only useful for detecting optional columns when pushdowning filters. val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField, dataSchema).asInstanceOf[StructType] - CatalystWriteSupport.setSchema(dataSchemaToWrite, conf) + ParquetWriteSupport.setSchema(dataSchemaToWrite, conf) // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) // and `CatalystWriteSupport` (writing actual rows to Parquet files). @@ -123,6 +125,11 @@ private[sql] class DefaultSource // Sets compression scheme conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) + // SPARK-15719: Disables writing Parquet summary files by default. + if (conf.get(ParquetOutputFormat.ENABLE_JOB_SUMMARY) == null) { + conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + } + new OutputWriterFactory { override def newInstance( path: String, @@ -138,12 +145,10 @@ private[sql] class DefaultSource sparkSession: SparkSession, parameters: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { + val parquetOptions = new ParquetOptions(parameters, sparkSession.sessionState.conf) + // Should we merge schemas from all Parquet part-files? - val shouldMergeSchemas = - parameters - .get(ParquetRelation.MERGE_SCHEMA) - .map(_.toBoolean) - .getOrElse(sparkSession.conf.get(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + val shouldMergeSchemas = parquetOptions.mergeSchema val mergeRespectSummaries = sparkSession.conf.get(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) @@ -217,7 +222,7 @@ private[sql] class DefaultSource .orElse(filesByType.data.headOption) .toSeq } - ParquetRelation.mergeSchemasInParallel(filesToTouch, sparkSession) + ParquetFileFormat.mergeSchemasInParallel(filesToTouch, sparkSession) } case class FileTypes( @@ -229,7 +234,8 @@ private[sql] class DefaultSource // Lists `FileStatus`es of all leaf nodes (files) under all base directories. val leaves = allFiles.filter { f => isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + !((f.getPath.getName.startsWith("_") && !f.getPath.getName.contains("=")) || + f.getPath.getName.startsWith(".")) }.toArray.sortBy(_.getPath.toString) FileTypes( @@ -255,6 +261,27 @@ private[sql] class DefaultSource schema.forall(_.dataType.isInstanceOf[AtomicType]) } + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + true + } + + override def buildReaderWithPartitionValues( + sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { + // For Parquet data source, `buildReader` already handles partition values appending. Here we + // simply delegate to `buildReader`. + buildReader( + sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) + } + override def buildReader( sparkSession: SparkSession, dataSchema: StructType, @@ -263,19 +290,19 @@ private[sql] class DefaultSource filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) hadoopConf.set( - CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - CatalystSchemaConverter.checkFieldNames(requiredSchema).json) + ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + ParquetSchemaConverter.checkFieldNames(requiredSchema).json) hadoopConf.set( - CatalystWriteSupport.SPARK_ROW_SCHEMA, - CatalystSchemaConverter.checkFieldNames(requiredSchema).json) + ParquetWriteSupport.SPARK_ROW_SCHEMA, + ParquetSchemaConverter.checkFieldNames(requiredSchema).json) // We want to clear this temporary metadata from saving into Parquet file. // This metadata is only useful for detecting optional columns when pushdowning filters. val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField, requiredSchema).asInstanceOf[StructType] - CatalystWriteSupport.setSchema(dataSchemaToWrite, hadoopConf) + ParquetWriteSupport.setSchema(dataSchemaToWrite, hadoopConf) // Sets flags for `CatalystSchemaConverter` hadoopConf.setBoolean( @@ -330,6 +357,11 @@ private[sql] class DefaultSource val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedHadoopConf.value.value, attemptId) + // Try to push down filters when filter push-down is enabled. + // Notice: This push-down is RowGroups level, not individual records. + if (pushed.isDefined) { + ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) + } val parquetReader = if (enableVectorizedReader) { val vectorizedReader = new VectorizedParquetRecordReader() vectorizedReader.initialize(split, hadoopAttemptContext) @@ -341,19 +373,21 @@ private[sql] class DefaultSource vectorizedReader } else { logDebug(s"Falling back to parquet-mr") + // ParquetRecordReader returns UnsafeRow val reader = pushed match { case Some(filter) => new ParquetRecordReader[InternalRow]( - new CatalystReadSupport, + new ParquetReadSupport, FilterCompat.get(filter, null)) case _ => - new ParquetRecordReader[InternalRow](new CatalystReadSupport) + new ParquetRecordReader[InternalRow](new ParquetReadSupport) } reader.initialize(split, hadoopAttemptContext) reader } val iter = new RecordReaderIterator(parquetReader) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close())) // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. if (parquetReader.isInstanceOf[VectorizedParquetRecordReader] && @@ -367,15 +401,132 @@ private[sql] class DefaultSource // This is a horrible erasure hack... if we type the iterator above, then it actually check // the type in next() and we get a class cast exception. If we make that function return // Object, then we can defer the cast until later! - iter.asInstanceOf[Iterator[InternalRow]] + if (partitionSchema.length == 0) { + // There is no partition columns + iter.asInstanceOf[Iterator[InternalRow]] + } else { + iter.asInstanceOf[Iterator[InternalRow]] .map(d => appendPartitionColumns(joinedRow(d, file.partitionValues))) + } } } } + + override def buildWriter( + sqlContext: SQLContext, + dataSchema: StructType, + options: Map[String, String]): OutputWriterFactory = { + new ParquetOutputWriterFactory( + sqlContext.conf, + dataSchema, + sqlContext.sessionState.newHadoopConf(), + options) + } } +/** + * A factory for generating OutputWriters for writing parquet files. This implemented is different + * from the [[ParquetOutputWriter]] as this does not use any [[OutputCommitter]]. It simply + * writes the data to the path used to generate the output writer. Callers of this factory + * has to ensure which files are to be considered as committed. + */ +private[parquet] class ParquetOutputWriterFactory( + sqlConf: SQLConf, + dataSchema: StructType, + hadoopConf: Configuration, + options: Map[String, String]) extends OutputWriterFactory { + + private val serializableConf: SerializableConfiguration = { + val job = Job.getInstance(hadoopConf) + val conf = ContextUtil.getConfiguration(job) + val parquetOptions = new ParquetOptions(options, sqlConf) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport]) + + // We want to clear this temporary metadata from saving into Parquet file. + // This metadata is only useful for detecting optional columns when pushing down filters. + val dataSchemaToWrite = StructType.removeMetadata( + StructType.metadataKeyForOptionalField, + dataSchema).asInstanceOf[StructType] + ParquetWriteSupport.setSchema(dataSchemaToWrite, conf) + + // Sets flags for `CatalystSchemaConverter` (which converts Catalyst schema to Parquet schema) + // and `CatalystWriteSupport` (writing actual rows to Parquet files). + conf.set( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sqlConf.isParquetBinaryAsString.toString) + + conf.set( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sqlConf.isParquetINT96AsTimestamp.toString) + + conf.set( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key, + sqlConf.writeLegacyParquetFormat.toString) + + // Sets compression scheme + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) + new SerializableConfiguration(conf) + } + + /** + * Returns a [[OutputWriter]] that writes data to the give path without using + * [[OutputCommitter]]. + */ + override def newWriter(path: String): OutputWriter = new OutputWriter { + + // Create TaskAttemptContext that is used to pass on Configuration to the ParquetRecordWriter + private val hadoopTaskAttempId = new TaskAttemptID(new TaskID(new JobID, TaskType.MAP, 0), 0) + private val hadoopAttemptContext = new TaskAttemptContextImpl( + serializableConf.value, hadoopTaskAttempId) + + // Instance of ParquetRecordWriter that does not use OutputCommitter + private val recordWriter = createNoCommitterRecordWriter(path, hadoopAttemptContext) + + override def write(row: Row): Unit = { + throw new UnsupportedOperationException("call writeInternal") + } + + protected[sql] override def writeInternal(row: InternalRow): Unit = { + recordWriter.write(null, row) + } + + override def close(): Unit = recordWriter.close(hadoopAttemptContext) + } + + /** Create a [[ParquetRecordWriter]] that writes the given path without using OutputCommitter */ + private def createNoCommitterRecordWriter( + path: String, + hadoopAttemptContext: TaskAttemptContext): RecordWriter[Void, InternalRow] = { + // Custom ParquetOutputFormat that disable use of committer and writes to the given path + val outputFormat = new ParquetOutputFormat[InternalRow]() { + override def getOutputCommitter(c: TaskAttemptContext): OutputCommitter = { null } + override def getDefaultWorkFile(c: TaskAttemptContext, ext: String): Path = { new Path(path) } + } + outputFormat.getRecordWriter(hadoopAttemptContext) + } + + /** Disable the use of the older API. */ + def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + throw new UnsupportedOperationException( + "this verison of newInstance not supported for " + + "ParquetOutputWriterFactory") + } +} + + // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter( +private[parquet] class ParquetOutputWriter( path: String, bucketId: Option[Int], context: TaskAttemptContext) @@ -395,7 +546,8 @@ private[sql] class ParquetOutputWriter( // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val uniqueWriteJobId = configuration.get( + CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") @@ -412,23 +564,13 @@ private[sql] class ParquetOutputWriter( override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } -private[sql] object ParquetRelation extends Logging { - // Whether we should merge schemas collected from all Parquet part-files. - private[sql] val MERGE_SCHEMA = "mergeSchema" - - // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used - // internally. - private[sql] val METASTORE_SCHEMA = "metastoreSchema" - - // If a ParquetRelation is converted from a Hive metastore table, this option is set to the - // original Hive table name. - private[sql] val METASTORE_TABLE_NAME = "metastoreTableName" +object ParquetFileFormat extends Logging { /** * If parquet's block size (row group size) setting is larger than the min split size, * we use parquet's block size setting as the min split size. Otherwise, we will create @@ -463,7 +605,7 @@ private[sql] object ParquetRelation extends Logging { assumeBinaryIsString: Boolean, assumeInt96IsTimestamp: Boolean)(job: Job): Unit = { val conf = job.getConfiguration - conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) + conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) // Try to push down filters when filter push-down is enabled. if (parquetFilterPushDown) { @@ -476,14 +618,14 @@ private[sql] object ParquetRelation extends Logging { .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) } - conf.set(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { + conf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) - CatalystSchemaConverter.checkFieldNames(requestedSchema).json + ParquetSchemaConverter.checkFieldNames(requestedSchema).json }) conf.set( - CatalystWriteSupport.SPARK_ROW_SCHEMA, - CatalystSchemaConverter.checkFieldNames(dataSchema).json) + ParquetWriteSupport.SPARK_ROW_SCHEMA, + ParquetSchemaConverter.checkFieldNames(dataSchema).json) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) @@ -512,7 +654,7 @@ private[sql] object ParquetRelation extends Logging { footers: Seq[Footer], sparkSession: SparkSession): Option[StructType] = { def parseParquetSchema(schema: MessageType): StructType = { - val converter = new CatalystSchemaConverter( + val converter = new ParquetSchemaConverter( sparkSession.sessionState.conf.isParquetBinaryAsString, sparkSession.sessionState.conf.isParquetBinaryAsString, sparkSession.sessionState.conf.writeLegacyParquetFormat) @@ -526,7 +668,7 @@ private[sql] object ParquetRelation extends Logging { val serializedSchema = metadata .getKeyValueMetaData .asScala.toMap - .get(CatalystReadSupport.SPARK_METADATA_KEY) + .get(ParquetReadSupport.SPARK_METADATA_KEY) if (serializedSchema.isEmpty) { // Falls back to Parquet schema if no Spark SQL schema found. Some(parseParquetSchema(metadata.getSchema)) @@ -575,7 +717,7 @@ private[sql] object ParquetRelation extends Logging { * distinguish binary and string). This method generates a correct schema by merging Metastore * schema data types and Parquet schema field names. */ - private[sql] def mergeMetastoreParquetSchema( + def mergeMetastoreParquetSchema( metastoreSchema: StructType, parquetSchema: StructType): StructType = { def schemaConflictMessage: String = @@ -651,14 +793,13 @@ private[sql] object ParquetRelation extends Logging { val assumeBinaryIsString = sparkSession.sessionState.conf.isParquetBinaryAsString val assumeInt96IsTimestamp = sparkSession.sessionState.conf.isParquetINT96AsTimestamp val writeLegacyParquetFormat = sparkSession.sessionState.conf.writeLegacyParquetFormat - val serializedConf = - new SerializableConfiguration(sparkSession.sessionState.newHadoopConf()) + val serializedConf = new SerializableConfiguration(sparkSession.sessionState.newHadoopConf()) // !! HACK ALERT !! // // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` - // but only `Writable`. What makes it worth, for some reason, `FileStatus` doesn't play well + // but only `Writable`. What makes it worse, for some reason, `FileStatus` doesn't play well // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These // facts virtually prevents us to serialize `FileStatus`es. // @@ -667,11 +808,16 @@ private[sql] object ParquetRelation extends Logging { // side, and resemble fake `FileStatus`es there. val partialFileStatusInfo = filesToTouch.map(f => (f.getPath.toString, f.getLen)) + // Set the number of partitions to prevent following schema reads from generating many tasks + // in case of a small number of parquet files. + val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1), + sparkSession.sparkContext.defaultParallelism) + // Issues a Spark job to read Parquet schema in parallel. val partiallyMergedSchemas = sparkSession .sparkContext - .parallelize(partialFileStatusInfo) + .parallelize(partialFileStatusInfo, numParallelism) .mapPartitions { iterator => // Resembles fake `FileStatus`es with serialized path and length information. val fakeFileStatuses = iterator.map { case (path, length) => @@ -688,7 +834,7 @@ private[sql] object ParquetRelation extends Logging { // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` val converter = - new CatalystSchemaConverter( + new ParquetSchemaConverter( assumeBinaryIsString = assumeBinaryIsString, assumeInt96IsTimestamp = assumeInt96IsTimestamp, writeLegacyParquetFormat = writeLegacyParquetFormat) @@ -696,9 +842,9 @@ private[sql] object ParquetRelation extends Logging { if (footers.isEmpty) { Iterator.empty } else { - var mergedSchema = ParquetRelation.readSchemaFromFooter(footers.head, converter) + var mergedSchema = ParquetFileFormat.readSchemaFromFooter(footers.head, converter) footers.tail.foreach { footer => - val schema = ParquetRelation.readSchemaFromFooter(footer, converter) + val schema = ParquetFileFormat.readSchemaFromFooter(footer, converter) try { mergedSchema = mergedSchema.merge(schema) } catch { case cause: SparkException => @@ -732,12 +878,12 @@ private[sql] object ParquetRelation extends Logging { * a [[StructType]] converted from the [[MessageType]] stored in this footer. */ def readSchemaFromFooter( - footer: Footer, converter: CatalystSchemaConverter): StructType = { + footer: Footer, converter: ParquetSchemaConverter): StructType = { val fileMetaData = footer.getParquetMetadata.getFileMetaData fileMetaData .getKeyValueMetaData .asScala.toMap - .get(CatalystReadSupport.SPARK_METADATA_KEY) + .get(ParquetReadSupport.SPARK_METADATA_KEY) .flatMap(deserializeSchemaString) .getOrElse(converter.convert(fileMetaData.getSchema)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 95afdc789f322..2edd2757428aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -28,7 +28,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.spark.sql.sources import org.apache.spark.sql.types._ -private[sql] object ParquetFilters { +object ParquetFilters { case class SetInFilter[T <: Comparable[T]]( valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable { @@ -215,10 +215,13 @@ private[sql] object ParquetFilters { */ private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match { case StructType(fields) => + // Here we don't flatten the fields in the nested schema but just look up through + // root fields. Currently, accessing to nested fields does not push down filters + // and it does not support to create filters for them. fields.filter { f => !f.metadata.contains(StructType.metadataKeyForOptionalField) || !f.metadata.getBoolean(StructType.metadataKeyForOptionalField) - }.map(f => f.name -> f.dataType) ++ fields.flatMap { f => getFieldMap(f.dataType) } + }.map(f => f.name -> f.dataType) case _ => Array.empty[(String, DataType)] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 00352f23ae660..3eec582714e15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -class ParquetOptions( +private[parquet] class ParquetOptions( @transient private val parameters: Map[String, String], @transient private val sqlConf: SQLConf) - extends Logging with Serializable { + extends Serializable { import ParquetOptions._ @@ -45,10 +44,21 @@ class ParquetOptions( } shortParquetCompressionCodecNames(codecName).name() } + + /** + * Whether it merges schemas or not. When the given Parquet files have different schemas, + * the schemas can be merged. By default use the value specified in SQLConf. + */ + val mergeSchema: Boolean = parameters + .get(MERGE_SCHEMA) + .map(_.toBoolean) + .getOrElse(sqlConf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) } object ParquetOptions { + val MERGE_SCHEMA = "mergeSchema" + // The parquet compression short names private val shortParquetCompressionCodecNames = Map( "none" -> CompressionCodecName.UNCOMPRESSED, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 850e807b8677e..12f497421f4b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.types._ * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]] * to [[prepareForRead()]], but use a private `var` for simplicity. */ -private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with Logging { +private[parquet] class ParquetReadSupport extends ReadSupport[InternalRow] with Logging { private var catalystRequestedSchema: StructType = _ /** @@ -58,13 +58,13 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with override def init(context: InitContext): ReadContext = { catalystRequestedSchema = { val conf = context.getConfiguration - val schemaString = conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + val schemaString = conf.get(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA) assert(schemaString != null, "Parquet requested schema not set.") StructType.fromString(schemaString) } val parquetRequestedSchema = - CatalystReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) + ParquetReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) } @@ -92,13 +92,13 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with """.stripMargin } - new CatalystRecordMaterializer( + new ParquetRecordMaterializer( parquetRequestedSchema, - CatalystReadSupport.expandUDT(catalystRequestedSchema)) + ParquetReadSupport.expandUDT(catalystRequestedSchema)) } } -private[parquet] object CatalystReadSupport { +private[parquet] object ParquetReadSupport { val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" @@ -112,7 +112,7 @@ private[parquet] object CatalystReadSupport { Types .buildMessage() .addFields(clippedParquetFields: _*) - .named(CatalystSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) + .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) } private def clipParquetType(parquetType: Type, catalystType: DataType): Type = { @@ -265,7 +265,7 @@ private[parquet] object CatalystReadSupport { private def clipParquetGroupFields( parquetRecord: GroupType, structType: StructType): Seq[Type] = { val parquetFieldMap = parquetRecord.getFields.asScala.map(f => f.getName -> f).toMap - val toParquet = new CatalystSchemaConverter(writeLegacyParquetFormat = false) + val toParquet = new ParquetSchemaConverter(writeLegacyParquetFormat = false) structType.map { f => parquetFieldMap .get(f.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala similarity index 90% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala index eeead9f5d88a2..0818d802b077a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRecordMaterializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRecordMaterializer.scala @@ -29,11 +29,11 @@ import org.apache.spark.sql.types.StructType * @param parquetSchema Parquet schema of the records to be read * @param catalystSchema Catalyst schema of the rows to be constructed */ -private[parquet] class CatalystRecordMaterializer( +private[parquet] class ParquetRecordMaterializer( parquetSchema: MessageType, catalystSchema: StructType) extends RecordMaterializer[InternalRow] { - private val rootConverter = new CatalystRowConverter(parquetSchema, catalystSchema, NoopUpdater) + private val rootConverter = new ParquetRowConverter(parquetSchema, catalystSchema, NoopUpdater) override def getCurrentRecord: InternalRow = rootConverter.currentRecord diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala similarity index 89% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 6bf82bee67881..9dad59647e0db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -68,9 +68,9 @@ private[parquet] trait HasParentContainerUpdater { } /** - * A convenient converter class for Parquet group types with an [[HasParentContainerUpdater]]. + * A convenient converter class for Parquet group types with a [[HasParentContainerUpdater]]. */ -private[parquet] abstract class CatalystGroupConverter(val updater: ParentContainerUpdater) +private[parquet] abstract class ParquetGroupConverter(val updater: ParentContainerUpdater) extends GroupConverter with HasParentContainerUpdater /** @@ -78,7 +78,7 @@ private[parquet] abstract class CatalystGroupConverter(val updater: ParentContai * are handled by this converter. Parquet primitive types are only a subset of those of Spark * SQL. For example, BYTE, SHORT, and INT in Spark SQL are all covered by INT32 in Parquet. */ -private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUpdater) +private[parquet] class ParquetPrimitiveConverter(val updater: ParentContainerUpdater) extends PrimitiveConverter with HasParentContainerUpdater { override def addBoolean(value: Boolean): Unit = updater.setBoolean(value) @@ -90,7 +90,7 @@ private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUp } /** - * A [[CatalystRowConverter]] is used to convert Parquet records into Catalyst [[InternalRow]]s. + * A [[ParquetRowConverter]] is used to convert Parquet records into Catalyst [[InternalRow]]s. * Since Catalyst `StructType` is also a Parquet record, this converter can be used as root * converter. Take the following Parquet type as an example: * {{{ @@ -104,11 +104,11 @@ private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUp * }}} * 5 converters will be created: * - * - a root [[CatalystRowConverter]] for [[MessageType]] `root`, which contains: - * - a [[CatalystPrimitiveConverter]] for required [[INT_32]] field `f1`, and - * - a nested [[CatalystRowConverter]] for optional [[GroupType]] `f2`, which contains: - * - a [[CatalystPrimitiveConverter]] for required [[DOUBLE]] field `f21`, and - * - a [[CatalystStringConverter]] for optional [[UTF8]] string field `f22` + * - a root [[ParquetRowConverter]] for [[MessageType]] `root`, which contains: + * - a [[ParquetPrimitiveConverter]] for required [[INT_32]] field `f1`, and + * - a nested [[ParquetRowConverter]] for optional [[GroupType]] `f2`, which contains: + * - a [[ParquetPrimitiveConverter]] for required [[DOUBLE]] field `f21`, and + * - a [[ParquetStringConverter]] for optional [[UTF8]] string field `f22` * * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have * any "parent" container. @@ -118,11 +118,11 @@ private[parquet] class CatalystPrimitiveConverter(val updater: ParentContainerUp * types should have been expanded. * @param updater An updater which propagates converted field values to the parent container */ -private[parquet] class CatalystRowConverter( +private[parquet] class ParquetRowConverter( parquetType: GroupType, catalystType: StructType, updater: ParentContainerUpdater) - extends CatalystGroupConverter(updater) with Logging { + extends ParquetGroupConverter(updater) with Logging { assert( parquetType.getFieldCount == catalystType.length, @@ -150,7 +150,7 @@ private[parquet] class CatalystRowConverter( """.stripMargin) /** - * Updater used together with field converters within a [[CatalystRowConverter]]. It propagates + * Updater used together with field converters within a [[ParquetRowConverter]]. It propagates * converted filed values to the `ordinal`-th cell in `currentRow`. */ private final class RowUpdater(row: MutableRow, ordinal: Int) extends ParentContainerUpdater { @@ -213,33 +213,33 @@ private[parquet] class CatalystRowConverter( catalystType match { case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType => - new CatalystPrimitiveConverter(updater) + new ParquetPrimitiveConverter(updater) case ByteType => - new CatalystPrimitiveConverter(updater) { + new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setByte(value.asInstanceOf[ByteType#InternalType]) } case ShortType => - new CatalystPrimitiveConverter(updater) { + new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = updater.setShort(value.asInstanceOf[ShortType#InternalType]) } // For INT32 backed decimals case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => - new CatalystIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + new ParquetIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater) // For INT64 backed decimals case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 => - new CatalystLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + new ParquetLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater) // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY || parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY => - new CatalystBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater) + new ParquetBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater) case t: DecimalType => throw new RuntimeException( @@ -248,11 +248,11 @@ private[parquet] class CatalystRowConverter( "FIXED_LEN_BYTE_ARRAY, or BINARY.") case StringType => - new CatalystStringConverter(updater) + new ParquetStringConverter(updater) case TimestampType => // TODO Implements `TIMESTAMP_MICROS` once parquet-mr has that. - new CatalystPrimitiveConverter(updater) { + new ParquetPrimitiveConverter(updater) { // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { assert( @@ -268,7 +268,7 @@ private[parquet] class CatalystRowConverter( } case DateType => - new CatalystPrimitiveConverter(updater) { + new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { // DateType is not specialized in `SpecificMutableRow`, have to box it here. updater.set(value.asInstanceOf[DateType#InternalType]) @@ -286,13 +286,13 @@ private[parquet] class CatalystRowConverter( } case t: ArrayType => - new CatalystArrayConverter(parquetType.asGroupType(), t, updater) + new ParquetArrayConverter(parquetType.asGroupType(), t, updater) case t: MapType => - new CatalystMapConverter(parquetType.asGroupType(), t, updater) + new ParquetMapConverter(parquetType.asGroupType(), t, updater) case t: StructType => - new CatalystRowConverter(parquetType.asGroupType(), t, new ParentContainerUpdater { + new ParquetRowConverter(parquetType.asGroupType(), t, new ParentContainerUpdater { override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) @@ -306,8 +306,8 @@ private[parquet] class CatalystRowConverter( /** * Parquet converter for strings. A dictionary is used to minimize string decoding cost. */ - private final class CatalystStringConverter(updater: ParentContainerUpdater) - extends CatalystPrimitiveConverter(updater) { + private final class ParquetStringConverter(updater: ParentContainerUpdater) + extends ParquetPrimitiveConverter(updater) { private var expandedDictionary: Array[UTF8String] = null @@ -337,9 +337,9 @@ private[parquet] class CatalystRowConverter( /** * Parquet converter for fixed-precision decimals. */ - private abstract class CatalystDecimalConverter( + private abstract class ParquetDecimalConverter( precision: Int, scale: Int, updater: ParentContainerUpdater) - extends CatalystPrimitiveConverter(updater) { + extends ParquetPrimitiveConverter(updater) { protected var expandedDictionary: Array[Decimal] = _ @@ -371,7 +371,7 @@ private[parquet] class CatalystRowConverter( protected def decimalFromBinary(value: Binary): Decimal = { if (precision <= Decimal.MAX_LONG_DIGITS) { // Constructs a `Decimal` with an unscaled `Long` value if possible. - val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) + val unscaled = ParquetRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) } else { // Otherwise, resorts to an unscaled `BigInteger` instead. @@ -380,9 +380,9 @@ private[parquet] class CatalystRowConverter( } } - private class CatalystIntDictionaryAwareDecimalConverter( + private class ParquetIntDictionaryAwareDecimalConverter( precision: Int, scale: Int, updater: ParentContainerUpdater) - extends CatalystDecimalConverter(precision, scale, updater) { + extends ParquetDecimalConverter(precision, scale, updater) { override def setDictionary(dictionary: Dictionary): Unit = { this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => @@ -391,9 +391,9 @@ private[parquet] class CatalystRowConverter( } } - private class CatalystLongDictionaryAwareDecimalConverter( + private class ParquetLongDictionaryAwareDecimalConverter( precision: Int, scale: Int, updater: ParentContainerUpdater) - extends CatalystDecimalConverter(precision, scale, updater) { + extends ParquetDecimalConverter(precision, scale, updater) { override def setDictionary(dictionary: Dictionary): Unit = { this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => @@ -402,9 +402,9 @@ private[parquet] class CatalystRowConverter( } } - private class CatalystBinaryDictionaryAwareDecimalConverter( + private class ParquetBinaryDictionaryAwareDecimalConverter( precision: Int, scale: Int, updater: ParentContainerUpdater) - extends CatalystDecimalConverter(precision, scale, updater) { + extends ParquetDecimalConverter(precision, scale, updater) { override def setDictionary(dictionary: Dictionary): Unit = { this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id => @@ -431,11 +431,11 @@ private[parquet] class CatalystRowConverter( * * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists */ - private final class CatalystArrayConverter( + private final class ParquetArrayConverter( parquetSchema: GroupType, catalystSchema: ArrayType, updater: ParentContainerUpdater) - extends CatalystGroupConverter(updater) { + extends ParquetGroupConverter(updater) { private var currentArray: ArrayBuffer[Any] = _ @@ -512,11 +512,11 @@ private[parquet] class CatalystRowConverter( } /** Parquet converter for maps */ - private final class CatalystMapConverter( + private final class ParquetMapConverter( parquetType: GroupType, catalystType: MapType, updater: ParentContainerUpdater) - extends CatalystGroupConverter(updater) { + extends ParquetGroupConverter(updater) { private var currentKeys: ArrayBuffer[Any] = _ private var currentValues: ArrayBuffer[Any] = _ @@ -638,7 +638,7 @@ private[parquet] class CatalystRowConverter( } } -private[parquet] object CatalystRowConverter { +private[parquet] object ParquetRowConverter { def binaryToUnscaledLong(binary: Binary): Long = { // The underlying `ByteBuffer` implementation is guaranteed to be `HeapByteBuffer`, so here // we are using `Binary.toByteBuffer.array()` to steal the underlying byte array without diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 6f6340f541ada..1ac083f48a8c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -26,7 +26,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.maxPrecisionForBytes +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.maxPrecisionForBytes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -52,7 +52,7 @@ import org.apache.spark.sql.types._ * When set to false, use standard format defined in parquet-format spec. This argument only * affects Parquet write path. */ -private[parquet] class CatalystSchemaConverter( +private[parquet] class ParquetSchemaConverter( assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get) { @@ -125,7 +125,7 @@ private[parquet] class CatalystSchemaConverter( val precision = field.getDecimalMetadata.getPrecision val scale = field.getDecimalMetadata.getScale - CatalystSchemaConverter.checkConversionRequirement( + ParquetSchemaConverter.checkConversionRequirement( maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") @@ -163,7 +163,7 @@ private[parquet] class CatalystSchemaConverter( } case INT96 => - CatalystSchemaConverter.checkConversionRequirement( + ParquetSchemaConverter.checkConversionRequirement( assumeInt96IsTimestamp, "INT96 is not supported unless it's interpreted as timestamp. " + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") @@ -206,11 +206,11 @@ private[parquet] class CatalystSchemaConverter( // // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists case LIST => - CatalystSchemaConverter.checkConversionRequirement( + ParquetSchemaConverter.checkConversionRequirement( field.getFieldCount == 1, s"Invalid list type $field") val repeatedType = field.getType(0) - CatalystSchemaConverter.checkConversionRequirement( + ParquetSchemaConverter.checkConversionRequirement( repeatedType.isRepetition(REPEATED), s"Invalid list type $field") if (isElementType(repeatedType, field.getName)) { @@ -226,17 +226,17 @@ private[parquet] class CatalystSchemaConverter( // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 // scalastyle:on case MAP | MAP_KEY_VALUE => - CatalystSchemaConverter.checkConversionRequirement( + ParquetSchemaConverter.checkConversionRequirement( field.getFieldCount == 1 && !field.getType(0).isPrimitive, s"Invalid map type: $field") val keyValueType = field.getType(0).asGroupType() - CatalystSchemaConverter.checkConversionRequirement( + ParquetSchemaConverter.checkConversionRequirement( keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, s"Invalid map type: $field") val keyType = keyValueType.getType(0) - CatalystSchemaConverter.checkConversionRequirement( + ParquetSchemaConverter.checkConversionRequirement( keyType.isPrimitive, s"Map key type is expected to be a primitive type, but found: $keyType") @@ -311,7 +311,7 @@ private[parquet] class CatalystSchemaConverter( Types .buildMessage() .addFields(catalystSchema.map(convertField): _*) - .named(CatalystSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) + .named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME) } /** @@ -322,7 +322,7 @@ private[parquet] class CatalystSchemaConverter( } private def convertField(field: StructField, repetition: Type.Repetition): Type = { - CatalystSchemaConverter.checkFieldName(field.name) + ParquetSchemaConverter.checkFieldName(field.name) field.dataType match { // =================== @@ -394,7 +394,7 @@ private[parquet] class CatalystSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(CatalystSchemaConverter.minBytesForPrecision(precision)) + .length(ParquetSchemaConverter.minBytesForPrecision(precision)) .named(field.name) // ======================== @@ -428,7 +428,7 @@ private[parquet] class CatalystSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(CatalystSchemaConverter.minBytesForPrecision(precision)) + .length(ParquetSchemaConverter.minBytesForPrecision(precision)) .named(field.name) // =================================== @@ -535,7 +535,7 @@ private[parquet] class CatalystSchemaConverter( } } -private[parquet] object CatalystSchemaConverter { +private[parquet] object ParquetSchemaConverter { val SPARK_PARQUET_SCHEMA_NAME = "spark_schema" def checkFieldName(name: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 67bfd39697ed7..307c64d6e895b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.minBytesForPrecision +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter.minBytesForPrecision import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -48,7 +48,7 @@ import org.apache.spark.sql.types._ * of this option is propagated to this class by the `init()` method and its Hadoop configuration * argument. */ -private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] with Logging { +private[parquet] class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { // A `ValueWriter` is responsible for writing a field of an `InternalRow` to the record consumer. // Here we are using `SpecializedGetters` rather than `InternalRow` so that we can directly access // data in `ArrayData` without the help of `SpecificMutableRow`. @@ -73,7 +73,7 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi private val decimalBuffer = new Array[Byte](minBytesForPrecision(DecimalType.MAX_PRECISION)) override def init(configuration: Configuration): WriteContext = { - val schemaString = configuration.get(CatalystWriteSupport.SPARK_ROW_SCHEMA) + val schemaString = configuration.get(ParquetWriteSupport.SPARK_ROW_SCHEMA) this.schema = StructType.fromString(schemaString) this.writeLegacyParquetFormat = { // `SQLConf.PARQUET_WRITE_LEGACY_FORMAT` should always be explicitly set in ParquetRelation @@ -82,8 +82,8 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi } this.rootFieldWriters = schema.map(_.dataType).map(makeWriter) - val messageType = new CatalystSchemaConverter(configuration).convert(schema) - val metadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> schemaString).asJava + val messageType = new ParquetSchemaConverter(configuration).convert(schema) + val metadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schemaString).asJava logInfo( s"""Initialized Parquet WriteSupport with Catalyst schema: @@ -423,11 +423,11 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi } } -private[parquet] object CatalystWriteSupport { +private[parquet] object ParquetWriteSupport { val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" def setSchema(schema: StructType, configuration: Configuration): Unit = { - schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) + schema.map(_.name).foreach(ParquetSchemaConverter.checkFieldName) configuration.set(SPARK_ROW_SCHEMA, schema.json) configuration.setIfUnset( ParquetOutputFormat.WRITER_VERSION, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index b622f859413a6..27420d58e56a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,20 +17,23 @@ package org.apache.spark.sql.execution.datasources +import scala.util.control.NonFatal + import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.CreateHiveTableAsSelectLogicalPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} /** - * Try to replaces [[UnresolvedRelation]]s with [[ResolvedDataSource]]. + * Try to replaces [[UnresolvedRelation]]s with [[ResolveDataSource]]. */ -private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] { +class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedRelation if u.tableIdentifier.database.isDefined => try { @@ -38,6 +41,16 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo sparkSession, paths = u.tableIdentifier.table :: Nil, className = u.tableIdentifier.database.get) + + val notSupportDirectQuery = try { + !classOf[FileFormat].isAssignableFrom(dataSource.providingClass) + } catch { + case NonFatal(e) => false + } + if (notSupportDirectQuery) { + throw new AnalysisException("Unsupported data source type for direct query on files: " + + s"${u.tableIdentifier.database.get}") + } val plan = LogicalRelation(dataSource.resolveRelation()) u.alias.map(a => SubqueryAlias(u.alias.get, plan)).getOrElse(plan) } catch { @@ -50,58 +63,111 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo } /** - * A rule to do pre-insert data type casting and field renaming. Before we insert into - * an [[InsertableRelation]], we will use this rule to make sure that - * the columns to be inserted have the correct data type and fields have the correct names. + * Analyze the query in CREATE TABLE AS SELECT (CTAS). After analysis, [[PreWriteCheck]] also + * can detect the cases that are not allowed. */ -private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { +case class AnalyzeCreateTableAsSelect(sparkSession: SparkSession) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Wait until children are resolved. - case p: LogicalPlan if !p.childrenResolved => p - - // We are inserting into an InsertableRelation or HadoopFsRelation. - case i @ InsertIntoTable( - l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) => - // First, make sure the data to be inserted have the same number of fields with the - // schema of the relation. - if (l.output.size != child.output.size) { - sys.error( - s"$l requires that the query in the SELECT clause of the INSERT INTO/OVERWRITE " + - s"statement generates the same number of columns as its schema.") + case c: CreateTableUsingAsSelect if !c.query.resolved => + c.copy(query = analyzeQuery(c.query)) + case c: CreateHiveTableAsSelectLogicalPlan if !c.query.resolved => + c.copy(query = analyzeQuery(c.query)) + } + + private def analyzeQuery(query: LogicalPlan): LogicalPlan = { + val qe = sparkSession.sessionState.executePlan(query) + qe.assertAnalyzed() + qe.analyzed + } +} + +/** + * Preprocess the [[InsertIntoTable]] plan. Throws exception if the number of columns mismatch, or + * specified partition columns are different from the existing partition columns in the target + * table. It also does data type casting and field renaming, to make sure that the columns to be + * inserted have the correct data type and fields have the correct names. + */ +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { + private def preprocess( + insert: InsertIntoTable, + tblName: String, + partColNames: Seq[String]): InsertIntoTable = { + + val expectedColumns = insert.expectedColumns + if (expectedColumns.isDefined && expectedColumns.get.length != insert.child.schema.length) { + throw new AnalysisException( + s"Cannot insert into table $tblName because the number of columns are different: " + + s"need ${expectedColumns.get.length} columns, " + + s"but query has ${insert.child.schema.length} columns.") + } + + if (insert.partition.nonEmpty) { + // the query's partitioning must match the table's partitioning + // this is set for queries like: insert into ... partition (one = "a", two = ) + val samePartitionColumns = + if (conf.caseSensitiveAnalysis) { + insert.partition.keySet == partColNames.toSet + } else { + insert.partition.keySet.map(_.toLowerCase) == partColNames.map(_.toLowerCase).toSet } - castAndRenameChildOutput(i, l.output, child) + if (!samePartitionColumns) { + throw new AnalysisException( + s""" + |Requested partitioning does not match the table $tblName: + |Requested partitions: ${insert.partition.keys.mkString(",")} + |Table partitions: ${partColNames.mkString(",")} + """.stripMargin) + } + expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) + } else { + // All partition columns are dynamic because because the InsertIntoTable command does + // not explicitly specify partitioning columns. + expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) + .copy(partition = partColNames.map(_ -> None).toMap) + } } - /** If necessary, cast data types and rename fields to the expected types and names. */ + // TODO: do we really need to rename? def castAndRenameChildOutput( - insertInto: InsertIntoTable, - expectedOutput: Seq[Attribute], - child: LogicalPlan): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(child.output).map { + insert: InsertIntoTable, + expectedOutput: Seq[Attribute]): InsertIntoTable = { + val newChildOutput = expectedOutput.zip(insert.child.output).map { case (expected, actual) => - val needCast = !expected.dataType.sameType(actual.dataType) - // We want to make sure the filed names in the data to be inserted exactly match - // names in the schema. - val needRename = expected.name != actual.name - (needCast, needRename) match { - case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)() - case (false, true) => Alias(actual, expected.name)() - case (_, _) => actual + if (expected.dataType.sameType(actual.dataType) && expected.name == actual.name) { + actual + } else { + Alias(Cast(actual, expected.dataType), expected.name)() } } - if (newChildOutput == child.output) { - insertInto + if (newChildOutput == insert.child.output) { + insert } else { - insertInto.copy(child = Project(newChildOutput, child)) + insert.copy(child = Project(newChildOutput, insert.child)) } } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case i @ InsertIntoTable(table, partition, child, _, _) if table.resolved && child.resolved => + table match { + case relation: CatalogRelation => + val metadata = relation.catalogTable + preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames) + case LogicalRelation(h: HadoopFsRelation, _, identifier) => + val tblName = identifier.map(_.quotedString).getOrElse("unknown") + preprocess(i, tblName, h.partitionSchema.map(_.name)) + case LogicalRelation(_: InsertableRelation, _, identifier) => + val tblName = identifier.map(_.quotedString).getOrElse("unknown") + preprocess(i, tblName, Nil) + case other => i + } + } } /** * A rule to do various checks before inserting into or writing to a data source table. */ -private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) +case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) extends (LogicalPlan => Unit) { def failAnalysis(msg: String): Unit = { throw new AnalysisException(msg) } @@ -142,7 +208,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) // OK } - PartitioningUtils.validatePartitionColumnDataTypes( + PartitioningUtils.validatePartitionColumn( r.schema, part.keySet.toSeq, conf.caseSensitiveAnalysis) // Get all input data source relations of the query. @@ -160,13 +226,6 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") - case logical.InsertIntoTable(t, _, _, _, _) => - if (!t.isInstanceOf[LeafNode] || t == OneRowRelation || t.isInstanceOf[LocalRelation]) { - failAnalysis(s"Inserting into an RDD-based table is not allowed.") - } else { - // OK - } - case c: CreateTableUsingAsSelect => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. @@ -177,7 +236,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation, _, _) => // Get all input data source relations of the query. - val srcRelations = c.child.collect { + val srcRelations = c.query.collect { case LogicalRelation(src: BaseRelation, _, _) => src } if (srcRelations.contains(dest)) { @@ -193,13 +252,13 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) // OK } - PartitioningUtils.validatePartitionColumnDataTypes( - c.child.schema, c.partitionColumns, conf.caseSensitiveAnalysis) + PartitioningUtils.validatePartitionColumn( + c.query.schema, c.partitionColumns, conf.caseSensitiveAnalysis) for { spec <- c.bucketSpec sortColumnName <- spec.sortColumnNames - sortColumn <- c.child.schema.find(_.name == sortColumnName) + sortColumn <- c.query.schema.find(_.name == sortColumnName) } { if (!RowOrdering.isOrderable(sortColumn.dataType)) { failAnalysis(s"Cannot use ${sortColumn.dataType.simpleString} for sorting column.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala similarity index 80% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 348edfcf7a851..6aa078af3ea19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -23,10 +23,12 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat +import org.apache.spark.TaskContext import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} @@ -35,7 +37,7 @@ import org.apache.spark.util.SerializableConfiguration /** * A data source for reading text files. */ -class DefaultSource extends FileFormat with DataSourceRegister { +class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "text" @@ -91,20 +93,32 @@ class DefaultSource extends FileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + assert( + requiredSchema.length <= 1, + "Text data source only produces a single data column named \"value\".") + val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) - - new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value).map { line => - // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow + val reader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => reader.close())) + + if (requiredSchema.isEmpty) { + val emptyUnsafeRow = new UnsafeRow(0) + reader.map(_ => emptyUnsafeRow) + } else { + val unsafeRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(unsafeRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + + reader.map { line => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) + unsafeRow.setTotalSize(bufferHolder.totalSize()) + unsafeRow + } } } } @@ -119,7 +133,7 @@ class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemp new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val uniqueWriteJobId = configuration.get(CreateDataSourceTableUtils.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.txt$extension") @@ -139,4 +153,3 @@ class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemp recordWriter.close(context) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index c77c889a1b7b8..dd9d83767e221 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.HashSet +import java.util.Collections + +import scala.collection.JavaConverters._ -import org.apache.spark.{Accumulator, AccumulatorParam} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ @@ -28,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.{AccumulatorV2, LongAccumulator} /** * Contains methods for debugging query execution. @@ -67,15 +69,6 @@ package object debug { output } - /** - * Augments [[SparkSession]] with debug methods. - */ - implicit class DebugSQLContext(sparkSession: SparkSession) { - def debug(): Unit = { - sparkSession.conf.set(SQLConf.DATAFRAME_EAGER_ANALYSIS.key, false) - } - } - /** * Augments [[Dataset]]s with debug methods. */ @@ -104,31 +97,34 @@ package object debug { } } - private[sql] case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { + case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport { def output: Seq[Attribute] = child.output - implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { - def zero(initialValue: HashSet[String]): HashSet[String] = { - initialValue.clear() - initialValue + class SetAccumulator[T] extends AccumulatorV2[T, java.util.Set[T]] { + private val _set = Collections.synchronizedSet(new java.util.HashSet[T]()) + override def isZero: Boolean = _set.isEmpty + override def copy(): AccumulatorV2[T, java.util.Set[T]] = { + val newAcc = new SetAccumulator[T]() + newAcc._set.addAll(_set) + newAcc } - - def addInPlace(v1: HashSet[String], v2: HashSet[String]): HashSet[String] = { - v1 ++= v2 - v1 + override def reset(): Unit = _set.clear() + override def add(v: T): Unit = _set.add(v) + override def merge(other: AccumulatorV2[T, java.util.Set[T]]): Unit = { + _set.addAll(other.value) } + override def value: java.util.Set[T] = _set } /** * A collection of metrics for each column of output. - * - * @param elementTypes the actual runtime types for the output. Useful when there are bugs - * causing the wrong data to be projected. */ - case class ColumnMetrics( - elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) + case class ColumnMetrics() { + val elementTypes = new SetAccumulator[String] + sparkContext.register(elementTypes) + } - val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) + val tupleCount: LongAccumulator = sparkContext.longAccumulator val numColumns: Int = child.output.size val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) @@ -137,7 +133,9 @@ package object debug { debugPrint(s"== ${child.simpleString} ==") debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => - val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") + // This is called on driver. All accumulator updates have a fixed value. So it's safe to use + // `asScala` which accesses the internal values using `java.util.Iterator`. + val actualDataTypes = metric.elementTypes.value.asScala.mkString("{", ",", "}") debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } @@ -149,12 +147,12 @@ package object debug { def next(): InternalRow = { val currentRow = iter.next() - tupleCount += 1 + tupleCount.add(1) var i = 0 while (i < numColumns) { val value = currentRow.get(i, output(i).dataType) if (value != null) { - columnStats(i).elementTypes += HashSet(value.getClass.getName) + columnStats(i).elementTypes.add(value.getClass.getName) } i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index b6ecd3cb065ae..4b45ce9c73ee1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.execution.exchange import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import org.apache.spark.broadcast +import org.apache.spark.{broadcast, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.util.ThreadUtils /** @@ -37,7 +38,7 @@ case class BroadcastExchangeExec( mode: BroadcastMode, child: SparkPlan) extends Exchange { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"), "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"), @@ -72,9 +73,18 @@ case class BroadcastExchangeExec( val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val input: Array[InternalRow] = child.executeCollect() + if (input.length >= 512000000) { + throw new SparkException( + s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows") + } val beforeBuild = System.nanoTime() longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - longMetric("dataSize") += input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + longMetric("dataSize") += dataSize + if (dataSize >= (8L << 30)) { + throw new SparkException( + s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") + } // Construct and broadcast the relation. val relation = mode.transform(input) @@ -83,6 +93,14 @@ case class BroadcastExchangeExec( val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 + + // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` + // directly without setting an execution id. We should be tolerant to it. + if (executionId != null) { + sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( + executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) + } + broadcasted } }(BroadcastExchangeExec.executionContext) @@ -99,7 +117,8 @@ case class BroadcastExchangeExec( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] + ThreadUtils.awaitResultInForkJoinSafely(relationFuture, timeout) + .asInstanceOf[broadcast.Broadcast[T]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 446571aa8409f..f17049949aa47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -236,7 +236,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => if (requiredOrdering.nonEmpty) { // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. - if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { + val orderingMatched = if (requiredOrdering.length > child.outputOrdering.length) { + false + } else { + requiredOrdering.zip(child.outputOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + + if (!orderingMatched) { SortExec(requiredOrdering, global = false, child = child) } else { child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 9da9df617405d..9a9597d3733e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -60,9 +60,6 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { child.executeBroadcast() } - - // Do not repeat the same tree in explain. - override def treeChildren: Seq[SparkPlan] = Nil } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index fb60d68f986d6..57da85fa84f99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -47,10 +47,10 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * partitions. * * The workflow of this coordinator is described as follows: - * - Before the execution of a [[SparkPlan]], for an [[ShuffleExchange]] operator, + * - Before the execution of a [[SparkPlan]], for a [[ShuffleExchange]] operator, * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. * This happens in the `doPrepare` method. - * - Once we start to execute a physical plan, an [[ShuffleExchange]] registered to this + * - Once we start to execute a physical plan, a [[ShuffleExchange]] registered to this * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle * [[ShuffledRowRDD]]. * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]] @@ -61,7 +61,7 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices * to a single post-shuffle partition whenever necessary. * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered - * [[ShuffleExchange]]s. So, when an [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator + * [[ShuffleExchange]]s. So, when a [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator * can lookup the corresponding [[RDD]]. * * The strategy used to determine the number of post-shuffle partitions is described as follows. @@ -79,7 +79,7 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * - post-shuffle partition 1: pre-shuffle partition 2 * - post-shuffle partition 2: pre-shuffle partition 3 and 4 */ -private[sql] class ExchangeCoordinator( +class ExchangeCoordinator( numExchanges: Int, advisoryTargetPostShuffleInputSize: Long, minNumPostShufflePartitions: Option[Int] = None) @@ -98,8 +98,8 @@ private[sql] class ExchangeCoordinator( @volatile private[this] var estimated: Boolean = false /** - * Registers an [[ShuffleExchange]] operator to this coordinator. This method is only allowed to - * be called in the `doPrepare` method of an [[ShuffleExchange]] operator. + * Registers a [[ShuffleExchange]] operator to this coordinator. This method is only allowed to + * be called in the `doPrepare` method of a [[ShuffleExchange]] operator. */ @GuardedBy("this") def registerExchange(exchange: ShuffleExchange): Unit = synchronized { @@ -112,7 +112,7 @@ private[sql] class ExchangeCoordinator( * Estimates partition start indices for post-shuffle partitions based on * mapOutputStatistics provided by all pre-shuffle stages. */ - private[sql] def estimatePartitionStartIndices( + def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { // If we have mapOutputStatistics.length < numExchange, it is because we do not submit // a stage when the number of partitions of this dependency is 0. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index e18b59f49b984..7a4a251370706 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -40,7 +40,7 @@ case class ShuffleExchange( child: SparkPlan, @transient coordinator: Option[ExchangeCoordinator]) extends Exchange { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")) override def nodeName: String = { @@ -81,7 +81,8 @@ case class ShuffleExchange( * the partitioning scheme defined in `newPartitioning`. Those partitions of * the returned ShuffleDependency will be the input of shuffle. */ - private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = { + private[exchange] def prepareShuffleDependency() + : ShuffleDependency[Int, InternalRow, InternalRow] = { ShuffleExchange.prepareShuffleDependency( child.execute(), child.output, newPartitioning, serializer) } @@ -92,7 +93,7 @@ case class ShuffleExchange( * partition start indices array. If this optional array is defined, the returned * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array. */ - private[sql] def preparePostShuffleRDD( + private[exchange] def preparePostShuffleRDD( shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow], specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = { // If an array of partition start indices is provided, we need to use this array @@ -129,7 +130,7 @@ case class ShuffleExchange( object ShuffleExchange { def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = { - ShuffleExchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) + ShuffleExchange(newPartitioning, child, coordinator = Option.empty[ExchangeCoordinator]) } /** @@ -194,7 +195,7 @@ object ShuffleExchange { * the partitioning scheme defined in `newPartitioning`. Those partitions of * the returned ShuffleDependency will be the input of shuffle. */ - private[sql] def prepareShuffleDependency( + def prepareShuffleDependency( rdd: RDD[InternalRow], outputAttributes: Seq[Attribute], newPartitioning: Partitioning, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 587c603192cce..0f24baacd18d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -45,11 +45,9 @@ case class BroadcastHashJoinExec( right: SparkPlan) extends BinaryExecNode with HashJoin with CodegenSupport { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning - override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) buildSide match { @@ -85,6 +83,7 @@ case class BroadcastHashJoinExec( case LeftOuter | RightOuter => codegenOuter(ctx, input) case LeftSemi => codegenSemi(ctx, input) case LeftAnti => codegenAnti(ctx, input) + case j: ExistenceJoin => codegenExistence(ctx, input) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") @@ -407,4 +406,67 @@ case class BroadcastHashJoinExec( """.stripMargin } } + + /** + * Generates the code for existence join. + */ + private def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + val existsVar = ctx.freshName("exists") + + val matched = ctx.freshName("matched") + val buildVars = genBuildSideVars(ctx, matched) + val checkCondition = if (condition.isDefined) { + val expr = condition.get + // evaluate the variables from build side that used by condition + val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) + // filter the output via condition + ctx.currentVars = input ++ buildVars + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) + s""" + |$eval + |${ev.code} + |$existsVar = !${ev.isNull} && ${ev.value}; + """.stripMargin + } else { + s"$existsVar = true;" + } + + val resultVar = input ++ Seq(ExprCode("", "false", existsVar)) + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashedRelation + |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + |boolean $existsVar = false; + |if ($matched != null) { + | $checkCondition + |} + |$numOutput.add(1); + |${consume(ctx, resultVar)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + s""" + |// generate join key for stream side + |${keyEv.code} + |// find matches from HashRelation + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); + |boolean $existsVar = false; + |if ($matches != null) { + | while (!$existsVar && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | } + |} + |$numOutput.add(1); + |${consume(ctx, resultVar)} + """.stripMargin + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index a659bf26e32df..6a9965f1a24cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class BroadcastNestedLoopJoinExec( @@ -32,9 +34,10 @@ case class BroadcastNestedLoopJoinExec( right: SparkPlan, buildSide: BuildSide, joinType: JoinType, - condition: Option[Expression]) extends BinaryExecNode { + condition: Option[Expression], + withinBroadcastThreshold: Boolean = true) extends BinaryExecNode { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) /** BuildRight means the right relation <=> the broadcast relation. */ @@ -50,19 +53,16 @@ case class BroadcastNestedLoopJoinExec( UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil } - private[this] def genResultProjection: InternalRow => InternalRow = { - if (joinType == LeftSemi) { + private[this] def genResultProjection: InternalRow => InternalRow = joinType match { + case LeftExistence(j) => UnsafeProjection.create(output, output) - } else { + case other => // Always put the stream side on left to simplify implementation // both of left and right side could be null UnsafeProjection.create( output, (streamed.output ++ broadcast.output).map(_.withNullability(true))) - } } - override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def output: Seq[Attribute] = { joinType match { case Inner => @@ -73,6 +73,8 @@ case class BroadcastNestedLoopJoinExec( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case j: ExistenceJoin => + left.output :+ j.exists case LeftExistence(_) => left.output case x => @@ -197,6 +199,28 @@ case class BroadcastNestedLoopJoinExec( } } + private def existenceJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + assert(buildSide == BuildRight) + streamed.execute().mapPartitionsInternal { streamedIter => + val buildRows = relation.value + val joinedRow = new JoinedRow + + if (condition.isDefined) { + val resultRow = new GenericMutableRow(Array[Any](null)) + streamedIter.map { row => + val result = buildRows.exists(r => boundCondition(joinedRow(row, r))) + resultRow.setBoolean(0, result) + joinedRow(row, resultRow) + } + } else { + val resultRow = new GenericMutableRow(Array[Any](buildRows.nonEmpty)) + streamedIter.map { row => + joinedRow(row, resultRow) + } + } + } + } + /** * The implementation for these joins: * @@ -204,7 +228,8 @@ case class BroadcastNestedLoopJoinExec( * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft - * Anti with BuildLeft + * LeftAnti with BuildLeft + * ExistenceJoin with BuildLeft */ private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { /** All rows that either match both-way, or rows from streamed joined with nulls. */ @@ -231,27 +256,50 @@ case class BroadcastNestedLoopJoinExec( new BitSet(relation.value.length) )(_ | _) - if (joinType == LeftSemi) { - assert(buildSide == BuildLeft) - val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val rel = relation.value - while (i < rel.length) { - if (matchedBroadcastRows.get(i)) { - buf += rel(i).copy() + joinType match { + case LeftSemi => + assert(buildSide == BuildLeft) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + if (matchedBroadcastRows.get(i)) { + buf += rel(i).copy() + } + i += 1 } - i += 1 - } - return sparkContext.makeRDD(buf) + return sparkContext.makeRDD(buf) + case j: ExistenceJoin => + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + val result = new GenericInternalRow(Array[Any](matchedBroadcastRows.get(i))) + buf += new JoinedRow(rel(i).copy(), result) + i += 1 + } + return sparkContext.makeRDD(buf) + case LeftAnti => + val notMatched: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val rel = relation.value + while (i < rel.length) { + if (!matchedBroadcastRows.get(i)) { + notMatched += rel(i).copy() + } + i += 1 + } + return sparkContext.makeRDD(notMatched) + case o => } val notMatchedBroadcastRows: Seq[InternalRow] = { val nulls = new GenericMutableRow(streamed.output.size) val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val buildRows = relation.value val joinedRow = new JoinedRow joinedRow.withLeft(nulls) + var i = 0 + val buildRows = relation.value while (i < buildRows.length) { if (!matchedBroadcastRows.get(i)) { buf += joinedRow.withRight(buildRows(i)).copy() @@ -261,10 +309,6 @@ case class BroadcastNestedLoopJoinExec( buf } - if (joinType == LeftAnti) { - return sparkContext.makeRDD(notMatchedBroadcastRows) - } - val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow @@ -296,6 +340,15 @@ case class BroadcastNestedLoopJoinExec( ) } + protected override def doPrepare(): Unit = { + if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) { + throw new AnalysisException("Both sides of this join are outside the broadcasting " + + "threshold and computing it could be prohibitively expensive. To explicitly enable it, " + + s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true") + } + super.doPrepare() + } + protected override def doExecute(): RDD[InternalRow] = { val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]() @@ -308,13 +361,16 @@ case class BroadcastNestedLoopJoinExec( leftExistenceJoin(broadcastedRelation, exists = true) case (LeftAnti, BuildRight) => leftExistenceJoin(broadcastedRelation, exists = false) + case (j: ExistenceJoin, BuildRight) => + existenceJoin(broadcastedRelation) case _ => /** * LeftOuter with BuildLeft * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft - * Anti with BuildLeft + * LeftAnti with BuildLeft + * ExistenceJoin with BuildLeft */ defaultJoin(broadcastedRelation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 8d7ecc442a9e1..c97fffe88b719 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter @@ -32,7 +34,6 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter * will be much faster than building the right partition for every row in left RDD, it also * materialize the right RDD (in case of the right RDD is nondeterministic). */ -private[spark] class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { @@ -47,6 +48,8 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField null, 1024, SparkEnv.get.memoryManager.pageSizeBytes, + SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), false) val partition = split.asInstanceOf[CartesianPartition] @@ -74,7 +77,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField for (x <- rdd1.iterator(partition.s1, context); y <- createIter()) yield (x, y) CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( - resultIter, sorter.cleanupResources) + resultIter, sorter.cleanupResources()) } } @@ -85,9 +88,18 @@ case class CartesianProductExec( condition: Option[Expression]) extends BinaryExecNode { override def output: Seq[Attribute] = left.output ++ right.output - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + protected override def doPrepare(): Unit = { + if (!sqlContext.conf.crossJoinEnabled) { + throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " + + "disabled by default. To explicitly enable them, please set " + + s"${SQLConf.CROSS_JOINS_ENABLED.key} = true") + } + super.doPrepare() + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 9c173d7bf1011..d11f7e623b5f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.{IntegralType, LongType} @@ -43,6 +44,8 @@ trait HashJoin { left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case j: ExistenceJoin => + left.output :+ j.exists case LeftExistence(_) => left.output case x => @@ -50,6 +53,8 @@ trait HashJoin { } } + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning + protected lazy val (buildPlan, streamedPlan) = buildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) @@ -58,45 +63,16 @@ trait HashJoin { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = HashJoin.rewriteKeyExpr(rightKeys) + .map(BindReferences.bindReference(_, right.output)) buildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) } } - /** - * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. - * - * If not, returns the original expressions. - */ - private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { - var keyExpr: Expression = null - var width = 0 - keys.foreach { e => - e.dataType match { - case dt: IntegralType if dt.defaultSize <= 8 - width => - if (width == 0) { - if (e.dataType != LongType) { - keyExpr = Cast(e, LongType) - } else { - keyExpr = e - } - width = dt.defaultSize - } else { - val bits = dt.defaultSize * 8 - keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) - width -= bits - } - // TODO: support BooleanType, DateType and TimestampType - case other => - return keys - } - } - keyExpr :: Nil - } + protected def buildSideKeyGenerator(): Projection = UnsafeProjection.create(buildKeys) @@ -110,15 +86,14 @@ trait HashJoin { (r: InternalRow) => true } - protected def createResultProjection(): (InternalRow) => InternalRow = { - if (joinType == LeftSemi) { + protected def createResultProjection(): (InternalRow) => InternalRow = joinType match { + case LeftExistence(_) => UnsafeProjection.create(output, output) - } else { + case _ => // Always put the stream side on left to simplify implementation // both of left and right side could be null UnsafeProjection.create( output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true))) - } } private def innerJoin( @@ -184,6 +159,23 @@ trait HashJoin { } } + private def existenceJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() + val result = new GenericMutableRow(Array[Any](null)) + val joinedRow = new JoinedRow + streamIter.map { current => + val key = joinKeys(current) + lazy val buildIter = hashedRelation.get(key) + val exists = !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists { + (row: InternalRow) => boundCondition(joinedRow(current, row)) + }) + result.setBoolean(0, exists) + joinedRow(current, result) + } + } + private def antiJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { @@ -212,6 +204,8 @@ trait HashJoin { semiJoin(streamedIter, hashed) case LeftAnti => antiJoin(streamedIter, hashed) + case j: ExistenceJoin => + existenceJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") @@ -224,3 +218,31 @@ trait HashJoin { } } } + +object HashJoin { + /** + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ + private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + assert(keys.nonEmpty) + // TODO: support BooleanType, DateType and TimestampType + if (keys.exists(!_.dataType.isInstanceOf[IntegralType]) + || keys.map(_.dataType.defaultSize).sum > 8) { + return keys + } + + var keyExpr: Expression = if (keys.head.dataType != LongType) { + Cast(keys.head, LongType) + } else { + keys.head + } + keys.tail.foreach { e => + val bits = e.dataType.defaultSize * 8 + keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + } + keyExpr :: Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b280c76c70a61..8821c0dea9ee5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.joins -import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} +import java.io._ + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.{SparkConf, SparkEnv, SparkException} import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} @@ -116,7 +119,7 @@ private[execution] object HashedRelation { private[joins] class UnsafeHashedRelation( private var numFields: Int, private var binaryMap: BytesToBytesMap) - extends HashedRelation with Externalizable { + extends HashedRelation with Externalizable with KryoSerializable { private[joins] def this() = this(0, null) // Needed for serialization @@ -171,10 +174,21 @@ private[joins] class UnsafeHashedRelation( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - out.writeInt(numFields) + write(out.writeInt, out.writeLong, out.write) + } + + override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException { + write(out.writeInt, out.writeLong, out.write) + } + + private def write( + writeInt: (Int) => Unit, + writeLong: (Long) => Unit, + writeBuffer: (Array[Byte], Int, Int) => Unit) : Unit = { + writeInt(numFields) // TODO: move these into BytesToBytesMap - out.writeLong(binaryMap.numKeys()) - out.writeLong(binaryMap.numValues()) + writeLong(binaryMap.numKeys()) + writeLong(binaryMap.numValues()) var buffer = new Array[Byte](64) def write(base: Object, offset: Long, length: Int): Unit = { @@ -182,25 +196,32 @@ private[joins] class UnsafeHashedRelation( buffer = new Array[Byte](length) } Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length) - out.write(buffer, 0, length) + writeBuffer(buffer, 0, length) } val iter = binaryMap.iterator() while (iter.hasNext) { val loc = iter.next() // [key size] [values size] [key bytes] [value bytes] - out.writeInt(loc.getKeyLength) - out.writeInt(loc.getValueLength) + writeInt(loc.getKeyLength) + writeInt(loc.getValueLength) write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength) write(loc.getValueBase, loc.getValueOffset, loc.getValueLength) } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - numFields = in.readInt() + read(in.readInt, in.readLong, in.readFully) + } + + private def read( + readInt: () => Int, + readLong: () => Long, + readBuffer: (Array[Byte], Int, Int) => Unit): Unit = { + numFields = readInt() resultRow = new UnsafeRow(numFields) - val nKeys = in.readLong() - val nValues = in.readLong() + val nKeys = readLong() + val nValues = readLong() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory // TODO(josh): This needs to be revisited before we merge this patch; making this change now // so that tests compile: @@ -227,16 +248,16 @@ private[joins] class UnsafeHashedRelation( var keyBuffer = new Array[Byte](1024) var valuesBuffer = new Array[Byte](1024) while (i < nValues) { - val keySize = in.readInt() - val valuesSize = in.readInt() + val keySize = readInt() + val valuesSize = readInt() if (keySize > keyBuffer.length) { keyBuffer = new Array[Byte](keySize) } - in.readFully(keyBuffer, 0, keySize) + readBuffer(keyBuffer, 0, keySize) if (valuesSize > valuesBuffer.length) { valuesBuffer = new Array[Byte](valuesSize) } - in.readFully(valuesBuffer, 0, valuesSize) + readBuffer(valuesBuffer, 0, valuesSize) val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize) val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize, @@ -248,6 +269,10 @@ private[joins] class UnsafeHashedRelation( i += 1 } } + + override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { + read(in.readInt, in.readLong, in.readBytes) + } } private[joins] object UnsafeHashedRelation { @@ -324,8 +349,8 @@ private[joins] object UnsafeHashedRelation { * * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/ */ -private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) - extends MemoryConsumer(mm) with Externalizable { +private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, capacity: Int) + extends MemoryConsumer(mm) with Externalizable with KryoSerializable { // Whether the keys are stored in dense mode or not. private var isDense = false @@ -373,9 +398,9 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap 0) } - private def acquireMemory(size: Long): Unit = { + private def ensureAcquireMemory(size: Long): Unit = { // do not support spilling - val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) + val got = acquireMemory(size) if (got < size) { freeMemory(got) throw new SparkException(s"Can't acquire $size bytes memory to build hash relation, " + @@ -383,15 +408,12 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap } } - private def freeMemory(size: Long): Unit = { - mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this) - } - private def init(): Unit = { if (mm != null) { + require(capacity < 512000000, "Cannot broadcast more than 512 millions rows") var n = 1 while (n < capacity) n *= 2 - acquireMemory(n * 2 * 8 + (1 << 20)) + ensureAcquireMemory(n * 2L * 8 + (1 << 20)) array = new Array[Long](n * 2) mask = n * 2 - 2 page = new Array[Long](1 << 17) // 1M bytes @@ -425,10 +447,20 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap */ private def nextSlot(pos: Int): Int = (pos + 2) & mask + private[this] def toAddress(offset: Long, size: Int): Long = { + ((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size + } + + private[this] def toOffset(address: Long): Long = { + (address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET + } + + private[this] def toSize(address: Long): Int = { + (address & SIZE_MASK).toInt + } + private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { - val offset = address >>> SIZE_BITS - val size = address & SIZE_MASK - resultRow.pointTo(page, offset, size.toInt) + resultRow.pointTo(page, toOffset(address), toSize(address)) resultRow } @@ -437,9 +469,11 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >= 0 && key <= maxKey && array(idx) > 0) { - return getRow(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return getRow(value, resultRow) + } } } else { var pos = firstSlot(key) @@ -461,9 +495,9 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap var addr = address override def hasNext: Boolean = addr != 0 override def next(): UnsafeRow = { - val offset = addr >>> SIZE_BITS - val size = addr & SIZE_MASK - resultRow.pointTo(page, offset, size.toInt) + val offset = toOffset(addr) + val size = toSize(addr) + resultRow.pointTo(page, offset, size) addr = Platform.getLong(page, offset + size) resultRow } @@ -475,9 +509,11 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >=0 && key <= maxKey && array(idx) > 0) { - return valueIter(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return valueIter(value, resultRow) + } } } else { var pos = firstSlot(key) @@ -513,12 +549,12 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap if (used >= (1 << 30)) { sys.error("Can not build a HashedRelation that is larger than 8G") } - acquireMemory(used * 8L * 2) + ensureAcquireMemory(used * 8L * 2) val newPage = new Array[Long](used * 2) Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, cursor - Platform.LONG_ARRAY_OFFSET) page = newPage - freeMemory(used * 8) + freeMemory(used * 8L) } // copy the bytes of UnsafeRow @@ -528,7 +564,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap Platform.putLong(page, cursor, 0) cursor += 8 numValues += 1 - updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes) + updateIndex(key, toAddress(offset, row.getSizeInBytes)) } /** @@ -536,6 +572,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap */ private def updateIndex(key: Long, address: Long): Unit = { var pos = firstSlot(key) + assert(numKeys < array.length / 2) while (array(pos) != key && array(pos + 1) != 0) { pos = nextSlot(pos) } @@ -556,7 +593,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap } } else { // there are some values for this key, put the address in the front of them. - val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK) + val pointer = toOffset(address) + toSize(address) Platform.putLong(page, pointer, array(pos + 1)) array(pos + 1) = address } @@ -566,7 +603,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap var old_array = array val n = array.length numKeys = 0 - acquireMemory(n * 2 * 8L) + ensureAcquireMemory(n * 2 * 8L) array = new Array[Long](n * 2) mask = n * 2 - 2 var i = 0 @@ -577,7 +614,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap i += 2 } old_array = null // release the reference to old array - freeMemory(n * 8) + freeMemory(n * 8L) } /** @@ -586,9 +623,10 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap def optimize(): Unit = { val range = maxKey - minKey // Convert to dense mode if it does not require more memory or could fit within L1 cache - if (range < array.length || range < 1024) { + // SPARK-16740: Make sure range doesn't overflow if minKey has a large negative value + if (range >= 0 && (range < array.length || range < 1024)) { try { - acquireMemory((range + 1) * 8) + ensureAcquireMemory((range + 1) * 8L) } catch { case e: SparkException => // there is no enough memory to convert @@ -606,7 +644,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap val old_length = array.length array = denseArray isDense = true - freeMemory(old_length * 8) + freeMemory(old_length * 8L) } } @@ -615,67 +653,94 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap */ def free(): Unit = { if (page != null) { - freeMemory(page.length * 8) + freeMemory(page.length * 8L) page = null } if (array != null) { - freeMemory(array.length * 8) + freeMemory(array.length * 8L) array = null } } - private def writeLongArray(out: ObjectOutput, arr: Array[Long], len: Int): Unit = { + private def writeLongArray( + writeBuffer: (Array[Byte], Int, Int) => Unit, + arr: Array[Long], + len: Int): Unit = { val buffer = new Array[Byte](4 << 10) var offset: Long = Platform.LONG_ARRAY_OFFSET val end = len * 8L + Platform.LONG_ARRAY_OFFSET while (offset < end) { val size = Math.min(buffer.length, (end - offset).toInt) Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) - out.write(buffer, 0, size) + writeBuffer(buffer, 0, size) offset += size } } - override def writeExternal(out: ObjectOutput): Unit = { - out.writeBoolean(isDense) - out.writeLong(minKey) - out.writeLong(maxKey) - out.writeLong(numKeys) - out.writeLong(numValues) - - out.writeLong(array.length) - writeLongArray(out, array, array.length) + private def write( + writeBoolean: (Boolean) => Unit, + writeLong: (Long) => Unit, + writeBuffer: (Array[Byte], Int, Int) => Unit): Unit = { + writeBoolean(isDense) + writeLong(minKey) + writeLong(maxKey) + writeLong(numKeys) + writeLong(numValues) + + writeLong(array.length) + writeLongArray(writeBuffer, array, array.length) val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt - out.writeLong(used) - writeLongArray(out, page, used) + writeLong(used) + writeLongArray(writeBuffer, page, used) + } + + override def writeExternal(output: ObjectOutput): Unit = { + write(output.writeBoolean, output.writeLong, output.write) } - private def readLongArray(in: ObjectInput, length: Int): Array[Long] = { + override def write(kryo: Kryo, out: Output): Unit = { + write(out.writeBoolean, out.writeLong, out.write) + } + + private def readLongArray( + readBuffer: (Array[Byte], Int, Int) => Unit, + length: Int): Array[Long] = { val array = new Array[Long](length) val buffer = new Array[Byte](4 << 10) var offset: Long = Platform.LONG_ARRAY_OFFSET val end = length * 8L + Platform.LONG_ARRAY_OFFSET while (offset < end) { val size = Math.min(buffer.length, (end - offset).toInt) - in.readFully(buffer, 0, size) + readBuffer(buffer, 0, size) Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) offset += size } array } - override def readExternal(in: ObjectInput): Unit = { - isDense = in.readBoolean() - minKey = in.readLong() - maxKey = in.readLong() - numKeys = in.readLong - numValues = in.readLong() + private def read( + readBoolean: () => Boolean, + readLong: () => Long, + readBuffer: (Array[Byte], Int, Int) => Unit): Unit = { + isDense = readBoolean() + minKey = readLong() + maxKey = readLong() + numKeys = readLong() + numValues = readLong() - val length = in.readLong().toInt + val length = readLong().toInt mask = length - 2 - array = readLongArray(in, length) - val pageLength = in.readLong().toInt - page = readLongArray(in, pageLength) + array = readLongArray(readBuffer, length) + val pageLength = readLong().toInt + page = readLongArray(readBuffer, pageLength) + } + + override def readExternal(in: ObjectInput): Unit = { + read(in.readBoolean, in.readLong, in.readFully) + } + + override def read(kryo: Kryo, in: Input): Unit = { + read(in.readBoolean, in.readLong, in.readBytes) } } @@ -740,7 +805,7 @@ private[joins] object LongHashedRelation { sizeEstimate: Int, taskMemoryManager: TaskMemoryManager): LongHashedRelation = { - val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) + val map = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) val keyGenerator = UnsafeProjection.create(key) // Create a mapping of key -> rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 3ef2fec352203..afb6e5e3dd235 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} @@ -39,22 +39,11 @@ case class ShuffledHashJoinExec( right: SparkPlan) extends BinaryExecNode with HashJoin { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"), "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map")) - override def outputPartitioning: Partitioning = joinType match { - case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - case LeftAnti => left.outputPartitioning - case LeftSemi => left.outputPartitioning - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"ShuffledHashJoin should not take $x as the JoinType") - } - override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 775f8ac50818f..5c9c1e6062f0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet /** - * Performs an sort merge join of two child relations. + * Performs a sort merge join of two child relations. */ case class SortMergeJoinExec( leftKeys: Seq[Expression], @@ -40,7 +40,7 @@ case class SortMergeJoinExec( left: SparkPlan, right: SparkPlan) extends BinaryExecNode with CodegenSupport { - override private[sql] lazy val metrics = Map( + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override def output: Seq[Attribute] = { @@ -53,6 +53,8 @@ case class SortMergeJoinExec( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => (left.output ++ right.output).map(_.withNullability(true)) + case j: ExistenceJoin => + left.output :+ j.exists case LeftExistence(_) => left.output case x => @@ -269,6 +271,44 @@ case class SortMergeJoinExec( override def getRow: InternalRow = currentLeftRow }.toScala + case j: ExistenceJoin => + new RowIterator { + private[this] var currentLeftRow: InternalRow = _ + private[this] val result: MutableRow = new GenericMutableRow(Array[Any](null)) + private[this] val smjScanner = new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter) + ) + private[this] val joinRow = new JoinedRow + + override def advanceNext(): Boolean = { + while (smjScanner.findNextOuterJoinRows()) { + currentLeftRow = smjScanner.getStreamedRow + val currentRightMatches = smjScanner.getBufferedMatches + var found = false + if (currentRightMatches != null) { + var i = 0 + while (!found && i < currentRightMatches.length) { + joinRow(currentLeftRow, currentRightMatches(i)) + if (boundCondition(joinRow)) { + found = true + } + i += 1 + } + } + result.setBoolean(0, found) + numOutputRows += 1 + return true + } + false + } + + override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result)) + }.toScala + case x => throw new IllegalArgumentException( s"SortMergeJoin should not take $x as the JoinType") @@ -296,13 +336,7 @@ case class SortMergeJoinExec( private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { vars.zipWithIndex.map { case (ev, i) => - val value = ctx.freshName("value") - ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "") - val code = - s""" - |$value = ${ev.value}; - """.stripMargin - ExprCode(code, "false", value) + ctx.addBufferedState(leftKeys(i).dataType, "value", ev.value) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index b71f3335c99e5..86a8770715600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.util.Utils /** @@ -38,9 +39,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def executeCollect(): Array[InternalRow] = child.executeTake(limit) private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) protected override def doExecute(): RDD[InternalRow] = { + val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( ShuffleExchange.prepareShuffleDependency( - child.execute(), child.output, SinglePartition, serializer)) + locallyLimited, child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } } @@ -113,11 +115,11 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec { case class TakeOrderedAndProjectExec( limit: Int, sortOrder: Seq[SortOrder], - projectList: Option[Seq[NamedExpression]], + projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = { - projectList.map(_.map(_.toAttribute)).getOrElse(child.output) + projectList.map(_.toAttribute) } override def outputPartitioning: Partitioning = SinglePartition @@ -125,8 +127,8 @@ case class TakeOrderedAndProjectExec( override def executeCollect(): Array[InternalRow] = { val ord = new LazilyGeneratedOrdering(sortOrder, child.output) val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - if (projectList.isDefined) { - val proj = UnsafeProjection.create(projectList.get, child.output) + if (projectList != child.output) { + val proj = UnsafeProjection.create(projectList, child.output) data.map(r => proj(r).copy()) } else { data @@ -147,8 +149,8 @@ case class TakeOrderedAndProjectExec( localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) - if (projectList.isDefined) { - val proj = UnsafeProjection.create(projectList.get, child.output) + if (projectList != child.output) { + val proj = UnsafeProjection.create(projectList, child.output) topK.map(r => proj(r)) } else { topK @@ -159,8 +161,8 @@ case class TakeOrderedAndProjectExec( override def outputOrdering: Seq[SortOrder] = sortOrder override def simpleString: String = { - val orderByString = sortOrder.mkString("[", ",", "]") - val outputString = output.mkString("[", ",", "]") + val orderByString = Utils.truncatedString(sortOrder, "[", ",", "]") + val outputString = Utils.truncatedString(output, "[", ",", "]") s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 7bf9225272612..0cc1edd196bc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -18,55 +18,59 @@ package org.apache.spark.sql.execution.metric import java.text.NumberFormat +import java.util.Locale -import org.apache.spark.{NewAccumulator, SparkContext} +import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo -import org.apache.spark.util.Utils +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} -class SQLMetric(val metricType: String, initValue: Long = 0L) extends NewAccumulator[Long, Long] { +class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 // values before calculate max, min, etc. private[this] var _value = initValue + private var _zeroValue = initValue - override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue) + override def copy(): SQLMetric = { + val newAcc = new SQLMetric(metricType, _value) + newAcc._zeroValue = initValue + newAcc + } + + override def reset(): Unit = _value = _zeroValue - override def merge(other: NewAccumulator[Long, Long]): Unit = other match { - case o: SQLMetric => _value += o.localValue + override def merge(other: AccumulatorV2[Long, Long]): Unit = other match { + case o: SQLMetric => _value += o.value case _ => throw new UnsupportedOperationException( s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - override def isZero(): Boolean = _value == initValue + override def isZero(): Boolean = _value == _zeroValue override def add(v: Long): Unit = _value += v def +=(v: Long): Unit = _value += v - override def localValue: Long = _value + override def value: Long = _value // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later - private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { - new AccumulableInfo(id, name, update, value, true, true, Some(SQLMetrics.ACCUM_IDENTIFIER)) + override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = { + new AccumulableInfo( + id, name, update, value, true, true, Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) } - - def reset(): Unit = _value = initValue } -private[sql] object SQLMetrics { - // Identifier for distinguishing SQL metrics from other accumulators - private[sql] val ACCUM_IDENTIFIER = "sql" - - private[sql] val SUM_METRIC = "sum" - private[sql] val SIZE_METRIC = "size" - private[sql] val TIMING_METRIC = "timing" +object SQLMetrics { + private val SUM_METRIC = "sum" + private val SIZE_METRIC = "size" + private val TIMING_METRIC = "timing" def createMetric(sc: SparkContext, name: String): SQLMetric = { val acc = new SQLMetric(SUM_METRIC) - acc.register(sc, name = Some(name), countFailedValues = true) + acc.register(sc, name = Some(name), countFailedValues = false) acc } @@ -79,7 +83,7 @@ private[sql] object SQLMetrics { // data size total (min, med, max): // 100GB (100MB, 1GB, 10GB) val acc = new SQLMetric(SIZE_METRIC, -1) - acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true) + acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = false) acc } @@ -88,7 +92,7 @@ private[sql] object SQLMetrics { // duration(min, med, max): // 5s (800ms, 1s, 2s) val acc = new SQLMetric(TIMING_METRIC, -1) - acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true) + acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = false) acc } @@ -98,7 +102,8 @@ private[sql] object SQLMetrics { */ def stringValue(metricsType: String, values: Seq[Long]): String = { if (metricsType == SUM_METRIC) { - NumberFormat.getInstance().format(values.sum) + val numberFormat = NumberFormat.getIntegerInstance(Locale.ENGLISH) + numberFormat.format(values.sum) } else { val strFormat: Long => String = if (metricsType == SIZE_METRIC) { Utils.bytesToString @@ -110,7 +115,7 @@ private[sql] object SQLMetrics { val validValues = values.filter(_ >= 0) val Seq(sum, min, med, max) = { - val metric = if (validValues.length == 0) { + val metric = if (validValues.isEmpty) { Seq.fill(4)(0L) } else { val sorted = validValues.sorted diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 56a39069511d7..c7e267152b5cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -20,24 +20,53 @@ package org.apache.spark.sql.execution import scala.language.existentials import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.api.r._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.{DataType, ObjectType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DataType, ObjectType, StructType} + + +/** + * Physical version of `ObjectProducer`. + */ +trait ObjectProducerExec extends SparkPlan { + // The attribute that reference to the single object field this operator outputs. + protected def outputObjAttr: Attribute + + override def output: Seq[Attribute] = outputObjAttr :: Nil + + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + + def outputObjectType: DataType = outputObjAttr.dataType +} + +/** + * Physical version of `ObjectConsumer`. + */ +trait ObjectConsumerExec extends UnaryExecNode { + assert(child.output.length == 1) + + // This operator always need all columns of its child, even it doesn't reference to. + override def references: AttributeSet = child.outputSet + + def inputObjectType: DataType = child.output.head.dataType +} /** * Takes the input row from child and turns it into object using the given deserializer expression. * The output of this operator is a single-field safe row containing the deserialized object. */ -case class DeserializeToObject( +case class DeserializeToObjectExec( deserializer: Expression, outputObjAttr: Attribute, - child: SparkPlan) extends UnaryExecNode with CodegenSupport { - - override def output: Seq[Attribute] = outputObjAttr :: Nil - override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with CodegenSupport { override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() @@ -69,7 +98,7 @@ case class DeserializeToObject( */ case class SerializeFromObjectExec( serializer: Seq[NamedExpression], - child: SparkPlan) extends UnaryExecNode with CodegenSupport { + child: SparkPlan) extends ObjectConsumerExec with CodegenSupport { override def output: Seq[Attribute] = serializer.map(_.toAttribute) @@ -101,7 +130,7 @@ case class SerializeFromObjectExec( /** * Helper functions for physical operators that work with user defined objects. */ -trait ObjectOperator extends SparkPlan { +object ObjectOperator { def deserializeRowToObject( deserializer: Expression, inputSchema: Seq[Attribute]): InternalRow => Any = { @@ -140,15 +169,12 @@ case class MapPartitionsExec( func: Iterator[Any] => Iterator[Any], outputObjAttr: Attribute, child: SparkPlan) - extends UnaryExecNode with ObjectOperator { - - override def output: Seq[Attribute] = outputObjAttr :: Nil - override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + extends ObjectConsumerExec with ObjectProducerExec { override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = unwrapObjectFromRow(child.output.head.dataType) - val outputObject = wrapObjectToRow(outputObjAttr.dataType) + val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) func(iter.map(getObject)).map(outputObject) } } @@ -158,17 +184,14 @@ case class MapPartitionsExec( * Applies the given function to each input object. * The output of its child must be a single-field row containing the input object. * - * This operator is kind of a safe version of [[ProjectExec]], as it's output is custom object, + * This operator is kind of a safe version of [[ProjectExec]], as its output is custom object, * we need to use safe row to contain it. */ case class MapElementsExec( func: AnyRef, outputObjAttr: Attribute, child: SparkPlan) - extends UnaryExecNode with ObjectOperator with CodegenSupport { - - override def output: Seq[Attribute] = outputObjAttr :: Nil - override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + extends ObjectConsumerExec with ObjectProducerExec with CodegenSupport { override def inputRDDs(): Seq[RDD[InternalRow]] = { child.asInstanceOf[CodegenSupport].inputRDDs() @@ -201,8 +224,8 @@ case class MapElementsExec( } child.execute().mapPartitionsInternal { iter => - val getObject = unwrapObjectFromRow(child.output.head.dataType) - val outputObject = wrapObjectToRow(outputObjAttr.dataType) + val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) iter.map(row => outputObject(callFunc(getObject(row)))) } } @@ -217,7 +240,7 @@ case class AppendColumnsExec( func: Any => Any, deserializer: Expression, serializer: Seq[NamedExpression], - child: SparkPlan) extends UnaryExecNode with ObjectOperator { + child: SparkPlan) extends UnaryExecNode { override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute) @@ -225,9 +248,9 @@ case class AppendColumnsExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = deserializeRowToObject(deserializer, child.output) + val getObject = ObjectOperator.deserializeRowToObject(deserializer, child.output) val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) - val outputObject = serializeObjectToRow(serializer) + val outputObject = ObjectOperator.serializeObjectToRow(serializer) iter.map { row => val newColumns = outputObject(func(getObject(row))) @@ -245,7 +268,7 @@ case class AppendColumnsWithObjectExec( func: Any => Any, inputSerializer: Seq[NamedExpression], newColumnsSerializer: Seq[NamedExpression], - child: SparkPlan) extends UnaryExecNode with ObjectOperator { + child: SparkPlan) extends ObjectConsumerExec { override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute) @@ -254,9 +277,9 @@ case class AppendColumnsWithObjectExec( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getChildObject = unwrapObjectFromRow(child.output.head.dataType) - val outputChildObject = serializeObjectToRow(inputSerializer) - val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer) + val getChildObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType) + val outputChildObject = ObjectOperator.serializeObjectToRow(inputSerializer) + val outputNewColumnOjb = ObjectOperator.serializeObjectToRow(newColumnsSerializer) val combiner = GenerateUnsafeRowJoiner.create(inputSchema, newColumnSchema) iter.map { row => @@ -279,10 +302,7 @@ case class MapGroupsExec( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - child: SparkPlan) extends UnaryExecNode with ObjectOperator { - - override def output: Seq[Attribute] = outputObjAttr :: Nil - override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec { override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -294,9 +314,9 @@ case class MapGroupsExec( child.execute().mapPartitionsInternal { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes) - val getValue = deserializeRowToObject(valueDeserializer, dataAttributes) - val outputObject = wrapObjectToRow(outputObjAttr.dataType) + val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) grouped.flatMap { case (key, rowIter) => val result = func( @@ -308,6 +328,72 @@ case class MapGroupsExec( } } +/** + * Groups the input rows together and calls the R function with each group and an iterator + * containing all elements in the group. + * The result of this function is flattened before being output. + */ +case class FlatMapGroupsInRExec( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + inputSchema: StructType, + outputSchema: StructType, + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec { + + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val isSerializedRData = + if (outputSchema == SERIALIZED_R_DATA_SCHEMA) true else false + val serializerForR = if (!isSerializedRData) { + SerializationFormats.ROW + } else { + SerializationFormats.BYTE + } + + child.execute().mapPartitionsInternal { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val runner = new RRunner[Array[Byte]]( + func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars, + isDataFrame = true, colNames = inputSchema.fieldNames, + mode = RRunnerModes.DATAFRAME_GAPPLY) + + val groupedRBytes = grouped.map { case (key, rowIter) => + val deserializedIter = rowIter.map(getValue) + val newIter = + deserializedIter.asInstanceOf[Iterator[Row]].map { row => rowToRBytes(row) } + val newKey = rowToRBytes(getKey(key).asInstanceOf[Row]) + (newKey, newIter) + } + + val outputIter = runner.compute(groupedRBytes, -1) + if (!isSerializedRData) { + val result = outputIter.map { bytes => bytesToRow(bytes, outputSchema) } + result.map(outputObject) + } else { + val result = outputIter.map { bytes => Row.fromSeq(Seq(bytes)) } + result.map(outputObject) + } + } + } +} + /** * Co-groups the data from left and right children, and calls the function with each group and 2 * iterators containing all elements in the group from left and right side. @@ -324,10 +410,7 @@ case class CoGroupExec( rightAttr: Seq[Attribute], outputObjAttr: Attribute, left: SparkPlan, - right: SparkPlan) extends BinaryExecNode with ObjectOperator { - - override def output: Seq[Attribute] = outputObjAttr :: Nil - override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + right: SparkPlan) extends BinaryExecNode with ObjectProducerExec { override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil @@ -340,10 +423,10 @@ case class CoGroupExec( val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val getKey = deserializeRowToObject(keyDeserializer, leftGroup) - val getLeft = deserializeRowToObject(leftDeserializer, leftAttr) - val getRight = deserializeRowToObject(rightDeserializer, rightAttr) - val outputObject = wrapObjectToRow(outputObjAttr.dataType) + val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup) + val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr) + val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr) + val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 061d7c7f79de8..d9bf4d3ccf698 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi def children: Seq[SparkPlan] = child :: Nil + override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length)) + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index cf68ed4ec36a8..724025b4647f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -24,9 +24,8 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} @@ -34,16 +33,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String object EvaluatePython { - def takeAndServe(df: DataFrame, n: Int): Int = { - registerPicklers() - df.withNewExecutionId { - val iter = new SerDeUtil.AutoBatchedPickler( - df.queryExecution.executedPlan.executeTake(n).iterator.map { row => - EvaluatePython.toJava(row, df.schema) - }) - PythonRDD.serveIterator(iter, s"serve-DataFrame") - } - } def needConversionInPython(dt: DataType): Boolean = dt match { case DateType | TimestampType => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index ab192360e1c1f..16e44845d5283 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -18,12 +18,68 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution import org.apache.spark.sql.execution.SparkPlan + +/** + * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or + * grouping key, evaluate them after aggregate. + */ +object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { + + /** + * Returns whether the expression could only be evaluated within aggregate. + */ + private def belongAggregate(e: Expression, agg: Aggregate): Boolean = { + e.isInstanceOf[AggregateExpression] || + agg.groupingExpressions.exists(_.semanticEquals(e)) + } + + private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = { + expr.find { + e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined + }.isDefined + } + + private def extract(agg: Aggregate): LogicalPlan = { + val projList = new ArrayBuffer[NamedExpression]() + val aggExpr = new ArrayBuffer[NamedExpression]() + agg.aggregateExpressions.foreach { expr => + if (hasPythonUdfOverAggregate(expr, agg)) { + // Python UDF can only be evaluated after aggregate + val newE = expr transformDown { + case e: Expression if belongAggregate(e, agg) => + val alias = e match { + case a: NamedExpression => a + case o => Alias(e, "agg")() + } + aggExpr += alias + alias.toAttribute + } + projList += newE.asInstanceOf[NamedExpression] + } else { + aggExpr += expr + projList += expr.toAttribute + } + } + // There is no Python UDF over aggregate expression + Project(projList, agg.copy(aggregateExpressions = aggExpr)) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) => + extract(agg) + } +} + + /** * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated * alone in a batch. @@ -34,7 +90,7 @@ import org.apache.spark.sql.execution.SparkPlan * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { +object ExtractPythonUDFs extends Rule[SparkPlan] { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined @@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { } /** - * Extract all the PythonUDFs from the current operator. + * Extract all the PythonUDFs from the current operator and evaluate them before the operator. */ - def extract(plan: SparkPlan): SparkPlan = { + private def extract(plan: SparkPlan): SparkPlan = { val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + // ignore the PythonUDF that come from second/third aggregate, which is not used + .filter(udf => udf.references.subsetOf(plan.inputSet)) if (udfs.isEmpty) { // If there aren't any, we are done. plan @@ -74,7 +132,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { val validUdfs = udfs.filter { case udf => // Check to make sure that the UDF can be evaluated with only the input of this child. udf.references.subsetOf(child.outputSet) - } + }.toArray // Turn it into an array since iterators cannot be serialized in Scala 2.10 if (validUdfs.nonEmpty) { val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() @@ -89,17 +147,13 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { // Other cases are disallowed as they are ambiguous or would require a cartesian // product. udfs.filterNot(attributeMap.contains).foreach { udf => - if (udf.references.subsetOf(plan.inputSet)) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") - } + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") } - val rewritten = plan.transformExpressions { + val rewritten = plan.withNewChildren(newChildren).transformExpressions { case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) - }.withNewChildren(newChildren) + } // extract remaining python UDFs recursively val newPlan = extract(rewritten) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala index dc6f2ef371584..d2178e971ec20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala @@ -17,17 +17,16 @@ package org.apache.spark.sql.execution.r -import org.apache.spark.api.r.RRunner -import org.apache.spark.api.r.SerializationFormats +import org.apache.spark.api.r._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.r.SQLUtils._ import org.apache.spark.sql.Row -import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.apache.spark.sql.types.StructType /** * A function wrapper that applies the given R function to each partition. */ -private[sql] case class MapPartitionsRWrapper( +case class MapPartitionsRWrapper( func: Array[Byte], packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]], @@ -40,7 +39,7 @@ private[sql] case class MapPartitionsRWrapper( val (newIter, deserializer, colNames) = if (!isSerializedRData) { - // Serialize each row into an byte array that can be deserialized in the R worker + // Serialize each row into a byte array that can be deserialized in the R worker (iter.asInstanceOf[Iterator[Row]].map {row => rowToRBytes(row)}, SerializationFormats.ROW, inputSchema.fieldNames) } else { @@ -55,7 +54,7 @@ private[sql] case class MapPartitionsRWrapper( val runner = new RRunner[Array[Byte]]( func, deserializer, serializer, packageNames, broadcastVars, - isDataFrame = true, colNames = colNames) + isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY) // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex. val outputIter = runner.compute(newIter, -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 34bd243d58de9..b9dbfcf7734c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -private[sql] object FrequentItems extends Logging { +object FrequentItems extends Logging { /** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */ private class FreqItemCounter(size: Int) extends Serializable { @@ -40,7 +40,7 @@ private[sql] object FrequentItems extends Logging { if (baseMap.size < size) { baseMap += key -> count } else { - val minCount = baseMap.values.min + val minCount = if (baseMap.values.isEmpty) 0 else baseMap.values.min val remainder = count - minCount if (remainder >= 0) { baseMap += key -> count // something will get kicked out, so we can add this @@ -79,11 +79,11 @@ private[sql] object FrequentItems extends Logging { * than 1e-4. * @return A Local DataFrame with the Array of frequent items for each column. */ - private[sql] def singlePassFreqItems( + def singlePassFreqItems( df: DataFrame, cols: Seq[String], support: Double): DataFrame = { - require(support >= 1e-4, s"support ($support) must be greater than 1e-4.") + require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.") val numCols = cols.length // number of max items to keep counts for val sizeOfMap = (1 / support).toInt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 9c0406168e6e3..7e2ebe856bc89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.stat -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} @@ -27,7 +27,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -private[sql] object StatFunctions extends Logging { +object StatFunctions extends Logging { import QuantileSummaries.Stats @@ -119,7 +119,7 @@ private[sql] object StatFunctions extends Logging { class QuantileSummaries( val compressThreshold: Int, val relativeError: Double, - val sampled: ArrayBuffer[Stats] = ArrayBuffer.empty, + val sampled: Array[Stats] = Array.empty, private[stat] var count: Long = 0L, val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty) extends Serializable { @@ -134,7 +134,12 @@ private[sql] object StatFunctions extends Logging { def insert(x: Double): QuantileSummaries = { headSampled.append(x) if (headSampled.size >= defaultHeadSize) { - this.withHeadBufferInserted + val result = this.withHeadBufferInserted + if (result.sampled.length >= compressThreshold) { + result.compress() + } else { + result + } } else { this } @@ -186,7 +191,7 @@ private[sql] object StatFunctions extends Logging { newSamples.append(sampled(sampleIdx)) sampleIdx += 1 } - new QuantileSummaries(compressThreshold, relativeError, newSamples, currentCount) + new QuantileSummaries(compressThreshold, relativeError, newSamples.toArray, currentCount) } /** @@ -305,10 +310,10 @@ private[sql] object StatFunctions extends Logging { private def compressImmut( currentSamples: IndexedSeq[Stats], - mergeThreshold: Double): ArrayBuffer[Stats] = { - val res: ArrayBuffer[Stats] = ArrayBuffer.empty + mergeThreshold: Double): Array[Stats] = { + val res = ListBuffer.empty[Stats] if (currentSamples.isEmpty) { - return res + return res.toArray } // Start for the last element, which is always part of the set. // The head contains the current new head, that may be merged with the current element. @@ -331,13 +336,16 @@ private[sql] object StatFunctions extends Logging { } res.prepend(head) // If necessary, add the minimum element: - res.prepend(currentSamples.head) - res + val currHead = currentSamples.head + if (currHead.value < head.value) { + res.prepend(currentSamples.head) + } + res.toArray } } /** Calculate the Pearson Correlation Coefficient for the given columns */ - private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { + def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols, "correlation") counts.Ck / math.sqrt(counts.MkX * counts.MkY) } @@ -407,13 +415,13 @@ private[sql] object StatFunctions extends Logging { * @param cols the column names * @return the covariance of the two columns. */ - private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + def calculateCov(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols, "covariance") counts.cov } /** Generate a table of frequencies for the elements of two columns. */ - private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { + def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { val tableName = s"${col1}_$col2" val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt) if (counts.length == 1e6.toInt) { @@ -423,9 +431,9 @@ private[sql] object StatFunctions extends Logging { def cleanElement(element: Any): String = { if (element == null) "null" else element.toString } - // get the distinct values of column 2, so that we can make them the column names + // get the distinct sorted values of column 2, so that we can make them the column names val distinctCol2: Map[Any, Int] = - counts.map(e => cleanElement(e.get(1))).distinct.zipWithIndex.toMap + counts.map(e => cleanElement(e.get(1))).distinct.sorted.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala new file mode 100644 index 0000000000000..c14feea91ed7d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{InputStream, IOException, OutputStream} +import java.nio.charset.StandardCharsets.UTF_8 + +import scala.io.{Source => IOSource} +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.{Path, PathFilter} + +import org.apache.spark.sql.SparkSession + +/** + * An abstract class for compactible metadata logs. It will write one log file for each batch. + * The first line of the log file is the version number, and there are multiple serialized + * metadata lines following. + * + * As reading from many small files is usually pretty slow, also too many + * small files in one folder will mess the FS, [[CompactibleFileStreamLog]] will + * compact log files every 10 batches by default into a big file. When + * doing a compaction, it will read all old log files and merge them with the new batch. + */ +abstract class CompactibleFileStreamLog[T: ClassTag]( + metadataLogVersion: String, + sparkSession: SparkSession, + path: String) + extends HDFSMetadataLog[Array[T]](sparkSession, path) { + + import CompactibleFileStreamLog._ + + /** + * If we delete the old files after compaction at once, there is a race condition in S3: other + * processes may see the old files are deleted but still cannot see the compaction file using + * "list". The `allFiles` handles this by looking for the next compaction file directly, however, + * a live lock may happen if the compaction happens too frequently: one processing keeps deleting + * old files while another one keeps retrying. Setting a reasonable cleanup delay could avoid it. + */ + protected def fileCleanupDelayMs: Long + + protected def isDeletingExpiredLog: Boolean + + protected def compactInterval: Int + + /** + * Serialize the data into encoded string. + */ + protected def serializeData(t: T): String + + /** + * Deserialize the string into data object. + */ + protected def deserializeData(encodedString: String): T + + /** + * Filter out the obsolete logs. + */ + def compactLogs(logs: Seq[T]): Seq[T] + + override def batchIdToPath(batchId: Long): Path = { + if (isCompactionBatch(batchId, compactInterval)) { + new Path(metadataPath, s"$batchId$COMPACT_FILE_SUFFIX") + } else { + new Path(metadataPath, batchId.toString) + } + } + + override def pathToBatchId(path: Path): Long = { + getBatchIdFromFileName(path.getName) + } + + override def isBatchFile(path: Path): Boolean = { + try { + getBatchIdFromFileName(path.getName) + true + } catch { + case _: NumberFormatException => false + } + } + + override def serialize(logData: Array[T], out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + out.write(metadataLogVersion.getBytes(UTF_8)) + logData.foreach { data => + out.write('\n') + out.write(serializeData(data).getBytes(UTF_8)) + } + } + + override def deserialize(in: InputStream): Array[T] = { + val lines = IOSource.fromInputStream(in, UTF_8.name()).getLines() + if (!lines.hasNext) { + throw new IllegalStateException("Incomplete log file") + } + val version = lines.next() + if (version != metadataLogVersion) { + throw new IllegalStateException(s"Unknown log version: ${version}") + } + lines.map(deserializeData).toArray + } + + override def add(batchId: Long, logs: Array[T]): Boolean = { + if (isCompactionBatch(batchId, compactInterval)) { + compact(batchId, logs) + } else { + super.add(batchId, logs) + } + } + + /** + * Compacts all logs before `batchId` plus the provided `logs`, and writes them into the + * corresponding `batchId` file. It will delete expired files as well if enabled. + */ + private def compact(batchId: Long, logs: Array[T]): Boolean = { + val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval) + val allLogs = validBatches.flatMap(batchId => super.get(batchId)).flatten ++ logs + if (super.add(batchId, compactLogs(allLogs).toArray)) { + if (isDeletingExpiredLog) { + deleteExpiredLog(batchId) + } + true + } else { + // Return false as there is another writer. + false + } + } + + /** + * Returns all files except the deleted ones. + */ + def allFiles(): Array[T] = { + var latestId = getLatest().map(_._1).getOrElse(-1L) + // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileCatalog` + // is calling this method. This loop will retry the reading to deal with the + // race condition. + while (true) { + if (latestId >= 0) { + try { + val logs = + getAllValidBatches(latestId, compactInterval).flatMap(id => super.get(id)).flatten + return compactLogs(logs).toArray + } catch { + case e: IOException => + // Another process using `CompactibleFileStreamLog` may delete the batch files when + // `StreamFileCatalog` are reading. However, it only happens when a compaction is + // deleting old files. If so, let's try the next compaction batch and we should find it. + // Otherwise, this is a real IO issue and we should throw it. + latestId = nextCompactionBatchId(latestId, compactInterval) + super.get(latestId).getOrElse { + throw e + } + } + } else { + return Array.empty + } + } + Array.empty + } + + /** + * Since all logs before `compactionBatchId` are compacted and written into the + * `compactionBatchId` log file, they can be removed. However, due to the eventual consistency of + * S3, the compaction file may not be seen by other processes at once. So we only delete files + * created `fileCleanupDelayMs` milliseconds ago. + */ + private def deleteExpiredLog(compactionBatchId: Long): Unit = { + val expiredTime = System.currentTimeMillis() - fileCleanupDelayMs + fileManager.list(metadataPath, new PathFilter { + override def accept(path: Path): Boolean = { + try { + val batchId = getBatchIdFromFileName(path.getName) + batchId < compactionBatchId + } catch { + case _: NumberFormatException => + false + } + } + }).foreach { f => + if (f.getModificationTime <= expiredTime) { + fileManager.delete(f.getPath) + } + } + } +} + +object CompactibleFileStreamLog { + val COMPACT_FILE_SUFFIX = ".compact" + + def getBatchIdFromFileName(fileName: String): Long = { + fileName.stripSuffix(COMPACT_FILE_SUFFIX).toLong + } + + /** + * Returns if this is a compaction batch. FileStreamSinkLog will compact old logs every + * `compactInterval` commits. + * + * E.g., if `compactInterval` is 3, then 2, 5, 8, ... are all compaction batches. + */ + def isCompactionBatch(batchId: Long, compactInterval: Int): Boolean = { + (batchId + 1) % compactInterval == 0 + } + + /** + * Returns all valid batches before the specified `compactionBatchId`. They contain all logs we + * need to do a new compaction. + * + * E.g., if `compactInterval` is 3 and `compactionBatchId` is 5, this method should returns + * `Seq(2, 3, 4)` (Note: it includes the previous compaction batch 2). + */ + def getValidBatchesBeforeCompactionBatch( + compactionBatchId: Long, + compactInterval: Int): Seq[Long] = { + assert(isCompactionBatch(compactionBatchId, compactInterval), + s"$compactionBatchId is not a compaction batch") + (math.max(0, compactionBatchId - compactInterval)) until compactionBatchId + } + + /** + * Returns all necessary logs before `batchId` (inclusive). If `batchId` is a compaction, just + * return itself. Otherwise, it will find the previous compaction batch and return all batches + * between it and `batchId`. + */ + def getAllValidBatches(batchId: Long, compactInterval: Long): Seq[Long] = { + assert(batchId >= 0) + val start = math.max(0, (batchId + 1) / compactInterval * compactInterval - 1) + start to batchId + } + + /** + * Returns the next compaction batch id after `batchId`. + */ + def nextCompactionBatchId(batchId: Long, compactInterval: Long): Long = { + (batchId + compactInterval + 1) / compactInterval * compactInterval - 1 + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala index 729c8462fed65..ebc6ee8184902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompositeOffset.scala @@ -23,36 +23,6 @@ package org.apache.spark.sql.execution.streaming * vector clock that must progress linearly forward. */ case class CompositeOffset(offsets: Seq[Option[Offset]]) extends Offset { - /** - * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, - * or greater than the specified object. - */ - override def compareTo(other: Offset): Int = other match { - case otherComposite: CompositeOffset if otherComposite.offsets.size == offsets.size => - val comparisons = offsets.zip(otherComposite.offsets).map { - case (Some(a), Some(b)) => a compareTo b - case (None, None) => 0 - case (None, _) => -1 - case (_, None) => 1 - } - val nonZeroSigns = comparisons.map(sign).filter(_ != 0).toSet - nonZeroSigns.size match { - case 0 => 0 // if both empty or only 0s - case 1 => nonZeroSigns.head // if there are only (0s and 1s) or (0s and -1s) - case _ => // there are both 1s and -1s - throw new IllegalArgumentException( - s"Invalid comparison between non-linear histories: $this <=> $other") - } - case _ => - throw new IllegalArgumentException(s"Cannot compare $this <=> $other") - } - - private def sign(num: Int): Int = num match { - case i if i < 0 => -1 - case i if i == 0 => 0 - case i if i > 0 => 1 - } - /** * Unpacks an offset into [[StreamProgress]] by associating each offset with the order list of * sources. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala deleted file mode 100644 index b1d24b6cfc0bd..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ContinuousQueryListenerBus.scala +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} -import org.apache.spark.sql.util.ContinuousQueryListener -import org.apache.spark.sql.util.ContinuousQueryListener._ -import org.apache.spark.util.ListenerBus - -/** - * A bus to forward events to [[ContinuousQueryListener]]s. This one will wrap received - * [[ContinuousQueryListener.Event]]s as WrappedContinuousQueryListenerEvents and send them to the - * Spark listener bus. It also registers itself with Spark listener bus, so that it can receive - * WrappedContinuousQueryListenerEvents, unwrap them as ContinuousQueryListener.Events and - * dispatch them to ContinuousQueryListener. - */ -class ContinuousQueryListenerBus(sparkListenerBus: LiveListenerBus) - extends SparkListener with ListenerBus[ContinuousQueryListener, ContinuousQueryListener.Event] { - - sparkListenerBus.addListener(this) - - /** - * Post a ContinuousQueryListener event to the Spark listener bus asynchronously. This event will - * be dispatched to all ContinuousQueryListener in the thread of the Spark listener bus. - */ - def post(event: ContinuousQueryListener.Event) { - event match { - case s: QueryStarted => - postToAll(s) - case _ => - sparkListenerBus.post(new WrappedContinuousQueryListenerEvent(event)) - } - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = { - event match { - case WrappedContinuousQueryListenerEvent(e) => - postToAll(e) - case _ => - } - } - - override protected def doPostEvent( - listener: ContinuousQueryListener, - event: ContinuousQueryListener.Event): Unit = { - event match { - case queryStarted: QueryStarted => - listener.onQueryStarted(queryStarted) - case queryProgress: QueryProgress => - listener.onQueryProgress(queryProgress) - case queryTerminated: QueryTerminated => - listener.onQueryTerminated(queryTerminated) - case _ => - } - } - - /** - * Wrapper for StreamingListenerEvent as SparkListenerEvent so that it can be posted to Spark - * listener bus. - */ - private case class WrappedContinuousQueryListenerEvent( - streamingListenerEvent: ContinuousQueryListener.Event) extends SparkListenerEvent { - - // Do not log streaming events in event log as history server does not support these events. - protected[spark] override def logEvent: Boolean = false - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala new file mode 100644 index 0000000000000..3efc20c1d662d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamOptions.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.util.Try + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap +import org.apache.spark.util.Utils + +/** + * User specified options for file streams. + */ +class FileStreamOptions(parameters: Map[String, String]) extends Logging { + + val maxFilesPerTrigger: Option[Int] = parameters.get("maxFilesPerTrigger").map { str => + Try(str.toInt).toOption.filter(_ > 0).getOrElse { + throw new IllegalArgumentException( + s"Invalid value '$str' for option 'maxFilesPerTrigger', must be a positive integer") + } + } + + /** + * Maximum age of a file that can be found in this directory, before it is deleted. + * + * The max age is specified with respect to the timestamp of the latest file, and not the + * timestamp of the current system. That this means if the last file has timestamp 1000, and the + * current system time is 2000, and max age is 200, the system will purge files older than + * 800 (rather than 1800) from the internal state. + * + * Default to a week. + */ + val maxFileAgeMs: Long = + Utils.timeStringAsMs(parameters.getOrElse("maxFileAge", "7d")) + + /** Options as specified by the user, in a case-insensitive map, without "path" set. */ + val optionMapWithoutPath: Map[String, String] = + new CaseInsensitiveMap(parameters).filterKeys(_ != "path") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 70aea7fa49a1e..b5f73fb2591dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -19,11 +19,21 @@ package org.apache.spark.sql.execution.streaming import java.util.UUID +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkEnv, SparkException, TaskContext, TaskContextImpl} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, PartitioningUtils} +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter object FileStreamSink { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -40,28 +50,25 @@ object FileStreamSink { class FileStreamSink( sparkSession: SparkSession, path: String, - fileFormat: FileFormat) extends Sink with Logging { + fileFormat: FileFormat, + partitionColumnNames: Seq[String], + options: Map[String, String]) extends Sink with Logging { private val basePath = new Path(path) private val logPath = new Path(basePath, FileStreamSink.metadataDir) - private val fileLog = new FileStreamSinkLog(sparkSession, logPath.toUri.toString) - private val fs = basePath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + private val fileLog = + new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) + private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private val fs = basePath.getFileSystem(hadoopConf) override def addBatch(batchId: Long, data: DataFrame): Unit = { if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") } else { - val files = fs.listStatus(writeFiles(data)).map { f => - SinkFileStatus( - path = f.getPath.toUri.toString, - size = f.getLen, - isDir = f.isDirectory, - modificationTime = f.getModificationTime, - blockReplication = f.getReplication, - blockSize = f.getBlockSize, - action = FileStreamSinkLog.ADD_ACTION) - } - if (fileLog.add(batchId, files)) { + val writer = new FileStreamSinkWriter( + data, fileFormat, path, partitionColumnNames, hadoopConf, options) + val fileStatuses = writer.write() + if (fileLog.add(batchId, fileStatuses)) { logInfo(s"Committed batch $batchId") } else { throw new IllegalStateException(s"Race while writing batch $batchId") @@ -69,17 +76,194 @@ class FileStreamSink( } } - /** Writes the [[DataFrame]] to a UUID-named dir, returning the list of files paths. */ - private def writeFiles(data: DataFrame): Array[Path] = { - val file = new Path(basePath, UUID.randomUUID().toString).toUri.toString - data.write.parquet(file) - sparkSession.read - .schema(data.schema) - .parquet(file) - .inputFiles - .map(new Path(_)) - .filterNot(_.getName.startsWith("_")) + override def toString: String = s"FileSink[$path]" +} + + +/** + * Writes data given to a [[FileStreamSink]] to the given `basePath` in the given `fileFormat`, + * partitioned by the given `partitionColumnNames`. This writer always appends data to the + * directory if it already has data. + */ +class FileStreamSinkWriter( + data: DataFrame, + fileFormat: FileFormat, + basePath: String, + partitionColumnNames: Seq[String], + hadoopConf: Configuration, + options: Map[String, String]) extends Serializable with Logging { + + PartitioningUtils.validatePartitionColumn( + data.schema, partitionColumnNames, data.sqlContext.conf.caseSensitiveAnalysis) + + private val serializableConf = new SerializableConfiguration(hadoopConf) + private val dataSchema = data.schema + private val dataColumns = data.logicalPlan.output + + // Get the actual partition columns as attributes after matching them by name with + // the given columns names. + private val partitionColumns = partitionColumnNames.map { col => + val nameEquality = if (data.sparkSession.sessionState.conf.caseSensitiveAnalysis) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse { + throw new RuntimeException(s"Partition column $col not found in schema $dataSchema") + } + } + + // Columns that are to be written to the files. If there are partitioning columns, then + // those will not be written to the files. + private val writeColumns = { + val partitionSet = AttributeSet(partitionColumns) + dataColumns.filterNot(partitionSet.contains) } - override def toString: String = s"FileSink[$path]" + // An OutputWriterFactory for generating writers in the executors for writing the files. + private val outputWriterFactory = + fileFormat.buildWriter(data.sqlContext, writeColumns.toStructType, options) + + /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ + private def partitionStringExpression: Seq[Expression] = { + partitionColumns.zipWithIndex.flatMap { case (c, i) => + val escaped = + ScalaUDF( + PartitioningUtils.escapePathName _, + StringType, + Seq(Cast(c, StringType)), + Seq(StringType)) + val str = If(IsNull(c), Literal(PartitioningUtils.DEFAULT_PARTITION_NAME), escaped) + val partitionName = Literal(c.name + "=") :: str :: Nil + if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName + } + } + + /** Generate a new output writer from the writer factory */ + private def newOutputWriter(path: Path): OutputWriter = { + val newWriter = outputWriterFactory.newWriter(path.toString) + newWriter.initConverter(dataSchema) + newWriter + } + + /** Write the dataframe to files. This gets called in the driver by the [[FileStreamSink]]. */ + def write(): Array[SinkFileStatus] = { + data.sqlContext.sparkContext.runJob( + data.queryExecution.toRdd, + (taskContext: TaskContext, iterator: Iterator[InternalRow]) => { + if (partitionColumns.isEmpty) { + Seq(writePartitionToSingleFile(iterator)) + } else { + writePartitionToPartitionedFiles(iterator) + } + }).flatten + } + + /** + * Writes a RDD partition to a single file without dynamic partitioning. + * This gets called in the executor, and it uses a [[OutputWriter]] to write the data. + */ + def writePartitionToSingleFile(iterator: Iterator[InternalRow]): SinkFileStatus = { + var writer: OutputWriter = null + try { + val path = new Path(basePath, UUID.randomUUID.toString) + val fs = path.getFileSystem(serializableConf.value) + writer = newOutputWriter(path) + while (iterator.hasNext) { + writer.writeInternal(iterator.next) + } + writer.close() + writer = null + SinkFileStatus(fs.getFileStatus(path)) + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + // call failure callbacks first, so we could have a chance to cleanup the writer. + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) + throw new SparkException("Task failed while writing rows.", cause) + } finally { + if (writer != null) { + writer.close() + } + } + } + + /** + * Writes a RDD partition to multiple dynamically partitioned files. + * This gets called in the executor. It first sorts the data based on the partitioning columns + * and then writes the data of each key to separate files using [[OutputWriter]]s. + */ + def writePartitionToPartitionedFiles(iterator: Iterator[InternalRow]): Seq[SinkFileStatus] = { + + // Returns the partitioning columns for sorting + val getSortingKey = UnsafeProjection.create(partitionColumns, dataColumns) + + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(writeColumns, dataColumns) + + // Returns the partition path given a partition key + val getPartitionString = + UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) + + // Sort the data before write, so that we only need one writer at the same time. + val sorter = new UnsafeKVExternalSorter( + partitionColumns.toStructType, + StructType.fromAttributes(writeColumns), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes, + SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logDebug(s"Sorting complete. Writing out partition files one at a time.") + + val sortedIterator = sorter.sortedIterator() + val paths = new ArrayBuffer[Path] + + // Write the sorted data to partitioned files, one for each unique key + var currentWriter: OutputWriter = null + try { + var currentKey: UnsafeRow = null + while (sortedIterator.next()) { + val nextKey = sortedIterator.getKey + + // If key changes, close current writer, and open a new writer to a new partitioned file + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + currentKey = nextKey.copy() + val partitionPath = getPartitionString(currentKey).getString(0) + val path = new Path(new Path(basePath, partitionPath), UUID.randomUUID.toString) + paths += path + currentWriter = newOutputWriter(path) + logInfo(s"Writing partition $currentKey to $path") + } + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + if (paths.nonEmpty) { + val fs = paths.head.getFileSystem(serializableConf.value) + paths.map(p => SinkFileStatus(fs.getFileStatus(p))) + } else Seq.empty + } catch { + case cause: Throwable => + logError("Aborting task.", cause) + // call failure callbacks first, so we could have a chance to cleanup the writer. + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) + throw new SparkException("Task failed while writing rows.", cause) + } finally { + if (currentWriter != null) { + currentWriter.close() + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala index b694b6155ad99..1ed5486f382a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -17,10 +17,7 @@ package org.apache.spark.sql.execution.streaming -import java.io.IOException -import java.nio.charset.StandardCharsets.UTF_8 - -import org.apache.hadoop.fs.{FileStatus, Path, PathFilter} +import org.apache.hadoop.fs.{FileStatus, Path} import org.json4s.NoTypeHints import org.json4s.jackson.Serialization import org.json4s.jackson.Serialization.{read, write} @@ -54,6 +51,19 @@ case class SinkFileStatus( } } +object SinkFileStatus { + def apply(f: FileStatus): SinkFileStatus = { + SinkFileStatus( + path = f.getPath.toUri.toString, + size = f.getLen, + isDir = f.isDirectory, + modificationTime = f.getModificationTime, + blockReplication = f.getReplication, + blockSize = f.getBlockSize, + action = FileStreamSinkLog.ADD_ACTION) + } +} + /** * A special log for [[FileStreamSink]]. It will write one log file for each batch. The first line * of the log file is the version number, and there are multiple JSON lines following. Each JSON @@ -66,213 +76,47 @@ case class SinkFileStatus( * When the reader uses `allFiles` to list all files, this method only returns the visible files * (drops the deleted files). */ -class FileStreamSinkLog(sparkSession: SparkSession, path: String) - extends HDFSMetadataLog[Seq[SinkFileStatus]](sparkSession, path) { - - import FileStreamSinkLog._ +class FileStreamSinkLog( + metadataLogVersion: String, + sparkSession: SparkSession, + path: String) + extends CompactibleFileStreamLog[SinkFileStatus](metadataLogVersion, sparkSession, path) { private implicit val formats = Serialization.formats(NoTypeHints) - /** - * If we delete the old files after compaction at once, there is a race condition in S3: other - * processes may see the old files are deleted but still cannot see the compaction file using - * "list". The `allFiles` handles this by looking for the next compaction file directly, however, - * a live lock may happen if the compaction happens too frequently: one processing keeps deleting - * old files while another one keeps retrying. Setting a reasonable cleanup delay could avoid it. - */ - private val fileCleanupDelayMs = sparkSession.conf.get(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY) + protected override val fileCleanupDelayMs = + sparkSession.conf.get(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY) + + protected override val isDeletingExpiredLog = + sparkSession.conf.get(SQLConf.FILE_SINK_LOG_DELETION) - private val isDeletingExpiredLog = sparkSession.conf.get(SQLConf.FILE_SINK_LOG_DELETION) + protected override val compactInterval = + sparkSession.conf.get(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL) - private val compactInterval = sparkSession.conf.get(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL) require(compactInterval > 0, s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " + "to a positive value.") - override def batchIdToPath(batchId: Long): Path = { - if (isCompactionBatch(batchId, compactInterval)) { - new Path(metadataPath, s"$batchId$COMPACT_FILE_SUFFIX") - } else { - new Path(metadataPath, batchId.toString) - } - } - - override def pathToBatchId(path: Path): Long = { - getBatchIdFromFileName(path.getName) - } - - override def isBatchFile(path: Path): Boolean = { - try { - getBatchIdFromFileName(path.getName) - true - } catch { - case _: NumberFormatException => false - } - } - - override def serialize(logData: Seq[SinkFileStatus]): Array[Byte] = { - (VERSION +: logData.map(write(_))).mkString("\n").getBytes(UTF_8) - } - - override def deserialize(bytes: Array[Byte]): Seq[SinkFileStatus] = { - val lines = new String(bytes, UTF_8).split("\n") - if (lines.length == 0) { - throw new IllegalStateException("Incomplete log file") - } - val version = lines(0) - if (version != VERSION) { - throw new IllegalStateException(s"Unknown log version: ${version}") - } - lines.toSeq.slice(1, lines.length).map(read[SinkFileStatus](_)) - } - - override def add(batchId: Long, logs: Seq[SinkFileStatus]): Boolean = { - if (isCompactionBatch(batchId, compactInterval)) { - compact(batchId, logs) - } else { - super.add(batchId, logs) - } + protected override def serializeData(data: SinkFileStatus): String = { + write(data) } - /** - * Returns all files except the deleted ones. - */ - def allFiles(): Array[SinkFileStatus] = { - var latestId = getLatest().map(_._1).getOrElse(-1L) - // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileCatalog` - // is calling this method. This loop will retry the reading to deal with the - // race condition. - while (true) { - if (latestId >= 0) { - val startId = getAllValidBatches(latestId, compactInterval)(0) - try { - val logs = get(Some(startId), Some(latestId)).flatMap(_._2) - return compactLogs(logs).toArray - } catch { - case e: IOException => - // Another process using `FileStreamSink` may delete the batch files when - // `StreamFileCatalog` are reading. However, it only happens when a compaction is - // deleting old files. If so, let's try the next compaction batch and we should find it. - // Otherwise, this is a real IO issue and we should throw it. - latestId = nextCompactionBatchId(latestId, compactInterval) - get(latestId).getOrElse { - throw e - } - } - } else { - return Array.empty - } - } - Array.empty + protected override def deserializeData(encodedString: String): SinkFileStatus = { + read[SinkFileStatus](encodedString) } - /** - * Compacts all logs before `batchId` plus the provided `logs`, and writes them into the - * corresponding `batchId` file. It will delete expired files as well if enabled. - */ - private def compact(batchId: Long, logs: Seq[SinkFileStatus]): Boolean = { - val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval) - val allLogs = validBatches.flatMap(batchId => get(batchId)).flatten ++ logs - if (super.add(batchId, compactLogs(allLogs))) { - if (isDeletingExpiredLog) { - deleteExpiredLog(batchId) - } - true + override def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = { + val deletedFiles = logs.filter(_.action == FileStreamSinkLog.DELETE_ACTION).map(_.path).toSet + if (deletedFiles.isEmpty) { + logs } else { - // Return false as there is another writer. - false - } - } - - /** - * Since all logs before `compactionBatchId` are compacted and written into the - * `compactionBatchId` log file, they can be removed. However, due to the eventual consistency of - * S3, the compaction file may not be seen by other processes at once. So we only delete files - * created `fileCleanupDelayMs` milliseconds ago. - */ - private def deleteExpiredLog(compactionBatchId: Long): Unit = { - val expiredTime = System.currentTimeMillis() - fileCleanupDelayMs - fileManager.list(metadataPath, new PathFilter { - override def accept(path: Path): Boolean = { - try { - val batchId = getBatchIdFromFileName(path.getName) - batchId < compactionBatchId - } catch { - case _: NumberFormatException => - false - } - } - }).foreach { f => - if (f.getModificationTime <= expiredTime) { - fileManager.delete(f.getPath) - } + logs.filter(f => !deletedFiles.contains(f.path)) } } } object FileStreamSinkLog { val VERSION = "v1" - val COMPACT_FILE_SUFFIX = ".compact" val DELETE_ACTION = "delete" val ADD_ACTION = "add" - - def getBatchIdFromFileName(fileName: String): Long = { - fileName.stripSuffix(COMPACT_FILE_SUFFIX).toLong - } - - /** - * Returns if this is a compaction batch. FileStreamSinkLog will compact old logs every - * `compactInterval` commits. - * - * E.g., if `compactInterval` is 3, then 2, 5, 8, ... are all compaction batches. - */ - def isCompactionBatch(batchId: Long, compactInterval: Int): Boolean = { - (batchId + 1) % compactInterval == 0 - } - - /** - * Returns all valid batches before the specified `compactionBatchId`. They contain all logs we - * need to do a new compaction. - * - * E.g., if `compactInterval` is 3 and `compactionBatchId` is 5, this method should returns - * `Seq(2, 3, 4)` (Note: it includes the previous compaction batch 2). - */ - def getValidBatchesBeforeCompactionBatch( - compactionBatchId: Long, - compactInterval: Int): Seq[Long] = { - assert(isCompactionBatch(compactionBatchId, compactInterval), - s"$compactionBatchId is not a compaction batch") - (math.max(0, compactionBatchId - compactInterval)) until compactionBatchId - } - - /** - * Returns all necessary logs before `batchId` (inclusive). If `batchId` is a compaction, just - * return itself. Otherwise, it will find the previous compaction batch and return all batches - * between it and `batchId`. - */ - def getAllValidBatches(batchId: Long, compactInterval: Long): Seq[Long] = { - assert(batchId >= 0) - val start = math.max(0, (batchId + 1) / compactInterval * compactInterval - 1) - start to batchId - } - - /** - * Removes all deleted files from logs. It assumes once one file is deleted, it won't be added to - * the log in future. - */ - def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = { - val deletedFiles = logs.filter(_.action == DELETE_ACTION).map(_.path).toSet - if (deletedFiles.isEmpty) { - logs - } else { - logs.filter(f => !deletedFiles.contains(f.path)) - } - } - - /** - * Returns the next compaction batch id after `batchId`. - */ - def nextCompactionBatchId(batchId: Long, compactInterval: Long): Long = { - (batchId + compactInterval + 1) / compactInterval * compactInterval - 1 - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 8e66538575b0c..c47033a07a19d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -17,35 +17,61 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.util.collection.OpenHashSet +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.execution.datasources.{DataSource, ListingFileCatalog, LogicalRelation} +import org.apache.spark.sql.types.StructType /** - * A very simple source that reads text files from the given directory as they appear. - * - * TODO Clean up the metadata files periodically + * A very simple source that reads files from the given directory as they appear. */ class FileStreamSource( sparkSession: SparkSession, - metadataPath: String, path: String, + fileFormatClassName: String, override val schema: StructType, - dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging { + partitionColumns: Seq[String], + metadataPath: String, + options: Map[String, String]) extends Source with Logging { + + import FileStreamSource._ + + private val sourceOptions = new FileStreamOptions(options) + + private val qualifiedBasePath: Path = { + val fs = new Path(path).getFileSystem(sparkSession.sessionState.newHadoopConf()) + fs.makeQualified(new Path(path)) // can contains glob patterns + } - private val fs = new Path(path).getFileSystem(sparkSession.sessionState.newHadoopConf()) - private val metadataLog = new HDFSMetadataLog[Seq[String]](sparkSession, metadataPath) + private val optionsWithPartitionBasePath = sourceOptions.optionMapWithoutPath ++ { + if (!SparkHadoopUtil.get.isGlobPath(new Path(path)) && options.contains("path")) { + Map("basePath" -> path) + } else { + Map() + }} + + private val metadataLog = + new FileStreamSourceLog(FileStreamSourceLog.VERSION, sparkSession, metadataPath) private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) - private val seenFiles = new OpenHashSet[String] - metadataLog.get(None, Some(maxBatchId)).foreach { case (batchId, files) => - files.foreach(seenFiles.add) + /** Maximum number of new files to be considered in each batch */ + private val maxFilesPerBatch = sourceOptions.maxFilesPerTrigger + + /** A mapping from a file that we have processed to some timestamp it was last modified. */ + // Visible for testing and debugging in production. + val seenFiles = new SeenFilesMap(sourceOptions.maxFileAgeMs) + + metadataLog.allFiles().foreach { entry => + seenFiles.add(entry.path, entry.timestamp) } + seenFiles.purge() + + logInfo(s"maxFilesPerBatch = $maxFilesPerBatch, maxFileAge = ${sourceOptions.maxFileAgeMs}") /** * Returns the maximum offset that can be retrieved from the source. @@ -54,21 +80,35 @@ class FileStreamSource( * there is no race here, so the cost of `synchronized` should be rare. */ private def fetchMaxOffset(): LongOffset = synchronized { - val filesPresent = fetchAllFiles() - val newFiles = new ArrayBuffer[String]() - filesPresent.foreach { file => - if (!seenFiles.contains(file)) { - logDebug(s"new file: $file") - newFiles.append(file) - seenFiles.add(file) - } else { - logDebug(s"old file: $file") - } + // All the new files found - ignore aged files and files that we have seen. + val newFiles = fetchAllFiles().filter { + case (path, timestamp) => seenFiles.isNewFile(path, timestamp) } - if (newFiles.nonEmpty) { + // Obey user's setting to limit the number of files in this batch trigger. + val batchFiles = + if (maxFilesPerBatch.nonEmpty) newFiles.take(maxFilesPerBatch.get) else newFiles + + batchFiles.foreach { file => + seenFiles.add(file._1, file._2) + logDebug(s"New file: $file") + } + val numPurged = seenFiles.purge() + + logTrace( + s""" + |Number of new files = ${newFiles.size} + |Number of files selected for batch = ${batchFiles.size} + |Number of seen files = ${seenFiles.size} + |Number of files purged from tracking map = $numPurged + """.stripMargin) + + if (batchFiles.nonEmpty) { maxBatchId += 1 - metadataLog.add(maxBatchId, newFiles) + metadataLog.add(maxBatchId, batchFiles.map { case (path, timestamp) => + FileEntry(path = path, timestamp = timestamp, batchId = maxBatchId) + }.toArray) + logInfo(s"Max batch id increased to $maxBatchId with ${batchFiles.size} new files") } new LongOffset(maxBatchId) @@ -95,23 +135,122 @@ class FileStreamSource( val endId = end.asInstanceOf[LongOffset].offset assert(startId <= endId) - val files = metadataLog.get(Some(startId + 1), Some(endId)).map(_._2).flatten + val files = metadataLog.get(Some(startId + 1), Some(endId)).flatMap(_._2) logInfo(s"Processing ${files.length} files from ${startId + 1}:$endId") - logDebug(s"Streaming ${files.mkString(", ")}") - dataFrameBuilder(files) + logTrace(s"Files are:\n\t" + files.mkString("\n\t")) + val newDataSource = + DataSource( + sparkSession, + paths = files.map(_.path), + userSpecifiedSchema = Some(schema), + partitionColumns = partitionColumns, + className = fileFormatClassName, + options = optionsWithPartitionBasePath) + Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation( + checkPathExist = false))) } - private def fetchAllFiles(): Seq[String] = { - val startTime = System.nanoTime() - val files = fs.listStatus(new Path(path)) - .filterNot(_.getPath.getName.startsWith("_")) - .map(_.getPath.toUri.toString) - val endTime = System.nanoTime() - logDebug(s"Listed ${files.size} in ${(endTime.toDouble - startTime) / 1000000}ms") + /** + * Returns a list of files found, sorted by their timestamp. + */ + private def fetchAllFiles(): Seq[(String, Long)] = { + val startTime = System.nanoTime + val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) + val catalog = new ListingFileCatalog(sparkSession, globbedPaths, options, Some(new StructType)) + val files = catalog.allFiles().sortBy(_.getModificationTime).map { status => + (status.getPath.toUri.toString, status.getModificationTime) + } + val endTime = System.nanoTime + val listingTimeMs = (endTime.toDouble - startTime) / 1000000 + if (listingTimeMs > 2000) { + // Output a warning when listing files uses more than 2 seconds. + logWarning(s"Listed ${files.size} file(s) in $listingTimeMs ms") + } else { + logTrace(s"Listed ${files.size} file(s) in $listingTimeMs ms") + } + logTrace(s"Files are:\n\t" + files.mkString("\n\t")) files } override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1) - override def toString: String = s"FileSource[$path]" + override def toString: String = s"FileStreamSource[$qualifiedBasePath]" + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + override def commit(end: Offset): Unit = { + // No-op for now; FileStreamSource currently garbage-collects files based on timestamp + // and the value of the maxFileAge parameter. + } + + override def stop() {} +} + + +object FileStreamSource { + + /** Timestamp for file modification time, in ms since January 1, 1970 UTC. */ + type Timestamp = Long + + case class FileEntry(path: String, timestamp: Timestamp, batchId: Long) extends Serializable + + /** + * A custom hash map used to track the list of files seen. This map is not thread-safe. + * + * To prevent the hash map from growing indefinitely, a purge function is available to + * remove files "maxAgeMs" older than the latest file. + */ + class SeenFilesMap(maxAgeMs: Long) { + require(maxAgeMs >= 0) + + /** Mapping from file to its timestamp. */ + private val map = new java.util.HashMap[String, Timestamp] + + /** Timestamp of the latest file. */ + private var latestTimestamp: Timestamp = 0L + + /** Timestamp for the last purge operation. */ + private var lastPurgeTimestamp: Timestamp = 0L + + /** Add a new file to the map. */ + def add(path: String, timestamp: Timestamp): Unit = { + map.put(path, timestamp) + if (timestamp > latestTimestamp) { + latestTimestamp = timestamp + } + } + + /** + * Returns true if we should consider this file a new file. The file is only considered "new" + * if it is new enough that we are still tracking, and we have not seen it before. + */ + def isNewFile(path: String, timestamp: Timestamp): Boolean = { + // Note that we are testing against lastPurgeTimestamp here so we'd never miss a file that + // is older than (latestTimestamp - maxAgeMs) but has not been purged yet. + timestamp >= lastPurgeTimestamp && !map.containsKey(path) + } + + /** Removes aged entries and returns the number of files removed. */ + def purge(): Int = { + lastPurgeTimestamp = latestTimestamp - maxAgeMs + val iter = map.entrySet().iterator() + var count = 0 + while (iter.hasNext) { + val entry = iter.next() + if (entry.getValue < lastPurgeTimestamp) { + count += 1 + iter.remove() + } + } + count + } + + def size: Int = map.size() + + def allEntries: Seq[(String, Timestamp)] = { + map.asScala.toSeq + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala new file mode 100644 index 0000000000000..8103309aff2ae --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.{LinkedHashMap => JLinkedHashMap} +import java.util.Map.Entry + +import scala.collection.mutable + +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.streaming.FileStreamSource.FileEntry +import org.apache.spark.sql.internal.SQLConf + +class FileStreamSourceLog( + metadataLogVersion: String, + sparkSession: SparkSession, + path: String) + extends CompactibleFileStreamLog[FileEntry](metadataLogVersion, sparkSession, path) { + + import CompactibleFileStreamLog._ + + // Configurations about metadata compaction + protected override val compactInterval = + sparkSession.conf.get(SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL) + require(compactInterval > 0, + s"Please set ${SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key} (was $compactInterval) to a " + + s"positive value.") + + protected override val fileCleanupDelayMs = + sparkSession.conf.get(SQLConf.FILE_SOURCE_LOG_CLEANUP_DELAY) + + protected override val isDeletingExpiredLog = + sparkSession.conf.get(SQLConf.FILE_SOURCE_LOG_DELETION) + + private implicit val formats = Serialization.formats(NoTypeHints) + + // A fixed size log entry cache to cache the file entries belong to the compaction batch. It is + // used to avoid scanning the compacted log file to retrieve it's own batch data. + private val cacheSize = compactInterval + private val fileEntryCache = new JLinkedHashMap[Long, Array[FileEntry]] { + override def removeEldestEntry(eldest: Entry[Long, Array[FileEntry]]): Boolean = { + size() > cacheSize + } + } + + protected override def serializeData(data: FileEntry): String = { + Serialization.write(data) + } + + protected override def deserializeData(encodedString: String): FileEntry = { + Serialization.read[FileEntry](encodedString) + } + + def compactLogs(logs: Seq[FileEntry]): Seq[FileEntry] = { + logs + } + + override def add(batchId: Long, logs: Array[FileEntry]): Boolean = { + if (super.add(batchId, logs)) { + if (isCompactionBatch(batchId, compactInterval)) { + fileEntryCache.put(batchId, logs) + } + true + } else { + false + } + } + + override def get(startId: Option[Long], endId: Option[Long]): Array[(Long, Array[FileEntry])] = { + val startBatchId = startId.getOrElse(0L) + val endBatchId = getLatest().map(_._1).getOrElse(0L) + + val (existedBatches, removedBatches) = (startBatchId to endBatchId).map { id => + if (isCompactionBatch(id, compactInterval) && fileEntryCache.containsKey(id)) { + (id, Some(fileEntryCache.get(id))) + } else { + val logs = super.get(id).map(_.filter(_.batchId == id)) + (id, logs) + } + }.partition(_._2.isDefined) + + // The below code may only be happened when original metadata log file has been removed, so we + // have to get the batch from latest compacted log file. This is quite time-consuming and may + // not be happened in the current FileStreamSource code path, since we only fetch the + // latest metadata log file. + val searchKeys = removedBatches.map(_._1) + val retrievedBatches = if (searchKeys.nonEmpty) { + logWarning(s"Get batches from removed files, this is unexpected in the current code path!!!") + val latestBatchId = getLatest().map(_._1).getOrElse(-1L) + if (latestBatchId < 0) { + Map.empty[Long, Option[Array[FileEntry]]] + } else { + val latestCompactedBatchId = getAllValidBatches(latestBatchId, compactInterval)(0) + val allLogs = new mutable.HashMap[Long, mutable.ArrayBuffer[FileEntry]] + + super.get(latestCompactedBatchId).foreach { entries => + entries.foreach { e => + allLogs.put(e.batchId, allLogs.getOrElse(e.batchId, mutable.ArrayBuffer()) += e) + } + } + + searchKeys.map(id => id -> allLogs.get(id).map(_.toArray)).filter(_._2.isDefined).toMap + } + } else { + Map.empty[Long, Option[Array[FileEntry]]] + } + + (existedBatches ++ retrievedBatches).map(i => i._1 -> i._2.get).toArray.sortBy(_._1) + } +} + +object FileStreamSourceLog { + val VERSION = "v1" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala new file mode 100644 index 0000000000000..24f98b9211f12 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, ForeachWriter} +import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde + +/** + * A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @tparam T The expected type of the sink. + */ +class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable { + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + // TODO: Refine this method when SPARK-16264 is resolved; see comments below. + + // This logic should've been as simple as: + // ``` + // data.as[T].foreachPartition { iter => ... } + // ``` + // + // Unfortunately, doing that would just break the incremental planing. The reason is, + // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` just + // does not support `IncrementalExecution`. + // + // So as a provisional fix, below we've made a special version of `Dataset` with its `rdd()` + // method supporting incremental planning. But in the long run, we should generally make newly + // created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to + // resolve). + + val datasetWithIncrementalExecution = + new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) { + override lazy val rdd: RDD[T] = { + val objectType = exprEnc.deserializer.dataType + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + + // was originally: sparkSession.sessionState.executePlan(deserialized) ... + val incrementalExecution = new IncrementalExecution( + this.sparkSession, + deserialized, + data.queryExecution.asInstanceOf[IncrementalExecution].outputMode, + data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation, + data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId) + incrementalExecution.toRdd.mapPartitions { rows => + rows.map(_.get(0, objectType)) + }.asInstanceOf[RDD[T]] + } + } + datasetWithIncrementalExecution.foreachPartition { iter => + if (writer.open(TaskContext.getPartitionId(), batchId)) { + try { + while (iter.hasNext) { + writer.process(iter.next()) + } + } catch { + case e: Throwable => + writer.close(e) + throw e + } + writer.close(null) + } else { + writer.close(null) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 9fe06a6c36cb9..c7235320fd6bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution.streaming -import java.io.{FileNotFoundException, IOException} -import java.nio.ByteBuffer +import java.io.{FileNotFoundException, InputStream, IOException, OutputStream} import java.util.{ConcurrentModificationException, EnumSet, UUID} import scala.reflect.ClassTag @@ -29,9 +28,9 @@ import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.SparkSession +import org.apache.spark.util.UninterruptibleThread /** @@ -48,6 +47,10 @@ import org.apache.spark.sql.SparkSession class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) extends MetadataLog[T] with Logging { + // Avoid serializing generic sequences, see SPARK-17372 + require(implicitly[ClassTag[T]].runtimeClass != classOf[Seq[_]], + "Should not create a log with type Seq, use Arrays instead - see SPARK-17372") + import HDFSMetadataLog._ val metadataPath = new Path(path) @@ -83,26 +86,42 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) } } - protected def serialize(metadata: T): Array[Byte] = { - JavaUtils.bufferToArray(serializer.serialize(metadata)) + protected def serialize(metadata: T, out: OutputStream): Unit = { + // called inside a try-finally where the underlying stream is closed in the caller + val outStream = serializer.serializeStream(out) + outStream.writeObject(metadata) } - protected def deserialize(bytes: Array[Byte]): T = { - serializer.deserialize[T](ByteBuffer.wrap(bytes)) + protected def deserialize(in: InputStream): T = { + // called inside a try-finally where the underlying stream is closed in the caller + val inStream = serializer.deserializeStream(in) + inStream.readObject[T]() } + /** + * Store the metadata for the specified batchId and return `true` if successful. If the batchId's + * metadata has already been stored, this method will return `false`. + * + * Note that this method must be called on a [[org.apache.spark.util.UninterruptibleThread]] + * so that interrupts can be disabled while writing the batch file. This is because there is a + * potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). If the thread + * running "Shell.runCommand" is interrupted, then the thread can get deadlocked. In our + * case, `writeBatch` creates a file using HDFS API and calls "Shell.runCommand" to set the + * file permissions, and can get deadlocked if the stream execution thread is stopped by + * interrupt. Hence, we make sure that this method is called on [[UninterruptibleThread]] which + * allows us to disable interrupts here. Also see SPARK-14131. + */ override def add(batchId: Long, metadata: T): Boolean = { get(batchId).map(_ => false).getOrElse { - // Only write metadata when the batch has not yet been written. - try { - writeBatch(batchId, serialize(metadata)) - true - } catch { - case e: IOException if "java.lang.InterruptedException" == e.getMessage => - // create may convert InterruptedException to IOException. Let's convert it back to - // InterruptedException so that this failure won't crash StreamExecution - throw new InterruptedException("Creating file is interrupted") + // Only write metadata when the batch has not yet been written + Thread.currentThread match { + case ut: UninterruptibleThread => + ut.runUninterruptibly { writeBatch(batchId, metadata, serialize) } + case _ => + throw new IllegalStateException( + "HDFSMetadataLog.add() must be executed on a o.a.spark.util.UninterruptibleThread") } + true } } @@ -112,7 +131,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, bytes: Array[Byte]): Unit = { + private def writeBatch(batchId: Long, metadata: T, writer: (T, OutputStream) => Unit): Unit = { // Use nextId to create a temp file var nextId = 0 while (true) { @@ -120,9 +139,9 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) try { val output = fileManager.create(tempPath) try { - output.write(bytes) + writer(metadata, output) } finally { - output.close() + IOUtils.closeQuietly(output) } try { // Try to commit the batch @@ -167,7 +186,7 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) private def isFileAlreadyExistsException(e: IOException): Boolean = { e.isInstanceOf[FileAlreadyExistsException] || // Old Hadoop versions don't throw FileAlreadyExistsException. Although it's fixed in - // HADOOP-9361, we still need to support old Hadoop versions. + // HADOOP-9361 in Hadoop 2.5, we still need to support old Hadoop versions. (e.getMessage != null && e.getMessage.startsWith("File already exists: ")) } @@ -175,8 +194,11 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) val batchMetadataFile = batchIdToPath(batchId) if (fileManager.exists(batchMetadataFile)) { val input = fileManager.open(batchMetadataFile) - val bytes = IOUtils.toByteArray(input) - Some(deserialize(bytes)) + try { + Some(deserialize(input)) + } finally { + IOUtils.closeQuietly(input) + } } else { logDebug(s"Unable to find batch $batchMetadataFile") None @@ -210,14 +232,29 @@ class HDFSMetadataLog[T: ClassTag](sparkSession: SparkSession, path: String) None } + /** + * Removes all the log entry earlier than thresholdBatchId (exclusive). + */ + override def purge(thresholdBatchId: Long): Unit = { + val batchIds = fileManager.list(metadataPath, batchFilesFilter) + .map(f => pathToBatchId(f.getPath)) + + for (batchId <- batchIds if batchId < thresholdBatchId) { + val path = batchIdToPath(batchId) + fileManager.delete(path) + logTrace(s"Removed metadata log file: $path") + } + } + private def createFileManager(): FileManager = { val hadoopConf = sparkSession.sessionState.newHadoopConf() try { new FileContextManager(metadataPath, hadoopConf) } catch { case e: UnsupportedFileSystemException => - logWarning("Could not use FileContext API for managing metadata log file. The log may be" + - "inconsistent under failures.", e) + logWarning("Could not use FileContext API for managing metadata log files at path " + + s"$metadataPath. Using FileSystem API instead for managing log files. The log may be " + + s"inconsistent under failures.") new FileSystemManager(metadataPath, hadoopConf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index b89144d727514..05294df2673dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.OutputMode +import org.apache.spark.sql.{InternalOutputModes, SparkSession} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} +import org.apache.spark.sql.streaming.OutputMode /** * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] @@ -30,13 +30,15 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, class IncrementalExecution( sparkSession: SparkSession, logicalPlan: LogicalPlan, - outputMode: OutputMode, - checkpointLocation: String, - currentBatchId: Long) + val outputMode: OutputMode, + val checkpointLocation: String, + val currentBatchId: Long) extends QueryExecution(sparkSession, logicalPlan) { // TODO: make this always part of planning. - val stateStrategy = sparkSession.sessionState.planner.StatefulAggregationStrategy :: Nil + val stateStrategy = sparkSession.sessionState.planner.StatefulAggregationStrategy +: + sparkSession.sessionState.planner.StreamingRelationStrategy +: + sparkSession.sessionState.experimentalMethods.extraStrategies // Modified planner with stateful operations. override def planner: SparkPlanner = @@ -47,22 +49,25 @@ class IncrementalExecution( /** * Records the current id for a given stateful operator in the query plan as the `state` - * preperation walks the query plan. + * preparation walks the query plan. */ private var operatorId = 0 /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan transform { - case StateStoreSaveExec(keys, None, + case StateStoreSaveExec(keys, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId - 1) + val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) + val returnAllStates = if (outputMode == InternalOutputModes.Complete) true else false operatorId += 1 StateStoreSaveExec( keys, Some(stateId), + Some(returnAllStates), agg.withNewChildren( StateStoreRestoreExec( keys, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index bb176408d8f59..c5e8827777792 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -22,12 +22,6 @@ package org.apache.spark.sql.execution.streaming */ case class LongOffset(offset: Long) extends Offset { - override def compareTo(other: Offset): Int = other match { - case l: LongOffset => offset.compareTo(l.offset) - case _ => - throw new IllegalArgumentException(s"Invalid comparison of $getClass with ${other.getClass}") - } - def +(increment: Long): LongOffset = new LongOffset(offset + increment) def -(decrement: Long): LongOffset = new LongOffset(offset - decrement) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala index cc70e1d314d1d..9e2604c9c069f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLog.scala @@ -24,6 +24,7 @@ package org.apache.spark.sql.execution.streaming * - Allow the user to query the latest batch id. * - Allow the user to query the metadata object of a specified batch id. * - Allow the user to query metadata objects in a range of batch ids. + * - Allow the user to remove obsolete metadata */ trait MetadataLog[T] { @@ -48,4 +49,10 @@ trait MetadataLog[T] { * Return the latest batch Id and its metadata if exist. */ def getLatest(): Option[(Long, T)] + + /** + * Removes all the log entry earlier than thresholdBatchId (exclusive). + * This operation should be idempotent. + */ + def purge(thresholdBatchId: Long): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala new file mode 100644 index 0000000000000..a32c4671e3475 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileCatalog.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import scala.collection.mutable + +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.datasources._ + + +/** + * A [[FileCatalog]] that generates the list of files to processing by reading them from the + * metadata log files generated by the [[FileStreamSink]]. + */ +class MetadataLogFileCatalog(sparkSession: SparkSession, path: Path) + extends PartitioningAwareFileCatalog(sparkSession, Map.empty, None) { + + private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) + logInfo(s"Reading streaming file log from $metadataDirectory") + private val metadataLog = + new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toUri.toString) + private val allFilesFromLog = metadataLog.allFiles().map(_.toFileStatus).filterNot(_.isDirectory) + private var cachedPartitionSpec: PartitionSpec = _ + + override protected val leafFiles: mutable.LinkedHashMap[Path, FileStatus] = { + new mutable.LinkedHashMap ++= allFilesFromLog.map(f => f.getPath -> f) + } + + override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { + allFilesFromLog.toArray.groupBy(_.getPath.getParent) + } + + override def paths: Seq[Path] = path :: Nil + + override def refresh(): Unit = { } + + override def partitionSpec(): PartitionSpec = { + if (cachedPartitionSpec == null) { + cachedPartitionSpec = inferPartitioning() + } + cachedPartitionSpec + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala index 0f5d6445b1e2b..1f52abf277581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.scala @@ -18,20 +18,9 @@ package org.apache.spark.sql.execution.streaming /** - * A offset is a monotonically increasing metric used to track progress in the computation of a - * stream. An [[Offset]] must be comparable, and the result of `compareTo` must be consistent - * with `equals` and `hashcode`. + * An offset is a monotonically increasing metric used to track progress in the computation of a + * stream. Since offsets are retrieved from a [[Source]] by a single thread, we know the global + * ordering of two [[Offset]] instances. We do assume that if two offsets are `equal` then no + * new data has arrived. */ -trait Offset extends Serializable { - - /** - * Returns a negative integer, zero, or a positive integer as this object is less than, equal to, - * or greater than the specified object. - */ - def compareTo(other: Offset): Int - - def >(other: Offset): Boolean = compareTo(other) > 0 - def <(other: Offset): Boolean = compareTo(other) < 0 - def <=(other: Offset): Boolean = compareTo(other) <= 0 - def >=(other: Offset): Boolean = compareTo(other) >= 0 -} +trait Offset extends Serializable {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala index e641e09b56adf..2571b59be54f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -30,6 +30,9 @@ trait Sink { * Adds a batch of data to this sink. The data for a given `batchId` is deterministic and if * this method is called more than once with the same batchId (which will happen in the case of * failures), then `data` should only be added once. + * + * Note: You cannot apply any operators on `data` except consuming it (e.g., `collect/foreach`). + * Otherwise, you may get a wrong result. */ def addBatch(batchId: Long, data: DataFrame): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index 14450c2e2fd14..f3bd5bfe23fdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -30,13 +30,30 @@ trait Source { /** Returns the schema of the data from this source */ def schema: StructType - /** Returns the maximum available offset for this source. */ + /** + * Returns the maximum available offset for this source. + * Returns `None` if this source has never received any data. + */ def getOffset: Option[Offset] /** - * Returns the data that is between the offsets (`start`, `end`]. When `start` is `None` then - * the batch should begin with the first available record. This method must always return the - * same data for a particular `start` and `end` pair. + * Returns the data that is between the offsets (`start`, `end`]. When `start` is `None`, + * then the batch should begin with the first record. This method must always return the + * same data for a particular `start` and `end` pair; even after the Source has been restarted + * on a different node. + * + * Higher layers will always call this method with a value of `start` greater than or equal + * to the last value passed to `commit` and a value of `end` less than or equal to the + * last value returned by `getOffset` */ def getBatch(start: Option[Offset], end: Offset): DataFrame + + /** + * Informs the source that Spark has completed processing all data for offsets less than or + * equal to `end` and will only request offsets greater than `end` in the future. + */ + def commit(end: Offset) : Unit = {} + + /** Stop this source and free any resources it has allocated. */ + def stop(): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index d5e4dd8f78ac2..587ea7d02acab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.SparkPlan @@ -56,7 +57,12 @@ case class StateStoreRestoreExec( child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override protected def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -69,6 +75,7 @@ case class StateStoreRestoreExec( iter.flatMap { row => val key = getKey(row) val savedState = store.get(key) + numOutputRows += 1 row +: savedState.toSeq } } @@ -82,10 +89,20 @@ case class StateStoreRestoreExec( case class StateStoreSaveExec( keyExpressions: Seq[Attribute], stateId: Option[OperatorStateId], + returnAllStates: Option[Boolean], child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) + override protected def doExecute(): RDD[InternalRow] = { + metrics // force lazy init at driver + assert(returnAllStates.nonEmpty, + "Incorrect planning in IncrementalExecution, returnAllStates have not been set") + val saveAndReturnFunc = if (returnAllStates.get) saveAndReturnAll _ else saveAndReturnUpdated _ child.execute().mapPartitionsWithStateStore( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -93,29 +110,73 @@ case class StateStoreSaveExec( keyExpressions.toStructType, child.output.toStructType, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - new Iterator[InternalRow] { - private[this] val baseIterator = iter - private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - - override def hasNext: Boolean = { - if (!baseIterator.hasNext) { - store.commit() - false - } else { - true - } - } - - override def next(): InternalRow = { - val row = baseIterator.next().asInstanceOf[UnsafeRow] - val key = getKey(row) - store.put(key.copy(), row.copy()) - row - } + Some(sqlContext.streams.stateStoreCoordinator) + )(saveAndReturnFunc) + } + + override def output: Seq[Attribute] = child.output + + /** + * Save all the rows to the state store, and return all the rows in the state store. + * Note that this returns an iterator that pipelines the saving to store with downstream + * processing. + */ + private def saveAndReturnUpdated( + store: StateStore, + iter: Iterator[InternalRow]): Iterator[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + + new Iterator[InternalRow] { + private[this] val baseIterator = iter + private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + numTotalStateRows += store.numKeys() + false + } else { + true } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numOutputRows += 1 + numUpdatedStateRows += 1 + row + } } } - override def output: Seq[Attribute] = child.output + /** + * Save all the rows to the state store, and return all the rows in the state store. + * Note that the saving to store is blocking; only after all the rows have been saved + * is the iterator on the update store data is generated. + */ + private def saveAndReturnAll( + store: StateStore, + iter: Iterator[InternalRow]): Iterator[InternalRow] = { + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + val numOutputRows = longMetric("numOutputRows") + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + + while (iter.hasNext) { + val row = iter.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + numUpdatedStateRows += 1 + } + store.commit() + numTotalStateRows += store.numKeys() + store.iterator().map { case (k, v) => + numOutputRows += 1 + v.asInstanceOf[InternalRow] + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index fc18e5f065a04..4707bfbce8e0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.{CountDownLatch, TimeUnit} -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.locks.ReentrantLock import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -27,15 +28,15 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.OutputMode import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.util.ContinuousQueryListener -import org.apache.spark.sql.util.ContinuousQueryListener._ -import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.command.ExplainCommand +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming._ +import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} /** * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. @@ -45,29 +46,46 @@ import org.apache.spark.util.UninterruptibleThread */ class StreamExecution( override val sparkSession: SparkSession, + override val id: Long, override val name: String, checkpointRoot: String, - private[sql] val logicalPlan: LogicalPlan, + val logicalPlan: LogicalPlan, val sink: Sink, - val outputMode: OutputMode, - val trigger: Trigger) - extends ContinuousQuery with Logging { + val trigger: Trigger, + val triggerClock: Clock, + val outputMode: OutputMode) + extends StreamingQuery with Logging { + + import org.apache.spark.sql.streaming.StreamingQueryListener._ + import StreamMetrics._ + + private val pollingDelayMs = sparkSession.conf.get(SQLConf.STREAMING_POLLING_DELAY) + + /** + * A lock used to wait/notify when batches complete. Use a fair lock to avoid thread starvation. + */ + private val awaitBatchLock = new ReentrantLock(true) + private val awaitBatchLockCondition = awaitBatchLock.newCondition() - /** An monitor used to wait/notify when batches complete. */ - private val awaitBatchLock = new Object private val startLatch = new CountDownLatch(1) private val terminationLatch = new CountDownLatch(1) /** * Tracks how much data we have processed and committed to the sink or state store from each * input source. + * Only the scheduler thread should modify this field, and only in atomic steps. + * Other threads should make a shallow copy if they are going to access this field more than + * once, since the field's value may change at any time. */ @volatile - private[sql] var committedOffsets = new StreamProgress + var committedOffsets = new StreamProgress /** * Tracks the offsets that are available to be processed, but have not yet be committed to the * sink. + * Only the scheduler thread should modify this field, and only in atomic steps. + * Other threads should make a shallow copy if they are going to access this field more than + * once, since the field's value may change at any time. */ @volatile private var availableOffsets = new StreamProgress @@ -83,7 +101,7 @@ class StreamExecution( private val uniqueSources = sources.distinct private val triggerExecutor = trigger match { - case t: ProcessingTime => ProcessingTimeExecutor(t) + case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) } /** Defines the internal state of execution */ @@ -91,15 +109,38 @@ class StreamExecution( private var state: State = INITIALIZED @volatile - private[sql] var lastExecution: QueryExecution = null + var lastExecution: QueryExecution = null + + @volatile + private var streamDeathCause: StreamingQueryException = null + + /* Get the call site in the caller thread; will pass this into the micro batch thread */ + private val callSite = Utils.getCallSite() + + /** Metrics for this query */ + private val streamMetrics = + new StreamMetrics(uniqueSources.toSet, triggerClock, s"StructuredStreaming.$name") + + @volatile + private var currentStatus: StreamingQueryStatus = null + /** Flag that signals whether any error with input metrics have already been logged */ @volatile - private[sql] var streamDeathCause: ContinuousQueryException = null + private var metricWarningLogged: Boolean = false - /** The thread that runs the micro-batches of this stream. */ - private[sql] val microBatchThread = - new UninterruptibleThread(s"stream execution thread for $name") { - override def run(): Unit = { runBatches() } + /** + * The thread that runs the micro-batches of this stream. Note that this thread must be + * [[org.apache.spark.util.UninterruptibleThread]] to avoid potential deadlocks in using + * [[HDFSMetadataLog]]. See SPARK-14131 for more details. + */ + val microBatchThread = + new StreamExecutionThread(s"stream execution thread for $name") { + override def run(): Unit = { + // To fix call site like "run at :0", we bridge the call site from the caller + // thread to this micro batch thread + sparkSession.sparkContext.setCallSite(callSite) + runBatches() + } } /** @@ -108,34 +149,32 @@ class StreamExecution( * processing is done. Thus, the Nth record in this log indicated data that is currently being * processed and the N-1th entry indicates which offsets have been durably committed to the sink. */ - private val offsetLog = - new HDFSMetadataLog[CompositeOffset](sparkSession, checkpointFile("offsets")) + val offsetLog = new HDFSMetadataLog[CompositeOffset](sparkSession, checkpointFile("offsets")) /** Whether the query is currently active or not */ override def isActive: Boolean = state == ACTIVE + /** Returns the current status of the query. */ + override def status: StreamingQueryStatus = currentStatus + /** Returns current status of all the sources. */ - override def sourceStatuses: Array[SourceStatus] = { - val localAvailableOffsets = availableOffsets - sources.map(s => new SourceStatus(s.toString, localAvailableOffsets.get(s))).toArray - } + override def sourceStatuses: Array[SourceStatus] = currentStatus.sourceStatuses.toArray /** Returns current status of the sink. */ - override def sinkStatus: SinkStatus = - new SinkStatus(sink.toString, committedOffsets.toCompositeOffset(sources)) + override def sinkStatus: SinkStatus = currentStatus.sinkStatus - /** Returns the [[ContinuousQueryException]] if the query was terminated by an exception. */ - override def exception: Option[ContinuousQueryException] = Option(streamDeathCause) + /** Returns the [[StreamingQueryException]] if the query was terminated by an exception. */ + override def exception: Option[StreamingQueryException] = Option(streamDeathCause) /** Returns the path of a file with `name` in the checkpoint directory. */ private def checkpointFile(name: String): String = new Path(new Path(checkpointRoot), name).toUri.toString /** - * Starts the execution. This returns only after the thread has started and [[QueryStarted]] event + * Starts the execution. This returns only after the thread has started and [[QueryStartedEvent]] * has been posted to all the listeners. */ - private[sql] def start(): Unit = { + def start(): Unit = { microBatchThread.setDaemon(true) microBatchThread.start() startLatch.await() // Wait until thread started and QueryStart event has been posted @@ -144,46 +183,91 @@ class StreamExecution( /** * Repeatedly attempts to run batches as data arrives. * - * Note that this method ensures that [[QueryStarted]] and [[QueryTerminated]] events are posted - * such that listeners are guaranteed to get a start event before a termination. Furthermore, this - * method also ensures that [[QueryStarted]] event is posted before the `start()` method returns. + * Note that this method ensures that [[QueryStartedEvent]] and [[QueryTerminatedEvent]] are + * posted such that listeners are guaranteed to get a start event before a termination. + * Furthermore, this method also ensures that [[QueryStartedEvent]] event is posted before the + * `start()` method returns. */ private def runBatches(): Unit = { try { // Mark ACTIVE and then post the event. QueryStarted event is synchronously sent to listeners, // so must mark this as ACTIVE first. state = ACTIVE - postEvent(new QueryStarted(this)) // Assumption: Does not throw exception. + if (sparkSession.conf.get(SQLConf.STREAMING_METRICS_ENABLED)) { + sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) + } + updateStatus() + postEvent(new QueryStartedEvent(currentStatus)) // Assumption: Does not throw exception. // Unblock starting thread startLatch.countDown() // While active, repeatedly attempt to run batches. - SQLContext.setActive(sparkSession.wrapped) - populateStartOffsets() - logDebug(s"Stream running from $committedOffsets to $availableOffsets") + SparkSession.setActiveSession(sparkSession) + triggerExecutor.execute(() => { - if (isActive) { - if (dataAvailable) runBatch() - constructNextBatch() - true - } else { - false + streamMetrics.reportTriggerStarted(currentBatchId) + streamMetrics.reportTriggerDetail(STATUS_MESSAGE, "Finding new data from sources") + updateStatus() + val isTerminated = reportTimeTaken(TRIGGER_LATENCY) { + if (isActive) { + if (currentBatchId < 0) { + // We'll do this initialization only once + populateStartOffsets() + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + } else { + constructNextBatch() + } + if (dataAvailable) { + streamMetrics.reportTriggerDetail(IS_DATA_PRESENT_IN_TRIGGER, true) + streamMetrics.reportTriggerDetail(STATUS_MESSAGE, "Processing new data") + updateStatus() + runBatch() + // We'll increase currentBatchId after we complete processing current batch's data + currentBatchId += 1 + } else { + streamMetrics.reportTriggerDetail(IS_DATA_PRESENT_IN_TRIGGER, false) + streamMetrics.reportTriggerDetail(STATUS_MESSAGE, "No new data") + updateStatus() + Thread.sleep(pollingDelayMs) + } + true + } else { + false + } } + // Update metrics and notify others + streamMetrics.reportTriggerFinished() + updateStatus() + postEvent(new QueryProgressEvent(currentStatus)) + isTerminated }) } catch { case _: InterruptedException if state == TERMINATED => // interrupted by stop() - case NonFatal(e) => - streamDeathCause = new ContinuousQueryException( + case e: Throwable => + streamDeathCause = new StreamingQueryException( this, s"Query $name terminated with exception: ${e.getMessage}", e, Some(committedOffsets.toCompositeOffset(sources))) logError(s"Query $name terminated with error", e) + // Rethrow the fatal errors to allow the user using `Thread.UncaughtExceptionHandler` to + // handle them + if (!NonFatal(e)) { + throw e + } } finally { state = TERMINATED + + // Update metrics and status + streamMetrics.stop() + sparkSession.sparkContext.env.metricsSystem.removeSource(streamMetrics) + updateStatus() + + // Notify others sparkSession.streams.notifyQueryTermination(StreamExecution.this) - postEvent(new QueryTerminated(this)) + postEvent( + new QueryTerminatedEvent(currentStatus, exception.map(_.cause).map(Utils.exceptionString))) terminationLatch.countDown() } } @@ -199,8 +283,8 @@ class StreamExecution( private def populateStartOffsets(): Unit = { offsetLog.getLatest() match { case Some((batchId, nextOffsets)) => - logInfo(s"Resuming continuous query, starting with batch $batchId") - currentBatchId = batchId + 1 + logInfo(s"Resuming streaming query, starting with batch $batchId") + currentBatchId = batchId availableOffsets = nextOffsets.toStreamProgress(sources) logDebug(s"Found possibly uncommitted offsets $availableOffsets") @@ -209,9 +293,8 @@ class StreamExecution( committedOffsets = lastOffsets.toStreamProgress(sources) logDebug(s"Resuming with committed offsets: $committedOffsets") } - case None => // We are starting this stream for the first time. - logInfo(s"Starting new continuous query.") + logInfo(s"Starting new streaming query.") currentBatchId = 0 constructNextBatch() } @@ -225,7 +308,7 @@ class StreamExecution( case (source, available) => committedOffsets .get(source) - .map(committed => committed < available) + .map(committed => committed != available) .getOrElse(true) } } @@ -235,64 +318,86 @@ class StreamExecution( * batchId counter is incremented and a new log entry is written with the newest offsets. */ private def constructNextBatch(): Unit = { - // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). - // If we interrupt some thread running Shell.runCommand, we may hit this issue. - // As "FileStreamSource.getOffset" will create a file using HDFS API and call "Shell.runCommand" - // to set the file permission, we should not interrupt "microBatchThread" when running this - // method. See SPARK-14131. - // // Check to see what new data is available. - val newData = microBatchThread.runUninterruptibly { - uniqueSources.flatMap(s => s.getOffset.map(o => s -> o)) - } - availableOffsets ++= newData + val hasNewData = { + awaitBatchLock.lock() + try { + reportTimeTaken(GET_OFFSET_LATENCY) { + val latestOffsets: Map[Source, Option[Offset]] = uniqueSources.map { s => + reportTimeTaken(s, SOURCE_GET_OFFSET_LATENCY) { + (s, s.getOffset) + } + }.toMap + availableOffsets ++= latestOffsets.filter { case (s, o) => o.nonEmpty }.mapValues(_.get) + } - val hasNewData = awaitBatchLock.synchronized { - if (dataAvailable) { - true - } else { - noNewData = true - false + if (dataAvailable) { + true + } else { + noNewData = true + false + } + } finally { + awaitBatchLock.unlock() } } if (hasNewData) { - // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). - // If we interrupt some thread running Shell.runCommand, we may hit this issue. - // As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set - // the file permission, we should not interrupt "microBatchThread" when running this method. - // See SPARK-14131. - microBatchThread.runUninterruptibly { - assert( - offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), - s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") + reportTimeTaken(OFFSET_WAL_WRITE_LATENCY) { + assert(offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), + s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") + logInfo(s"Committed offsets for batch $currentBatchId.") + + // NOTE: The following code is correct because runBatches() processes exactly one + // batch at a time. If we add pipeline parallelism (multiple batches in flight at + // the same time), this cleanup logic will need to change. + + // Now that we've updated the scheduler's persistent checkpoint, it is safe for the + // sources to discard data from the previous batch. + val prevBatchOff = offsetLog.get(currentBatchId - 1) + if (prevBatchOff.isDefined) { + prevBatchOff.get.toStreamProgress(sources).foreach { + case (src, off) => src.commit(off) + } + } + + // Now that we have logged the new batch, no further processing will happen for + // the batch before the previous batch, and it is safe to discard the old metadata. + // Note that purge is exclusive, i.e. it purges everything before the target ID. + offsetLog.purge(currentBatchId - 1) } - currentBatchId += 1 - logInfo(s"Committed offsets for batch $currentBatchId.") } else { - awaitBatchLock.synchronized { + awaitBatchLock.lock() + try { // Wake up any threads that are waiting for the stream to progress. - awaitBatchLock.notifyAll() + awaitBatchLockCondition.signalAll() + } finally { + awaitBatchLock.unlock() } } + reportTimestamp(GET_OFFSET_TIMESTAMP) } /** * Processes any data available between `availableOffsets` and `committedOffsets`. */ private def runBatch(): Unit = { - val startTime = System.nanoTime() - // TODO: Move this to IncrementalExecution. // Request unprocessed data from all sources. - val newData = availableOffsets.flatMap { - case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) => - val current = committedOffsets.get(source) - val batch = source.getBatch(current, available) - logDebug(s"Retrieving data from $source: $current -> $available") - Some(source -> batch) - case _ => None - }.toMap + val newData = reportTimeTaken(GET_BATCH_LATENCY) { + availableOffsets.flatMap { + case (source, available) + if committedOffsets.get(source).map(_ != available).getOrElse(true) => + val current = committedOffsets.get(source) + val batch = reportTimeTaken(source, SOURCE_GET_BATCH_LATENCY) { + source.getBatch(current, available) + } + logDebug(s"Retrieving data from $source: $current -> $available") + Some(source -> batch) + case _ => None + } + } + reportTimestamp(GET_BATCH_TIMESTAMP) // A list of attributes that will need to be updated. var replacements = new ArrayBuffer[(Attribute, Attribute)] @@ -302,7 +407,8 @@ class StreamExecution( newData.get(source).map { data => val newPlan = data.logicalPlan assert(output.size == newPlan.output.size, - s"Invalid batch: ${output.mkString(",")} != ${newPlan.output.mkString(",")}") + s"Invalid batch: ${Utils.truncatedString(output, ",")} != " + + s"${Utils.truncatedString(newPlan.output, ",")}") replacements ++= output.zip(newPlan.output) newPlan }.getOrElse { @@ -312,39 +418,38 @@ class StreamExecution( // Rewire the plan to use the new attributes that were returned by the source. val replacementMap = AttributeMap(replacements) - val newPlan = withNewSources transformAllExpressions { + val triggerLogicalPlan = withNewSources transformAllExpressions { case a: Attribute if replacementMap.contains(a) => replacementMap(a) } - val optimizerStart = System.nanoTime() - lastExecution = new IncrementalExecution( - sparkSession, - newPlan, - outputMode, - checkpointFile("state"), - currentBatchId) - - lastExecution.executedPlan - val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 - logDebug(s"Optimized batch in ${optimizerTime}ms") + val executedPlan = reportTimeTaken(OPTIMIZER_LATENCY) { + lastExecution = new IncrementalExecution( + sparkSession, + triggerLogicalPlan, + outputMode, + checkpointFile("state"), + currentBatchId) + lastExecution.executedPlan // Force the lazy generation of execution plan + } val nextBatch = new Dataset(sparkSession, lastExecution, RowEncoder(lastExecution.analyzed.schema)) - sink.addBatch(currentBatchId - 1, nextBatch) + sink.addBatch(currentBatchId, nextBatch) + reportNumRows(executedPlan, triggerLogicalPlan, newData) - awaitBatchLock.synchronized { + awaitBatchLock.lock() + try { // Wake up any threads that are waiting for the stream to progress. - awaitBatchLock.notifyAll() + awaitBatchLockCondition.signalAll() + } finally { + awaitBatchLock.unlock() } - val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 - logInfo(s"Completed up to $availableOffsets in ${batchTime}ms") // Update committed offsets. committedOffsets ++= availableOffsets - postEvent(new QueryProgress(this)) } - private def postEvent(event: ContinuousQueryListener.Event) { + private def postEvent(event: StreamingQueryListener.Event) { sparkSession.streams.postListenerEvent(event) } @@ -360,22 +465,30 @@ class StreamExecution( microBatchThread.interrupt() microBatchThread.join() } + uniqueSources.foreach(_.stop()) logInfo(s"Query $name was stopped") } /** * Blocks the current thread until processing for data from the given `source` has reached at - * least the given `Offset`. This method is indented for use primarily when writing tests. + * least the given `Offset`. This method is intended for use primarily when writing tests. */ - def awaitOffset(source: Source, newOffset: Offset): Unit = { + private[sql] def awaitOffset(source: Source, newOffset: Offset): Unit = { def notDone = { val localCommittedOffsets = committedOffsets - !localCommittedOffsets.contains(source) || localCommittedOffsets(source) < newOffset + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset } while (notDone) { - logInfo(s"Waiting until $newOffset at $source") - awaitBatchLock.synchronized { awaitBatchLock.wait(100) } + awaitBatchLock.lock() + try { + awaitBatchLockCondition.await(100, TimeUnit.MILLISECONDS) + if (streamDeathCause != null) { + throw streamDeathCause + } + } finally { + awaitBatchLock.unlock() + } } logDebug(s"Unblocked at $newOffset for $source") } @@ -383,16 +496,21 @@ class StreamExecution( /** A flag to indicate that a batch has completed with no new data available. */ @volatile private var noNewData = false - override def processAllAvailable(): Unit = awaitBatchLock.synchronized { - noNewData = false - while (true) { - awaitBatchLock.wait(10000) - if (streamDeathCause != null) { - throw streamDeathCause - } - if (noNewData) { - return + override def processAllAvailable(): Unit = { + awaitBatchLock.lock() + try { + noNewData = false + while (true) { + awaitBatchLockCondition.await(10000, TimeUnit.MILLISECONDS) + if (streamDeathCause != null) { + throw streamDeathCause + } + if (noNewData) { + return + } } + } finally { + awaitBatchLock.unlock() } } @@ -419,8 +537,27 @@ class StreamExecution( } } + /** Expose for tests */ + def explainInternal(extended: Boolean): String = { + if (lastExecution == null) { + "No physical plan. Waiting for data." + } else { + val explain = ExplainCommand(lastExecution.logical, extended = extended) + sparkSession.sessionState.executePlan(explain).executedPlan.executeCollect() + .map(_.getString(0)).mkString("\n") + } + } + + override def explain(extended: Boolean): Unit = { + // scalastyle:off println + println(explainInternal(extended)) + // scalastyle:on println + } + + override def explain(): Unit = explain(extended = false) + override def toString: String = { - s"Continuous Query - $name [state = $state]" + s"Streaming Query - $name [state = $state]" } def toDebugString: String = { @@ -428,7 +565,7 @@ class StreamExecution( "Error:\n" + stackTraceToString(streamDeathCause.cause) } else "" s""" - |=== Continuous Query === + |=== Streaming Query === |Name: $name |Current Offsets: $committedOffsets | @@ -442,14 +579,147 @@ class StreamExecution( """.stripMargin } + /** + * Report row metrics of the executed trigger + * @param triggerExecutionPlan Execution plan of the trigger + * @param triggerLogicalPlan Logical plan of the trigger, generated from the query logical plan + * @param sourceToDF Source to DataFrame returned by the source.getBatch in this trigger + */ + private def reportNumRows( + triggerExecutionPlan: SparkPlan, + triggerLogicalPlan: LogicalPlan, + sourceToDF: Map[Source, DataFrame]): Unit = { + // We want to associate execution plan leaves to sources that generate them, so that we match + // the their metrics (e.g. numOutputRows) to the sources. To do this we do the following. + // Consider the translation from the streaming logical plan to the final executed plan. + // + // streaming logical plan (with sources) <==> trigger's logical plan <==> executed plan + // + // 1. We keep track of streaming sources associated with each leaf in the trigger's logical plan + // - Each logical plan leaf will be associated with a single streaming source. + // - There can be multiple logical plan leaves associated with a streaming source. + // - There can be leaves not associated with any streaming source, because they were + // generated from a batch source (e.g. stream-batch joins) + // + // 2. Assuming that the executed plan has same number of leaves in the same order as that of + // the trigger logical plan, we associate executed plan leaves with corresponding + // streaming sources. + // + // 3. For each source, we sum the metrics of the associated execution plan leaves. + // + val logicalPlanLeafToSource = sourceToDF.flatMap { case (source, df) => + df.logicalPlan.collectLeaves().map { leaf => leaf -> source } + } + val allLogicalPlanLeaves = triggerLogicalPlan.collectLeaves() // includes non-streaming sources + val allExecPlanLeaves = triggerExecutionPlan.collectLeaves() + val sourceToNumInputRows: Map[Source, Long] = + if (allLogicalPlanLeaves.size == allExecPlanLeaves.size) { + val execLeafToSource = allLogicalPlanLeaves.zip(allExecPlanLeaves).flatMap { + case (lp, ep) => logicalPlanLeafToSource.get(lp).map { source => ep -> source } + } + val sourceToNumInputRows = execLeafToSource.map { case (execLeaf, source) => + val numRows = execLeaf.metrics.get("numOutputRows").map(_.value).getOrElse(0L) + source -> numRows + } + sourceToNumInputRows.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source + } else { + if (!metricWarningLogged) { + def toString[T](seq: Seq[T]): String = s"(size = ${seq.size}), ${seq.mkString(", ")}" + logWarning( + "Could not report metrics as number leaves in trigger logical plan did not match that" + + s" of the execution plan:\n" + + s"logical plan leaves: ${toString(allLogicalPlanLeaves)}\n" + + s"execution plan leaves: ${toString(allExecPlanLeaves)}\n") + metricWarningLogged = true + } + Map.empty + } + val numOutputRows = triggerExecutionPlan.metrics.get("numOutputRows").map(_.value) + val stateNodes = triggerExecutionPlan.collect { + case p if p.isInstanceOf[StateStoreSaveExec] => p + } + + streamMetrics.reportNumInputRows(sourceToNumInputRows) + stateNodes.zipWithIndex.foreach { case (s, i) => + streamMetrics.reportTriggerDetail( + NUM_TOTAL_STATE_ROWS(i + 1), + s.metrics.get("numTotalStateRows").map(_.value).getOrElse(0L)) + streamMetrics.reportTriggerDetail( + NUM_UPDATED_STATE_ROWS(i + 1), + s.metrics.get("numUpdatedStateRows").map(_.value).getOrElse(0L)) + } + updateStatus() + } + + private def reportTimeTaken[T](triggerDetailKey: String)(body: => T): T = { + val startTime = triggerClock.getTimeMillis() + val result = body + val endTime = triggerClock.getTimeMillis() + val timeTaken = math.max(endTime - startTime, 0) + streamMetrics.reportTriggerDetail(triggerDetailKey, timeTaken) + updateStatus() + if (triggerDetailKey == TRIGGER_LATENCY) { + logInfo(s"Completed up to $availableOffsets in $timeTaken ms") + } + result + } + + private def reportTimeTaken[T](source: Source, triggerDetailKey: String)(body: => T): T = { + val startTime = triggerClock.getTimeMillis() + val result = body + val endTime = triggerClock.getTimeMillis() + streamMetrics.reportSourceTriggerDetail( + source, triggerDetailKey, math.max(endTime - startTime, 0)) + updateStatus() + result + } + + private def reportTimestamp(triggerDetailKey: String): Unit = { + streamMetrics.reportTriggerDetail(triggerDetailKey, triggerClock.getTimeMillis) + updateStatus() + } + + private def updateStatus(): Unit = { + val localAvailableOffsets = availableOffsets + val sourceStatuses = sources.map { s => + SourceStatus( + s.toString, + localAvailableOffsets.get(s).map(_.toString).getOrElse("-"), // TODO: use json if available + streamMetrics.currentSourceInputRate(s), + streamMetrics.currentSourceProcessingRate(s), + streamMetrics.currentSourceTriggerDetails(s)) + }.toArray + val sinkStatus = SinkStatus( + sink.toString, + committedOffsets.toCompositeOffset(sources).toString) + + currentStatus = + StreamingQueryStatus( + name = name, + id = id, + timestamp = triggerClock.getTimeMillis(), + inputRate = streamMetrics.currentInputRate(), + processingRate = streamMetrics.currentProcessingRate(), + latency = streamMetrics.currentLatency(), + sourceStatuses = sourceStatuses, + sinkStatus = sinkStatus, + triggerDetails = streamMetrics.currentTriggerDetails()) + } + trait State case object INITIALIZED extends State case object ACTIVE extends State case object TERMINATED extends State } -private[sql] object StreamExecution { - private val nextId = new AtomicInteger() +object StreamExecution { + private val _nextId = new AtomicLong(0) - def nextName: String = s"query-${nextId.getAndIncrement}" + def nextId: Long = _nextId.getAndIncrement() } + +/** + * A special thread to run the stream query. Some codes require to run in the StreamExecutionThread + * and will use `classOf[StreamExecutionThread]` to check. + */ +abstract class StreamExecutionThread(name: String) extends UninterruptibleThread(name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala deleted file mode 100644 index 4f699719c2768..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.hadoop.fs.{FileStatus, Path} - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.execution.datasources.{FileCatalog, Partition, PartitionSpec} -import org.apache.spark.sql.types.StructType - -class StreamFileCatalog(sparkSession: SparkSession, path: Path) extends FileCatalog with Logging { - val metadataDirectory = new Path(path, FileStreamSink.metadataDir) - logInfo(s"Reading streaming file log from $metadataDirectory") - val metadataLog = new FileStreamSinkLog(sparkSession, metadataDirectory.toUri.toString) - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - - override def paths: Seq[Path] = path :: Nil - - override def partitionSpec(): PartitionSpec = PartitionSpec(StructType(Nil), Nil) - - /** - * Returns all valid files grouped into partitions when the data is partitioned. If the data is - * unpartitioned, this will return a single partition with not partition values. - * - * @param filters the filters used to prune which partitions are returned. These filters must - * only refer to partition columns and this method will only return files - * where these predicates are guaranteed to evaluate to `true`. Thus, these - * filters will not need to be evaluated again on the returned data. - */ - override def listFiles(filters: Seq[Expression]): Seq[Partition] = - Partition(InternalRow.empty, allFiles()) :: Nil - - override def getStatus(path: Path): Array[FileStatus] = fs.listStatus(path) - - override def refresh(): Unit = {} - - override def allFiles(): Seq[FileStatus] = { - metadataLog.allFiles().map(_.toFileStatus) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala new file mode 100644 index 0000000000000..e98d1883e4596 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamMetrics.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.{util => ju} + +import scala.collection.mutable + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.source.{Source => CodahaleSource} +import org.apache.spark.util.Clock + +/** + * Class that manages all the metrics related to a StreamingQuery. It does the following. + * - Calculates metrics (rates, latencies, etc.) based on information reported by StreamExecution. + * - Allows the current metric values to be queried + * - Serves some of the metrics through Codahale/DropWizard metrics + * + * @param sources Unique set of sources in a query + * @param triggerClock Clock used for triggering in StreamExecution + * @param codahaleSourceName Root name for all the Codahale metrics + */ +class StreamMetrics(sources: Set[Source], triggerClock: Clock, codahaleSourceName: String) + extends CodahaleSource with Logging { + + import StreamMetrics._ + + // Trigger infos + private val triggerDetails = new mutable.HashMap[String, String] + private val sourceTriggerDetails = new mutable.HashMap[Source, mutable.HashMap[String, String]] + + // Rate estimators for sources and sinks + private val inputRates = new mutable.HashMap[Source, RateCalculator] + private val processingRates = new mutable.HashMap[Source, RateCalculator] + + // Number of input rows in the current trigger + private val numInputRows = new mutable.HashMap[Source, Long] + private var currentTriggerStartTimestamp: Long = -1 + private var previousTriggerStartTimestamp: Long = -1 + private var latency: Option[Double] = None + + override val sourceName: String = codahaleSourceName + override val metricRegistry: MetricRegistry = new MetricRegistry + + // =========== Initialization =========== + + // Metric names should not have . in them, so that all the metrics of a query are identified + // together in Ganglia as a single metric group + registerGauge("inputRate-total", currentInputRate) + registerGauge("processingRate-total", () => currentProcessingRate) + registerGauge("latency", () => currentLatency().getOrElse(-1.0)) + + sources.foreach { s => + inputRates.put(s, new RateCalculator) + processingRates.put(s, new RateCalculator) + sourceTriggerDetails.put(s, new mutable.HashMap[String, String]) + + registerGauge(s"inputRate-${s.toString}", () => currentSourceInputRate(s)) + registerGauge(s"processingRate-${s.toString}", () => currentSourceProcessingRate(s)) + } + + // =========== Setter methods =========== + + def reportTriggerStarted(triggerId: Long): Unit = synchronized { + numInputRows.clear() + triggerDetails.clear() + sourceTriggerDetails.values.foreach(_.clear()) + + reportTriggerDetail(TRIGGER_ID, triggerId) + sources.foreach(s => reportSourceTriggerDetail(s, TRIGGER_ID, triggerId)) + reportTriggerDetail(IS_TRIGGER_ACTIVE, true) + currentTriggerStartTimestamp = triggerClock.getTimeMillis() + reportTriggerDetail(START_TIMESTAMP, currentTriggerStartTimestamp) + } + + def reportTriggerDetail[T](key: String, value: T): Unit = synchronized { + triggerDetails.put(key, value.toString) + } + + def reportSourceTriggerDetail[T](source: Source, key: String, value: T): Unit = synchronized { + sourceTriggerDetails(source).put(key, value.toString) + } + + def reportNumInputRows(inputRows: Map[Source, Long]): Unit = synchronized { + numInputRows ++= inputRows + } + + def reportTriggerFinished(): Unit = synchronized { + require(currentTriggerStartTimestamp >= 0) + val currentTriggerFinishTimestamp = triggerClock.getTimeMillis() + reportTriggerDetail(FINISH_TIMESTAMP, currentTriggerFinishTimestamp) + triggerDetails.remove(STATUS_MESSAGE) + reportTriggerDetail(IS_TRIGGER_ACTIVE, false) + + // Report number of rows + val totalNumInputRows = numInputRows.values.sum + reportTriggerDetail(NUM_INPUT_ROWS, totalNumInputRows) + numInputRows.foreach { case (s, r) => + reportSourceTriggerDetail(s, NUM_SOURCE_INPUT_ROWS, r) + } + + val currentTriggerDuration = currentTriggerFinishTimestamp - currentTriggerStartTimestamp + val previousInputIntervalOption = if (previousTriggerStartTimestamp >= 0) { + Some(currentTriggerStartTimestamp - previousTriggerStartTimestamp) + } else None + + // Update input rate = num rows received by each source during the previous trigger interval + // Interval is measures as interval between start times of previous and current trigger. + // + // TODO: Instead of trigger start, we should use time when getOffset was called on each source + // as this may be different for each source if there are many sources in the query plan + // and getOffset is called serially on them. + if (previousInputIntervalOption.nonEmpty) { + sources.foreach { s => + inputRates(s).update(numInputRows.getOrElse(s, 0), previousInputIntervalOption.get) + } + } + + // Update processing rate = num rows processed for each source in current trigger duration + sources.foreach { s => + processingRates(s).update(numInputRows.getOrElse(s, 0), currentTriggerDuration) + } + + // Update latency = if data present, 0.5 * previous trigger interval + current trigger duration + if (previousInputIntervalOption.nonEmpty && totalNumInputRows > 0) { + latency = Some((previousInputIntervalOption.get.toDouble / 2) + currentTriggerDuration) + } else { + latency = None + } + + previousTriggerStartTimestamp = currentTriggerStartTimestamp + currentTriggerStartTimestamp = -1 + } + + // =========== Getter methods =========== + + def currentInputRate(): Double = synchronized { + // Since we are calculating source input rates using the same time interval for all sources + // it is fine to calculate total input rate as the sum of per source input rate. + inputRates.map(_._2.currentRate).sum + } + + def currentSourceInputRate(source: Source): Double = synchronized { + inputRates(source).currentRate + } + + def currentProcessingRate(): Double = synchronized { + // Since we are calculating source processing rates using the same time interval for all sources + // it is fine to calculate total processing rate as the sum of per source processing rate. + processingRates.map(_._2.currentRate).sum + } + + def currentSourceProcessingRate(source: Source): Double = synchronized { + processingRates(source).currentRate + } + + def currentLatency(): Option[Double] = synchronized { latency } + + def currentTriggerDetails(): Map[String, String] = synchronized { triggerDetails.toMap } + + def currentSourceTriggerDetails(source: Source): Map[String, String] = synchronized { + sourceTriggerDetails(source).toMap + } + + // =========== Other methods =========== + + private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { + synchronized { + metricRegistry.register(name, new Gauge[T] { + override def getValue: T = f() + }) + } + } + + def stop(): Unit = synchronized { + triggerDetails.clear() + inputRates.valuesIterator.foreach { _.stop() } + processingRates.valuesIterator.foreach { _.stop() } + latency = None + } +} + +object StreamMetrics extends Logging { + /** Simple utility class to calculate rate while avoiding DivideByZero */ + class RateCalculator { + @volatile private var rate: Option[Double] = None + + def update(numRows: Long, timeGapMs: Long): Unit = { + if (timeGapMs > 0) { + rate = Some(numRows.toDouble * 1000 / timeGapMs) + } else { + rate = None + logDebug(s"Rate updates cannot with zero or negative time gap $timeGapMs") + } + } + + def currentRate: Double = rate.getOrElse(0.0) + + def stop(): Unit = { rate = None } + } + + + val TRIGGER_ID = "triggerId" + val IS_TRIGGER_ACTIVE = "isTriggerActive" + val IS_DATA_PRESENT_IN_TRIGGER = "isDataPresentInTrigger" + val STATUS_MESSAGE = "statusMessage" + + val START_TIMESTAMP = "timestamp.triggerStart" + val GET_OFFSET_TIMESTAMP = "timestamp.afterGetOffset" + val GET_BATCH_TIMESTAMP = "timestamp.afterGetBatch" + val FINISH_TIMESTAMP = "timestamp.triggerFinish" + + val GET_OFFSET_LATENCY = "latency.getOffset.total" + val GET_BATCH_LATENCY = "latency.getBatch.total" + val OFFSET_WAL_WRITE_LATENCY = "latency.offsetLogWrite" + val OPTIMIZER_LATENCY = "latency.optimizer" + val TRIGGER_LATENCY = "latency.fullTrigger" + val SOURCE_GET_OFFSET_LATENCY = "latency.getOffset.source" + val SOURCE_GET_BATCH_LATENCY = "latency.getBatch.source" + + val NUM_INPUT_ROWS = "numRows.input.total" + val NUM_SOURCE_INPUT_ROWS = "numRows.input.source" + def NUM_TOTAL_STATE_ROWS(aggId: Int): String = s"numRows.state.aggregation$aggId.total" + def NUM_UPDATED_STATE_ROWS(aggId: Int): String = s"numRows.state.aggregation$aggId.updated" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 405a5f0387a7e..db0bd9e6bc6f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -26,7 +26,7 @@ class StreamProgress( val baseMap: immutable.Map[Source, Offset] = new immutable.HashMap[Source, Offset]) extends scala.collection.immutable.Map[Source, Offset] { - private[sql] def toCompositeOffset(source: Seq[Source]): CompositeOffset = { + def toCompositeOffset(source: Seq[Source]): CompositeOffset = { CompositeOffset(source.map(get)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala new file mode 100644 index 0000000000000..fc2190d39da4f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} +import org.apache.spark.sql.streaming.StreamingQueryListener +import org.apache.spark.util.ListenerBus + +/** + * A bus to forward events to [[StreamingQueryListener]]s. This one will send received + * [[StreamingQueryListener.Event]]s to the Spark listener bus. It also registers itself with + * Spark listener bus, so that it can receive [[StreamingQueryListener.Event]]s and dispatch them + * to StreamingQueryListener. + */ +class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) + extends SparkListener with ListenerBus[StreamingQueryListener, StreamingQueryListener.Event] { + + import StreamingQueryListener._ + + sparkListenerBus.addListener(this) + + /** + * Post a StreamingQueryListener event to the Spark listener bus asynchronously. This event will + * be dispatched to all StreamingQueryListener in the thread of the Spark listener bus. + */ + def post(event: StreamingQueryListener.Event) { + event match { + case s: QueryStartedEvent => + postToAll(s) + case _ => + sparkListenerBus.post(event) + } + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case e: StreamingQueryListener.Event => + postToAll(e) + case _ => + } + } + + override protected def doPostEvent( + listener: StreamingQueryListener, + event: StreamingQueryListener.Event): Unit = { + event match { + case queryStarted: QueryStartedEvent => + listener.onQueryStarted(queryStarted) + case queryProgress: QueryProgressEvent => + listener.onQueryProgress(queryProgress) + case queryTerminated: QueryTerminatedEvent => + listener.onQueryTerminated(queryTerminated) + case _ => + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 4d65d2f4f57fc..e8b00094add3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -17,8 +17,11 @@ package org.apache.spark.sql.execution.streaming +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource object StreamingRelation { @@ -50,6 +53,17 @@ case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) ex override def toString: String = source.toString } +/** + * A dummy physical plan for [[StreamingRelation]] to support + * [[org.apache.spark.sql.Dataset.explain]] + */ +case class StreamingRelationExec(sourceName: String, output: Seq[Attribute]) extends LeafExecNode { + override def toString: String = sourceName + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException("StreamingRelationExec cannot be executed") + } +} + object StreamingExecutionRelation { def apply(source: Source): StreamingExecutionRelation = { StreamingExecutionRelation(source, source.schema.toAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala index a1132d510685c..ac510df209f0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging -import org.apache.spark.sql.ProcessingTime +import org.apache.spark.sql.streaming.ProcessingTime import org.apache.spark.util.{Clock, SystemClock} trait TriggerExecutor { @@ -65,8 +65,13 @@ case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds") } - /** Return the next multiple of intervalMs */ + /** + * Returns the start time in milliseconds for the next batch interval, given the current time. + * Note that a batch interval is inclusive with respect to its start time, and thus calling + * `nextBatchTime` with the result of a previous call should return the next interval. (i.e. given + * an interval of `100 ms`, `nextBatchTime(nextBatchTime(0)) = 200` rather than `0`). + */ def nextBatchTime(now: Long): Long = { - (now - 1) / intervalMs * intervalMs + intervalMs + now / intervalMs * intervalMs + intervalMs } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala new file mode 100644 index 0000000000000..e8b9712d19cd5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} +import org.apache.spark.sql.streaming.OutputMode + +class ConsoleSink(options: Map[String, String]) extends Sink with Logging { + // Number of rows to display, by default 20 rows + private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20) + + // Truncate the displayed data if it is too long, by default it is true + private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true) + + // Track the batch id + private var lastBatchId = -1L + + override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { + val batchIdStr = if (batchId <= lastBatchId) { + s"Rerun batch: $batchId" + } else { + lastBatchId = batchId + s"Batch: $batchId" + } + + // scalastyle:off println + println("-------------------------------------------") + println(batchIdStr) + println("-------------------------------------------") + // scalastyle:off println + data.sparkSession.createDataFrame( + data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) + .show(numRowsToShow, isTruncated) + } +} + +class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister { + def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + new ConsoleSink(parameters) + } + + def shortName(): String = "console" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index a34927ff994a4..66dc20471b3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -20,15 +20,17 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils object MemoryStream { protected val currentBlockId = new AtomicInteger(0) @@ -49,12 +51,23 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected val logicalPlan = StreamingExecutionRelation(this) protected val output = logicalPlan.output + /** + * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive. + * Stored in a ListBuffer to facilitate removing committed batches. + */ @GuardedBy("this") - protected val batches = new ArrayBuffer[Dataset[A]] + protected val batches = new ListBuffer[Dataset[A]] @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) + /** + * Last offset that was discarded, or -1 if no commits have occurred. Note that the value + * -1 is used in calculations below and isn't just an arbitrary constant. + */ + @GuardedBy("this") + protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) + def schema: StructType = encoder.schema def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { @@ -80,24 +93,28 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } } - override def toString: String = s"MemoryStream[${output.mkString(",")}]" + override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]" override def getOffset: Option[Offset] = synchronized { - if (batches.isEmpty) { + if (currentOffset.offset == -1) { None } else { Some(currentOffset) } } - /** - * Returns the data that is between the offsets (`start`, `end`]. - */ override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal) val startOrdinal = start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 - val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) } + + // Internal buffer only holds the batches after lastCommittedOffset. + val newBlocks = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + batches.slice(sliceStart, sliceEnd) + } logDebug( s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") @@ -108,41 +125,86 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) sys.error("No data selected!") } } + + override def commit(end: Offset): Unit = synchronized { + end match { + case newOffset: LongOffset => + val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt + + if (offsetDiff < 0) { + sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") + } + + batches.trimStart(offsetDiff) + lastOffsetCommitted = newOffset + case _ => + sys.error(s"MemoryStream.commit() received an offset ($end) that did not originate with " + + "an instance of this class") + } + } + + override def stop() {} + + def reset(): Unit = synchronized { + batches.clear() + currentOffset = new LongOffset(-1) + lastOffsetCommitted = new LongOffset(-1) + } } /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(val schema: StructType) extends Sink with Logging { +class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink with Logging { + + private case class AddedData(batchId: Long, data: Array[Row]) + /** An order list of batches that have been written to this [[Sink]]. */ @GuardedBy("this") - private val batches = new ArrayBuffer[Array[Row]]() + private val batches = new ArrayBuffer[AddedData]() /** Returns all rows that are stored in this [[Sink]]. */ def allData: Seq[Row] = synchronized { - batches.flatten + batches.map(_.data).flatten + } + + def latestBatchId: Option[Long] = synchronized { + batches.lastOption.map(_.batchId) } - def lastBatch: Seq[Row] = synchronized { batches.last } + def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } def toDebugString: String = synchronized { - batches.zipWithIndex.map { case (b, i) => - val dataStr = try b.mkString(" ") catch { + batches.map { case AddedData(batchId, data) => + val dataStr = try data.mkString(" ") catch { case NonFatal(e) => "[Error converting to string]" } - s"$i: $dataStr" + s"$batchId: $dataStr" }.mkString("\n") } override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { - if (batchId == batches.size) { - logDebug(s"Committing batch $batchId") - batches.append(data.collect()) + if (latestBatchId.isEmpty || batchId > latestBatchId.get) { + logDebug(s"Committing batch $batchId to $this") + outputMode match { + case InternalOutputModes.Append | InternalOutputModes.Update => + batches.append(AddedData(batchId, data.collect())) + + case InternalOutputModes.Complete => + batches.clear() + batches.append(AddedData(batchId, data.collect())) + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") + } } else { logDebug(s"Skipping already committed batch: $batchId") } } + + override def toString(): String = "MemorySink" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala new file mode 100644 index 0000000000000..c662e7c6bc775 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.net.Socket +import java.sql.Timestamp +import java.text.SimpleDateFormat +import java.util.Calendar +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable.ListBuffer +import scala.util.{Failure, Success, Try} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + + +object TextSocketSource { + val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) + val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil) + val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") +} + +/** + * A source that reads text lines through a TCP socket, designed only for tutorials and debugging. + * This source will *not* work in production applications due to multiple reasons, including no + * support for fault recovery and keeping all of the text read in memory forever. + */ +class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlContext: SQLContext) + extends Source with Logging +{ + @GuardedBy("this") + private var socket: Socket = null + + @GuardedBy("this") + private var readThread: Thread = null + + /** + * All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive. + * Stored in a ListBuffer to facilitate removing committed batches. + */ + @GuardedBy("this") + protected val batches = new ListBuffer[(String, Timestamp)] + + @GuardedBy("this") + protected var currentOffset: LongOffset = new LongOffset(-1) + + @GuardedBy("this") + protected var lastOffsetCommitted : LongOffset = new LongOffset(-1) + + initialize() + + private def initialize(): Unit = synchronized { + socket = new Socket(host, port) + val reader = new BufferedReader(new InputStreamReader(socket.getInputStream)) + readThread = new Thread(s"TextSocketSource($host, $port)") { + setDaemon(true) + + override def run(): Unit = { + try { + while (true) { + val line = reader.readLine() + if (line == null) { + // End of file reached + logWarning(s"Stream closed by $host:$port") + return + } + TextSocketSource.this.synchronized { + val newData = (line, + Timestamp.valueOf( + TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime())) + ) + currentOffset = currentOffset + 1 + batches.append(newData) + } + } + } catch { + case e: IOException => + } + } + } + readThread.start() + } + + /** Returns the schema of the data from this source */ + override def schema: StructType = if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP + else TextSocketSource.SCHEMA_REGULAR + + override def getOffset: Option[Offset] = synchronized { + if (currentOffset.offset == -1) { + None + } else { + Some(currentOffset) + } + } + + /** Returns the data that is between the offsets (`start`, `end`]. */ + override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { + val startOrdinal = + start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 + val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 + + // Internal buffer only holds the batches after lastOffsetCommitted + val rawList = synchronized { + val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1 + batches.slice(sliceStart, sliceEnd) + } + + import sqlContext.implicits._ + val rawBatch = sqlContext.createDataset(rawList) + + // Underlying MemoryStream has schema (String, Timestamp); strip out the timestamp + // if requested. + if (includeTimestamp) { + rawBatch.toDF("value", "timestamp") + } else { + // Strip out timestamp + rawBatch.select("_1").toDF("value") + } + } + + override def commit(end: Offset): Unit = synchronized { + if (end.isInstanceOf[LongOffset]) { + val newOffset = end.asInstanceOf[LongOffset] + val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt + + if (offsetDiff < 0) { + sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") + } + + batches.trimStart(offsetDiff) + lastOffsetCommitted = newOffset + } else { + sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " + + s"originate with an instance of this class") + } + } + + /** Stop this source. */ + override def stop(): Unit = synchronized { + if (socket != null) { + try { + // Unfortunately, BufferedReader.readLine() cannot be interrupted, so the only way to + // stop the readThread is to close the socket. + socket.close() + } catch { + case e: IOException => + } + socket = null + } + } +} + +class TextSocketSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { + private def parseIncludeTimestamp(params: Map[String, String]): Boolean = { + Try(params.getOrElse("includeTimestamp", "false").toBoolean) match { + case Success(bool) => bool + case Failure(_) => + throw new AnalysisException("includeTimestamp must be set to either \"true\" or \"false\"") + } + } + + /** Returns the name and schema of the source that can be used to continually read data. */ + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + logWarning("The socket source should not be used for production applications! " + + "It does not support recovery.") + if (!parameters.contains("host")) { + throw new AnalysisException("Set a host to read from with option(\"host\", ...).") + } + if (!parameters.contains("port")) { + throw new AnalysisException("Set a port to read from with option(\"port\", ...).") + } + val schema = + if (parseIncludeTimestamp(parameters)) { + TextSocketSource.SCHEMA_TIMESTAMP + } else { + TextSocketSource.SCHEMA_REGULAR + } + ("textSocket", schema) + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + val host = parameters("host") + val port = parameters("port").toInt + new TextSocketSource(host, port, parseIncludeTimestamp(parameters), sqlContext) + } + + /** String that represents the format that this data source provider uses. */ + override def shortName(): String = "socket" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3335755fd3b67..bbbb6d39c63f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -159,7 +159,7 @@ private[state] class HDFSBackedStateStoreProvider( } catch { case NonFatal(e) => throw new IllegalStateException( - s"Error committing version $newVersion into ${HDFSBackedStateStoreProvider.this}", e) + s"Error committing version $newVersion into $this", e) } } @@ -197,12 +197,18 @@ private[state] class HDFSBackedStateStoreProvider( allUpdates.values().asScala.toIterator } + override def numKeys(): Long = mapToUpdate.size() + /** * Whether all updates have been committed */ override private[state] def hasCommitted: Boolean = { state == COMMITTED } + + override def toString(): String = { + s"HDFSStateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + } } /** Get the state store for making updates to create a new `version` of the store. */ @@ -213,7 +219,7 @@ private[state] class HDFSBackedStateStoreProvider( newMap.putAll(loadMap(version)) } val store = new HDFSBackedStateStore(version, newMap) - logInfo(s"Retrieved version $version of $this for update") + logInfo(s"Retrieved version $version of ${HDFSBackedStateStoreProvider.this} for update") store } @@ -229,7 +235,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def toString(): String = { - s"StateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" + s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" } /* Internal classes and methods */ @@ -489,10 +495,12 @@ private[state] class HDFSBackedStateStoreProvider( val mapsToRemove = loadedMaps.keys.filter(_ < earliestVersionToRetain).toSeq mapsToRemove.foreach(loadedMaps.remove) } - files.filter(_.version < earliestFileToRetain.version).foreach { f => + val filesToDelete = files.filter(_.version < earliestFileToRetain.version) + filesToDelete.foreach { f => fs.delete(f.path, true) } - logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this") + logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + + filesToDelete.mkString(", ")) } } } catch { @@ -556,7 +564,7 @@ private[state] class HDFSBackedStateStoreProvider( } } val storeFiles = versionToFiles.values.toSeq.sortBy(_.version) - logDebug(s"Current set of files for $this: $storeFiles") + logDebug(s"Current set of files for $this: ${storeFiles.mkString(", ")}") storeFiles } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 952150632519e..7132e284c28f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -77,6 +77,9 @@ trait StateStore { */ def updates(): Iterator[StoreUpdate] + /** Number of keys in the state store */ + def numKeys(): Long + /** * Whether all updates have been committed */ @@ -113,9 +116,9 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate * the store is the active instance. Accordingly, it either keeps it loaded and performs * maintenance, or unloads the store. */ -private[state] object StateStore extends Logging { +object StateStore extends Logging { - val MAINTENANCE_INTERVAL_CONFIG = "spark.streaming.stateStore.maintenanceInterval" + val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() @@ -155,6 +158,10 @@ private[state] object StateStore extends Logging { loadedProviders.contains(storeId) } + def isMaintenanceRunning: Boolean = loadedProviders.synchronized { + maintenanceTask != null + } + /** Unload and stop all state store providers */ def stop(): Unit = loadedProviders.synchronized { loadedProviders.clear() @@ -187,44 +194,44 @@ private[state] object StateStore extends Logging { */ private def doMaintenance(): Unit = { logDebug("Doing maintenance") - loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => - try { - if (verifyIfStoreInstanceActive(id)) { - provider.doMaintenance() - } else { - unload(id) - logInfo(s"Unloaded $provider") + if (SparkEnv.get == null) { + stop() + } else { + loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => + try { + if (verifyIfStoreInstanceActive(id)) { + provider.doMaintenance() + } else { + unload(id) + logInfo(s"Unloaded $provider") + } + } catch { + case NonFatal(e) => + logWarning(s"Error managing $provider, stopping management thread") + stop() } - } catch { - case NonFatal(e) => - logWarning(s"Error managing $provider") } } } private def reportActiveStoreInstance(storeId: StateStoreId): Unit = { - try { + if (SparkEnv.get != null) { val host = SparkEnv.get.blockManager.blockManagerId.host val executorId = SparkEnv.get.blockManager.blockManagerId.executorId coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId)) logDebug(s"Reported that the loaded instance $storeId is active") - } catch { - case NonFatal(e) => - logWarning(s"Error reporting active instance of $storeId") } } private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = { - try { + if (SparkEnv.get != null) { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId val verified = coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) - logDebug(s"Verified whether the loaded instance $storeId is active: $verified" ) + logDebug(s"Verified whether the loaded instance $storeId is active: $verified") verified - } catch { - case NonFatal(e) => - logWarning(s"Error verifying active instance of $storeId") - false + } else { + false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index e418217238cca..267d17623d5e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -38,14 +38,14 @@ private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: Str private case class GetLocation(storeId: StateStoreId) extends StateStoreCoordinatorMessage -private case class DeactivateInstances(storeRootLocation: String) +private case class DeactivateInstances(checkpointLocation: String) extends StateStoreCoordinatorMessage private object StopCoordinator extends StateStoreCoordinatorMessage /** Helper object used to create reference to [[StateStoreCoordinator]]. */ -private[sql] object StateStoreCoordinatorRef extends Logging { +object StateStoreCoordinatorRef extends Logging { private val endpointName = "StateStoreCoordinator" @@ -77,7 +77,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging { * Reference to a [[StateStoreCoordinator]] that can be used to coordinate instances of * [[StateStore]]s across all the executors, and get their locations for job scheduling. */ -private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { +class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { private[state] def reportActiveInstance( storeId: StateStoreId, @@ -111,11 +111,13 @@ private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointR * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, * and get their locations for job scheduling. */ -private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { +private class StateStoreCoordinator(override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with Logging { private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { case ReportActiveInstance(id, host, executorId) => + logDebug(s"Reported state store $id is active at $executorId") instances.put(id, ExecutorCacheTaskLocation(host, executorId)) } @@ -125,19 +127,25 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadS case Some(location) => location.executorId == execId case None => false } + logDebug(s"Verified that state store $id is active: $response") context.reply(response) case GetLocation(id) => - context.reply(instances.get(id).map(_.toString)) + val executorId = instances.get(id).map(_.toString) + logDebug(s"Got location of the state store $id: $executorId") + context.reply(executorId) - case DeactivateInstances(loc) => + case DeactivateInstances(checkpointLocation) => val storeIdsToRemove = - instances.keys.filter(_.checkpointLocation == loc).toSeq + instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq instances --= storeIdsToRemove + logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " + + storeIdsToRemove.mkString(", ")) context.reply(true) case StopCoordinator => stop() // Stop before replying to ensure that endpoint name has been deregistered + logInfo("StateStoreCoordinator stopped") context.reply(true) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 4a1f12d68519a..461d3010ada7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DataType * This is the physical copy of ScalarSubquery to be used inside SparkPlan. */ case class ScalarSubquery( - @transient executedPlan: SparkPlan, + executedPlan: SparkPlan, exprId: ExprId) extends SubqueryExpression { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index 4b4fa126b85f3..23fc0bd0bce13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -24,7 +24,7 @@ import scala.xml.Node import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} -private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging { +class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging { private val listener = parent.listener diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 9118593c0e4ce..60f13432d78d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -26,6 +26,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.metric._ import org.apache.spark.ui.SparkUI +import org.apache.spark.util.AccumulatorContext @DeveloperApi case class SparkListenerSQLExecutionStart( @@ -41,14 +42,18 @@ case class SparkListenerSQLExecutionStart( case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent -private[sql] class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { +@DeveloperApi +case class SparkListenerDriverAccumUpdates(executionId: Long, accumUpdates: Seq[(Long, Long)]) + extends SparkListenerEvent + +class SQLHistoryListenerFactory extends SparkHistoryListenerFactory { override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = { List(new SQLHistoryListener(conf, sparkUI)) } } -private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Logging { +class SQLListener(conf: SparkConf) extends SparkListener with Logging { private val retainedExecutions = conf.getInt("spark.sql.ui.retainedExecutions", 1000) @@ -164,7 +169,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskEnd.taskInfo.taskId, taskEnd.stageId, taskEnd.stageAttemptId, - taskEnd.taskMetrics.accumulators().map(a => a.toInfo(Some(a.localValue), None)), + taskEnd.taskMetrics.externalAccums.map(a => a.toInfo(Some(a.value), None)), finishTask = true) } } @@ -177,8 +182,10 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi taskId: Long, stageId: Int, stageAttemptID: Int, - accumulatorUpdates: Seq[AccumulableInfo], + _accumulatorUpdates: Seq[AccumulableInfo], finishTask: Boolean): Unit = { + val accumulatorUpdates = + _accumulatorUpdates.filter(_.update.isDefined).map(accum => (accum.id, accum.update.get)) _stageIdToStageMetrics.get(stageId) match { case Some(stageMetrics) => @@ -248,6 +255,13 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } } } + case SparkListenerDriverAccumUpdates(executionId, accumUpdates) => synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + for ((accId, accValue) <- accumUpdates) { + executionUIData.driverAccumUpdates(accId) = accValue + } + } + } case _ => // Ignore } @@ -290,12 +304,12 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield { - assert(accumulatorUpdate.update.isDefined, s"accumulator update from " + - s"task did not have a partial value: ${accumulatorUpdate.name}") - (accumulatorUpdate.id, accumulatorUpdate.update.get) + (accumulatorUpdate._1, accumulatorUpdate._2) } }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } - mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => + + val driverUpdates = executionUIData.driverAccumUpdates.toSeq + mergeAccumulatorUpdates(accumulatorUpdates ++ driverUpdates, accumulatorId => executionUIData.accumulatorMetrics(accumulatorId).metricType) case None => // This execution has been dropped @@ -319,7 +333,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi /** * A [[SQLListener]] for rendering the SQL UI in the history server. */ -private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) +class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) extends SQLListener(conf) { private var sqlTabAttached = false @@ -336,7 +350,7 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI) taskEnd.taskInfo.accumulables.flatMap { a => // Filter out accumulators that are not SQL metrics // For now we assume all SQL metrics are Long's that have been JSON serialized as String's - if (a.metadata == Some(SQLMetrics.ACCUM_IDENTIFIER)) { + if (a.metadata == Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) { val newValue = a.update.map(_.toString.toLong).getOrElse(0L) Some(a.copy(update = Some(newValue))) } else { @@ -367,10 +381,15 @@ private[ui] class SQLExecutionUIData( val physicalPlanDescription: String, val physicalPlanGraph: SparkPlanGraph, val accumulatorMetrics: Map[Long, SQLPlanMetric], - val submissionTime: Long, - var completionTime: Option[Long] = None, - val jobs: mutable.HashMap[Long, JobExecutionStatus] = mutable.HashMap.empty, - val stages: mutable.ArrayBuffer[Int] = mutable.ArrayBuffer()) { + val submissionTime: Long) { + + var completionTime: Option[Long] = None + + val jobs: mutable.HashMap[Long, JobExecutionStatus] = mutable.HashMap.empty + + val stages: mutable.ArrayBuffer[Int] = mutable.ArrayBuffer() + + val driverAccumUpdates: mutable.HashMap[Long, Long] = mutable.HashMap.empty /** * Return whether there are running jobs in this execution. @@ -418,4 +437,4 @@ private[ui] class SQLStageMetrics( private[ui] class SQLTaskMetrics( val attemptId: Long, // TODO not used yet var finished: Boolean, - var accumulatorUpdates: Seq[AccumulableInfo]) + var accumulatorUpdates: Seq[(Long, Any)]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala index e8675ce749a2b..d0376af3e31ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.ui import org.apache.spark.internal.Logging import org.apache.spark.ui.{SparkUI, SparkUITab} -private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) +class SQLTab(val listener: SQLListener, sparkUI: SparkUI) extends SparkUITab(sparkUI, "SQL") with Logging { val parent = sparkUI @@ -32,6 +32,6 @@ private[sql] class SQLTab(val listener: SQLListener, sparkUI: SparkUI) parent.addStaticHandler(SQLTab.STATIC_RESOURCE_DIR, "/static/sql") } -private[sql] object SQLTab { +object SQLTab { private val STATIC_RESOURCE_DIR = "org/apache/spark/sql/execution/ui/static" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 8f5681bfc7cc6..4bb9d6fef4c1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.execution.{SparkPlanInfo, WholeStageCodegenExec} -import org.apache.spark.sql.execution.metric.SQLMetrics + /** * A graph used for storing information of an executionPlan of DataFrame. @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics * Each graph is defined with a set of nodes and a set of edges. Each node represents a node in the * SparkPlan tree, and each edge represents a parent-child relationship between two nodes. */ -private[ui] case class SparkPlanGraph( +case class SparkPlanGraph( nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) { def makeDotFile(metrics: Map[Long, String]): String = { @@ -55,7 +55,7 @@ private[ui] case class SparkPlanGraph( } } -private[sql] object SparkPlanGraph { +object SparkPlanGraph { /** * Build a SparkPlanGraph from the root of a SparkPlan tree. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index baae9dd2d5e3e..51179a528c503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.expressions +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression /** + * :: Experimental :: * A base class for user-defined aggregations, which can be used in [[Dataset]] operations to take * all of the elements of a group and reduce them to a single value. * @@ -48,6 +50,7 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * @tparam OUT The type of the final output result. * @since 1.6.0 */ +@Experimental abstract class Aggregator[-IN, BUF, OUT] extends Serializable { /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala new file mode 100644 index 0000000000000..174378304d4a5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder + +/** + * An aggregator that uses a single associative and commutative reduce function. This reduce + * function can be used to go through all input values and reduces them to a single value. + * If there is no input, a null value is returned. + * + * This class currently assumes there is at least one input row. + */ +private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) + extends Aggregator[T, (Boolean, T), T] { + + private val encoder = implicitly[Encoder[T]] + + override def zero: (Boolean, T) = (false, null.asInstanceOf[T]) + + override def bufferEncoder: Encoder[(Boolean, T)] = + ExpressionEncoder.tuple( + ExpressionEncoder[Boolean](), + encoder.asInstanceOf[ExpressionEncoder[T]]) + + override def outputEncoder: Encoder[T] = encoder + + override def reduce(b: (Boolean, T), a: T): (Boolean, T) = { + if (b._1) { + (true, func(b._2, a)) + } else { + (true, a) + } + } + + override def merge(b1: (Boolean, T), b2: (Boolean, T)): (Boolean, T) = { + if (!b1._1) { + b2 + } else if (!b2._1) { + b1 + } else { + (true, func(b1._2, b2._2)) + } + } + + override def finish(reduction: (Boolean, T)): T = { + if (!reduction._1) { + throw new IllegalStateException("ReduceAggregator requires at least one input row") + } + reduction._2 + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index bd35d19aa20bb..49fdec57558e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.types.DataType /** * A user-defined function. To create one, use the `udf` functions in [[functions]]. + * Note that the user-defined functions must be deterministic. Due to optimization, + * duplicate invocations may be eliminated or the function may even be invoked more times than + * it is present in the query. * As an example: * {{{ * // Defined a UDF that returns true or false based on some numeric score. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 350c2836461e2..c29ec6f426789 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -74,7 +74,7 @@ object Window { spec.orderBy(cols : _*) } - private def spec: WindowSpec = { + private[sql] def spec: WindowSpec = { new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala rename to sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala index d0eb190afd036..60d7b7d0894d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.expressions.scala +package org.apache.spark.sql.expressions.scalalang import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ @@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.aggregate._ * :: Experimental :: * Type-safe functions available for [[Dataset]] operations in Scala. * - * Java users should use [[org.apache.spark.sql.expressions.java.typed]]. + * Java users should use [[org.apache.spark.sql.expressions.javalang.typed]]. * * @since 2.0.0 */ @@ -38,7 +38,7 @@ object typed { // The reason we have separate files for Java and Scala is because in the Scala version, we can // use tighter types (primitive types) for return types, whereas in the Java version we can only // use boxed primitive types. - // For example, avg in the Scala veresion returns Scala primitive Double, whose bytecode + // For example, avg in the Scala version returns Scala primitive Double, whose bytecode // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double. // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 48925910ac8cf..eac658c6176cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -133,7 +133,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * :: Experimental :: - * A [[Row]] representing an mutable aggregation buffer. + * A [[Row]] representing a mutable aggregation buffer. * * This is not meant to be extended outside of Spark. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fe63c80815a99..eb504c81bd80f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -195,18 +195,14 @@ object functions { /** * Aggregate function: returns a list of objects with duplicates. * - * For now this is an alias for the collect_list Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ - def collect_list(e: Column): Column = callUDF("collect_list", e) + def collect_list(e: Column): Column = withAggregateFunction { CollectList(e.expr) } /** * Aggregate function: returns a list of objects with duplicates. * - * For now this is an alias for the collect_list Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ @@ -215,18 +211,14 @@ object functions { /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * - * For now this is an alias for the collect_set Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ - def collect_set(e: Column): Column = callUDF("collect_set", e) + def collect_set(e: Column): Column = withAggregateFunction { CollectSet(e.expr) } /** * Aggregate function: returns a set of objects with duplicate elements eliminated. * - * For now this is an alias for the collect_set Hive UDAF. - * * @group agg_funcs * @since 1.6.0 */ @@ -931,8 +923,8 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def broadcast(df: DataFrame): DataFrame = { - Dataset.ofRows(df.sparkSession, BroadcastHint(df.logicalPlan)) + def broadcast[T](df: Dataset[T]): Dataset[T] = { + Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) } /** @@ -986,6 +978,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ + @deprecated("Use monotonically_increasing_id()", "2.0.0") def monotonicallyIncreasingId(): Column = monotonically_increasing_id() /** @@ -1176,7 +1169,7 @@ object functions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse { + val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { new SparkSqlParser(new SQLConf) } Column(parser.parseExpression(expr)) @@ -2004,7 +1997,7 @@ object functions { /** * Computes the numeric value of the first character of the string column, and returns the - * result as a int column. + * result as an int column. * * @group string_funcs * @since 1.5.0 @@ -2182,7 +2175,8 @@ object functions { def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } /** - * Extract a specific(idx) group identified by a java regex, from the specified string column. + * Extract a specific group matched by a Java regex, from the specified string column. + * If the regex did not match, or the specified group did not match, an empty string is returned. * * @group string_funcs * @since 1.5.0 @@ -2445,7 +2439,7 @@ object functions { */ def minute(e: Column): Column = withExpr { Minute(e.expr) } - /* + /** * Returns number of months between dates `date1` and `date2`. * @group datetime_funcs * @since 1.5.0 @@ -2595,19 +2589,22 @@ object functions { * 09:00:25-09:01:25 ... * }}} * - * For a continuous query, you may use the function `current_timestamp` to generate windows on + * For a streaming query, you may use the function `current_timestamp` to generate windows on * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for - * valid duration identifiers. + * valid duration identifiers. Note that the duration is a fixed length of + * time, and does not vary over time according to a calendar. For example, + * `1 day` always means 86,400,000 milliseconds, not a calendar day. * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. * A new window will be generated every `slideDuration`. Must be less than * or equal to the `windowDuration`. Check * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration - * identifiers. + * identifiers. This duration is likewise absolute, and does not vary + * according to a calendar. * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start * window intervals. For example, in order to have hourly tumbling windows that * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide @@ -2649,18 +2646,22 @@ object functions { * 09:00:20-09:01:20 ... * }}} * - * For a continuous query, you may use the function `current_timestamp` to generate windows on + * For a streaming query, you may use the function `current_timestamp` to generate windows on * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for - * valid duration identifiers. + * valid duration identifiers. Note that the duration is a fixed length of + * time, and does not vary over time according to a calendar. For example, + * `1 day` always means 86,400,000 milliseconds, not a calendar day. * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. * A new window will be generated every `slideDuration`. Must be less than * or equal to the `windowDuration`. Check - * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration. + * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration + * identifiers. This duration is likewise absolute, and does not vary + * according to a calendar. * * @group datetime_funcs * @since 2.0.0 @@ -2691,7 +2692,7 @@ object functions { * 09:02:00-09:03:00 ... * }}} * - * For a continuous query, you may use the function `current_timestamp` to generate windows on + * For a streaming query, you may use the function `current_timestamp` to generate windows on * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. @@ -2729,6 +2730,14 @@ object functions { */ def explode(e: Column): Column = withExpr { Explode(e.expr) } + /** + * Creates a new row for each element with position in the given array or map column. + * + * @group collection_funcs + * @since 2.1.0 + */ + def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) } + /** * Extracts json object from a json string based on json path specified, and returns json string * of the extracted json object. It will return null if the input json string is invalid. @@ -2960,8 +2969,8 @@ object functions { * import org.apache.spark.sql._ * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") - * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * val spark = df.sparkSession + * spark.udf.register("simpleUDF", (v: Int) => v * v) * df.select($"id", callUDF("simpleUDF", $"value")) * }}} * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 976c9c53de139..414a4a5ed9107 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -50,14 +50,6 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } } - private def makeDataset[T <: DefinedByConstructorParams: TypeTag](data: Seq[T]): Dataset[T] = { - val enc = ExpressionEncoder[T]() - val encoded = data.map(d => enc.toRow(d).copy()) - val plan = new LocalRelation(enc.schema.toAttributes, encoded) - val queryExecution = sparkSession.executePlan(plan) - new Dataset[T](sparkSession, queryExecution, enc) - } - /** * Returns the current default database in this session. */ @@ -83,7 +75,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { description = metadata.description, locationUri = metadata.locationUri) } - makeDataset(databases) + CatalogImpl.makeDataset(databases, sparkSession) } /** @@ -111,7 +103,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { tableType = metadata.map(_.tableType.name).getOrElse("TEMPORARY"), isTemporary = isTemp) } - makeDataset(tables) + CatalogImpl.makeDataset(tables, sparkSession) } /** @@ -129,15 +121,16 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { @throws[AnalysisException]("database does not exist") override def listFunctions(dbName: String): Dataset[Function] = { requireDatabaseExists(dbName) - val functions = sessionCatalog.listFunctions(dbName).map { funcIdent => + val functions = sessionCatalog.listFunctions(dbName).map { case (funcIdent, _) => val metadata = sessionCatalog.lookupFunctionInfo(funcIdent) new Function( name = funcIdent.identifier, + database = funcIdent.database.orNull, description = null, // for now, this is always undefined className = metadata.getClassName, isTemporary = funcIdent.database.isEmpty) } - makeDataset(functions) + CatalogImpl.makeDataset(functions, sparkSession) } /** @@ -145,7 +138,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("table does not exist") override def listColumns(tableName: String): Dataset[Column] = { - listColumns(currentDatabase, tableName) + listColumns(TableIdentifier(tableName, None)) } /** @@ -154,7 +147,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { @throws[AnalysisException]("database or table does not exist") override def listColumns(dbName: String, tableName: String): Dataset[Column] = { requireTableExists(dbName, tableName) - val tableMetadata = sessionCatalog.getTableMetadata(TableIdentifier(tableName, Some(dbName))) + listColumns(TableIdentifier(tableName, Some(dbName))) + } + + private def listColumns(tableIdentifier: TableIdentifier): Dataset[Column] = { + val tableMetadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdentifier) val partitionColumnNames = tableMetadata.partitionColumnNames.toSet val bucketColumnNames = tableMetadata.bucketColumnNames.toSet val columns = tableMetadata.schema.map { c => @@ -166,7 +163,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { isPartition = partitionColumnNames.contains(c.name), isBucket = bucketColumnNames.contains(c.name)) } - makeDataset(columns) + CatalogImpl.makeDataset(columns, sparkSession) } /** @@ -233,10 +230,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { userSpecifiedSchema = None, source, temporary = false, - options, + options = options, + partitionColumns = Array.empty[String], + bucketSpec = None, allowExisting = false, managedIfNoPath = false) - sparkSession.executePlan(cmd).toRdd + sparkSession.sessionState.executePlan(cmd).toRdd sparkSession.table(tableIdent) } @@ -280,23 +279,27 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { source, temporary = false, options, + partitionColumns = Array.empty[String], + bucketSpec = None, allowExisting = false, managedIfNoPath = false) - sparkSession.executePlan(cmd).toRdd + sparkSession.sessionState.executePlan(cmd).toRdd sparkSession.table(tableIdent) } /** - * Drops the temporary table with the given table name in the catalog. - * If the table has been cached/persisted before, it's also unpersisted. + * Drops the temporary view with the given view name in the catalog. + * If the view has been cached/persisted before, it's also unpersisted. * - * @param tableName the name of the table to be unregistered. + * @param viewName the name of the view to be dropped. * @group ddl_ops * @since 2.0.0 */ - override def dropTempTable(tableName: String): Unit = { - sparkSession.cacheManager.tryUncacheQuery(sparkSession.table(tableName)) - sessionCatalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true) + override def dropTempView(viewName: String): Unit = { + sparkSession.sessionState.catalog.getTempView(viewName).foreach { tempView => + sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) + sessionCatalog.dropTempView(viewName) + } } /** @@ -306,7 +309,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def isCached(tableName: String): Boolean = { - sparkSession.cacheManager.lookupCachedData(sparkSession.table(tableName)).nonEmpty + sparkSession.sharedState.cacheManager.lookupCachedData(sparkSession.table(tableName)).nonEmpty } /** @@ -316,7 +319,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def cacheTable(tableName: String): Unit = { - sparkSession.cacheManager.cacheQuery(sparkSession.table(tableName), Some(tableName)) + sparkSession.sharedState.cacheManager.cacheQuery(sparkSession.table(tableName), Some(tableName)) } /** @@ -326,7 +329,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.cacheManager.uncacheQuery(sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName)) } /** @@ -336,7 +339,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def clearCache(): Unit = { - sparkSession.cacheManager.clearCache() + sparkSession.sharedState.cacheManager.clearCache() } /** @@ -346,7 +349,59 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ protected[sql] def isCached(qName: Dataset[_]): Boolean = { - sparkSession.cacheManager.lookupCachedData(qName).nonEmpty + sparkSession.sharedState.cacheManager.lookupCachedData(qName).nonEmpty + } + + /** + * Refresh the cache entry for a table, if any. For Hive metastore table, the metadata + * is refreshed. + * + * @group cachemgmt + * @since 2.0.0 + */ + override def refreshTable(tableName: String): Unit = { + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + sessionCatalog.refreshTable(tableIdent) + + // If this table is cached as an InMemoryRelation, drop the original + // cached version and make the new version cached lazily. + val logicalPlan = sparkSession.sessionState.catalog.lookupRelation(tableIdent) + // Use lookupCachedData directly since RefreshTable also takes databaseName. + val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty + if (isCached) { + // Create a data frame to represent the table. + // TODO: Use uncacheTable once it supports database name. + val df = Dataset.ofRows(sparkSession, logicalPlan) + // Uncache the logicalPlan. + sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true) + // Cache it again. + sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table)) + } + } + + /** + * Refresh the cache entry and the associated metadata for all dataframes (if any), that contain + * the given data source path. + * + * @group cachemgmt + * @since 2.0.0 + */ + override def refreshByPath(resourcePath: String): Unit = { + sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath) + } +} + + +private[sql] object CatalogImpl { + + def makeDataset[T <: DefinedByConstructorParams: TypeTag]( + data: Seq[T], + sparkSession: SparkSession): Dataset[T] = { + val enc = ExpressionEncoder[T]() + val encoded = data.map(d => enc.toRow(d).copy()) + val plan = new LocalRelation(enc.schema.toAttributes, encoded) + val queryExecution = sparkSession.sessionState.executePlan(plan) + new Dataset[T](sparkSession, queryExecution, enc) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala index 38317d46dd82d..ad69137f7401b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/HiveSerDe.scala @@ -71,10 +71,12 @@ object HiveSerDe { val key = source.toLowerCase match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s if s.equals("orcfile") => "orc" + case s if s.equals("parquetfile") => "parquet" + case s if s.equals("avrofile") => "avro" case s => s } serdeMap.get(key) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0bcf0f817a1d7..7598d475042ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -23,12 +23,14 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.immutable +import org.apache.hadoop.fs.Path import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -55,7 +57,7 @@ object SQLConf { val WAREHOUSE_PATH = SQLConfigBuilder("spark.sql.warehouse.dir") .doc("The default location for managed databases and tables.") .stringConf - .createWithDefault("${system:user.dir}/spark-warehouse") + .createWithDefault(Utils.resolveURI("spark-warehouse").toString) val OPTIMIZER_MAX_ITERATIONS = SQLConfigBuilder("spark.sql.optimizer.maxIterations") .internal() @@ -70,16 +72,6 @@ object SQLConf { .intConf .createWithDefault(10) - val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts") - .doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " + - "When set to false, only one SQLContext/HiveContext is allowed to be created " + - "through the constructor (new SQLContexts/HiveContexts created through newSession " + - "method is allowed). Please note that this conf needs to be set in Spark Conf. Once " + - "a SQLContext/HiveContext has been created, changing the value of this conf will not " + - "have effect.") - .booleanConf - .createWithDefault(true) - val COMPRESS_CACHED = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compressed") .internal() .doc("When set to true Spark SQL will automatically select a compression codec for each " + @@ -120,8 +112,15 @@ object SQLConf { "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + "Note that currently statistics are only supported for Hive Metastore tables where the " + "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.") - .intConf - .createWithDefault(10 * 1024 * 1024) + .longConf + .createWithDefault(10L * 1024 * 1024) + + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = + SQLConfigBuilder("spark.sql.statistics.fallBackToHdfs") + .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + + " This is useful in determining if a table is small enough to use auto broadcast joins.") + .booleanConf + .createWithDefault(false) val DEFAULT_SIZE_IN_BYTES = SQLConfigBuilder("spark.sql.defaultSizeInBytes") .internal() @@ -167,7 +166,9 @@ object SQLConf { .createWithDefault(true) val CASE_SENSITIVE = SQLConfigBuilder("spark.sql.caseSensitive") - .doc("Whether the query analyzer should be case sensitive or not. Default to case insensitive.") + .internal() + .doc("Whether the query analyzer should be case sensitive or not. " + + "Default to case insensitive. It is highly discouraged to turn on case sensitive mode.") .booleanConf .createWithDefault(false) @@ -311,6 +312,22 @@ object SQLConf { .stringConf .createWithDefault("parquet") + val CONVERT_CTAS = SQLConfigBuilder("spark.sql.hive.convertCTAS") + .internal() + .doc("When true, a table created by a Hive CTAS statement (no USING clause) " + + "without specifying any storage property will be converted to a data source table, " + + "using the data source set by spark.sql.sources.default.") + .booleanConf + .createWithDefault(false) + + val GATHER_FASTSTAT = SQLConfigBuilder("spark.sql.hive.gatherFastStats") + .internal() + .doc("When true, fast stats (number of files and total size of all files) will be gathered" + + " in parallel while repairing table partitions to avoid the sequential listing in Hive" + + " metastore.") + .booleanConf + .createWithDefault(true) + // This is used to control the when we will split a schema's JSON string to multiple pieces // in order to fit the JSON string in metastore's table property (by default, the value has // a length restriction of 4000 characters). We will split the JSON string of a schema @@ -346,9 +363,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled") + .doc("When false, we will throw an error if a query contains a cross join") + .booleanConf + .createWithDefault(false) + val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal") .doc("When true, the ordinal numbers are treated as the position in the select list. " + - "When false, the ordinal numbers in order/sort By clause are ignored.") + "When false, the ordinal numbers in order/sort by clause are ignored.") .booleanConf .createWithDefault(true) @@ -375,14 +397,6 @@ object SQLConf { .intConf .createWithDefault(32) - // Whether to perform eager analysis when constructing a dataframe. - // Set to false when debugging requires the ability to look at invalid query plans. - val DATAFRAME_EAGER_ANALYSIS = SQLConfigBuilder("spark.sql.eagerAnalysis") - .internal() - .doc("When true, eagerly applies query analysis on DataFrame operations.") - .booleanConf - .createWithDefault(true) - // Whether to automatically resolve ambiguity in join conditions for self-joins. // See SPARK-6231. val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = @@ -421,7 +435,14 @@ object SQLConf { .doc("The maximum number of fields (including nested fields) that will be supported before" + " deactivating whole-stage codegen.") .intConf - .createWithDefault(200) + .createWithDefault(100) + + val WHOLESTAGE_FALLBACK = SQLConfigBuilder("spark.sql.codegen.fallback") + .internal() + .doc("When true, whole stage codegen could be temporary disabled for the part of query that" + + " fail to compile generated code") + .booleanConf + .createWithDefault(true) val MAX_CASES_BRANCHES = SQLConfigBuilder("spark.sql.codegen.maxCaseBranches") .internal() @@ -465,14 +486,14 @@ object SQLConf { .createWithDefault(2) val CHECKPOINT_LOCATION = SQLConfigBuilder("spark.sql.streaming.checkpointLocation") - .doc("The default location for storing checkpoint data for continuously executing queries.") + .doc("The default location for storing checkpoint data for streaming queries.") .stringConf .createOptional val UNSUPPORTED_OPERATION_CHECK_ENABLED = SQLConfigBuilder("spark.sql.streaming.unsupportedOperationCheck") .internal() - .doc("When true, the logical plan for continuous query will be checked for unsupported" + + .doc("When true, the logical plan for streaming query will be checked for unsupported" + " operations.") .booleanConf .createWithDefault(true) @@ -514,10 +535,51 @@ object SQLConf { val FILE_SINK_LOG_CLEANUP_DELAY = SQLConfigBuilder("spark.sql.streaming.fileSink.log.cleanupDelay") + .internal() + .doc("How long that a file is guaranteed to be visible for all readers.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(TimeUnit.MINUTES.toMillis(10)) // 10 minutes + + val FILE_SOURCE_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSource.log.deletion") + .internal() + .doc("Whether to delete the expired log files in file stream source.") + .booleanConf + .createWithDefault(true) + + val FILE_SOURCE_LOG_COMPACT_INTERVAL = + SQLConfigBuilder("spark.sql.streaming.fileSource.log.compactInterval") + .internal() + .doc("Number of log files after which all the previous files " + + "are compacted into the next log file.") + .intConf + .createWithDefault(10) + + val FILE_SOURCE_LOG_CLEANUP_DELAY = + SQLConfigBuilder("spark.sql.streaming.fileSource.log.cleanupDelay") .internal() .doc("How long in milliseconds a file is guaranteed to be visible for all readers.") .timeConf(TimeUnit.MILLISECONDS) - .createWithDefault(60 * 1000L) // 10 minutes + .createWithDefault(TimeUnit.MINUTES.toMillis(10)) // 10 minutes + + val STREAMING_SCHEMA_INFERENCE = + SQLConfigBuilder("spark.sql.streaming.schemaInference") + .internal() + .doc("Whether file-based streaming sources will infer its own schema") + .booleanConf + .createWithDefault(false) + + val STREAMING_POLLING_DELAY = + SQLConfigBuilder("spark.sql.streaming.pollingDelay") + .internal() + .doc("How long to delay polling new data when no data is available") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(10L) + + val STREAMING_METRICS_ENABLED = + SQLConfigBuilder("spark.sql.streaming.metricsEnabled") + .doc("Whether Dropwizard/Codahale metrics will be reported for active streaming queries.") + .booleanConf + .createWithDefault(false) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -546,7 +608,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD) - def checkpointLocation: String = getConf(CHECKPOINT_LOCATION) + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) @@ -582,10 +644,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def nativeView: Boolean = getConf(NATIVE_VIEW) + def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) + def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) + def wholeStageFallback: Boolean = getConf(WHOLESTAGE_FALLBACK) + def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES) def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) @@ -597,14 +663,15 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def subexpressionEliminationEnabled: Boolean = getConf(SUBEXPRESSION_ELIMINATION_ENABLED) - def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD) + + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED) - def defaultSizeInBytes: Long = - getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) + def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, Long.MaxValue) def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) @@ -620,6 +687,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) + def convertCTAS: Boolean = getConf(CONVERT_CTAS) + def partitionDiscoveryEnabled(): Boolean = getConf(SQLConf.PARTITION_DISCOVERY_ENABLED) @@ -631,12 +700,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) - def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) - def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) @@ -651,9 +720,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH) def warehousePath: String = { - getConf(WAREHOUSE_PATH).replace("${system:user.dir}", System.getProperty("user.dir")) + new Path(getConf(WAREHOUSE_PATH).replace("${system:user.dir}", + System.getProperty("user.dir"))).toString } - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) @@ -717,12 +786,11 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { /** * Return the value of an optional Spark SQL configuration property for the given key. If the key - * is not set yet, throw an exception. + * is not set yet, returns None. */ - def getConf[T](entry: OptionalConfigEntry[T]): T = { + def getConf[T](entry: OptionalConfigEntry[T]): Option[T] = { require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") - Option(settings.get(entry.key)).map(entry.rawValueConverter). - getOrElse(throw new NoSuchElementException(entry.key)) + Option(settings.get(entry.key)).map(entry.rawValueConverter) } /** @@ -751,7 +819,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { */ def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { sqlConfEntries.values.asScala.filter(_.isPublic).map { entry => - (entry.key, entry.defaultValueString, entry.doc) + (entry.key, getConfString(entry.key, entry.defaultValueString), entry.doc) }.toSeq } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index ebff7569798a6..e054ef2bcf75b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -18,14 +18,10 @@ package org.apache.spark.sql.internal import java.io.File -import java.util.Properties - -import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{ArchiveResource, _} @@ -33,8 +29,9 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.AnalyzeTable -import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, PreInsertCastAndRename, ResolveDataSource} +import org.apache.spark.sql.execution.command.AnalyzeTableCommand +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager} import org.apache.spark.sql.util.ExecutionListenerManager @@ -96,7 +93,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { * Internal catalog for managing table and database states. */ lazy val catalog = new SessionCatalog( - sparkSession.externalCatalog, + sparkSession.sharedState.externalCatalog, functionResourceLoader, functionRegistry, conf, @@ -104,6 +101,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { /** * Interface exposed to the user for registering user-defined functions. + * Note that the user-defined functions must be deterministic. */ lazy val udf: UDFRegistration = new UDFRegistration(functionRegistry) @@ -113,8 +111,10 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = - PreInsertCastAndRename :: - DataSourceAnalysis :: + AnalyzeCreateTableAsSelect(sparkSession) :: + PreprocessTableInsertion(conf) :: + new FindDataSourceTable(sparkSession) :: + DataSourceAnalysis(conf) :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) override val extendedCheckRules = Seq(datasources.PreWriteCheck(conf, catalog)) @@ -144,10 +144,10 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager /** - * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. + * Interface to start and stop [[StreamingQuery]]s. */ - lazy val continuousQueryManager: ContinuousQueryManager = { - new ContinuousQueryManager(sparkSession) + lazy val streamingQueryManager: StreamingQueryManager = { + new StreamingQueryManager(sparkSession) } private val jarClassLoader: NonClosableMutableURLClassLoader = @@ -163,16 +163,14 @@ private[sql] class SessionState(sparkSession: SparkSession) { // Helper methods, partially leftover from pre-2.0 days // ------------------------------------------------------ + def executeSql(sql: String): QueryExecution = executePlan(sqlParser.parsePlan(sql)) + def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan) def refreshTable(tableName: String): Unit = { catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) } - def invalidateTable(tableName: String): Unit = { - catalog.invalidateTable(sqlParser.parseTableIdentifier(tableName)) - } - def addJar(path: String): Unit = { sparkSession.sparkContext.addJar(path) @@ -196,6 +194,6 @@ private[sql] class SessionState(sparkSession: SparkSession) { * in the external catalog. */ def analyze(tableName: String): Unit = { - AnalyzeTable(tableName).run(sparkSession) + AnalyzeTableCommand(tableName).run(sparkSession) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index ab4af8d142d31..54aee5e02bb9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -18,17 +18,18 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkContext -import org.apache.spark.sql.SQLContext +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} import org.apache.spark.sql.execution.CacheManager -import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.util.MutableURLClassLoader +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** * A class that holds all state shared across sessions in a given [[SQLContext]]. */ -private[sql] class SharedState(val sparkContext: SparkContext) { +private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { /** * Class for caching query results reused in future executions. @@ -38,12 +39,19 @@ private[sql] class SharedState(val sparkContext: SparkContext) { /** * A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s. */ - val listener: SQLListener = SQLContext.createListenerAndUI(sparkContext) + val listener: SQLListener = createListenerAndUI(sparkContext) + + { + val configFile = Utils.getContextOrSparkClassLoader.getResource("hive-site.xml") + if (configFile != null) { + sparkContext.hadoopConfiguration.addResource(configFile) + } + } /** * A catalog that interacts with external systems. */ - lazy val externalCatalog: ExternalCatalog = new InMemoryCatalog + lazy val externalCatalog: ExternalCatalog = new InMemoryCatalog(sparkContext.hadoopConfiguration) /** * A classloader used to load all user-added jar. @@ -51,6 +59,43 @@ private[sql] class SharedState(val sparkContext: SparkContext) { val jarClassLoader = new NonClosableMutableURLClassLoader( org.apache.spark.util.Utils.getContextOrSparkClassLoader) + { + // Set the Hive metastore warehouse path to the one we use + val tempConf = new SQLConf + sparkContext.conf.getAll.foreach { case (k, v) => tempConf.setConfString(k, v) } + val hiveWarehouseDir = sparkContext.hadoopConfiguration.get("hive.metastore.warehouse.dir") + if (hiveWarehouseDir != null && !tempConf.contains(SQLConf.WAREHOUSE_PATH.key)) { + // If hive.metastore.warehouse.dir is set and spark.sql.warehouse.dir is not set, + // we will respect the value of hive.metastore.warehouse.dir. + tempConf.setConfString(SQLConf.WAREHOUSE_PATH.key, hiveWarehouseDir) + sparkContext.conf.set(SQLConf.WAREHOUSE_PATH.key, hiveWarehouseDir) + logInfo(s"${SQLConf.WAREHOUSE_PATH.key} is not set, but hive.metastore.warehouse.dir " + + s"is set. Setting ${SQLConf.WAREHOUSE_PATH.key} to the value of " + + s"hive.metastore.warehouse.dir ('$hiveWarehouseDir').") + } else { + // If spark.sql.warehouse.dir is set, we will override hive.metastore.warehouse.dir using + // the value of spark.sql.warehouse.dir. + // When neither spark.sql.warehouse.dir nor hive.metastore.warehouse.dir is set, + // we will set hive.metastore.warehouse.dir to the default value of spark.sql.warehouse.dir. + sparkContext.conf.set("hive.metastore.warehouse.dir", tempConf.warehousePath) + } + + logInfo(s"Warehouse path is '${tempConf.warehousePath}'.") + } + + /** + * Create a SQLListener then add it into SparkContext, and create a SQLTab if there is SparkUI. + */ + private def createListenerAndUI(sc: SparkContext): SQLListener = { + if (SparkSession.sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (SparkSession.sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } + } + SparkSession.sqlListener.get() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 46b3877a7cab3..b795e8b42df0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -38,6 +38,12 @@ private case object OracleDialect extends JdbcDialect { // This is sub-optimal as we have to pick a precision/scale in advance whereas the data // in Oracle is allowed to have different precision/scale for each value. Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + } else if (sqlType == Types.NUMERIC && md.build().getLong("scale") == -127) { + // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts + // this to NUMERIC with -127 scale + // Not sure if there is a more robust way to identify the field as a float (or other + // numeric types that do not specify a scale. + Option(DecimalType(DecimalType.MAX_PRECISION, 10)) } else { None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 2d6c3974a833e..6baf1b6f16cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -89,7 +89,7 @@ private object PostgresDialect extends JdbcDialect { // // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor // - if (properties.getOrElse("fetchsize", "0").toInt > 0) { + if (properties.getOrElse(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) { connection.setAutoCommit(false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 97e35bb10407e..28d8bc3de68b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -18,7 +18,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -40,7 +40,7 @@ package object sql { * [[org.apache.spark.sql.sources]] */ @DeveloperApi - type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + type Strategy = SparkStrategy type DataFrame = Dataset[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 26285bde31ad0..b84953deac9e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** @@ -111,8 +112,10 @@ trait SchemaRelationProvider { } /** + * ::Experimental:: * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. */ +@Experimental trait StreamSourceProvider { /** Returns the name and schema of the source that can be used to continually read data. */ @@ -131,13 +134,16 @@ trait StreamSourceProvider { } /** + * ::Experimental:: * Implemented by objects that can produce a streaming [[Sink]] for a specific format or system. */ +@Experimental trait StreamSinkProvider { def createSink( sqlContext: SQLContext, parameters: Map[String, String], - partitionColumns: Seq[String]): Sink + partitionColumns: Seq[String], + outputMode: OutputMode): Sink } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala new file mode 100644 index 0000000000000..5141a014f8f79 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.types.StructType + +/** + * Interface used to load a streaming [[Dataset]] from external storage systems (e.g. file systems, + * key-value stores, etc). Use [[SparkSession.readStream]] to access this. + * + * @since 2.0.0 + */ +@Experimental +final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { + /** + * Specifies the input data source format. + * + * @since 2.0.0 + */ + def format(source: String): DataStreamReader = { + this.source = source + this + } + + /** + * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data source can + * skip the schema inference step, and thus speed up data loading. + * + * @since 2.0.0 + */ + def schema(schema: StructType): DataStreamReader = { + this.userSpecifiedSchema = Option(schema) + this + } + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: String): DataStreamReader = { + this.extraOptions += (key -> value) + this + } + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): DataStreamReader = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): DataStreamReader = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): DataStreamReader = option(key, value.toString) + + /** + * (Scala-specific) Adds input options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: scala.collection.Map[String, String]): DataStreamReader = { + this.extraOptions ++= options + this + } + + /** + * Adds input options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: java.util.Map[String, String]): DataStreamReader = { + this.options(options.asScala) + this + } + + + /** + * Loads input data stream in as a [[DataFrame]], for data streams that don't require a path + * (e.g. external key-value stores). + * + * @since 2.0.0 + */ + def load(): DataFrame = { + val dataSource = + DataSource( + sparkSession, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap) + Dataset.ofRows(sparkSession, StreamingRelation(dataSource)) + } + + /** + * Loads input in as a [[DataFrame]], for data streams that read from some path. + * + * @since 2.0.0 + */ + def load(path: String): DataFrame = { + option("path", path).load() + } + + /** + * Loads a JSON file stream (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can set the following JSON-specific options to deal with non-standard JSON files: + *
        + *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be + * considered in every trigger.
      • + *
      • `primitivesAsString` (default `false`): infers all primitive values as a string type
      • + *
      • `prefersDecimal` (default `false`): infers all floating-point values as a decimal + * type. If the values do not fit in decimal, then it infers them as doubles.
      • + *
      • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
      • + *
      • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
      • + *
      • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes + *
      • + *
      • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers + * (e.g. 00012)
      • + *
      • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all + * character using backslash quoting mechanism
      • + *
      • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records + * during parsing. + *
          + *
        • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record, and puts + * the malformed string into a new field configured by `columnNameOfCorruptRecord`. When + * a schema is set by user, it sets `null` for extra fields.
        • + *
        • `DROPMALFORMED` : ignores the whole corrupted records.
        • + *
        • `FAILFAST` : throws an exception when it meets corrupted records.
        • + *
        + *
      • + *
      • `columnNameOfCorruptRecord` (default is the value specified in + * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field having malformed string + * created by `PERMISSIVE` mode. This overrides `spark.sql.columnNameOfCorruptRecord`.
      • + *
      • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
      • + *
      • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
      • + *
      + * + * @since 2.0.0 + */ + def json(path: String): DataFrame = format("json").load(path) + + /** + * Loads a CSV file stream and returns the result as a [[DataFrame]]. + * + * This function will go through the input once to determine the input schema if `inferSchema` + * is enabled. To avoid going through the entire data once, disable `inferSchema` option or + * specify the schema explicitly using [[schema]]. + * + * You can set the following CSV-specific options to deal with CSV files: + *
        + *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be + * considered in every trigger.
      • + *
      • `sep` (default `,`): sets the single character as a separator for each + * field and value.
      • + *
      • `encoding` (default `UTF-8`): decodes the CSV files by the given encoding + * type.
      • + *
      • `quote` (default `"`): sets the single character used for escaping quoted values where + * the separator can be part of the value. If you would like to turn off quotations, you need to + * set not `null` but an empty string. This behaviour is different form + * `com.databricks.spark.csv`.
      • + *
      • `escape` (default `\`): sets the single character used for escaping quotes inside + * an already quoted value.
      • + *
      • `comment` (default empty string): sets the single character used for skipping lines + * beginning with this character. By default, it is disabled.
      • + *
      • `header` (default `false`): uses the first line as names of columns.
      • + *
      • `inferSchema` (default `false`): infers the input schema automatically from data. It + * requires one extra pass over the data.
      • + *
      • `ignoreLeadingWhiteSpace` (default `false`): defines whether or not leading whitespaces + * from values being read should be skipped.
      • + *
      • `ignoreTrailingWhiteSpace` (default `false`): defines whether or not trailing + * whitespaces from values being read should be skipped.
      • + *
      • `nullValue` (default empty string): sets the string representation of a null value. Since + * 2.0.1, this applies to all supported types including the string type.
      • + *
      • `nanValue` (default `NaN`): sets the string representation of a non-number" value.
      • + *
      • `positiveInf` (default `Inf`): sets the string representation of a positive infinity + * value.
      • + *
      • `negativeInf` (default `-Inf`): sets the string representation of a negative infinity + * value.
      • + *
      • `dateFormat` (default `yyyy-MM-dd`): sets the string that indicates a date format. + * Custom date formats follow the formats at `java.text.SimpleDateFormat`. This applies to + * date type.
      • + *
      • `timestampFormat` (default `yyyy-MM-dd'T'HH:mm:ss.SSSZZ`): sets the string that + * indicates a timestamp format. Custom date formats follow the formats at + * `java.text.SimpleDateFormat`. This applies to timestamp type.
      • + *
      • `maxColumns` (default `20480`): defines a hard limit of how many columns + * a record can have.
      • + *
      • `maxCharsPerColumn` (default `1000000`): defines the maximum number of characters allowed + * for any given value being read.
      • + *
      • `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records + * during parsing. + *
          + *
        • `PERMISSIVE` : sets other fields to `null` when it meets a corrupted record. When + * a schema is set by user, it sets `null` for extra fields.
        • + *
        • `DROPMALFORMED` : ignores the whole corrupted records.
        • + *
        • `FAILFAST` : throws an exception when it meets corrupted records.
        • + *
        + *
      • + *
      + * + * @since 2.0.0 + */ + def csv(path: String): DataFrame = format("csv").load(path) + + /** + * Loads a Parquet file stream, returning the result as a [[DataFrame]]. + * + * You can set the following Parquet-specific option(s) for reading Parquet files: + *
        + *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be + * considered in every trigger.
      • + *
      • `mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets + * whether we should merge schemas collected from all Parquet part-files. This will override + * `spark.sql.parquet.mergeSchema`.
      • + *
      + * + * @since 2.0.0 + */ + def parquet(path: String): DataFrame = { + format("parquet").load(path) + } + + /** + * Loads text files and returns a [[DataFrame]] whose schema starts with a string column named + * "value", and followed by partitioned columns if there are any. + * + * Each line in the text files is a new row in the resulting DataFrame. For example: + * {{{ + * // Scala: + * spark.readStream.text("/path/to/directory/") + * + * // Java: + * spark.readStream().text("/path/to/directory/") + * }}} + * + * You can set the following text-specific options to deal with text files: + *
        + *
      • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be + * considered in every trigger.
      • + *
      + * + * @since 2.0.0 + */ + def text(path: String): DataFrame = format("text").load(path) + + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = sparkSession.sessionState.conf.defaultDataSourceName + + private var userSpecifiedSchema: Option[StructType] = None + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala new file mode 100644 index 0000000000000..b959444b49298 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, ForeachWriter} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, MemorySink} + +/** + * :: Experimental :: + * Interface used to write a streaming [[Dataset]] to external storage systems (e.g. file systems, + * key-value stores, etc). Use [[Dataset.writeStream]] to access this. + * + * @since 2.0.0 + */ +@Experimental +final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { + + private val df = ds.toDF() + + /** + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + * - `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be + * written to the sink + * - `OutputMode.Complete()`: all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time these is some updates + * + * @since 2.0.0 + */ + def outputMode(outputMode: OutputMode): DataStreamWriter[T] = { + this.outputMode = outputMode + this + } + + /** + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. + * - `append`: only the new rows in the streaming DataFrame/Dataset will be written to + * the sink + * - `complete`: all the rows in the streaming DataFrame/Dataset will be written to the sink + * every time these is some updates + * + * @since 2.0.0 + */ + def outputMode(outputMode: String): DataStreamWriter[T] = { + this.outputMode = outputMode.toLowerCase match { + case "append" => + OutputMode.Append + case "complete" => + OutputMode.Complete + case _ => + throw new IllegalArgumentException(s"Unknown output mode $outputMode. " + + "Accepted output modes are 'append' and 'complete'") + } + this + } + + /** + * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run + * the query as fast as possible. + * + * Scala Example: + * {{{ + * df.writeStream.trigger(ProcessingTime("10 seconds")) + * + * import scala.concurrent.duration._ + * df.writeStream.trigger(ProcessingTime(10.seconds)) + * }}} + * + * Java Example: + * {{{ + * df.writeStream().trigger(ProcessingTime.create("10 seconds")) + * + * import java.util.concurrent.TimeUnit + * df.writeStream().trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.0.0 + */ + def trigger(trigger: Trigger): DataStreamWriter[T] = { + this.trigger = trigger + this + } + + + /** + * Specifies the name of the [[StreamingQuery]] that can be started with `start()`. + * This name must be unique among all the currently active queries in the associated SQLContext. + * + * @since 2.0.0 + */ + def queryName(queryName: String): DataStreamWriter[T] = { + this.extraOptions += ("queryName" -> queryName) + this + } + + /** + * Specifies the underlying output data source. Built-in options include "parquet" for now. + * + * @since 2.0.0 + */ + def format(source: String): DataStreamWriter[T] = { + this.source = source + this + } + + /** + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme. As an example, when we + * partition a dataset by year and then month, the directory layout would look like: + * + * - year=2016/month=01/ + * - year=2016/month=02/ + * + * Partitioning is one of the most widely used techniques to optimize physical data layout. + * It provides a coarse-grained index for skipping unnecessary data reads when queries have + * predicates on the partitioned columns. In order for partitioning to work well, the number + * of distinct values in each column should typically be less than tens of thousands. + * + * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataStreamWriter[T] = { + this.partitioningColumns = Option(colNames) + this + } + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: String): DataStreamWriter[T] = { + this.extraOptions += (key -> value) + this + } + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): DataStreamWriter[T] = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): DataStreamWriter[T] = option(key, value.toString) + + /** + * Adds an output option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) + + /** + * (Scala-specific) Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { + this.extraOptions ++= options + this + } + + /** + * Adds output options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { + this.options(options.asScala) + this + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def start(path: String): StreamingQuery = { + option("path", path).start() + } + + /** + * Starts the execution of the streaming query, which will continually output results to the given + * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with + * the stream. + * + * @since 2.0.0 + */ + def start(): StreamingQuery = { + if (source == "memory") { + assertNotPartitioned("memory") + if (extraOptions.get("queryName").isEmpty) { + throw new AnalysisException("queryName must be specified for memory sink") + } + + val sink = new MemorySink(df.schema, outputMode) + val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) + val query = df.sparkSession.sessionState.streamingQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + sink, + outputMode, + useTempCheckpointLocation = true, + recoverFromCheckpointLocation = false, + trigger = trigger) + resultDf.createOrReplaceTempView(query.name) + query + } else if (source == "foreach") { + assertNotPartitioned("foreach") + val sink = new ForeachSink[T](foreachWriter)(ds.exprEnc) + df.sparkSession.sessionState.streamingQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + sink, + outputMode, + useTempCheckpointLocation = true, + trigger = trigger) + } else { + val (useTempCheckpointLocation, recoverFromCheckpointLocation) = + if (source == "console") { + (true, false) + } else { + (false, true) + } + val dataSource = + DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + df.sparkSession.sessionState.streamingQueryManager.startQuery( + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), + df, + dataSource.createSink(outputMode), + outputMode, + useTempCheckpointLocation = useTempCheckpointLocation, + recoverFromCheckpointLocation = recoverFromCheckpointLocation, + trigger = trigger) + } + } + + /** + * Starts the execution of the streaming query, which will continually send results to the given + * [[ForeachWriter]] as as new data arrives. The [[ForeachWriter]] can be used to send the data + * generated by the [[DataFrame]]/[[Dataset]] to an external system. + * + * Scala example: + * {{{ + * datasetOfString.writeStream.foreach(new ForeachWriter[String] { + * + * def open(partitionId: Long, version: Long): Boolean = { + * // open connection + * } + * + * def process(record: String) = { + * // write string to connection + * } + * + * def close(errorOrNull: Throwable): Unit = { + * // close the connection + * } + * }).start() + * }}} + * + * Java example: + * {{{ + * datasetOfString.writeStream().foreach(new ForeachWriter() { + * + * @Override + * public boolean open(long partitionId, long version) { + * // open connection + * } + * + * @Override + * public void process(String value) { + * // write string to connection + * } + * + * @Override + * public void close(Throwable errorOrNull) { + * // close the connection + * } + * }).start(); + * }}} + * + * @since 2.0.0 + */ + def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { + this.source = "foreach" + this.foreachWriter = if (writer != null) { + ds.sparkSession.sparkContext.clean(writer) + } else { + throw new IllegalArgumentException("foreach writer cannot be null") + } + this + } + + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => + cols.map(normalize(_, "Partition")) + } + + /** + * The given column name may not be equal to any of the existing column names if we were in + * case-insensitive context. Normalize the given column name to the real one so that we don't + * need to care about case sensitivity afterwards. + */ + private def normalize(columnName: String, columnType: String): String = { + val validColumnNames = df.logicalPlan.output.map(_.name) + validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName)) + .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + + s"existing columns (${validColumnNames.mkString(", ")})")) + } + + private def assertNotPartitioned(operation: String): Unit = { + if (partitioningColumns.isDefined) { + throw new AnalysisException(s"'$operation' does not support partitioning") + } + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName + + private var outputMode: OutputMode = OutputMode.Append + + private var trigger: Trigger = ProcessingTime(0L) + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + + private var foreachWriter: ForeachWriter[T] = null + + private var partitioningColumns: Option[Seq[String]] = None +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala new file mode 100644 index 0000000000000..ab19602207ad8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SinkStatus.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.streaming.StreamingQueryStatus.indent + +/** + * :: Experimental :: + * Status and metrics of a streaming sink. + * + * @param description Description of the source corresponding to this status. + * @param offsetDesc Description of the current offsets up to which data has been written + * by the sink. + * @since 2.0.0 + */ +@Experimental +class SinkStatus private( + val description: String, + val offsetDesc: String) { + + /** The compact JSON representation of this status. */ + def json: String = compact(render(jsonValue)) + + /** The pretty (i.e. indented) JSON representation of this status. */ + def prettyJson: String = pretty(render(jsonValue)) + + override def toString: String = + "Status of sink " + indent(prettyString).trim + + private[sql] def jsonValue: JValue = { + ("description" -> JString(description)) ~ + ("offsetDesc" -> JString(offsetDesc)) + } + + private[sql] def prettyString: String = { + s"""$description + |Committed offsets: $offsetDesc + |""".stripMargin + } +} + +/** Companion object, primarily for creating SinkStatus instances internally */ +private[sql] object SinkStatus { + def apply(desc: String, offsetDesc: String): SinkStatus = new SinkStatus(desc, offsetDesc) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala new file mode 100644 index 0000000000000..cfdf11370e06d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/SourceStatus.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.{util => ju} + +import scala.collection.JavaConverters._ + +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.streaming.StreamingQueryStatus.indent +import org.apache.spark.util.JsonProtocol + +/** + * :: Experimental :: + * Status and metrics of a streaming Source. + * + * @param description Description of the source corresponding to this status. + * @param offsetDesc Description of the current offset if known. + * @param inputRate Current rate (rows/sec) at which data is being generated by the source. + * @param processingRate Current rate (rows/sec) at which the query is processing data from + * the source. + * @param triggerDetails Low-level details of the currently active trigger (e.g. number of + * rows processed in trigger, latency of intermediate steps, etc.). + * If no trigger is active, then it will have details of the last completed + * trigger. + * @since 2.0.0 + */ +@Experimental +class SourceStatus private( + val description: String, + val offsetDesc: String, + val inputRate: Double, + val processingRate: Double, + val triggerDetails: ju.Map[String, String]) { + + /** The compact JSON representation of this status. */ + def json: String = compact(render(jsonValue)) + + /** The pretty (i.e. indented) JSON representation of this status. */ + def prettyJson: String = pretty(render(jsonValue)) + + override def toString: String = + "Status of source " + indent(prettyString).trim + + private[sql] def jsonValue: JValue = { + ("description" -> JString(description)) ~ + ("offsetDesc" -> JString(offsetDesc)) ~ + ("inputRate" -> JDouble(inputRate)) ~ + ("processingRate" -> JDouble(processingRate)) ~ + ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) + } + + private[sql] def prettyString: String = { + val triggerDetailsLines = + triggerDetails.asScala.map { case (k, v) => s"$k: $v" } + s"""$description + |Available offset: $offsetDesc + |Input rate: $inputRate rows/sec + |Processing rate: $processingRate rows/sec + |Trigger details: + |""".stripMargin + indent(triggerDetailsLines) + } +} + +/** Companion object, primarily for creating SourceStatus instances internally */ +private[sql] object SourceStatus { + def apply( + desc: String, + offsetDesc: String, + inputRate: Double, + processingRate: Double, + triggerDetails: Map[String, String]): SourceStatus = { + new SourceStatus(desc, offsetDesc, inputRate, processingRate, triggerDetails.asJava) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala similarity index 68% rename from sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 4d5afe2eb5f4c..0a85414451981 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -15,9 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.SparkSession /** * :: Experimental :: @@ -26,14 +27,23 @@ import org.apache.spark.annotation.Experimental * @since 2.0.0 */ @Experimental -trait ContinuousQuery { +trait StreamingQuery { /** - * Returns the name of the query. + * Returns the name of the query. This name is unique across all active queries. This can be + * set in the [[org.apache.spark.sql.DataStreamWriter DataStreamWriter]] as + * `dataframe.writeStream.queryName("query").start()`. * @since 2.0.0 */ def name: String + /** + * Returns the unique id of this query. This id is automatically generated and is unique across + * all queries that have been started in the current process. + * @since 2.0.0 + */ + def id: Long + /** * Returns the [[SparkSession]] associated with `this`. * @since 2.0.0 @@ -47,18 +57,29 @@ trait ContinuousQuery { def isActive: Boolean /** - * Returns the [[ContinuousQueryException]] if the query was terminated by an exception. + * Returns the [[StreamingQueryException]] if the query was terminated by an exception. * @since 2.0.0 */ - def exception: Option[ContinuousQueryException] + def exception: Option[StreamingQueryException] + + /** + * Returns the current status of the query. + * @since 2.0.2 + */ + def status: StreamingQueryStatus /** * Returns current status of all the sources. * @since 2.0.0 */ + @deprecated("use status.sourceStatuses", "2.0.2") def sourceStatuses: Array[SourceStatus] - /** Returns current status of the sink. */ + /** + * Returns current status of the sink. + * @since 2.0.0 + */ + @deprecated("use status.sinkStatus", "2.0.2") def sinkStatus: SinkStatus /** @@ -69,7 +90,7 @@ trait ContinuousQuery { * immediately (if the query was terminated by `stop()`), or throw the exception * immediately (if the query has terminated with exception). * - * @throws ContinuousQueryException, if `this` query has terminated with an exception. + * @throws StreamingQueryException, if `this` query has terminated with an exception. * * @since 2.0.0 */ @@ -85,18 +106,19 @@ trait ContinuousQuery { * `true` immediately (if the query was terminated by `stop()`), or throw the exception * immediately (if the query has terminated with exception). * - * @throws ContinuousQueryException, if `this` query has terminated with an exception + * @throws StreamingQueryException, if `this` query has terminated with an exception * * @since 2.0.0 */ def awaitTermination(timeoutMs: Long): Boolean /** - * Blocks until all available data in the source has been processed an committed to the sink. + * Blocks until all available data in the source has been processed and committed to the sink. * This method is intended for testing. Note that in the case of continually arriving data, this * method may block forever. Additionally, this method is only guaranteed to block until data that * has been synchronously appended data to a [[org.apache.spark.sql.execution.streaming.Source]] * prior to invocation. (i.e. `getOffset` must immediately reflect the addition). + * @since 2.0.0 */ def processAllAvailable(): Unit @@ -106,4 +128,18 @@ trait ContinuousQuery { * @since 2.0.0 */ def stop(): Unit + + /** + * Prints the physical plan to the console for debugging purposes. + * @since 2.0.0 + */ + def explain(): Unit + + /** + * Prints the physical plan to the console for debugging purposes. + * + * @param extended whether to do extended explain or not + * @since 2.0.0 + */ + def explain(extended: Boolean): Unit } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index fec38629d914e..bd3e5a5618ec4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryException.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution} /** * :: Experimental :: - * Exception that stopped a [[ContinuousQuery]]. + * Exception that stopped a [[StreamingQuery]]. Use `cause` get the actual exception + * that caused the failure. * @param query Query that caused the exception * @param message Message of this exception * @param cause Internal cause of this exception @@ -31,8 +32,8 @@ import org.apache.spark.sql.execution.streaming.{Offset, StreamExecution} * @since 2.0.0 */ @Experimental -class ContinuousQueryException private[sql]( - @transient val query: ContinuousQuery, +class StreamingQueryException private[sql]( + @transient val query: StreamingQuery, val message: String, val cause: Throwable, val startOffset: Option[Offset] = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala new file mode 100644 index 0000000000000..9e311fae842be --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.annotation.Experimental +import org.apache.spark.scheduler.SparkListenerEvent + +/** + * :: Experimental :: + * Interface for listening to events related to [[StreamingQuery StreamingQueries]]. + * @note The methods are not thread-safe as they may be called from different threads. + * + * @since 2.0.0 + */ +@Experimental +abstract class StreamingQueryListener { + + import StreamingQueryListener._ + + /** + * Called when a query is started. + * @note This is called synchronously with + * [[org.apache.spark.sql.streaming.DataStreamWriter `DataStreamWriter.start()`]], + * that is, `onQueryStart` will be called on all listeners before + * `DataStreamWriter.start()` returns the corresponding [[StreamingQuery]]. Please + * don't block this method as it will block your query. + * @since 2.0.0 + */ + def onQueryStarted(event: QueryStartedEvent): Unit + + /** + * Called when there is some status update (ingestion rate updated, etc.) + * + * @note This method is asynchronous. The status in [[StreamingQuery]] will always be + * latest no matter when this method is called. Therefore, the status of [[StreamingQuery]] + * may be changed before/when you process the event. E.g., you may find [[StreamingQuery]] + * is terminated when you are processing [[QueryProgressEvent]]. + * @since 2.0.0 + */ + def onQueryProgress(event: QueryProgressEvent): Unit + + /** + * Called when a query is stopped, with or without error. + * @since 2.0.0 + */ + def onQueryTerminated(event: QueryTerminatedEvent): Unit +} + + +/** + * :: Experimental :: + * Companion object of [[StreamingQueryListener]] that defines the listener events. + * @since 2.0.0 + */ +@Experimental +object StreamingQueryListener { + + /** + * :: Experimental :: + * Base type of [[StreamingQueryListener]] events + * @since 2.0.0 + */ + @Experimental + trait Event extends SparkListenerEvent + + /** + * :: Experimental :: + * Event representing the start of a query + * @since 2.0.0 + */ + @Experimental + class QueryStartedEvent private[sql](val queryStatus: StreamingQueryStatus) extends Event + + /** + * :: Experimental :: + * Event representing any progress updates in a query + * @since 2.0.0 + */ + @Experimental + class QueryProgressEvent private[sql](val queryStatus: StreamingQueryStatus) extends Event + + /** + * :: Experimental :: + * Event representing that termination of a query + * + * @param queryStatus Information about the status of the query. + * @param exception The exception message of the [[StreamingQuery]] if the query was terminated + * with an exception. Otherwise, it will be `None`. + * @since 2.0.0 + */ + @Experimental + class QueryTerminatedEvent private[sql]( + val queryStatus: StreamingQueryStatus, + val exception: Option[String]) extends Event +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala similarity index 60% rename from sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index f82130cfa849d..bae7f56a23f81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -15,53 +15,54 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import scala.collection.mutable +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode, UnsupportedOperationChecker} +import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.util.ContinuousQueryListener +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * :: Experimental :: - * A class to manage all the [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active - * on a [[SparkSession]]. + * A class to manage all the [[StreamingQuery]] active on a [[SparkSession]]. * * @since 2.0.0 */ @Experimental -class ContinuousQueryManager(sparkSession: SparkSession) { +class StreamingQueryManager private[sql] (sparkSession: SparkSession) { private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) - private val listenerBus = new ContinuousQueryListenerBus(sparkSession.sparkContext.listenerBus) - private val activeQueries = new mutable.HashMap[String, ContinuousQuery] + private val listenerBus = new StreamingQueryListenerBus(sparkSession.sparkContext.listenerBus) + private val activeQueries = new mutable.HashMap[Long, StreamingQuery] private val activeQueriesLock = new Object private val awaitTerminationLock = new Object - private var lastTerminatedQuery: ContinuousQuery = null + private var lastTerminatedQuery: StreamingQuery = null /** * Returns a list of active queries associated with this SQLContext * * @since 2.0.0 */ - def active: Array[ContinuousQuery] = activeQueriesLock.synchronized { + def active: Array[StreamingQuery] = activeQueriesLock.synchronized { activeQueries.values.toArray } /** - * Returns an active query from this SQLContext or throws exception if bad name + * Returns the query if there is an active query with the given id, or null. * * @since 2.0.0 */ - def get(name: String): ContinuousQuery = activeQueriesLock.synchronized { - activeQueries.getOrElse(name, - throw new IllegalArgumentException(s"There is no active query with name $name")) + def get(id: Long): StreamingQuery = activeQueriesLock.synchronized { + activeQueries.get(id).orNull } /** @@ -80,7 +81,7 @@ class ContinuousQueryManager(sparkSession: SparkSession) { * users need to stop all of them after any of them terminates with exception, and then check the * `query.exception()` for each query. * - * @throws ContinuousQueryException, if any query has terminated with an exception + * @throws StreamingQueryException, if any query has terminated with an exception * * @since 2.0.0 */ @@ -112,7 +113,7 @@ class ContinuousQueryManager(sparkSession: SparkSession) { * users need to stop all of them after any of them terminates with exception, and then check the * `query.exception()` for each query. * - * @throws ContinuousQueryException, if any query has terminated with an exception + * @throws StreamingQueryException, if any query has terminated with an exception * * @since 2.0.0 */ @@ -145,42 +146,89 @@ class ContinuousQueryManager(sparkSession: SparkSession) { } /** - * Register a [[ContinuousQueryListener]] to receive up-calls for life cycle events of - * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]]. + * Register a [[StreamingQueryListener]] to receive up-calls for life cycle events of + * [[StreamingQuery]]. * * @since 2.0.0 */ - def addListener(listener: ContinuousQueryListener): Unit = { + def addListener(listener: StreamingQueryListener): Unit = { listenerBus.addListener(listener) } /** - * Deregister a [[ContinuousQueryListener]]. + * Deregister a [[StreamingQueryListener]]. * * @since 2.0.0 */ - def removeListener(listener: ContinuousQueryListener): Unit = { + def removeListener(listener: StreamingQueryListener): Unit = { listenerBus.removeListener(listener) } /** Post a listener event */ - private[sql] def postListenerEvent(event: ContinuousQueryListener.Event): Unit = { + private[sql] def postListenerEvent(event: StreamingQueryListener.Event): Unit = { listenerBus.post(event) } - /** Start a query */ + /** + * Start a [[StreamingQuery]]. + * @param userSpecifiedName Query name optionally specified by the user. + * @param userSpecifiedCheckpointLocation Checkpoint location optionally specified by the user. + * @param df Streaming DataFrame. + * @param sink Sink to write the streaming outputs. + * @param outputMode Output mode for the sink. + * @param useTempCheckpointLocation Whether to use a temporary checkpoint location when the user + * has not specified one. If false, then error will be thrown. + * @param recoverFromCheckpointLocation Whether to recover query from the checkpoint location. + * If false and the checkpoint location exists, then error + * will be thrown. + * @param trigger [[Trigger]] for the query. + * @param triggerClock [[Clock]] to use for the triggering. + */ private[sql] def startQuery( - name: String, - checkpointLocation: String, + userSpecifiedName: Option[String], + userSpecifiedCheckpointLocation: Option[String], df: DataFrame, sink: Sink, + outputMode: OutputMode, + useTempCheckpointLocation: Boolean = false, + recoverFromCheckpointLocation: Boolean = true, trigger: Trigger = ProcessingTime(0), - outputMode: OutputMode = Append): ContinuousQuery = { + triggerClock: Clock = new SystemClock()): StreamingQuery = { activeQueriesLock.synchronized { - if (activeQueries.contains(name)) { + val id = StreamExecution.nextId + val name = userSpecifiedName.getOrElse(s"query-$id") + if (activeQueries.values.exists(_.name == name)) { throw new IllegalArgumentException( s"Cannot start query with name $name as a query with that name is already active") } + val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified => + new Path(userSpecified).toUri.toString + }.orElse { + df.sparkSession.conf.get(SQLConf.CHECKPOINT_LOCATION).map { location => + new Path(location, name).toUri.toString + } + }.getOrElse { + if (useTempCheckpointLocation) { + Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath + } else { + throw new AnalysisException( + "checkpointLocation must be specified either " + + """through option("checkpointLocation", ...) or """ + + s"""SparkSession.conf.set("${SQLConf.CHECKPOINT_LOCATION.key}", ...)""") + } + } + + // If offsets have already been created, we trying to resume a query. + if (!recoverFromCheckpointLocation) { + val checkpointPath = new Path(checkpointLocation, "offsets") + val fs = checkpointPath.getFileSystem(df.sparkSession.sessionState.newHadoopConf()) + if (fs.exists(checkpointPath)) { + throw new AnalysisException( + s"This query does not support recovering from checkpoint location. " + + s"Delete $checkpointPath to start over.") + } + } + val analyzedPlan = df.queryExecution.analyzed df.queryExecution.assertAnalyzed() @@ -202,22 +250,24 @@ class ContinuousQueryManager(sparkSession: SparkSession) { } val query = new StreamExecution( sparkSession, + id, name, checkpointLocation, logicalPlan, sink, - outputMode, - trigger) + trigger, + triggerClock, + outputMode) query.start() - activeQueries.put(name, query) + activeQueries.put(id, query) query } } - /** Notify (by the ContinuousQuery) that the query has been terminated */ - private[sql] def notifyQueryTermination(terminatedQuery: ContinuousQuery): Unit = { + /** Notify (by the StreamingQuery) that the query has been terminated */ + private[sql] def notifyQueryTermination(terminatedQuery: StreamingQuery): Unit = { activeQueriesLock.synchronized { - activeQueries -= terminatedQuery.name + activeQueries -= terminatedQuery.id } awaitTerminationLock.synchronized { if (lastTerminatedQuery == null || terminatedQuery.exception.nonEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala new file mode 100644 index 0000000000000..a50b0d96c13f7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.{util => ju} + +import scala.collection.JavaConverters._ + +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset} +import org.apache.spark.util.JsonProtocol + +/** + * :: Experimental :: + * A class used to report information about the progress of a [[StreamingQuery]]. + * + * @param name Name of the query. This name is unique across all active queries. + * @param id Id of the query. This id is unique across + * all queries that have been started in the current process. + * @param timestamp Timestamp (ms) of when this query was generated. + * @param inputRate Current rate (rows/sec) at which data is being generated by all the sources. + * @param processingRate Current rate (rows/sec) at which the query is processing data from + * all the sources. + * @param latency Current average latency between the data being available in source and the sink + * writing the corresponding output. + * @param sourceStatuses Current statuses of the sources. + * @param sinkStatus Current status of the sink. + * @param triggerDetails Low-level details of the currently active trigger (e.g. number of + * rows processed in trigger, latency of intermediate steps, etc.). + * If no trigger is active, then it will have details of the last completed + * trigger. + * @since 2.0.0 + */ +@Experimental +class StreamingQueryStatus private( + val name: String, + val id: Long, + val timestamp: Long, + val inputRate: Double, + val processingRate: Double, + val latency: Option[Double], + val sourceStatuses: Array[SourceStatus], + val sinkStatus: SinkStatus, + val triggerDetails: ju.Map[String, String]) { + + import StreamingQueryStatus._ + + /** The compact JSON representation of this status. */ + def json: String = compact(render(jsonValue)) + + /** The pretty (i.e. indented) JSON representation of this status. */ + def prettyJson: String = pretty(render(jsonValue)) + + override def toString: String = { + val sourceStatusLines = sourceStatuses.zipWithIndex.map { case (s, i) => + s"Source ${i + 1} - " + indent(s.prettyString).trim + } + val sinkStatusLines = sinkStatus.prettyString.trim + val triggerDetailsLines = triggerDetails.asScala.map { case (k, v) => s"$k: $v" }.toSeq.sorted + val numSources = sourceStatuses.length + val numSourcesString = s"$numSources source" + { if (numSources > 1) "s" else "" } + + val allLines = + s"""|Query id: $id + |Status timestamp: $timestamp + |Input rate: $inputRate rows/sec + |Processing rate $processingRate rows/sec + |Latency: ${latency.getOrElse("-")} ms + |Trigger details: + |${indent(triggerDetailsLines)} + |Source statuses [$numSourcesString]: + |${indent(sourceStatusLines)} + |Sink status - ${indent(sinkStatusLines).trim}""".stripMargin + + s"Status of query '$name'\n${indent(allLines)}" + } + + private[sql] def jsonValue: JValue = { + ("name" -> JString(name)) ~ + ("id" -> JInt(id)) ~ + ("timestamp" -> JInt(timestamp)) ~ + ("inputRate" -> JDouble(inputRate)) ~ + ("processingRate" -> JDouble(processingRate)) ~ + ("latency" -> latency.map(JDouble).getOrElse(JNothing)) ~ + ("triggerDetails" -> JsonProtocol.mapToJson(triggerDetails.asScala)) + ("sourceStatuses" -> JArray(sourceStatuses.map(_.jsonValue).toList)) ~ + ("sinkStatus" -> sinkStatus.jsonValue) + } +} + +/** Companion object, primarily for creating StreamingQueryInfo instances internally */ +private[sql] object StreamingQueryStatus { + def apply( + name: String, + id: Long, + timestamp: Long, + inputRate: Double, + processingRate: Double, + latency: Option[Double], + sourceStatuses: Array[SourceStatus], + sinkStatus: SinkStatus, + triggerDetails: Map[String, String]): StreamingQueryStatus = { + new StreamingQueryStatus(name, id, timestamp, inputRate, processingRate, + latency, sourceStatuses, sinkStatus, triggerDetails.asJava) + } + + def indent(strings: Iterable[String]): String = strings.map(indent).mkString("\n") + def indent(string: String): String = string.split("\n").map(" " + _).mkString("\n") + + /** Create an instance of status for python testing */ + def testStatus(): StreamingQueryStatus = { + import org.apache.spark.sql.execution.streaming.StreamMetrics._ + StreamingQueryStatus( + name = "query", + id = 1, + timestamp = 123, + inputRate = 15.5, + processingRate = 23.5, + latency = Some(345), + sourceStatuses = Array( + SourceStatus( + desc = "MySource1", + offsetDesc = LongOffset(0).toString, + inputRate = 15.5, + processingRate = 23.5, + triggerDetails = Map( + NUM_SOURCE_INPUT_ROWS -> "100", + SOURCE_GET_OFFSET_LATENCY -> "10", + SOURCE_GET_BATCH_LATENCY -> "20"))), + sinkStatus = SinkStatus( + desc = "MySink", + offsetDesc = CompositeOffset(Some(LongOffset(1)) :: None :: Nil).toString), + triggerDetails = Map( + TRIGGER_ID -> "5", + IS_TRIGGER_ACTIVE -> "true", + IS_DATA_PRESENT_IN_TRIGGER -> "true", + GET_OFFSET_LATENCY -> "10", + GET_BATCH_LATENCY -> "20", + NUM_INPUT_ROWS -> "100" + )) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala similarity index 92% rename from sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala rename to sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala index 256e8a47a4665..55be7a711adb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/Trigger.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import java.util.concurrent.TimeUnit @@ -28,10 +28,12 @@ import org.apache.spark.unsafe.types.CalendarInterval /** * :: Experimental :: - * Used to indicate how often results should be produced by a [[ContinuousQuery]]. + * Used to indicate how often results should be produced by a [[StreamingQuery]]. + * + * @since 2.0.0 */ @Experimental -sealed trait Trigger {} +sealed trait Trigger /** * :: Experimental :: @@ -53,6 +55,8 @@ sealed trait Trigger {} * import java.util.concurrent.TimeUnit * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} + * + * @since 2.0.0 */ @Experimental case class ProcessingTime(intervalMs: Long) extends Trigger { @@ -61,7 +65,9 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { /** * :: Experimental :: - * Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s. + * Used to create [[ProcessingTime]] triggers for [[StreamingQuery]]s. + * + * @since 2.0.0 */ @Experimental object ProcessingTime { @@ -73,6 +79,8 @@ object ProcessingTime { * {{{ * df.write.trigger(ProcessingTime("10 seconds")) * }}} + * + * @since 2.0.0 */ def apply(interval: String): ProcessingTime = { if (StringUtils.isBlank(interval)) { @@ -101,6 +109,8 @@ object ProcessingTime { * import scala.concurrent.duration._ * df.write.trigger(ProcessingTime(10.seconds)) * }}} + * + * @since 2.0.0 */ def apply(interval: Duration): ProcessingTime = { new ProcessingTime(interval.toMillis) @@ -113,6 +123,8 @@ object ProcessingTime { * {{{ * df.write.trigger(ProcessingTime.create("10 seconds")) * }}} + * + * @since 2.0.0 */ def create(interval: String): ProcessingTime = { apply(interval) @@ -126,6 +138,8 @@ object ProcessingTime { * import java.util.concurrent.TimeUnit * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} + * + * @since 2.0.0 */ def create(interval: Long, unit: TimeUnit): ProcessingTime = { new ProcessingTime(unit.toMillis(interval)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala deleted file mode 100644 index ba1facf11b7d5..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.util - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.ContinuousQuery -import org.apache.spark.sql.util.ContinuousQueryListener._ - -/** - * :: Experimental :: - * Interface for listening to events related to [[ContinuousQuery ContinuousQueries]]. - * @note The methods are not thread-safe as they may be called from different threads. - */ -@Experimental -abstract class ContinuousQueryListener { - - /** - * Called when a query is started. - * @note This is called synchronously with - * [[org.apache.spark.sql.DataFrameWriter `DataFrameWriter.startStream()`]], - * that is, `onQueryStart` will be called on all listeners before - * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]]. Please - * don't block this method as it will block your query. - */ - def onQueryStarted(queryStarted: QueryStarted): Unit - - /** - * Called when there is some status update (ingestion rate updated, etc.) - * - * @note This method is asynchronous. The status in [[ContinuousQuery]] will always be - * latest no matter when this method is called. Therefore, the status of [[ContinuousQuery]] - * may be changed before/when you process the event. E.g., you may find [[ContinuousQuery]] - * is terminated when you are processing [[QueryProgress]]. - */ - def onQueryProgress(queryProgress: QueryProgress): Unit - - /** Called when a query is stopped, with or without error */ - def onQueryTerminated(queryTerminated: QueryTerminated): Unit -} - - -/** - * :: Experimental :: - * Companion object of [[ContinuousQueryListener]] that defines the listener events. - */ -@Experimental -object ContinuousQueryListener { - - /** Base type of [[ContinuousQueryListener]] events */ - trait Event - - /** Event representing the start of a query */ - class QueryStarted private[sql](val query: ContinuousQuery) extends Event - - /** Event representing any progress updates in a query */ - class QueryProgress private[sql](val query: ContinuousQuery) extends Event - - /** Event representing that termination of a query */ - class QueryTerminated private[sql](val query: ContinuousQuery) extends Event -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 189cc3972c9ba..573d0e3594363 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -28,14 +28,13 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -44,21 +43,22 @@ // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaApplySchemaSuite implements Serializable { - private transient JavaSparkContext javaCtx; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - SparkContext context = new SparkContext("local[*]", "testing"); - javaCtx = new JavaSparkContext(context); - sqlContext = new SQLContext(context); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - javaCtx = null; + spark.stop(); + spark = null; } public static class Person implements Serializable { @@ -94,7 +94,7 @@ public void applySchema() { person2.setAge(28); personList.add(person2); - JavaRDD rowRDD = javaCtx.parallelize(personList).map( + JavaRDD rowRDD = jsc.parallelize(personList).map( new Function() { @Override public Row call(Person person) throws Exception { @@ -107,9 +107,9 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset df = sqlContext.createDataFrame(rowRDD, schema); - df.registerTempTable("people"); - List actual = sqlContext.sql("SELECT * FROM people").collectAsList(); + Dataset df = spark.createDataFrame(rowRDD, schema); + df.createOrReplaceTempView("people"); + List actual = spark.sql("SELECT * FROM people").collectAsList(); List expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); @@ -130,7 +130,7 @@ public void dataFrameRDDOperations() { person2.setAge(28); personList.add(person2); - JavaRDD rowRDD = javaCtx.parallelize(personList).map( + JavaRDD rowRDD = jsc.parallelize(personList).map( new Function() { @Override public Row call(Person person) { @@ -143,9 +143,9 @@ public Row call(Person person) { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset df = sqlContext.createDataFrame(rowRDD, schema); - df.registerTempTable("people"); - List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD() + Dataset df = spark.createDataFrame(rowRDD, schema); + df.createOrReplaceTempView("people"); + List actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(new Function() { @Override public String call(Row row) { @@ -162,7 +162,7 @@ public String call(Row row) { @Test public void applySchemaToJSON() { - JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( + JavaRDD jsonRDD = jsc.parallelize(Arrays.asList( "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + "\"boolean\":true, \"null\":null}", @@ -199,18 +199,18 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - Dataset df1 = sqlContext.read().json(jsonRDD); + Dataset df1 = spark.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); - df1.registerTempTable("jsonTable1"); - List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); + df1.createOrReplaceTempView("jsonTable1"); + List actual1 = spark.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - Dataset df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + Dataset df2 = spark.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - df2.registerTempTable("jsonTable2"); - List actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); + df2.createOrReplaceTempView("jsonTable2"); + List actual2 = spark.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java new file mode 100644 index 0000000000000..7babf7573c075 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java @@ -0,0 +1,158 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package test.org.apache.spark.sql; + +import java.io.File; +import java.util.HashMap; + +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.test.TestSparkSession; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.Utils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class JavaDataFrameReaderWriterSuite { + private SparkSession spark = new TestSparkSession(); + private StructType schema = new StructType().add("s", "string"); + private transient String input; + private transient String output; + + @Before + public void setUp() { + input = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "input").toString(); + File f = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "output"); + f.delete(); + output = f.toString(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @Test + public void testFormatAPI() { + spark + .read() + .format("org.apache.spark.sql.test") + .load() + .write() + .format("org.apache.spark.sql.test") + .save(); + } + + @Test + public void testOptionsAPI() { + HashMap map = new HashMap(); + map.put("e", "1"); + spark + .read() + .option("a", "1") + .option("b", 1) + .option("c", 1.0) + .option("d", true) + .options(map) + .text() + .write() + .option("a", "1") + .option("b", 1) + .option("c", 1.0) + .option("d", true) + .options(map) + .format("org.apache.spark.sql.test") + .save(); + } + + @Test + public void testSaveModeAPI() { + spark + .range(10) + .write() + .format("org.apache.spark.sql.test") + .mode(SaveMode.ErrorIfExists) + .save(); + } + + @Test + public void testLoadAPI() { + spark.read().format("org.apache.spark.sql.test").load(); + spark.read().format("org.apache.spark.sql.test").load(input); + spark.read().format("org.apache.spark.sql.test").load(input, input, input); + spark.read().format("org.apache.spark.sql.test").load(new String[]{input, input}); + } + + @Test + public void testTextAPI() { + spark.read().text(); + spark.read().text(input); + spark.read().text(input, input, input); + spark.read().text(new String[]{input, input}) + .write().text(output); + } + + @Test + public void testTextFileAPI() { + spark.read().textFile(); + spark.read().textFile(input); + spark.read().textFile(input, input, input); + spark.read().textFile(new String[]{input, input}); + } + + @Test + public void testCsvAPI() { + spark.read().schema(schema).csv(); + spark.read().schema(schema).csv(input); + spark.read().schema(schema).csv(input, input, input); + spark.read().schema(schema).csv(new String[]{input, input}) + .write().csv(output); + } + + @Test + public void testJsonAPI() { + spark.read().schema(schema).json(); + spark.read().schema(schema).json(input); + spark.read().schema(schema).json(input, input, input); + spark.read().schema(schema).json(new String[]{input, input}) + .write().json(output); + } + + @Test + public void testParquetAPI() { + spark.read().schema(schema).parquet(); + spark.read().schema(schema).parquet(input); + spark.read().schema(schema).parquet(input, input, input); + spark.read().schema(schema).parquet(new String[] { input, input }) + .write().parquet(output); + } + + /** + * This only tests whether API compiles, but does not run it as orc() + * cannot be run without Hive classes. + */ + public void testOrcAPI() { + spark.read().schema(schema).orc(); + spark.read().schema(schema).orc(input); + spark.read().schema(schema).orc(input, input, input); + spark.read().schema(schema).orc(new String[]{input, input}) + .write().orc(output); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 1eb680dc4c029..c44fc3d393862 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -20,12 +20,9 @@ import java.io.Serializable; import java.net.URISyntaxException; import java.net.URL; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.ArrayList; +import java.util.*; +import java.math.BigInteger; +import java.math.BigDecimal; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -34,46 +31,45 @@ import com.google.common.primitives.Ints; import org.junit.*; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; +import org.apache.spark.util.sketch.BloomFilter; import org.apache.spark.util.sketch.CountMinSketch; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; -import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } @Test public void testExecution() { - Dataset df = context.table("testData").filter("key = 1"); + Dataset df = spark.table("testData").filter("key = 1"); Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); } @Test public void testCollectAndTake() { - Dataset df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Dataset df = spark.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } @@ -83,7 +79,7 @@ public void testCollectAndTake() { */ @Test public void testVarargMethods() { - Dataset df = context.table("testData"); + Dataset df = spark.table("testData"); df.toDF("key1", "value1"); @@ -112,7 +108,7 @@ public void testVarargMethods() { df.select(coalesce(col("key"))); // Varargs with mathfunctions - Dataset df2 = context.table("testData2"); + Dataset df2 = spark.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -126,7 +122,7 @@ public void testVarargMethods() { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - Dataset df = context.table("testData"); + Dataset df = spark.table("testData"); df.show(); df.show(1000); } @@ -136,6 +132,7 @@ public static class Bean implements Serializable { private Integer[] b = { 0, 1 }; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); + private BigInteger e = new BigInteger("1234567"); public double getA() { return a; @@ -152,6 +149,8 @@ public Map getC() { public List getD() { return d; } + + public BigInteger getE() { return e; } } void validateDataFrameWithBeans(Bean bean, Dataset df) { @@ -169,7 +168,9 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { Assert.assertEquals( new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), schema.apply("d")); - Row first = df.select("a", "b", "c", "d").first(); + Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, + Metadata.empty()), schema.apply("e")); + Row first = df.select("a", "b", "c", "d", "e").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -188,13 +189,15 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { for (int i = 0; i < d.length(); i++) { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } + // Java.math.BigInteger is equavient to Spark Decimal(38,0) + Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4)); } @Test public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List data = Arrays.asList(bean); - Dataset df = context.createDataFrame(data, Bean.class); + Dataset df = spark.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -202,7 +205,7 @@ public void testCreateDataFrameFromLocalJavaBeans() { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); - Dataset df = context.createDataFrame(rdd, Bean.class); + Dataset df = spark.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -210,7 +213,7 @@ public void testCreateDataFrameFromJavaBeans() { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List rows = Arrays.asList(RowFactory.create(0)); - Dataset df = context.createDataFrame(rows, schema); + Dataset df = spark.createDataFrame(rows, schema); List result = df.collectAsList(); Assert.assertEquals(1, result.size()); } @@ -239,12 +242,12 @@ public int compare(Row row1, Row row2) { @Test public void testCrosstab() { - Dataset df = context.table("testData2"); + Dataset df = spark.table("testData2"); Dataset crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); - Assert.assertEquals("2", columnNames[1]); - Assert.assertEquals("1", columnNames[2]); + Assert.assertEquals("1", columnNames[1]); + Assert.assertEquals("2", columnNames[2]); List rows = crosstab.collectAsList(); Collections.sort(rows, crosstabRowComparator); Integer count = 1; @@ -258,7 +261,7 @@ public void testCrosstab() { @Test public void testFrequentItems() { - Dataset df = context.table("testData2"); + Dataset df = spark.table("testData2"); String[] cols = {"a"}; Dataset results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1)); @@ -266,21 +269,21 @@ public void testFrequentItems() { @Test public void testCorrelation() { - Dataset df = context.table("testData2"); + Dataset df = spark.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - Dataset df = context.table("testData2"); + Dataset df = spark.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - Dataset df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); @@ -291,7 +294,7 @@ public void testSampleBy() { @Test public void pivot() { - Dataset df = context.table("courseSales"); + Dataset df = spark.table("courseSales"); List actual = df.groupBy("year") .pivot("course", Arrays.asList("dotNET", "Java")) .agg(sum("earnings")).orderBy("year").collectAsList(); @@ -324,29 +327,29 @@ private String getResource(String resource) { @Test public void testGenericLoad() { - Dataset df1 = context.read().format("text").load(getResource("text-suite.txt")); + Dataset df1 = spark.read().format("text").load(getResource("test-data/text-suite.txt")); Assert.assertEquals(4L, df1.count()); - Dataset df2 = context.read().format("text").load( - getResource("text-suite.txt"), - getResource("text-suite2.txt")); + Dataset df2 = spark.read().format("text").load( + getResource("test-data/text-suite.txt"), + getResource("test-data/text-suite2.txt")); Assert.assertEquals(5L, df2.count()); } @Test public void testTextLoad() { - Dataset ds1 = context.read().text(getResource("text-suite.txt")); + Dataset ds1 = spark.read().textFile(getResource("test-data/text-suite.txt")); Assert.assertEquals(4L, ds1.count()); - Dataset ds2 = context.read().text( - getResource("text-suite.txt"), - getResource("text-suite2.txt")); + Dataset ds2 = spark.read().textFile( + getResource("test-data/text-suite.txt"), + getResource("test-data/text-suite2.txt")); Assert.assertEquals(5L, ds2.count()); } @Test public void testCountMinSketch() { - Dataset df = context.range(1000); + Dataset df = spark.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -371,7 +374,7 @@ public void testCountMinSketch() { @Test public void testBloomFilter() { - Dataset df = context.range(1000); + Dataset df = spark.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java similarity index 97% rename from sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java index 0e49f871de5c4..fe863715162f5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.sources; +package test.org.apache.spark.sql; import java.util.Arrays; @@ -30,7 +30,7 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; import org.apache.spark.sql.expressions.Aggregator; -import org.apache.spark.sql.expressions.java.typed; +import org.apache.spark.sql.expressions.javalang.typed; /** * Suite for testing the aggregate functionality of Datasets in Java. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java similarity index 78% rename from sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java index 7863177093c15..8fc4eff55ddd0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetAggregatorSuiteBase.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.sql.sources; +package test.org.apache.spark.sql; import java.io.Serializable; import java.util.Arrays; @@ -26,36 +26,30 @@ import org.junit.After; import org.junit.Before; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.test.TestSparkSession; /** * Common test base shared across this and Java8DatasetAggregatorSuite. */ public class JavaDatasetAggregatorSuiteBase implements Serializable { - protected transient JavaSparkContext jsc; - protected transient TestSQLContext context; + private transient TestSparkSession spark; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } protected Tuple2 tuple2(T1 t1, T2 t2) { @@ -66,7 +60,7 @@ protected KeyValueGroupedDataset> generateGroupe Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); List> data = Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); - Dataset> ds = context.createDataset(data, encoder); + Dataset> ds = spark.createDataset(data, encoder); return ds.groupByKey( new MapFunction, String>() { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f1b1c22e4a6a8..a711811f418c6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,46 +23,43 @@ import java.sql.Timestamp; import java.util.*; -import com.google.common.base.Objects; -import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; +import com.google.common.base.Objects; import org.junit.*; +import org.junit.rules.ExpectedException; -import org.apache.spark.Accumulator; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.functions.*; +import org.apache.spark.util.LongAccumulator; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.expr; import static org.apache.spark.sql.types.DataTypes.*; public class JavaDatasetSuite implements Serializable { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } private Tuple2 tuple2(T1 t1, T2 t2) { @@ -72,7 +69,7 @@ private Tuple2 tuple2(T1 t1, T2 t2) { @Test public void testCollect() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); List collected = ds.collectAsList(); Assert.assertEquals(Arrays.asList("hello", "world"), collected); } @@ -80,7 +77,7 @@ public void testCollect() { @Test public void testTake() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); List collected = ds.takeAsList(1); Assert.assertEquals(Arrays.asList("hello"), collected); } @@ -88,17 +85,30 @@ public void testTake() { @Test public void testToLocalIterator() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); Iterator iter = ds.toLocalIterator(); Assert.assertEquals("hello", iter.next()); Assert.assertEquals("world", iter.next()); Assert.assertFalse(iter.hasNext()); } + // SPARK-15632: typed filter should preserve the underlying logical schema + @Test + public void testTypedFilterPreservingSchema() { + Dataset ds = spark.range(10); + Dataset ds2 = ds.filter(new FilterFunction() { + @Override + public boolean call(Long value) throws Exception { + return value > 3; + } + }); + Assert.assertEquals(ds.schema(), ds2.schema()); + } + @Test public void testCommonOperation() { List data = Arrays.asList("hello", "world"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset filtered = ds.filter(new FilterFunction() { @@ -147,9 +157,9 @@ public Iterator call(String s) { @Test public void testForeach() { - final Accumulator accum = jsc.accumulator(0); + final LongAccumulator accum = jsc.sc().longAccumulator(); List data = Arrays.asList("a", "b", "c"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); ds.foreach(new ForeachFunction() { @Override @@ -163,7 +173,7 @@ public void call(String s) throws Exception { @Test public void testReduce() { List data = Arrays.asList(1, 2, 3); - Dataset ds = context.createDataset(data, Encoders.INT()); + Dataset ds = spark.createDataset(data, Encoders.INT()); int reduced = ds.reduce(new ReduceFunction() { @Override @@ -177,7 +187,7 @@ public Integer call(Integer v1, Integer v2) throws Exception { @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); KeyValueGroupedDataset grouped = ds.groupByKey( new MapFunction() { @Override @@ -227,7 +237,7 @@ public String call(String v1, String v2) throws Exception { toSet(reduced.collectAsList())); List data2 = Arrays.asList(2, 6, 10); - Dataset ds2 = context.createDataset(data2, Encoders.INT()); + Dataset ds2 = spark.createDataset(data2, Encoders.INT()); KeyValueGroupedDataset grouped2 = ds2.groupByKey( new MapFunction() { @Override @@ -261,7 +271,7 @@ public Iterator call(Integer key, Iterator left, Iterator data = Arrays.asList(2, 6); - Dataset ds = context.createDataset(data, Encoders.INT()); + Dataset ds = spark.createDataset(data, Encoders.INT()); Dataset> selected = ds.select( expr("value + 1"), @@ -275,12 +285,12 @@ public void testSelect() { @Test public void testSetOperation() { List data = Arrays.asList("abc", "abc", "xyz"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList())); List data2 = Arrays.asList("xyz", "foo", "foo"); - Dataset ds2 = context.createDataset(data2, Encoders.STRING()); + Dataset ds2 = spark.createDataset(data2, Encoders.STRING()); Dataset intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); @@ -307,9 +317,9 @@ private static Set asSet(T... records) { @Test public void testJoin() { List data = Arrays.asList(1, 2, 3); - Dataset ds = context.createDataset(data, Encoders.INT()).as("a"); + Dataset ds = spark.createDataset(data, Encoders.INT()).as("a"); List data2 = Arrays.asList(2, 3, 4); - Dataset ds2 = context.createDataset(data2, Encoders.INT()).as("b"); + Dataset ds2 = spark.createDataset(data2, Encoders.INT()).as("b"); Dataset> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); @@ -322,21 +332,21 @@ public void testJoin() { public void testTupleEncoder() { Encoder> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); - Dataset> ds2 = context.createDataset(data2, encoder2); + Dataset> ds2 = spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); Encoder> encoder3 = Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List> data3 = Arrays.asList(new Tuple3<>(1, 2L, "a")); - Dataset> ds3 = context.createDataset(data3, encoder3); + Dataset> ds3 = spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder> encoder4 = Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List> data4 = Arrays.asList(new Tuple4<>(1, "b", 2L, "a")); - Dataset> ds4 = context.createDataset(data4, encoder4); + Dataset> ds4 = spark.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder> encoder5 = @@ -345,7 +355,7 @@ public void testTupleEncoder() { List> data5 = Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true)); Dataset> ds5 = - context.createDataset(data5, encoder5); + spark.createDataset(data5, encoder5); Assert.assertEquals(data5, ds5.collectAsList()); } @@ -356,7 +366,7 @@ public void testNestedTupleEncoder() { Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); - Dataset, String>> ds = context.createDataset(data, encoder); + Dataset, String>> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); // test (int, (string, string, long)) @@ -366,7 +376,7 @@ public void testNestedTupleEncoder() { List>> data2 = Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L))); Dataset>> ds2 = - context.createDataset(data2, encoder2); + spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); // test (int, ((string, long), string)) @@ -376,7 +386,7 @@ public void testNestedTupleEncoder() { List, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset, String>>> ds3 = - context.createDataset(data3, encoder3); + spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); } @@ -390,7 +400,7 @@ public void testPrimitiveEncoder() { 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); Dataset> ds = - context.createDataset(data, encoder); + spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -441,7 +451,7 @@ public void testKryoEncoder() { Encoder encoder = Encoders.kryo(KryoSerializable.class); List data = Arrays.asList( new KryoSerializable("hello"), new KryoSerializable("world")); - Dataset ds = context.createDataset(data, encoder); + Dataset ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -450,14 +460,14 @@ public void testJavaEncoder() { Encoder encoder = Encoders.javaSerialization(JavaSerializable.class); List data = Arrays.asList( new JavaSerializable("hello"), new JavaSerializable("world")); - Dataset ds = context.createDataset(data, encoder); + Dataset ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @Test public void testRandomSplit() { List data = Arrays.asList("hello", "world", "from", "spark"); - Dataset ds = context.createDataset(data, Encoders.STRING()); + Dataset ds = spark.createDataset(data, Encoders.STRING()); double[] arraySplit = {1, 2, 3}; List> randomSplit = ds.randomSplitAsList(arraySplit, 1); @@ -647,14 +657,14 @@ public void testJavaBeanEncoder() { obj2.setF(Arrays.asList(300L, null, 400L)); List data = Arrays.asList(obj1, obj2); - Dataset ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Dataset ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds.collectAsList()); NestedJavaBean obj3 = new NestedJavaBean(); obj3.setA(obj1); List data2 = Arrays.asList(obj3); - Dataset ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Dataset ds2 = spark.createDataset(data2, Encoders.bean(NestedJavaBean.class)); Assert.assertEquals(data2, ds2.collectAsList()); Row row1 = new GenericRow(new Object[]{ @@ -678,7 +688,7 @@ public void testJavaBeanEncoder() { .add("d", createArrayType(StringType)) .add("e", createArrayType(StringType)) .add("f", createArrayType(LongType)); - Dataset ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + Dataset ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema) .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); } @@ -692,7 +702,7 @@ public void testJavaBeanEncoder2() { obj.setB(new Date(0)); obj.setC(java.math.BigDecimal.valueOf(1)); Dataset ds = - context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); + spark.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); ds.collect(); } @@ -776,7 +786,7 @@ public void testRuntimeNullabilityCheck() { }) }); - Dataset df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); SmallBean smallBean = new SmallBean(); @@ -793,7 +803,7 @@ public void testRuntimeNullabilityCheck() { { Row row = new GenericRow(new Object[] { null }); - Dataset df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); NestedSmallBean nestedSmallBean = new NestedSmallBean(); @@ -810,7 +820,7 @@ public void testRuntimeNullabilityCheck() { }) }); - Dataset df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); ds.collect(); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java similarity index 76% rename from sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java index 9e65158eb0a33..6941c86dfcd4b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java @@ -15,18 +15,20 @@ * limitations under the License. */ -package test.org.apache.spark.sql.sources; +package test.org.apache.spark.sql; import java.io.File; import java.io.IOException; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; @@ -37,8 +39,8 @@ public class JavaSaveLoadSuite { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; File path; Dataset df; @@ -52,9 +54,11 @@ private static void checkAnswer(Dataset actual, List expected) { @Before public void setUp() throws IOException { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); path = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile(); @@ -66,16 +70,15 @@ public void setUp() throws IOException { for (int i = 0; i < 10; i++) { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } - JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.read().json(rdd); - df.registerTempTable("jsonTable"); + JavaRDD rdd = jsc.parallelize(jsonObjects); + df = spark.read().json(rdd); + df.createOrReplaceTempView("jsonTable"); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @Test @@ -83,7 +86,7 @@ public void saveAndLoad() { Map options = new HashMap<>(); options.put("path", path.toString()); df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); - Dataset loadedDF = sqlContext.read().format("json").options(options).load(); + Dataset loadedDF = spark.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -96,8 +99,8 @@ public void saveAndLoadWithSchema() { List fields = new ArrayList<>(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - Dataset loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); + Dataset loadedDF = spark.read().format("json").schema(schema).options(options).load(); - checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); + checkAnswer(loadedDF, spark.sql("SELECT b FROM jsonTable").collectAsList()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 4a78dca7fea66..2274912521a56 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -24,33 +24,30 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaUDFSuite implements Serializable { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @SuppressWarnings("unchecked") @@ -60,14 +57,14 @@ public void udf1Test() { // sqlContext.registerFunction( // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF1() { + spark.udf().register("stringLengthTest", new UDF1() { @Override public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); + Row result = spark.sql("SELECT stringLengthTest('test')").head(); Assert.assertEquals(4, result.getInt(0)); } @@ -80,14 +77,14 @@ public void udf2Test() { // (String str1, String str2) -> str1.length() + str2.length, // DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF2() { + spark.udf().register("stringLengthTest", new UDF2() { @Override public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); + Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); } } diff --git a/sql/core/src/test/resources/hive-site.xml b/sql/core/src/test/resources/hive-site.xml new file mode 100644 index 0000000000000..17297b3e22a7e --- /dev/null +++ b/sql/core/src/test/resources/hive-site.xml @@ -0,0 +1,26 @@ + + + + + + + hive.in.test + true + Internal marker for test. + + diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index e53cb1f4e681d..33b9ecf1e2826 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -16,7 +16,7 @@ # # Set everything to be logged to the file core/target/unit-tests.log -log4j.rootLogger=DEBUG, CA, FA +log4j.rootLogger=INFO, CA, FA #Console Appender log4j.appender.CA=org.apache.log4j.ConsoleAppender diff --git a/sql/core/src/test/resources/old-repeated.parquet b/sql/core/src/test/resources/old-repeated.parquet deleted file mode 100644 index 213f1a90291b3..0000000000000 Binary files a/sql/core/src/test/resources/old-repeated.parquet and /dev/null differ diff --git a/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql new file mode 100644 index 0000000000000..f62b10ca0037b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/arithmetic.sql @@ -0,0 +1,34 @@ + +-- unary minus and plus +select -100; +select +230; +select -5.2; +select +6.8e0; +select -key, +key from testdata where key = 2; +select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1; +select -max(key), +max(key) from testdata; +select - (-10); +select + (-key) from testdata where key = 32; +select - (+max(key)) from testdata; +select - - 3; +select - + 20; +select + + 100; +select - - max(key) from testdata; +select + - key from testdata where key = 33; + +-- div +select 5 / 2; +select 5 / 0; +select 5 / null; +select null / 5; +select 5 div 2; +select 5 div 0; +select 5 div null; +select null div 5; + +-- other arithmetics +select 1 + 2; +select 1 - 2; +select 2 * 5; +select 5 % 3; +select pmod(-7, 3); diff --git a/sql/core/src/test/resources/sql-tests/inputs/array.sql b/sql/core/src/test/resources/sql-tests/inputs/array.sql new file mode 100644 index 0000000000000..984321ab795fc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/array.sql @@ -0,0 +1,92 @@ +-- test cases for array functions + +create temporary view data as select * from values + ("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))), + ("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223))) + as data(a, b, c); + +select * from data; + +-- index into array +select a, b[0], b[0] + b[1] from data; + +-- index into array of arrays +select a, c[0][0] + c[0][0 + 1] from data; + + +create temporary view primitive_arrays as select * from values ( + array(true), + array(2Y, 1Y), + array(2S, 1S), + array(2, 1), + array(2L, 1L), + array(9223372036854775809, 9223372036854775808), + array(2.0D, 1.0D), + array(float(2.0), float(1.0)), + array(date '2016-03-14', date '2016-03-13'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000') +) as primitive_arrays( + boolean_array, + tinyint_array, + smallint_array, + int_array, + bigint_array, + decimal_array, + double_array, + float_array, + date_array, + timestamp_array +); + +select * from primitive_arrays; + +-- array_contains on all primitive types: result should alternate between true and false +select + array_contains(boolean_array, true), array_contains(boolean_array, false), + array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y), + array_contains(smallint_array, 2S), array_contains(smallint_array, 0S), + array_contains(int_array, 2), array_contains(int_array, 0), + array_contains(bigint_array, 2L), array_contains(bigint_array, 0L), + array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1), + array_contains(double_array, 2.0D), array_contains(double_array, 0.0D), + array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)), + array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'), + array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000') +from primitive_arrays; + +-- array_contains on nested arrays +select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data; + +-- sort_array +select + sort_array(boolean_array), + sort_array(tinyint_array), + sort_array(smallint_array), + sort_array(int_array), + sort_array(bigint_array), + sort_array(decimal_array), + sort_array(double_array), + sort_array(float_array), + sort_array(date_array), + sort_array(timestamp_array) +from primitive_arrays; + +-- sort_array with an invalid string literal for the argument of sort order. +select sort_array(array('b', 'd'), '1'); + +-- sort_array with an invalid null literal casted as boolean for the argument of sort order. +select sort_array(array('b', 'd'), cast(NULL as boolean)); + +-- size +select + size(boolean_array), + size(tinyint_array), + size(smallint_array), + size(int_array), + size(bigint_array), + size(decimal_array), + size(double_array), + size(float_array), + size(date_array), + size(timestamp_array) +from primitive_arrays; diff --git a/sql/core/src/test/resources/sql-tests/inputs/blacklist.sql b/sql/core/src/test/resources/sql-tests/inputs/blacklist.sql new file mode 100644 index 0000000000000..d69f8147a5264 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/blacklist.sql @@ -0,0 +1,4 @@ +-- This is a query file that has been blacklisted. +-- It includes a query that should crash Spark. +-- If the test case is run, the whole suite would fail. +some random not working query that should crash Spark. diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql new file mode 100644 index 0000000000000..3fd1c37e71795 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -0,0 +1,4 @@ +-- date time functions + +-- [SPARK-16836] current_date and current_timestamp literals +select current_date = current_date(), current_timestamp = current_timestamp(); diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql new file mode 100644 index 0000000000000..2e916d9edc73b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql @@ -0,0 +1,27 @@ +CREATE TABLE t (a STRING, b INT) PARTITIONED BY (c STRING, d STRING); + +ALTER TABLE t ADD PARTITION (c='Us', d=1); + +DESC t; + +-- Ignore these because there exist timestamp results, e.g., `Create Table`. +-- DESC EXTENDED t; +-- DESC FORMATTED t; + +DESC t PARTITION (c='Us', d=1); + +-- Ignore these because there exist timestamp results, e.g., transient_lastDdlTime. +-- DESC EXTENDED t PARTITION (c='Us', d=1); +-- DESC FORMATTED t PARTITION (c='Us', d=1); + +-- NoSuchPartitionException: Partition not found in table +DESC t PARTITION (c='Us', d=2); + +-- AnalysisException: Partition spec is invalid +DESC t PARTITION (c='Us'); + +-- ParseException: PARTITION specification is incomplete +DESC t PARTITION (c='Us', d); + +-- DROP TEST TABLE +DROP TABLE t; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql new file mode 100644 index 0000000000000..36b469c61788c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -0,0 +1,50 @@ +-- group by ordinal positions + +create temporary view data as select * from values + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + as data(a, b); + +-- basic case +select a, sum(b) from data group by 1; + +-- constant case +select 1, 2, sum(b) from data group by 1, 2; + +-- duplicate group by column +select a, 1, sum(b) from data group by a, 1; +select a, 1, sum(b) from data group by 1, 2; + +-- group by a non-aggregate expression's ordinal +select a, b + 2, count(2) from data group by a, 2; + +-- with alias +select a as aa, b + 2 as bb, count(2) from data group by 1, 2; + +-- foldable non-literal: this should be the same as no grouping. +select sum(b) from data group by 1 + 0; + +-- negative cases: ordinal out of range +select a, b from data group by -1; +select a, b from data group by 0; +select a, b from data group by 3; + +-- negative case: position is an aggregate expression +select a, b, sum(b) from data group by 3; +select a, b, sum(b) + 2 from data group by 3; + +-- negative case: nondeterministic expression +select a, rand(0), sum(b) from data group by a, 2; + +-- negative case: star +select * from data group by a, b, 1; + +-- turn of group by ordinal +set spark.sql.groupByOrdinal=false; + +-- can now group by negative literal +select sum(b) from data group by -1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql new file mode 100644 index 0000000000000..6741703d9d82c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -0,0 +1,17 @@ +-- Temporary data. +create temporary view myview as values 128, 256 as v(int_col); + +-- group by should produce all input rows, +select int_col, count(*) from myview group by int_col; + +-- group by should produce a single row. +select 'foo', count(*) from myview group by 1; + +-- group-by should not produce any rows (whole stage code generation). +select 'foo' from myview where int_col == 0 group by 1; + +-- group-by should not produce any rows (hash aggregate). +select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1; + +-- group-by should not produce any rows (sort aggregate). +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql new file mode 100644 index 0000000000000..364c022d959dc --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -0,0 +1,15 @@ +create temporary view hav as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", 5) + as hav(k, v); + +-- having clause +SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2; + +-- having condition contains grouping column +SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2; + +-- SPARK-11032: resolve having correctly +SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql new file mode 100644 index 0000000000000..5107fa4d55537 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -0,0 +1,48 @@ + +-- single row, without table and column alias +select * from values ("one", 1); + +-- single row, without column alias +select * from values ("one", 1) as data; + +-- single row +select * from values ("one", 1) as data(a, b); + +-- single column multiple rows +select * from values 1, 2, 3 as data(a); + +-- three rows +select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b); + +-- null type +select * from values ("one", null), ("two", null) as data(a, b); + +-- int and long coercion +select * from values ("one", 1), ("two", 2L) as data(a, b); + +-- foldable expressions +select * from values ("one", 1 + 0), ("two", 1 + 3L) as data(a, b); + +-- complex types +select * from values ("one", array(0, 1)), ("two", array(2, 3)) as data(a, b); + +-- decimal and double coercion +select * from values ("one", 2.0), ("two", 3.0D) as data(a, b); + +-- error reporting: nondeterministic function rand +select * from values ("one", rand(5)), ("two", 3.0D) as data(a, b); + +-- error reporting: different number of columns +select * from values ("one", 2.0), ("two") as data(a, b); + +-- error reporting: types that are incompatible +select * from values ("one", array(0, 1)), ("two", struct(1, 2)) as data(a, b); + +-- error reporting: number aliases different from number data values +select * from values ("one"), ("two") as data(a, b); + +-- error reporting: unresolved expression +select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b); + +-- error reporting: aggregate expression +select * from values ("one", count(1)), ("two", 2) as data(a, b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/limit.sql b/sql/core/src/test/resources/sql-tests/inputs/limit.sql new file mode 100644 index 0000000000000..2ea35f7f3a5c8 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/limit.sql @@ -0,0 +1,23 @@ + +-- limit on various data types +select * from testdata limit 2; +select * from arraydata limit 2; +select * from mapdata limit 2; + +-- foldable non-literal in limit +select * from testdata limit 2 + 1; + +select * from testdata limit CAST(1 AS int); + +-- limit must be non-negative +select * from testdata limit -1; + +-- limit must be foldable +select * from testdata limit key > 3; + +-- limit must be integer +select * from testdata limit true; +select * from testdata limit 'a'; + +-- limit within a subquery +select * from (select * from range(10) limit 5) where id > 3; diff --git a/sql/core/src/test/resources/sql-tests/inputs/literals.sql b/sql/core/src/test/resources/sql-tests/inputs/literals.sql new file mode 100644 index 0000000000000..a532a598c6bf9 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/literals.sql @@ -0,0 +1,98 @@ +-- Literal parsing + +-- null +select null, Null, nUll; + +-- boolean +select true, tRue, false, fALse; + +-- byte (tinyint) +select 1Y; +select 127Y, -128Y; + +-- out of range byte +select 128Y; + +-- short (smallint) +select 1S; +select 32767S, -32768S; + +-- out of range short +select 32768S; + +-- long (bigint) +select 1L, 2147483648L; +select 9223372036854775807L, -9223372036854775808L; + +-- out of range long +select 9223372036854775808L; + +-- integral parsing + +-- parse int +select 1, -1; + +-- parse int max and min value as int +select 2147483647, -2147483648; + +-- parse long max and min value as long +select 9223372036854775807, -9223372036854775808; + +-- parse as decimals (Long.MaxValue + 1, and Long.MinValue - 1) +select 9223372036854775808, -9223372036854775809; + +-- out of range decimal numbers +select 1234567890123456789012345678901234567890; +select 1234567890123456789012345678901234567890.0; + +-- double +select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1; +select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5; +-- negative double +select .e3; +-- inf and -inf +select 1E309, -1E309; + +-- decimal parsing +select 0.3, -0.8, .5, -.18, 0.1111, .1111; + +-- super large scientific notation numbers should still be valid doubles +select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10; + +-- string +select "Hello Peter!", 'hello lee!'; +-- multi string +select 'hello' 'world', 'hello' " " 'lee'; +-- single quote within double quotes +select "hello 'peter'"; +select 'pattern%', 'no-pattern\%', 'pattern\\%', 'pattern\\\%'; +select '\'', '"', '\n', '\r', '\t', 'Z'; +-- "Hello!" in octals +select '\110\145\154\154\157\041'; +-- "World :)" in unicode +select '\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029'; + +-- date +select dAte '2016-03-12'; +-- invalid date +select date 'mar 11 2016'; + +-- timestamp +select tImEstAmp '2016-03-11 20:54:00.000'; +-- invalid timestamp +select timestamp '2016-33-11 20:54:00.000'; + +-- interval +select interval 13.123456789 seconds, interval -13.123456789 second; +select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond; +-- ns is not supported +select interval 10 nanoseconds; + +-- unsupported data type +select GEO '(10,-6)'; + +-- big decimal parsing +select 90912830918230182310293801923652346786BD, 123.0E-28BD, 123.08BD; + +-- out of range big decimal +select 1.20E-38BD; diff --git a/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql new file mode 100644 index 0000000000000..71a50157b766c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/natural-join.sql @@ -0,0 +1,20 @@ +create temporary view nt1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3) + as nt1(k, v1); + +create temporary view nt2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5) + as nt2(k, v2); + + +SELECT * FROM nt1 natural join nt2 where k = "one"; + +SELECT * FROM nt1 natural left join nt2 order by v1, v2; + +SELECT * FROM nt1 natural right join nt2 order by v1, v2; + +SELECT count(*) FROM nt1 natural full outer join nt2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql b/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql new file mode 100644 index 0000000000000..66549da7971d3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql @@ -0,0 +1,9 @@ + +-- count(null) should be 0 +SELECT COUNT(NULL) FROM VALUES 1, 2, 3; +SELECT COUNT(1 + NULL) FROM VALUES 1, 2, 3; + +-- count(null) on window should be 0 +SELECT COUNT(NULL) OVER () FROM VALUES 1, 2, 3; +SELECT COUNT(1 + NULL) OVER () FROM VALUES 1, 2, 3; + diff --git a/sql/core/src/test/resources/sql-tests/inputs/order-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/order-by-ordinal.sql new file mode 100644 index 0000000000000..8d733e77fa8d3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/order-by-ordinal.sql @@ -0,0 +1,36 @@ +-- order by and sort by ordinal positions + +create temporary view data as select * from values + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + as data(a, b); + +select * from data order by 1 desc; + +-- mix ordinal and column name +select * from data order by 1 desc, b desc; + +-- order by multiple ordinals +select * from data order by 1 desc, 2 desc; + +-- 1 + 0 is considered a constant (not an ordinal) and thus ignored +select * from data order by 1 + 0 desc, b desc; + +-- negative cases: ordinal position out of range +select * from data order by 0; +select * from data order by -1; +select * from data order by 3; + +-- sort by ordinal +select * from data sort by 1 desc; + +-- turn off order by ordinal +set spark.sql.orderByOrdinal=false; + +-- 0 is now a valid literal +select * from data order by 0; +select * from data sort by 0; diff --git a/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql new file mode 100644 index 0000000000000..f50f1ebad970e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/outer-join.sql @@ -0,0 +1,36 @@ +-- SPARK-17099: Incorrect result when HAVING clause is added to group by query +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(-234), (145), (367), (975), (298) +as t1(int_col1); + +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES +(-769, -244), (-800, -409), (940, 86), (-507, 304), (-367, 158) +as t2(int_col0, int_col1); + +SELECT + (SUM(COALESCE(t1.int_col1, t2.int_col0))), + ((COALESCE(t1.int_col1, t2.int_col0)) * 2) +FROM t1 +RIGHT JOIN t2 + ON (t2.int_col0) = (t1.int_col1) +GROUP BY GREATEST(COALESCE(t2.int_col1, 109), COALESCE(t1.int_col1, -449)), + COALESCE(t1.int_col1, t2.int_col0) +HAVING (SUM(COALESCE(t1.int_col1, t2.int_col0))) + > ((COALESCE(t1.int_col1, t2.int_col0)) * 2); + + +-- SPARK-17120: Analyzer incorrectly optimizes plan to empty LocalRelation +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (97) as t1(int_col1); + +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (0) as t2(int_col1); + +SELECT * +FROM ( +SELECT + COALESCE(t2.int_col1, t1.int_col1) AS int_col + FROM t1 + LEFT JOIN t2 ON false +) t where (t.int_col) is not null; + + + diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql new file mode 100644 index 0000000000000..2e6dcd538b7ac --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -0,0 +1,20 @@ +-- unresolved function +select * from dummy(3); + +-- range call with end +select * from range(6 + cos(3)); + +-- range call with start and end +select * from range(5, 10); + +-- range call with step +select * from range(0, 10, 2); + +-- range call with numPartitions +select * from range(0, 10, 1, 200); + +-- range call error +select * from range(1, 1, 1, 1, 1); + +-- range call with null +select * from range(1, null); diff --git a/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out new file mode 100644 index 0000000000000..6abe048af477d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/arithmetic.sql.out @@ -0,0 +1,226 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 28 + + +-- !query 0 +select -100 +-- !query 0 schema +struct<-100:int> +-- !query 0 output +-100 + + +-- !query 1 +select +230 +-- !query 1 schema +struct<230:int> +-- !query 1 output +230 + + +-- !query 2 +select -5.2 +-- !query 2 schema +struct<-5.2:decimal(2,1)> +-- !query 2 output +-5.2 + + +-- !query 3 +select +6.8e0 +-- !query 3 schema +struct<6.8:double> +-- !query 3 output +6.8 + + +-- !query 4 +select -key, +key from testdata where key = 2 +-- !query 4 schema +struct<(- key):int,key:int> +-- !query 4 output +-2 2 + + +-- !query 5 +select -(key + 1), - key + 1, +(key + 5) from testdata where key = 1 +-- !query 5 schema +struct<(- (key + 1)):int,((- key) + 1):int,(key + 5):int> +-- !query 5 output +-2 0 6 + + +-- !query 6 +select -max(key), +max(key) from testdata +-- !query 6 schema +struct<(- max(key)):int,max(key):int> +-- !query 6 output +-100 100 + + +-- !query 7 +select - (-10) +-- !query 7 schema +struct<(- -10):int> +-- !query 7 output +10 + + +-- !query 8 +select + (-key) from testdata where key = 32 +-- !query 8 schema +struct<(- key):int> +-- !query 8 output +-32 + + +-- !query 9 +select - (+max(key)) from testdata +-- !query 9 schema +struct<(- max(key)):int> +-- !query 9 output +-100 + + +-- !query 10 +select - - 3 +-- !query 10 schema +struct<(- -3):int> +-- !query 10 output +3 + + +-- !query 11 +select - + 20 +-- !query 11 schema +struct<(- 20):int> +-- !query 11 output +-20 + + +-- !query 12 +select + + 100 +-- !query 12 schema +struct<100:int> +-- !query 12 output +100 + + +-- !query 13 +select - - max(key) from testdata +-- !query 13 schema +struct<(- (- max(key))):int> +-- !query 13 output +100 + + +-- !query 14 +select + - key from testdata where key = 33 +-- !query 14 schema +struct<(- key):int> +-- !query 14 output +-33 + + +-- !query 15 +select 5 / 2 +-- !query 15 schema +struct<(CAST(5 AS DOUBLE) / CAST(2 AS DOUBLE)):double> +-- !query 15 output +2.5 + + +-- !query 16 +select 5 / 0 +-- !query 16 schema +struct<(CAST(5 AS DOUBLE) / CAST(0 AS DOUBLE)):double> +-- !query 16 output +NULL + + +-- !query 17 +select 5 / null +-- !query 17 schema +struct<(CAST(5 AS DOUBLE) / CAST(NULL AS DOUBLE)):double> +-- !query 17 output +NULL + + +-- !query 18 +select null / 5 +-- !query 18 schema +struct<(CAST(NULL AS DOUBLE) / CAST(5 AS DOUBLE)):double> +-- !query 18 output +NULL + + +-- !query 19 +select 5 div 2 +-- !query 19 schema +struct +-- !query 19 output +2 + + +-- !query 20 +select 5 div 0 +-- !query 20 schema +struct +-- !query 20 output +NULL + + +-- !query 21 +select 5 div null +-- !query 21 schema +struct +-- !query 21 output +NULL + + +-- !query 22 +select null div 5 +-- !query 22 schema +struct +-- !query 22 output +NULL + + +-- !query 23 +select 1 + 2 +-- !query 23 schema +struct<(1 + 2):int> +-- !query 23 output +3 + + +-- !query 24 +select 1 - 2 +-- !query 24 schema +struct<(1 - 2):int> +-- !query 24 output +-1 + + +-- !query 25 +select 2 * 5 +-- !query 25 schema +struct<(2 * 5):int> +-- !query 25 output +10 + + +-- !query 26 +select 5 % 3 +-- !query 26 schema +struct<(5 % 3):int> +-- !query 26 output +2 + + +-- !query 27 +select pmod(-7, 3) +-- !query 27 schema +struct +-- !query 27 output +2 diff --git a/sql/core/src/test/resources/sql-tests/results/array.sql.out b/sql/core/src/test/resources/sql-tests/results/array.sql.out new file mode 100644 index 0000000000000..499a3d5fb72f6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/array.sql.out @@ -0,0 +1,159 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +create temporary view data as select * from values + ("one", array(11, 12, 13), array(array(111, 112, 113), array(121, 122, 123))), + ("two", array(21, 22, 23), array(array(211, 212, 213), array(221, 222, 223))) + as data(a, b, c) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select * from data +-- !query 1 schema +struct,c:array>> +-- !query 1 output +one [11,12,13] [[111,112,113],[121,122,123]] +two [21,22,23] [[211,212,213],[221,222,223]] + + +-- !query 2 +select a, b[0], b[0] + b[1] from data +-- !query 2 schema +struct +-- !query 2 output +one 11 23 +two 21 43 + + +-- !query 3 +select a, c[0][0] + c[0][0 + 1] from data +-- !query 3 schema +struct +-- !query 3 output +one 223 +two 423 + + +-- !query 4 +create temporary view primitive_arrays as select * from values ( + array(true), + array(2Y, 1Y), + array(2S, 1S), + array(2, 1), + array(2L, 1L), + array(9223372036854775809, 9223372036854775808), + array(2.0D, 1.0D), + array(float(2.0), float(1.0)), + array(date '2016-03-14', date '2016-03-13'), + array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000') +) as primitive_arrays( + boolean_array, + tinyint_array, + smallint_array, + int_array, + bigint_array, + decimal_array, + double_array, + float_array, + date_array, + timestamp_array +) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +select * from primitive_arrays +-- !query 5 schema +struct,tinyint_array:array,smallint_array:array,int_array:array,bigint_array:array,decimal_array:array,double_array:array,float_array:array,date_array:array,timestamp_array:array> +-- !query 5 output +[true] [2,1] [2,1] [2,1] [2,1] [9223372036854775809,9223372036854775808] [2.0,1.0] [2.0,1.0] [2016-03-14,2016-03-13] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0] + + +-- !query 6 +select + array_contains(boolean_array, true), array_contains(boolean_array, false), + array_contains(tinyint_array, 2Y), array_contains(tinyint_array, 0Y), + array_contains(smallint_array, 2S), array_contains(smallint_array, 0S), + array_contains(int_array, 2), array_contains(int_array, 0), + array_contains(bigint_array, 2L), array_contains(bigint_array, 0L), + array_contains(decimal_array, 9223372036854775809), array_contains(decimal_array, 1), + array_contains(double_array, 2.0D), array_contains(double_array, 0.0D), + array_contains(float_array, float(2.0)), array_contains(float_array, float(0.0)), + array_contains(date_array, date '2016-03-14'), array_contains(date_array, date '2016-01-01'), + array_contains(timestamp_array, timestamp '2016-11-15 20:54:00.000'), array_contains(timestamp_array, timestamp '2016-01-01 20:54:00.000') +from primitive_arrays +-- !query 6 schema +struct +-- !query 6 output +true false true false true false true false true false true false true false true false true false true false + + +-- !query 7 +select array_contains(b, 11), array_contains(c, array(111, 112, 113)) from data +-- !query 7 schema +struct +-- !query 7 output +false false +true true + + +-- !query 8 +select + sort_array(boolean_array), + sort_array(tinyint_array), + sort_array(smallint_array), + sort_array(int_array), + sort_array(bigint_array), + sort_array(decimal_array), + sort_array(double_array), + sort_array(float_array), + sort_array(date_array), + sort_array(timestamp_array) +from primitive_arrays +-- !query 8 schema +struct,sort_array(tinyint_array, true):array,sort_array(smallint_array, true):array,sort_array(int_array, true):array,sort_array(bigint_array, true):array,sort_array(decimal_array, true):array,sort_array(double_array, true):array,sort_array(float_array, true):array,sort_array(date_array, true):array,sort_array(timestamp_array, true):array> +-- !query 8 output +[true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00.0,2016-11-15 20:54:00.0] + +-- !query 9 +select sort_array(array('b', 'd'), '1') +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve 'sort_array(array('b', 'd'), '1')' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + +-- !query 10 +select sort_array(array('b', 'd'), cast(NULL as boolean)) +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve 'sort_array(array('b', 'd'), CAST(NULL AS BOOLEAN))' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7 + +-- !query 11 +select + size(boolean_array), + size(tinyint_array), + size(smallint_array), + size(int_array), + size(bigint_array), + size(decimal_array), + size(double_array), + size(float_array), + size(date_array), + size(timestamp_array) +from primitive_arrays +-- !query 11 schema +struct +-- !query 11 output +1 2 2 2 2 2 2 2 2 2 diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out new file mode 100644 index 0000000000000..032e4258500fb --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -0,0 +1,10 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 1 + + +-- !query 0 +select current_date = current_date(), current_timestamp = current_timestamp() +-- !query 0 schema +struct<(current_date() = current_date()):boolean,(current_timestamp() = current_timestamp()):boolean> +-- !query 0 output +true true diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out new file mode 100644 index 0000000000000..37bf303f1bfe4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out @@ -0,0 +1,90 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +CREATE TABLE t (a STRING, b INT) PARTITIONED BY (c STRING, d STRING) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +ALTER TABLE t ADD PARTITION (c='Us', d=1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +DESC t +-- !query 2 schema +struct +-- !query 2 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 3 +DESC t PARTITION (c='Us', d=1) +-- !query 3 schema +struct +-- !query 3 output +# Partition Information +# col_name data_type comment +a string +b int +c string +c string +d string +d string + + +-- !query 4 +DESC t PARTITION (c='Us', d=2) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException +Partition not found in table 't' database 'default': +c -> Us +d -> 2; + + +-- !query 5 +DESC t PARTITION (c='Us') +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Partition spec is invalid. The spec (c) must match the partition spec (c, d) defined in table '`default`.`t`'; + + +-- !query 6 +DESC t PARTITION (c='Us', d) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.catalyst.parser.ParseException + +PARTITION specification is incomplete: `d`(line 1, pos 0) + +== SQL == +DESC t PARTITION (c='Us', d) +^^^ + + +-- !query 7 +DROP TABLE t +-- !query 7 schema +struct<> +-- !query 7 output + diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out new file mode 100644 index 0000000000000..2f10b7ebc6d32 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -0,0 +1,168 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +create temporary view data as select * from values + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + as data(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select a, sum(b) from data group by 1 +-- !query 1 schema +struct +-- !query 1 output +1 3 +2 3 +3 3 + + +-- !query 2 +select 1, 2, sum(b) from data group by 1, 2 +-- !query 2 schema +struct<1:int,2:int,sum(b):bigint> +-- !query 2 output +1 2 9 + + +-- !query 3 +select a, 1, sum(b) from data group by a, 1 +-- !query 3 schema +struct +-- !query 3 output +1 1 3 +2 1 3 +3 1 3 + + +-- !query 4 +select a, 1, sum(b) from data group by 1, 2 +-- !query 4 schema +struct +-- !query 4 output +1 1 3 +2 1 3 +3 1 3 + + +-- !query 5 +select a, b + 2, count(2) from data group by a, 2 +-- !query 5 schema +struct +-- !query 5 output +1 3 1 +1 4 1 +2 3 1 +2 4 1 +3 3 1 +3 4 1 + + +-- !query 6 +select a as aa, b + 2 as bb, count(2) from data group by 1, 2 +-- !query 6 schema +struct +-- !query 6 output +1 3 1 +1 4 1 +2 3 1 +2 4 1 +3 3 1 +3 4 1 + + +-- !query 7 +select sum(b) from data group by 1 + 0 +-- !query 7 schema +struct +-- !query 7 output +9 + + +-- !query 8 +select a, b from data group by -1 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +GROUP BY position -1 is not in select list (valid range is [1, 2]); line 1 pos 31 + + +-- !query 9 +select a, b from data group by 0 +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +GROUP BY position 0 is not in select list (valid range is [1, 2]); line 1 pos 31 + + +-- !query 10 +select a, b from data group by 3 +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +GROUP BY position 3 is not in select list (valid range is [1, 2]); line 1 pos 31 + + +-- !query 11 +select a, b, sum(b) from data group by 3 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39 + + +-- !query 12 +select a, b, sum(b) + 2 from data group by 3 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43 + + +-- !query 13 +select a, rand(0), sum(b) from data group by a, 2 +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +nondeterministic expression rand(0) should not appear in grouping expression.; + + +-- !query 14 +select * from data group by a, b, 1 +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Star (*) is not allowed in select list when GROUP BY ordinal position is used; + + +-- !query 15 +set spark.sql.groupByOrdinal=false +-- !query 15 schema +struct +-- !query 15 output +spark.sql.groupByOrdinal + + +-- !query 16 +select sum(b) from data group by -1 +-- !query 16 schema +struct +-- !query 16 output +9 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out new file mode 100644 index 0000000000000..9127bd4dd4c6f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -0,0 +1,51 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +create temporary view myview as values 128, 256 as v(int_col) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select int_col, count(*) from myview group by int_col +-- !query 1 schema +struct +-- !query 1 output +128 1 +256 1 + + +-- !query 2 +select 'foo', count(*) from myview group by 1 +-- !query 2 schema +struct +-- !query 2 output +foo 2 + + +-- !query 3 +select 'foo' from myview where int_col == 0 group by 1 +-- !query 3 schema +struct +-- !query 3 output + + + +-- !query 4 +select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1 +-- !query 4 schema +struct +-- !query 4 output + + + +-- !query 5 +select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1 +-- !query 5 schema +struct> +-- !query 5 output + diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out new file mode 100644 index 0000000000000..e0923832673cb --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -0,0 +1,40 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +create temporary view hav as select * from values + ("one", 1), + ("two", 2), + ("three", 3), + ("one", 5) + as hav(k, v) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2 +-- !query 1 schema +struct +-- !query 1 output +one 6 +three 3 + + +-- !query 2 +SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2 +-- !query 2 schema +struct +-- !query 2 output +1 + + +-- !query 3 +SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0) +-- !query 3 schema +struct +-- !query 3 output +1 diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out new file mode 100644 index 0000000000000..de6f01b8de772 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -0,0 +1,145 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 16 + + +-- !query 0 +select * from values ("one", 1) +-- !query 0 schema +struct +-- !query 0 output +one 1 + + +-- !query 1 +select * from values ("one", 1) as data +-- !query 1 schema +struct +-- !query 1 output +one 1 + + +-- !query 2 +select * from values ("one", 1) as data(a, b) +-- !query 2 schema +struct +-- !query 2 output +one 1 + + +-- !query 3 +select * from values 1, 2, 3 as data(a) +-- !query 3 schema +struct +-- !query 3 output +1 +2 +3 + + +-- !query 4 +select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) +-- !query 4 schema +struct +-- !query 4 output +one 1 +three NULL +two 2 + + +-- !query 5 +select * from values ("one", null), ("two", null) as data(a, b) +-- !query 5 schema +struct +-- !query 5 output +one NULL +two NULL + + +-- !query 6 +select * from values ("one", 1), ("two", 2L) as data(a, b) +-- !query 6 schema +struct +-- !query 6 output +one 1 +two 2 + + +-- !query 7 +select * from values ("one", 1 + 0), ("two", 1 + 3L) as data(a, b) +-- !query 7 schema +struct +-- !query 7 output +one 1 +two 4 + + +-- !query 8 +select * from values ("one", array(0, 1)), ("two", array(2, 3)) as data(a, b) +-- !query 8 schema +struct> +-- !query 8 output +one [0,1] +two [2,3] + + +-- !query 9 +select * from values ("one", 2.0), ("two", 3.0D) as data(a, b) +-- !query 9 schema +struct +-- !query 9 output +one 2.0 +two 3.0 + + +-- !query 10 +select * from values ("one", rand(5)), ("two", 3.0D) as data(a, b) +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot evaluate expression rand(5) in inline table definition; line 1 pos 29 + + +-- !query 11 +select * from values ("one", 2.0), ("two") as data(a, b) +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +expected 2 columns but found 1 columns in row 1; line 1 pos 14 + + +-- !query 12 +select * from values ("one", array(0, 1)), ("two", struct(1, 2)) as data(a, b) +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +incompatible types found in column b for inline table; line 1 pos 14 + + +-- !query 13 +select * from values ("one"), ("two") as data(a, b) +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.AnalysisException +expected 2 columns but found 1 columns in row 0; line 1 pos 14 + + +-- !query 14 +select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b) +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Undefined function: 'random_not_exist_func'. This function is neither a registered temporary function nor a permanent function registered in the database 'default'.; line 1 pos 29 + + +-- !query 15 +select * from values ("one", count(1)), ("two", 2) as data(a, b) +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +cannot evaluate expression count(1) in inline table definition; line 1 pos 29 diff --git a/sql/core/src/test/resources/sql-tests/results/limit.sql.out b/sql/core/src/test/resources/sql-tests/results/limit.sql.out new file mode 100644 index 0000000000000..cb4e4d04810d0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/limit.sql.out @@ -0,0 +1,91 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query 0 +select * from testdata limit 2 +-- !query 0 schema +struct +-- !query 0 output +1 1 +2 2 + + +-- !query 1 +select * from arraydata limit 2 +-- !query 1 schema +struct,nestedarraycol:array>> +-- !query 1 output +[1,2,3] [[1,2,3]] +[2,3,4] [[2,3,4]] + + +-- !query 2 +select * from mapdata limit 2 +-- !query 2 schema +struct> +-- !query 2 output +{1:"a1",2:"b1",3:"c1",4:"d1",5:"e1"} +{1:"a2",2:"b2",3:"c2",4:"d2"} + + +-- !query 3 +select * from testdata limit 2 + 1 +-- !query 3 schema +struct +-- !query 3 output +1 1 +2 2 +3 3 + + +-- !query 4 +select * from testdata limit CAST(1 AS int) +-- !query 4 schema +struct +-- !query 4 output +1 1 + + +-- !query 5 +select * from testdata limit -1 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +The limit expression must be equal to or greater than 0, but got -1; + + +-- !query 6 +select * from testdata limit key > 3 +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +The limit expression must evaluate to a constant value, but got (testdata.`key` > 3); + + +-- !query 7 +select * from testdata limit true +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got boolean; + + +-- !query 8 +select * from testdata limit 'a' +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +The limit expression must be integer type, but got string; + + +-- !query 9 +select * from (select * from range(10) limit 5) where id > 3 +-- !query 9 schema +struct +-- !query 9 output +4 diff --git a/sql/core/src/test/resources/sql-tests/results/literals.sql.out b/sql/core/src/test/resources/sql-tests/results/literals.sql.out new file mode 100644 index 0000000000000..85629f7ba813a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/literals.sql.out @@ -0,0 +1,378 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 40 + + +-- !query 0 +select null, Null, nUll +-- !query 0 schema +struct +-- !query 0 output +NULL NULL NULL + + +-- !query 1 +select true, tRue, false, fALse +-- !query 1 schema +struct +-- !query 1 output +true true false false + + +-- !query 2 +select 1Y +-- !query 2 schema +struct<1:tinyint> +-- !query 2 output +1 + + +-- !query 3 +select 127Y, -128Y +-- !query 3 schema +struct<127:tinyint,-128:tinyint> +-- !query 3 output +127 -128 + + +-- !query 4 +select 128Y +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.catalyst.parser.ParseException + +Numeric literal 128 does not fit in range [-128, 127] for type tinyint(line 1, pos 7) + +== SQL == +select 128Y +-------^^^ + + +-- !query 5 +select 1S +-- !query 5 schema +struct<1:smallint> +-- !query 5 output +1 + + +-- !query 6 +select 32767S, -32768S +-- !query 6 schema +struct<32767:smallint,-32768:smallint> +-- !query 6 output +32767 -32768 + + +-- !query 7 +select 32768S +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.catalyst.parser.ParseException + +Numeric literal 32768 does not fit in range [-32768, 32767] for type smallint(line 1, pos 7) + +== SQL == +select 32768S +-------^^^ + + +-- !query 8 +select 1L, 2147483648L +-- !query 8 schema +struct<1:bigint,2147483648:bigint> +-- !query 8 output +1 2147483648 + + +-- !query 9 +select 9223372036854775807L, -9223372036854775808L +-- !query 9 schema +struct<9223372036854775807:bigint,-9223372036854775808:bigint> +-- !query 9 output +9223372036854775807 -9223372036854775808 + + +-- !query 10 +select 9223372036854775808L +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.parser.ParseException + +Numeric literal 9223372036854775808 does not fit in range [-9223372036854775808, 9223372036854775807] for type bigint(line 1, pos 7) + +== SQL == +select 9223372036854775808L +-------^^^ + + +-- !query 11 +select 1, -1 +-- !query 11 schema +struct<1:int,-1:int> +-- !query 11 output +1 -1 + + +-- !query 12 +select 2147483647, -2147483648 +-- !query 12 schema +struct<2147483647:int,-2147483648:int> +-- !query 12 output +2147483647 -2147483648 + + +-- !query 13 +select 9223372036854775807, -9223372036854775808 +-- !query 13 schema +struct<9223372036854775807:bigint,-9223372036854775808:bigint> +-- !query 13 output +9223372036854775807 -9223372036854775808 + + +-- !query 14 +select 9223372036854775808, -9223372036854775809 +-- !query 14 schema +struct<9223372036854775808:decimal(19,0),-9223372036854775809:decimal(19,0)> +-- !query 14 output +9223372036854775808 -9223372036854775809 + + +-- !query 15 +select 1234567890123456789012345678901234567890 +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.catalyst.parser.ParseException + +DecimalType can only support precision up to 38 +== SQL == +select 1234567890123456789012345678901234567890 + + +-- !query 16 +select 1234567890123456789012345678901234567890.0 +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.catalyst.parser.ParseException + +DecimalType can only support precision up to 38 +== SQL == +select 1234567890123456789012345678901234567890.0 + + +-- !query 17 +select 1D, 1.2D, 1e10, 1.5e5, .10D, 0.10D, .1e5, .9e+2, 0.9e+2, 900e-1, 9.e+1 +-- !query 17 schema +struct<1.0:double,1.2:double,1.0E10:double,150000.0:double,0.1:double,0.1:double,10000.0:double,90.0:double,90.0:double,90.0:double,90.0:double> +-- !query 17 output +1.0 1.2 1.0E10 150000.0 0.1 0.1 10000.0 90.0 90.0 90.0 90.0 + + +-- !query 18 +select -1D, -1.2D, -1e10, -1.5e5, -.10D, -0.10D, -.1e5 +-- !query 18 schema +struct<-1.0:double,-1.2:double,-1.0E10:double,-150000.0:double,-0.1:double,-0.1:double,-10000.0:double> +-- !query 18 output +-1.0 -1.2 -1.0E10 -150000.0 -0.1 -0.1 -10000.0 + + +-- !query 19 +select .e3 +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.catalyst.parser.ParseException + +no viable alternative at input 'select .'(line 1, pos 7) + +== SQL == +select .e3 +-------^^^ + + +-- !query 20 +select 1E309, -1E309 +-- !query 20 schema +struct +-- !query 20 output +Infinity -Infinity + + +-- !query 21 +select 0.3, -0.8, .5, -.18, 0.1111, .1111 +-- !query 21 schema +struct<0.3:decimal(1,1),-0.8:decimal(1,1),0.5:decimal(1,1),-0.18:decimal(2,2),0.1111:decimal(4,4),0.1111:decimal(4,4)> +-- !query 21 output +0.3 -0.8 0.5 -0.18 0.1111 0.1111 + + +-- !query 22 +select 123456789012345678901234567890123456789e10, 123456789012345678901234567890123456789.1e10 +-- !query 22 schema +struct<1.2345678901234568E48:double,1.2345678901234568E48:double> +-- !query 22 output +1.2345678901234568E48 1.2345678901234568E48 + + +-- !query 23 +select "Hello Peter!", 'hello lee!' +-- !query 23 schema +struct +-- !query 23 output +Hello Peter! hello lee! + + +-- !query 24 +select 'hello' 'world', 'hello' " " 'lee' +-- !query 24 schema +struct +-- !query 24 output +helloworld hello lee + + +-- !query 25 +select "hello 'peter'" +-- !query 25 schema +struct +-- !query 25 output +hello 'peter' + + +-- !query 26 +select 'pattern%', 'no-pattern\%', 'pattern\\%', 'pattern\\\%' +-- !query 26 schema +struct +-- !query 26 output +pattern% no-pattern\% pattern\% pattern\\% + + +-- !query 27 +select '\'', '"', '\n', '\r', '\t', 'Z' +-- !query 27 schema +struct<':string,":string, +:string, :string, :string,Z:string> +-- !query 27 output +' " + Z + + +-- !query 28 +select '\110\145\154\154\157\041' +-- !query 28 schema +struct +-- !query 28 output +Hello! + + +-- !query 29 +select '\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029' +-- !query 29 schema +struct +-- !query 29 output +World :) + + +-- !query 30 +select dAte '2016-03-12' +-- !query 30 schema +struct +-- !query 30 output +2016-03-12 + + +-- !query 31 +select date 'mar 11 2016' +-- !query 31 schema +struct<> +-- !query 31 output +java.lang.IllegalArgumentException +null + + +-- !query 32 +select tImEstAmp '2016-03-11 20:54:00.000' +-- !query 32 schema +struct +-- !query 32 output +2016-03-11 20:54:00 + + +-- !query 33 +select timestamp '2016-33-11 20:54:00.000' +-- !query 33 schema +struct<> +-- !query 33 output +java.lang.IllegalArgumentException +Timestamp format must be yyyy-mm-dd hh:mm:ss[.fffffffff] + + +-- !query 34 +select interval 13.123456789 seconds, interval -13.123456789 second +-- !query 34 schema +struct<> +-- !query 34 output +scala.MatchError +(interval 13 seconds 123 milliseconds 456 microseconds,CalendarIntervalType) (of class scala.Tuple2) + + +-- !query 35 +select interval 1 year 2 month 3 week 4 day 5 hour 6 minute 7 seconds 8 millisecond, 9 microsecond +-- !query 35 schema +struct<> +-- !query 35 output +scala.MatchError +(interval 1 years 2 months 3 weeks 4 days 5 hours 6 minutes 7 seconds 8 milliseconds,CalendarIntervalType) (of class scala.Tuple2) + + +-- !query 36 +select interval 10 nanoseconds +-- !query 36 schema +struct<> +-- !query 36 output +org.apache.spark.sql.catalyst.parser.ParseException + +No interval can be constructed(line 1, pos 16) + +== SQL == +select interval 10 nanoseconds +----------------^^^ + + +-- !query 37 +select GEO '(10,-6)' +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.catalyst.parser.ParseException + +Literals of type 'GEO' are currently not supported.(line 1, pos 7) + +== SQL == +select GEO '(10,-6)' +-------^^^ + + +-- !query 38 +select 90912830918230182310293801923652346786BD, 123.0E-28BD, 123.08BD +-- !query 38 schema +struct<90912830918230182310293801923652346786:decimal(38,0),1.230E-26:decimal(29,29),123.08:decimal(5,2)> +-- !query 38 output +90912830918230182310293801923652346786 0.0000000000000000000000000123 123.08 + + +-- !query 39 +select 1.20E-38BD +-- !query 39 schema +struct<> +-- !query 39 output +org.apache.spark.sql.catalyst.parser.ParseException + +DecimalType can only support precision up to 38(line 1, pos 7) + +== SQL == +select 1.20E-38BD +-------^^^ diff --git a/sql/core/src/test/resources/sql-tests/results/natural-join.sql.out b/sql/core/src/test/resources/sql-tests/results/natural-join.sql.out new file mode 100644 index 0000000000000..43f2f9af61d9b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/natural-join.sql.out @@ -0,0 +1,64 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +create temporary view nt1 as select * from values + ("one", 1), + ("two", 2), + ("three", 3) + as nt1(k, v1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view nt2 as select * from values + ("one", 1), + ("two", 22), + ("one", 5) + as nt2(k, v2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * FROM nt1 natural join nt2 where k = "one" +-- !query 2 schema +struct +-- !query 2 output +one 1 1 +one 1 5 + + +-- !query 3 +SELECT * FROM nt1 natural left join nt2 order by v1, v2 +-- !query 3 schema +struct +-- !query 3 output +one 1 1 +one 1 5 +two 2 22 +three 3 NULL + + +-- !query 4 +SELECT * FROM nt1 natural right join nt2 order by v1, v2 +-- !query 4 schema +struct +-- !query 4 output +one 1 1 +one 1 5 +two 2 22 + + +-- !query 5 +SELECT count(*) FROM nt1 natural full outer join nt2 +-- !query 5 schema +struct +-- !query 5 output +4 diff --git a/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out b/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out new file mode 100644 index 0000000000000..ed3a651aa6614 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out @@ -0,0 +1,38 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +SELECT COUNT(NULL) FROM VALUES 1, 2, 3 +-- !query 0 schema +struct +-- !query 0 output +0 + + +-- !query 1 +SELECT COUNT(1 + NULL) FROM VALUES 1, 2, 3 +-- !query 1 schema +struct +-- !query 1 output +0 + + +-- !query 2 +SELECT COUNT(NULL) OVER () FROM VALUES 1, 2, 3 +-- !query 2 schema +struct +-- !query 2 output +0 +0 +0 + + +-- !query 3 +SELECT COUNT(1 + NULL) OVER () FROM VALUES 1, 2, 3 +-- !query 3 schema +struct +-- !query 3 output +0 +0 +0 diff --git a/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out new file mode 100644 index 0000000000000..03a4e72d0fa3e --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/order-by-ordinal.sql.out @@ -0,0 +1,143 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +create temporary view data as select * from values + (1, 1), + (1, 2), + (2, 1), + (2, 2), + (3, 1), + (3, 2) + as data(a, b) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +select * from data order by 1 desc +-- !query 1 schema +struct +-- !query 1 output +3 1 +3 2 +2 1 +2 2 +1 1 +1 2 + + +-- !query 2 +select * from data order by 1 desc, b desc +-- !query 2 schema +struct +-- !query 2 output +3 2 +3 1 +2 2 +2 1 +1 2 +1 1 + + +-- !query 3 +select * from data order by 1 desc, 2 desc +-- !query 3 schema +struct +-- !query 3 output +3 2 +3 1 +2 2 +2 1 +1 2 +1 1 + + +-- !query 4 +select * from data order by 1 + 0 desc, b desc +-- !query 4 schema +struct +-- !query 4 output +1 2 +2 2 +3 2 +1 1 +2 1 +3 1 + + +-- !query 5 +select * from data order by 0 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +ORDER BY position 0 is not in select list (valid range is [1, 2]); line 1 pos 28 + + +-- !query 6 +select * from data order by -1 +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +ORDER BY position -1 is not in select list (valid range is [1, 2]); line 1 pos 28 + + +-- !query 7 +select * from data order by 3 +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +ORDER BY position 3 is not in select list (valid range is [1, 2]); line 1 pos 28 + + +-- !query 8 +select * from data sort by 1 desc +-- !query 8 schema +struct +-- !query 8 output +1 1 +1 2 +2 1 +2 2 +3 1 +3 2 + + +-- !query 9 +set spark.sql.orderByOrdinal=false +-- !query 9 schema +struct +-- !query 9 output +spark.sql.orderByOrdinal + + +-- !query 10 +select * from data order by 0 +-- !query 10 schema +struct +-- !query 10 output +1 1 +1 2 +2 1 +2 2 +3 1 +3 2 + + +-- !query 11 +select * from data sort by 0 +-- !query 11 schema +struct +-- !query 11 output +1 1 +1 2 +2 1 +2 2 +3 1 +3 2 diff --git a/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out new file mode 100644 index 0000000000000..b39fdb0e58720 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/outer-join.sql.out @@ -0,0 +1,72 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(-234), (145), (367), (975), (298) +as t1(int_col1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES +(-769, -244), (-800, -409), (940, 86), (-507, 304), (-367, 158) +as t2(int_col0, int_col1) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT + (SUM(COALESCE(t1.int_col1, t2.int_col0))), + ((COALESCE(t1.int_col1, t2.int_col0)) * 2) +FROM t1 +RIGHT JOIN t2 + ON (t2.int_col0) = (t1.int_col1) +GROUP BY GREATEST(COALESCE(t2.int_col1, 109), COALESCE(t1.int_col1, -449)), + COALESCE(t1.int_col1, t2.int_col0) +HAVING (SUM(COALESCE(t1.int_col1, t2.int_col0))) + > ((COALESCE(t1.int_col1, t2.int_col0)) * 2) +-- !query 2 schema +struct +-- !query 2 output +-367 -734 +-507 -1014 +-769 -1538 +-800 -1600 + + +-- !query 3 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (97) as t1(int_col1) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +CREATE OR REPLACE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (0) as t2(int_col1) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +SELECT * +FROM ( +SELECT + COALESCE(t2.int_col1, t1.int_col1) AS int_col + FROM t1 + LEFT JOIN t2 ON false +) t where (t.int_col) is not null +-- !query 5 schema +struct +-- !query 5 output +97 diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out new file mode 100644 index 0000000000000..d769bcef0aca7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -0,0 +1,87 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +select * from dummy(3) +-- !query 0 schema +struct<> +-- !query 0 output +org.apache.spark.sql.AnalysisException +could not resolve `dummy` to a table-valued function; line 1 pos 14 + + +-- !query 1 +select * from range(6 + cos(3)) +-- !query 1 schema +struct +-- !query 1 output +0 +1 +2 +3 +4 + + +-- !query 2 +select * from range(5, 10) +-- !query 2 schema +struct +-- !query 2 output +5 +6 +7 +8 +9 + + +-- !query 3 +select * from range(0, 10, 2) +-- !query 3 schema +struct +-- !query 3 output +0 +2 +4 +6 +8 + + +-- !query 4 +select * from range(0, 10, 1, 200) +-- !query 4 schema +struct +-- !query 4 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 5 +select * from range(1, 1, 1, 1, 1) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +error: table-valued function range with alternatives: + (end: long) + (start: long, end: long) + (start: long, end: long, step: long) + (start: long, end: long, step: long, numPartitions: integer) +cannot be applied to: (integer, integer, integer, integer, integer); line 1 pos 14 + + +-- !query 6 +select * from range(1, null) +-- !query 6 schema +struct<> +-- !query 6 output +java.lang.IllegalArgumentException +Invalid arguments for resolved function: 1, null diff --git a/sql/core/src/test/resources/bool.csv b/sql/core/src/test/resources/test-data/bool.csv similarity index 100% rename from sql/core/src/test/resources/bool.csv rename to sql/core/src/test/resources/test-data/bool.csv diff --git a/sql/core/src/test/resources/cars-alternative.csv b/sql/core/src/test/resources/test-data/cars-alternative.csv similarity index 100% rename from sql/core/src/test/resources/cars-alternative.csv rename to sql/core/src/test/resources/test-data/cars-alternative.csv diff --git a/sql/core/src/test/resources/test-data/cars-blank-column-name.csv b/sql/core/src/test/resources/test-data/cars-blank-column-name.csv new file mode 100644 index 0000000000000..0b804b1614d60 --- /dev/null +++ b/sql/core/src/test/resources/test-data/cars-blank-column-name.csv @@ -0,0 +1,3 @@ +"",,make,customer,comment +2012,"Tesla","S","bill","blank" +2013,"Tesla","S","c","something" diff --git a/sql/core/src/test/resources/cars-malformed.csv b/sql/core/src/test/resources/test-data/cars-malformed.csv similarity index 100% rename from sql/core/src/test/resources/cars-malformed.csv rename to sql/core/src/test/resources/test-data/cars-malformed.csv diff --git a/sql/core/src/test/resources/cars-null.csv b/sql/core/src/test/resources/test-data/cars-null.csv similarity index 100% rename from sql/core/src/test/resources/cars-null.csv rename to sql/core/src/test/resources/test-data/cars-null.csv diff --git a/sql/core/src/test/resources/cars-unbalanced-quotes.csv b/sql/core/src/test/resources/test-data/cars-unbalanced-quotes.csv similarity index 100% rename from sql/core/src/test/resources/cars-unbalanced-quotes.csv rename to sql/core/src/test/resources/test-data/cars-unbalanced-quotes.csv diff --git a/sql/core/src/test/resources/cars.csv b/sql/core/src/test/resources/test-data/cars.csv similarity index 100% rename from sql/core/src/test/resources/cars.csv rename to sql/core/src/test/resources/test-data/cars.csv diff --git a/sql/core/src/test/resources/cars.tsv b/sql/core/src/test/resources/test-data/cars.tsv similarity index 100% rename from sql/core/src/test/resources/cars.tsv rename to sql/core/src/test/resources/test-data/cars.tsv diff --git a/sql/core/src/test/resources/cars_iso-8859-1.csv b/sql/core/src/test/resources/test-data/cars_iso-8859-1.csv similarity index 100% rename from sql/core/src/test/resources/cars_iso-8859-1.csv rename to sql/core/src/test/resources/test-data/cars_iso-8859-1.csv diff --git a/sql/core/src/test/resources/comments.csv b/sql/core/src/test/resources/test-data/comments.csv similarity index 100% rename from sql/core/src/test/resources/comments.csv rename to sql/core/src/test/resources/test-data/comments.csv diff --git a/sql/core/src/test/resources/dates.csv b/sql/core/src/test/resources/test-data/dates.csv similarity index 100% rename from sql/core/src/test/resources/dates.csv rename to sql/core/src/test/resources/test-data/dates.csv diff --git a/sql/core/src/test/resources/dec-in-fixed-len.parquet b/sql/core/src/test/resources/test-data/dec-in-fixed-len.parquet similarity index 100% rename from sql/core/src/test/resources/dec-in-fixed-len.parquet rename to sql/core/src/test/resources/test-data/dec-in-fixed-len.parquet diff --git a/sql/core/src/test/resources/dec-in-i32.parquet b/sql/core/src/test/resources/test-data/dec-in-i32.parquet similarity index 100% rename from sql/core/src/test/resources/dec-in-i32.parquet rename to sql/core/src/test/resources/test-data/dec-in-i32.parquet diff --git a/sql/core/src/test/resources/dec-in-i64.parquet b/sql/core/src/test/resources/test-data/dec-in-i64.parquet similarity index 100% rename from sql/core/src/test/resources/dec-in-i64.parquet rename to sql/core/src/test/resources/test-data/dec-in-i64.parquet diff --git a/sql/core/src/test/resources/test-data/decimal.csv b/sql/core/src/test/resources/test-data/decimal.csv new file mode 100644 index 0000000000000..870f6aaf1bb4c --- /dev/null +++ b/sql/core/src/test/resources/test-data/decimal.csv @@ -0,0 +1,7 @@ +~ decimal field has integer, integer and decimal values. The last value cannot fit to a long +~ long field has integer, long and integer values. +~ double field has double, double and decimal values. +decimal,long,double +1,1,0.1 +1,9223372036854775807,1.0 +92233720368547758070,1,92233720368547758070 diff --git a/sql/core/src/test/resources/disable_comments.csv b/sql/core/src/test/resources/test-data/disable_comments.csv similarity index 100% rename from sql/core/src/test/resources/disable_comments.csv rename to sql/core/src/test/resources/test-data/disable_comments.csv diff --git a/sql/core/src/test/resources/empty.csv b/sql/core/src/test/resources/test-data/empty.csv similarity index 100% rename from sql/core/src/test/resources/empty.csv rename to sql/core/src/test/resources/test-data/empty.csv diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/test-data/nested-array-struct.parquet similarity index 100% rename from sql/core/src/test/resources/nested-array-struct.parquet rename to sql/core/src/test/resources/test-data/nested-array-struct.parquet diff --git a/sql/core/src/test/resources/numbers.csv b/sql/core/src/test/resources/test-data/numbers.csv similarity index 100% rename from sql/core/src/test/resources/numbers.csv rename to sql/core/src/test/resources/test-data/numbers.csv diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/test-data/old-repeated-int.parquet similarity index 100% rename from sql/core/src/test/resources/old-repeated-int.parquet rename to sql/core/src/test/resources/test-data/old-repeated-int.parquet diff --git a/sql/core/src/test/resources/old-repeated-message.parquet b/sql/core/src/test/resources/test-data/old-repeated-message.parquet similarity index 100% rename from sql/core/src/test/resources/old-repeated-message.parquet rename to sql/core/src/test/resources/test-data/old-repeated-message.parquet diff --git a/sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet b/sql/core/src/test/resources/test-data/parquet-thrift-compat.snappy.parquet similarity index 100% rename from sql/core/src/test/resources/parquet-thrift-compat.snappy.parquet rename to sql/core/src/test/resources/test-data/parquet-thrift-compat.snappy.parquet diff --git a/sql/core/src/test/resources/proto-repeated-string.parquet b/sql/core/src/test/resources/test-data/proto-repeated-string.parquet similarity index 100% rename from sql/core/src/test/resources/proto-repeated-string.parquet rename to sql/core/src/test/resources/test-data/proto-repeated-string.parquet diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/test-data/proto-repeated-struct.parquet similarity index 100% rename from sql/core/src/test/resources/proto-repeated-struct.parquet rename to sql/core/src/test/resources/test-data/proto-repeated-struct.parquet diff --git a/sql/core/src/test/resources/proto-struct-with-array-many.parquet b/sql/core/src/test/resources/test-data/proto-struct-with-array-many.parquet similarity index 100% rename from sql/core/src/test/resources/proto-struct-with-array-many.parquet rename to sql/core/src/test/resources/test-data/proto-struct-with-array-many.parquet diff --git a/sql/core/src/test/resources/proto-struct-with-array.parquet b/sql/core/src/test/resources/test-data/proto-struct-with-array.parquet similarity index 100% rename from sql/core/src/test/resources/proto-struct-with-array.parquet rename to sql/core/src/test/resources/test-data/proto-struct-with-array.parquet diff --git a/sql/core/src/test/resources/simple_sparse.csv b/sql/core/src/test/resources/test-data/simple_sparse.csv similarity index 100% rename from sql/core/src/test/resources/simple_sparse.csv rename to sql/core/src/test/resources/test-data/simple_sparse.csv diff --git a/sql/core/src/test/resources/test-data/text-partitioned/year=2014/data.txt b/sql/core/src/test/resources/test-data/text-partitioned/year=2014/data.txt new file mode 100644 index 0000000000000..e2719428bb28e --- /dev/null +++ b/sql/core/src/test/resources/test-data/text-partitioned/year=2014/data.txt @@ -0,0 +1 @@ +2014-test diff --git a/sql/core/src/test/resources/test-data/text-partitioned/year=2015/data.txt b/sql/core/src/test/resources/test-data/text-partitioned/year=2015/data.txt new file mode 100644 index 0000000000000..b8c03daa8c199 --- /dev/null +++ b/sql/core/src/test/resources/test-data/text-partitioned/year=2015/data.txt @@ -0,0 +1 @@ +2015-test diff --git a/sql/core/src/test/resources/text-suite.txt b/sql/core/src/test/resources/test-data/text-suite.txt similarity index 100% rename from sql/core/src/test/resources/text-suite.txt rename to sql/core/src/test/resources/test-data/text-suite.txt diff --git a/sql/core/src/test/resources/text-suite2.txt b/sql/core/src/test/resources/test-data/text-suite2.txt similarity index 100% rename from sql/core/src/test/resources/text-suite2.txt rename to sql/core/src/test/resources/test-data/text-suite2.txt diff --git a/sql/core/src/test/resources/unescaped-quotes.csv b/sql/core/src/test/resources/test-data/unescaped-quotes.csv similarity index 100% rename from sql/core/src/test/resources/unescaped-quotes.csv rename to sql/core/src/test/resources/test-data/unescaped-quotes.csv diff --git a/sql/core/src/test/resources/tpcds/q1.sql b/sql/core/src/test/resources/tpcds/q1.sql new file mode 100755 index 0000000000000..4d20faad8ef58 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q1.sql @@ -0,0 +1,19 @@ +WITH customer_total_return AS +( SELECT + sr_customer_sk AS ctr_customer_sk, + sr_store_sk AS ctr_store_sk, + sum(sr_return_amt) AS ctr_total_return + FROM store_returns, date_dim + WHERE sr_returned_date_sk = d_date_sk AND d_year = 2000 + GROUP BY sr_customer_sk, sr_store_sk) +SELECT c_customer_id +FROM customer_total_return ctr1, store, customer +WHERE ctr1.ctr_total_return > + (SELECT avg(ctr_total_return) * 1.2 + FROM customer_total_return ctr2 + WHERE ctr1.ctr_store_sk = ctr2.ctr_store_sk) + AND s_store_sk = ctr1.ctr_store_sk + AND s_state = 'TN' + AND ctr1.ctr_customer_sk = c_customer_sk +ORDER BY c_customer_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q10.sql b/sql/core/src/test/resources/tpcds/q10.sql new file mode 100755 index 0000000000000..5500e1aea1552 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q10.sql @@ -0,0 +1,57 @@ +SELECT + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + ca_county IN ('Rush County', 'Toole County', 'Jefferson County', + 'Dona Ana County', 'La Porte County') AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_moy BETWEEN 1 AND 1 + 3) AND + (exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_moy BETWEEN 1 AND 1 + 3) OR + exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_moy BETWEEN 1 AND 1 + 3)) +GROUP BY cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +ORDER BY cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q11.sql b/sql/core/src/test/resources/tpcds/q11.sql new file mode 100755 index 0000000000000..3618fb14fa39c --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q11.sql @@ -0,0 +1,68 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ss_ext_list_price - ss_ext_discount_amt) year_total, + 's' sale_type + FROM customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + GROUP BY c_customer_id + , c_first_name + , c_last_name + , d_year + , c_preferred_cust_flag + , c_birth_country + , c_login + , c_email_address + , d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(ws_ext_list_price - ws_ext_discount_amt) year_total, + 'w' sale_type + FROM customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + GROUP BY + c_customer_id, c_first_name, c_last_name, c_preferred_cust_flag, c_birth_country, + c_login, c_email_address, d_year) +SELECT t_s_secyear.customer_preferred_cust_flag +FROM year_total t_s_firstyear + , year_total t_s_secyear + , year_total t_w_firstyear + , year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.dyear = 2001 + AND t_s_secyear.dyear = 2001 + 1 + AND t_w_firstyear.dyear = 2001 + AND t_w_secyear.dyear = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END +ORDER BY t_s_secyear.customer_preferred_cust_flag +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q12.sql b/sql/core/src/test/resources/tpcds/q12.sql new file mode 100755 index 0000000000000..0382737f5aa2c --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q12.sql @@ -0,0 +1,22 @@ +SELECT + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ws_ext_sales_price) AS itemrevenue, + sum(ws_ext_sales_price) * 100 / sum(sum(ws_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + web_sales, item, date_dim +WHERE + ws_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q13.sql b/sql/core/src/test/resources/tpcds/q13.sql new file mode 100755 index 0000000000000..32dc9e26097ba --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q13.sql @@ -0,0 +1,49 @@ +SELECT + avg(ss_quantity), + avg(ss_ext_sales_price), + avg(ss_ext_wholesale_cost), + sum(ss_ext_wholesale_cost) +FROM store_sales + , store + , customer_demographics + , household_demographics + , customer_address + , date_dim +WHERE s_store_sk = ss_store_sk + AND ss_sold_date_sk = d_date_sk AND d_year = 2001 + AND ((ss_hdemo_sk = hd_demo_sk + AND cd_demo_sk = ss_cdemo_sk + AND cd_marital_status = 'M' + AND cd_education_status = 'Advanced Degree' + AND ss_sales_price BETWEEN 100.00 AND 150.00 + AND hd_dep_count = 3 +) OR + (ss_hdemo_sk = hd_demo_sk + AND cd_demo_sk = ss_cdemo_sk + AND cd_marital_status = 'S' + AND cd_education_status = 'College' + AND ss_sales_price BETWEEN 50.00 AND 100.00 + AND hd_dep_count = 1 + ) OR + (ss_hdemo_sk = hd_demo_sk + AND cd_demo_sk = ss_cdemo_sk + AND cd_marital_status = 'W' + AND cd_education_status = '2 yr Degree' + AND ss_sales_price BETWEEN 150.00 AND 200.00 + AND hd_dep_count = 1 + )) + AND ((ss_addr_sk = ca_address_sk + AND ca_country = 'United States' + AND ca_state IN ('TX', 'OH', 'TX') + AND ss_net_profit BETWEEN 100 AND 200 +) OR + (ss_addr_sk = ca_address_sk + AND ca_country = 'United States' + AND ca_state IN ('OR', 'NM', 'KY') + AND ss_net_profit BETWEEN 150 AND 300 + ) OR + (ss_addr_sk = ca_address_sk + AND ca_country = 'United States' + AND ca_state IN ('VA', 'TX', 'MS') + AND ss_net_profit BETWEEN 50 AND 250 + )) diff --git a/sql/core/src/test/resources/tpcds/q14a.sql b/sql/core/src/test/resources/tpcds/q14a.sql new file mode 100755 index 0000000000000..954ddd41be0e6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q14a.sql @@ -0,0 +1,120 @@ +WITH cross_items AS +(SELECT i_item_sk ss_item_sk + FROM item, + (SELECT + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + FROM store_sales, item iss, date_dim d1 + WHERE ss_item_sk = iss.i_item_sk + AND ss_sold_date_sk = d1.d_date_sk + AND d1.d_year BETWEEN 1999 AND 1999 + 2 + INTERSECT + SELECT + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + FROM catalog_sales, item ics, date_dim d2 + WHERE cs_item_sk = ics.i_item_sk + AND cs_sold_date_sk = d2.d_date_sk + AND d2.d_year BETWEEN 1999 AND 1999 + 2 + INTERSECT + SELECT + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + FROM web_sales, item iws, date_dim d3 + WHERE ws_item_sk = iws.i_item_sk + AND ws_sold_date_sk = d3.d_date_sk + AND d3.d_year BETWEEN 1999 AND 1999 + 2) x + WHERE i_brand_id = brand_id + AND i_class_id = class_id + AND i_category_id = category_id +), + avg_sales AS + (SELECT avg(quantity * list_price) average_sales + FROM ( + SELECT + ss_quantity quantity, + ss_list_price list_price + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk + AND d_year BETWEEN 1999 AND 2001 + UNION ALL + SELECT + cs_quantity quantity, + cs_list_price list_price + FROM catalog_sales, date_dim + WHERE cs_sold_date_sk = d_date_sk + AND d_year BETWEEN 1999 AND 1999 + 2 + UNION ALL + SELECT + ws_quantity quantity, + ws_list_price list_price + FROM web_sales, date_dim + WHERE ws_sold_date_sk = d_date_sk + AND d_year BETWEEN 1999 AND 1999 + 2) x) +SELECT + channel, + i_brand_id, + i_class_id, + i_category_id, + sum(sales), + sum(number_sales) +FROM ( + SELECT + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + FROM store_sales, item, date_dim + WHERE ss_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_year = 1999 + 2 + AND d_moy = 11 + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(ss_quantity * ss_list_price) > (SELECT average_sales + FROM avg_sales) + UNION ALL + SELECT + 'catalog' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(cs_quantity * cs_list_price) sales, + count(*) number_sales + FROM catalog_sales, item, date_dim + WHERE cs_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND cs_item_sk = i_item_sk + AND cs_sold_date_sk = d_date_sk + AND d_year = 1999 + 2 + AND d_moy = 11 + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(cs_quantity * cs_list_price) > (SELECT average_sales FROM avg_sales) + UNION ALL + SELECT + 'web' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ws_quantity * ws_list_price) sales, + count(*) number_sales + FROM web_sales, item, date_dim + WHERE ws_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND ws_item_sk = i_item_sk + AND ws_sold_date_sk = d_date_sk + AND d_year = 1999 + 2 + AND d_moy = 11 + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(ws_quantity * ws_list_price) > (SELECT average_sales + FROM avg_sales) + ) y +GROUP BY ROLLUP (channel, i_brand_id, i_class_id, i_category_id) +ORDER BY channel, i_brand_id, i_class_id, i_category_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q14b.sql b/sql/core/src/test/resources/tpcds/q14b.sql new file mode 100755 index 0000000000000..929a8484bf9b4 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q14b.sql @@ -0,0 +1,95 @@ +WITH cross_items AS +(SELECT i_item_sk ss_item_sk + FROM item, + (SELECT + iss.i_brand_id brand_id, + iss.i_class_id class_id, + iss.i_category_id category_id + FROM store_sales, item iss, date_dim d1 + WHERE ss_item_sk = iss.i_item_sk + AND ss_sold_date_sk = d1.d_date_sk + AND d1.d_year BETWEEN 1999 AND 1999 + 2 + INTERSECT + SELECT + ics.i_brand_id, + ics.i_class_id, + ics.i_category_id + FROM catalog_sales, item ics, date_dim d2 + WHERE cs_item_sk = ics.i_item_sk + AND cs_sold_date_sk = d2.d_date_sk + AND d2.d_year BETWEEN 1999 AND 1999 + 2 + INTERSECT + SELECT + iws.i_brand_id, + iws.i_class_id, + iws.i_category_id + FROM web_sales, item iws, date_dim d3 + WHERE ws_item_sk = iws.i_item_sk + AND ws_sold_date_sk = d3.d_date_sk + AND d3.d_year BETWEEN 1999 AND 1999 + 2) x + WHERE i_brand_id = brand_id + AND i_class_id = class_id + AND i_category_id = category_id +), + avg_sales AS + (SELECT avg(quantity * list_price) average_sales + FROM (SELECT + ss_quantity quantity, + ss_list_price list_price + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk AND d_year BETWEEN 1999 AND 1999 + 2 + UNION ALL + SELECT + cs_quantity quantity, + cs_list_price list_price + FROM catalog_sales, date_dim + WHERE cs_sold_date_sk = d_date_sk AND d_year BETWEEN 1999 AND 1999 + 2 + UNION ALL + SELECT + ws_quantity quantity, + ws_list_price list_price + FROM web_sales, date_dim + WHERE ws_sold_date_sk = d_date_sk AND d_year BETWEEN 1999 AND 1999 + 2) x) +SELECT * +FROM + (SELECT + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + FROM store_sales, item, date_dim + WHERE ss_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_year = 1999 + 1 AND d_moy = 12 AND d_dom = 11) + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(ss_quantity * ss_list_price) > (SELECT average_sales + FROM avg_sales)) this_year, + (SELECT + 'store' channel, + i_brand_id, + i_class_id, + i_category_id, + sum(ss_quantity * ss_list_price) sales, + count(*) number_sales + FROM store_sales, item, date_dim + WHERE ss_item_sk IN (SELECT ss_item_sk + FROM cross_items) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_year = 1999 AND d_moy = 12 AND d_dom = 11) + GROUP BY i_brand_id, i_class_id, i_category_id + HAVING sum(ss_quantity * ss_list_price) > (SELECT average_sales + FROM avg_sales)) last_year +WHERE this_year.i_brand_id = last_year.i_brand_id + AND this_year.i_class_id = last_year.i_class_id + AND this_year.i_category_id = last_year.i_category_id +ORDER BY this_year.channel, this_year.i_brand_id, this_year.i_class_id, this_year.i_category_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q15.sql b/sql/core/src/test/resources/tpcds/q15.sql new file mode 100755 index 0000000000000..b8182e23b0195 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q15.sql @@ -0,0 +1,15 @@ +SELECT + ca_zip, + sum(cs_sales_price) +FROM catalog_sales, customer, customer_address, date_dim +WHERE cs_bill_customer_sk = c_customer_sk + AND c_current_addr_sk = ca_address_sk + AND (substr(ca_zip, 1, 5) IN ('85669', '86197', '88274', '83405', '86475', + '85392', '85460', '80348', '81792') + OR ca_state IN ('CA', 'WA', 'GA') + OR cs_sales_price > 500) + AND cs_sold_date_sk = d_date_sk + AND d_qoy = 2 AND d_year = 2001 +GROUP BY ca_zip +ORDER BY ca_zip +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q16.sql b/sql/core/src/test/resources/tpcds/q16.sql new file mode 100755 index 0000000000000..732ad0d848071 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q16.sql @@ -0,0 +1,23 @@ +SELECT + count(DISTINCT cs_order_number) AS `order count `, + sum(cs_ext_ship_cost) AS `total shipping cost `, + sum(cs_net_profit) AS `total net profit ` +FROM + catalog_sales cs1, date_dim, customer_address, call_center +WHERE + d_date BETWEEN '2002-02-01' AND (CAST('2002-02-01' AS DATE) + INTERVAL 60 days) + AND cs1.cs_ship_date_sk = d_date_sk + AND cs1.cs_ship_addr_sk = ca_address_sk + AND ca_state = 'GA' + AND cs1.cs_call_center_sk = cc_call_center_sk + AND cc_county IN + ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County') + AND EXISTS(SELECT * + FROM catalog_sales cs2 + WHERE cs1.cs_order_number = cs2.cs_order_number + AND cs1.cs_warehouse_sk <> cs2.cs_warehouse_sk) + AND NOT EXISTS(SELECT * + FROM catalog_returns cr1 + WHERE cs1.cs_order_number = cr1.cr_order_number) +ORDER BY count(DISTINCT cs_order_number) +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q17.sql b/sql/core/src/test/resources/tpcds/q17.sql new file mode 100755 index 0000000000000..4d647f7956004 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q17.sql @@ -0,0 +1,33 @@ +SELECT + i_item_id, + i_item_desc, + s_state, + count(ss_quantity) AS store_sales_quantitycount, + avg(ss_quantity) AS store_sales_quantityave, + stddev_samp(ss_quantity) AS store_sales_quantitystdev, + stddev_samp(ss_quantity) / avg(ss_quantity) AS store_sales_quantitycov, + count(sr_return_quantity) as_store_returns_quantitycount, + avg(sr_return_quantity) as_store_returns_quantityave, + stddev_samp(sr_return_quantity) as_store_returns_quantitystdev, + stddev_samp(sr_return_quantity) / avg(sr_return_quantity) AS store_returns_quantitycov, + count(cs_quantity) AS catalog_sales_quantitycount, + avg(cs_quantity) AS catalog_sales_quantityave, + stddev_samp(cs_quantity) / avg(cs_quantity) AS catalog_sales_quantitystdev, + stddev_samp(cs_quantity) / avg(cs_quantity) AS catalog_sales_quantitycov +FROM store_sales, store_returns, catalog_sales, date_dim d1, date_dim d2, date_dim d3, store, item +WHERE d1.d_quarter_name = '2001Q1' + AND d1.d_date_sk = ss_sold_date_sk + AND i_item_sk = ss_item_sk + AND s_store_sk = ss_store_sk + AND ss_customer_sk = sr_customer_sk + AND ss_item_sk = sr_item_sk + AND ss_ticket_number = sr_ticket_number + AND sr_returned_date_sk = d2.d_date_sk + AND d2.d_quarter_name IN ('2001Q1', '2001Q2', '2001Q3') + AND sr_customer_sk = cs_bill_customer_sk + AND sr_item_sk = cs_item_sk + AND cs_sold_date_sk = d3.d_date_sk + AND d3.d_quarter_name IN ('2001Q1', '2001Q2', '2001Q3') +GROUP BY i_item_id, i_item_desc, s_state +ORDER BY i_item_id, i_item_desc, s_state +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q18.sql b/sql/core/src/test/resources/tpcds/q18.sql new file mode 100755 index 0000000000000..4055c80fdef51 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q18.sql @@ -0,0 +1,28 @@ +SELECT + i_item_id, + ca_country, + ca_state, + ca_county, + avg(cast(cs_quantity AS DECIMAL(12, 2))) agg1, + avg(cast(cs_list_price AS DECIMAL(12, 2))) agg2, + avg(cast(cs_coupon_amt AS DECIMAL(12, 2))) agg3, + avg(cast(cs_sales_price AS DECIMAL(12, 2))) agg4, + avg(cast(cs_net_profit AS DECIMAL(12, 2))) agg5, + avg(cast(c_birth_year AS DECIMAL(12, 2))) agg6, + avg(cast(cd1.cd_dep_count AS DECIMAL(12, 2))) agg7 +FROM catalog_sales, customer_demographics cd1, + customer_demographics cd2, customer, customer_address, date_dim, item +WHERE cs_sold_date_sk = d_date_sk AND + cs_item_sk = i_item_sk AND + cs_bill_cdemo_sk = cd1.cd_demo_sk AND + cs_bill_customer_sk = c_customer_sk AND + cd1.cd_gender = 'F' AND + cd1.cd_education_status = 'Unknown' AND + c_current_cdemo_sk = cd2.cd_demo_sk AND + c_current_addr_sk = ca_address_sk AND + c_birth_month IN (1, 6, 8, 9, 12, 2) AND + d_year = 1998 AND + ca_state IN ('MS', 'IN', 'ND', 'OK', 'NM', 'VA', 'MS') +GROUP BY ROLLUP (i_item_id, ca_country, ca_state, ca_county) +ORDER BY ca_country, ca_state, ca_county, i_item_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q19.sql b/sql/core/src/test/resources/tpcds/q19.sql new file mode 100755 index 0000000000000..e38ab7f2683f4 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q19.sql @@ -0,0 +1,19 @@ +SELECT + i_brand_id brand_id, + i_brand brand, + i_manufact_id, + i_manufact, + sum(ss_ext_sales_price) ext_price +FROM date_dim, store_sales, item, customer, customer_address, store +WHERE d_date_sk = ss_sold_date_sk + AND ss_item_sk = i_item_sk + AND i_manager_id = 8 + AND d_moy = 11 + AND d_year = 1998 + AND ss_customer_sk = c_customer_sk + AND c_current_addr_sk = ca_address_sk + AND substr(ca_zip, 1, 5) <> substr(s_zip, 1, 5) + AND ss_store_sk = s_store_sk +GROUP BY i_brand, i_brand_id, i_manufact_id, i_manufact +ORDER BY ext_price DESC, brand, brand_id, i_manufact_id, i_manufact +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q2.sql b/sql/core/src/test/resources/tpcds/q2.sql new file mode 100755 index 0000000000000..52c0e90c46740 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q2.sql @@ -0,0 +1,81 @@ +WITH wscs AS +( SELECT + sold_date_sk, + sales_price + FROM (SELECT + ws_sold_date_sk sold_date_sk, + ws_ext_sales_price sales_price + FROM web_sales) x + UNION ALL + (SELECT + cs_sold_date_sk sold_date_sk, + cs_ext_sales_price sales_price + FROM catalog_sales)), + wswscs AS + ( SELECT + d_week_seq, + sum(CASE WHEN (d_day_name = 'Sunday') + THEN sales_price + ELSE NULL END) + sun_sales, + sum(CASE WHEN (d_day_name = 'Monday') + THEN sales_price + ELSE NULL END) + mon_sales, + sum(CASE WHEN (d_day_name = 'Tuesday') + THEN sales_price + ELSE NULL END) + tue_sales, + sum(CASE WHEN (d_day_name = 'Wednesday') + THEN sales_price + ELSE NULL END) + wed_sales, + sum(CASE WHEN (d_day_name = 'Thursday') + THEN sales_price + ELSE NULL END) + thu_sales, + sum(CASE WHEN (d_day_name = 'Friday') + THEN sales_price + ELSE NULL END) + fri_sales, + sum(CASE WHEN (d_day_name = 'Saturday') + THEN sales_price + ELSE NULL END) + sat_sales + FROM wscs, date_dim + WHERE d_date_sk = sold_date_sk + GROUP BY d_week_seq) +SELECT + d_week_seq1, + round(sun_sales1 / sun_sales2, 2), + round(mon_sales1 / mon_sales2, 2), + round(tue_sales1 / tue_sales2, 2), + round(wed_sales1 / wed_sales2, 2), + round(thu_sales1 / thu_sales2, 2), + round(fri_sales1 / fri_sales2, 2), + round(sat_sales1 / sat_sales2, 2) +FROM + (SELECT + wswscs.d_week_seq d_week_seq1, + sun_sales sun_sales1, + mon_sales mon_sales1, + tue_sales tue_sales1, + wed_sales wed_sales1, + thu_sales thu_sales1, + fri_sales fri_sales1, + sat_sales sat_sales1 + FROM wswscs, date_dim + WHERE date_dim.d_week_seq = wswscs.d_week_seq AND d_year = 2001) y, + (SELECT + wswscs.d_week_seq d_week_seq2, + sun_sales sun_sales2, + mon_sales mon_sales2, + tue_sales tue_sales2, + wed_sales wed_sales2, + thu_sales thu_sales2, + fri_sales fri_sales2, + sat_sales sat_sales2 + FROM wswscs, date_dim + WHERE date_dim.d_week_seq = wswscs.d_week_seq AND d_year = 2001 + 1) z +WHERE d_week_seq1 = d_week_seq2 - 53 +ORDER BY d_week_seq1 diff --git a/sql/core/src/test/resources/tpcds/q20.sql b/sql/core/src/test/resources/tpcds/q20.sql new file mode 100755 index 0000000000000..7ac6c7a75d8ea --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q20.sql @@ -0,0 +1,18 @@ +SELECT + i_item_desc, + i_category, + i_class, + i_current_price, + sum(cs_ext_sales_price) AS itemrevenue, + sum(cs_ext_sales_price) * 100 / sum(sum(cs_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM catalog_sales, item, date_dim +WHERE cs_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) +AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY i_category, i_class, i_item_id, i_item_desc, revenueratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q21.sql b/sql/core/src/test/resources/tpcds/q21.sql new file mode 100755 index 0000000000000..550881143f809 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q21.sql @@ -0,0 +1,25 @@ +SELECT * +FROM ( + SELECT + w_warehouse_name, + i_item_id, + sum(CASE WHEN (cast(d_date AS DATE) < cast('2000-03-11' AS DATE)) + THEN inv_quantity_on_hand + ELSE 0 END) AS inv_before, + sum(CASE WHEN (cast(d_date AS DATE) >= cast('2000-03-11' AS DATE)) + THEN inv_quantity_on_hand + ELSE 0 END) AS inv_after + FROM inventory, warehouse, item, date_dim + WHERE i_current_price BETWEEN 0.99 AND 1.49 + AND i_item_sk = inv_item_sk + AND inv_warehouse_sk = w_warehouse_sk + AND inv_date_sk = d_date_sk + AND d_date BETWEEN (cast('2000-03-11' AS DATE) - INTERVAL 30 days) + AND (cast('2000-03-11' AS DATE) + INTERVAL 30 days) + GROUP BY w_warehouse_name, i_item_id) x +WHERE (CASE WHEN inv_before > 0 + THEN inv_after / inv_before + ELSE NULL + END) BETWEEN 2.0 / 3.0 AND 3.0 / 2.0 +ORDER BY w_warehouse_name, i_item_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q22.sql b/sql/core/src/test/resources/tpcds/q22.sql new file mode 100755 index 0000000000000..add3b41f7c76c --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q22.sql @@ -0,0 +1,14 @@ +SELECT + i_product_name, + i_brand, + i_class, + i_category, + avg(inv_quantity_on_hand) qoh +FROM inventory, date_dim, item, warehouse +WHERE inv_date_sk = d_date_sk + AND inv_item_sk = i_item_sk + AND inv_warehouse_sk = w_warehouse_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 +GROUP BY ROLLUP (i_product_name, i_brand, i_class, i_category) +ORDER BY qoh, i_product_name, i_brand, i_class, i_category +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q23a.sql b/sql/core/src/test/resources/tpcds/q23a.sql new file mode 100755 index 0000000000000..37791f643375c --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q23a.sql @@ -0,0 +1,53 @@ +WITH frequent_ss_items AS +(SELECT + substr(i_item_desc, 1, 30) itemdesc, + i_item_sk item_sk, + d_date solddate, + count(*) cnt + FROM store_sales, date_dim, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_item_sk = i_item_sk + AND d_year IN (2000, 2000 + 1, 2000 + 2, 2000 + 3) + GROUP BY substr(i_item_desc, 1, 30), i_item_sk, d_date + HAVING count(*) > 4), + max_store_sales AS + (SELECT max(csales) tpcds_cmax + FROM (SELECT + c_customer_sk, + sum(ss_quantity * ss_sales_price) csales + FROM store_sales, customer, date_dim + WHERE ss_customer_sk = c_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2000, 2000 + 1, 2000 + 2, 2000 + 3) + GROUP BY c_customer_sk) x), + best_ss_customer AS + (SELECT + c_customer_sk, + sum(ss_quantity * ss_sales_price) ssales + FROM store_sales, customer + WHERE ss_customer_sk = c_customer_sk + GROUP BY c_customer_sk + HAVING sum(ss_quantity * ss_sales_price) > (50 / 100.0) * + (SELECT * + FROM max_store_sales)) +SELECT sum(sales) +FROM ((SELECT cs_quantity * cs_list_price sales +FROM catalog_sales, date_dim +WHERE d_year = 2000 + AND d_moy = 2 + AND cs_sold_date_sk = d_date_sk + AND cs_item_sk IN (SELECT item_sk +FROM frequent_ss_items) + AND cs_bill_customer_sk IN (SELECT c_customer_sk +FROM best_ss_customer)) + UNION ALL + (SELECT ws_quantity * ws_list_price sales + FROM web_sales, date_dim + WHERE d_year = 2000 + AND d_moy = 2 + AND ws_sold_date_sk = d_date_sk + AND ws_item_sk IN (SELECT item_sk + FROM frequent_ss_items) + AND ws_bill_customer_sk IN (SELECT c_customer_sk + FROM best_ss_customer))) y +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q23b.sql b/sql/core/src/test/resources/tpcds/q23b.sql new file mode 100755 index 0000000000000..01150197af2ba --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q23b.sql @@ -0,0 +1,68 @@ +WITH frequent_ss_items AS +(SELECT + substr(i_item_desc, 1, 30) itemdesc, + i_item_sk item_sk, + d_date solddate, + count(*) cnt + FROM store_sales, date_dim, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_item_sk = i_item_sk + AND d_year IN (2000, 2000 + 1, 2000 + 2, 2000 + 3) + GROUP BY substr(i_item_desc, 1, 30), i_item_sk, d_date + HAVING count(*) > 4), + max_store_sales AS + (SELECT max(csales) tpcds_cmax + FROM (SELECT + c_customer_sk, + sum(ss_quantity * ss_sales_price) csales + FROM store_sales, customer, date_dim + WHERE ss_customer_sk = c_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2000, 2000 + 1, 2000 + 2, 2000 + 3) + GROUP BY c_customer_sk) x), + best_ss_customer AS + (SELECT + c_customer_sk, + sum(ss_quantity * ss_sales_price) ssales + FROM store_sales + , customer + WHERE ss_customer_sk = c_customer_sk + GROUP BY c_customer_sk + HAVING sum(ss_quantity * ss_sales_price) > (50 / 100.0) * + (SELECT * + FROM max_store_sales)) +SELECT + c_last_name, + c_first_name, + sales +FROM ((SELECT + c_last_name, + c_first_name, + sum(cs_quantity * cs_list_price) sales +FROM catalog_sales, customer, date_dim +WHERE d_year = 2000 + AND d_moy = 2 + AND cs_sold_date_sk = d_date_sk + AND cs_item_sk IN (SELECT item_sk +FROM frequent_ss_items) + AND cs_bill_customer_sk IN (SELECT c_customer_sk +FROM best_ss_customer) + AND cs_bill_customer_sk = c_customer_sk +GROUP BY c_last_name, c_first_name) + UNION ALL + (SELECT + c_last_name, + c_first_name, + sum(ws_quantity * ws_list_price) sales + FROM web_sales, customer, date_dim + WHERE d_year = 2000 + AND d_moy = 2 + AND ws_sold_date_sk = d_date_sk + AND ws_item_sk IN (SELECT item_sk + FROM frequent_ss_items) + AND ws_bill_customer_sk IN (SELECT c_customer_sk + FROM best_ss_customer) + AND ws_bill_customer_sk = c_customer_sk + GROUP BY c_last_name, c_first_name)) y +ORDER BY c_last_name, c_first_name, sales +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q24a.sql b/sql/core/src/test/resources/tpcds/q24a.sql new file mode 100755 index 0000000000000..bcc189486634d --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q24a.sql @@ -0,0 +1,34 @@ +WITH ssales AS +(SELECT + c_last_name, + c_first_name, + s_store_name, + ca_state, + s_state, + i_color, + i_current_price, + i_manager_id, + i_units, + i_size, + sum(ss_net_paid) netpaid + FROM store_sales, store_returns, store, item, customer, customer_address + WHERE ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_customer_sk = c_customer_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND c_birth_country = upper(ca_country) + AND s_zip = ca_zip + AND s_market_id = 8 + GROUP BY c_last_name, c_first_name, s_store_name, ca_state, s_state, i_color, + i_current_price, i_manager_id, i_units, i_size) +SELECT + c_last_name, + c_first_name, + s_store_name, + sum(netpaid) paid +FROM ssales +WHERE i_color = 'pale' +GROUP BY c_last_name, c_first_name, s_store_name +HAVING sum(netpaid) > (SELECT 0.05 * avg(netpaid) +FROM ssales) diff --git a/sql/core/src/test/resources/tpcds/q24b.sql b/sql/core/src/test/resources/tpcds/q24b.sql new file mode 100755 index 0000000000000..830eb670bcdd2 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q24b.sql @@ -0,0 +1,34 @@ +WITH ssales AS +(SELECT + c_last_name, + c_first_name, + s_store_name, + ca_state, + s_state, + i_color, + i_current_price, + i_manager_id, + i_units, + i_size, + sum(ss_net_paid) netpaid + FROM store_sales, store_returns, store, item, customer, customer_address + WHERE ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_customer_sk = c_customer_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND c_birth_country = upper(ca_country) + AND s_zip = ca_zip + AND s_market_id = 8 + GROUP BY c_last_name, c_first_name, s_store_name, ca_state, s_state, + i_color, i_current_price, i_manager_id, i_units, i_size) +SELECT + c_last_name, + c_first_name, + s_store_name, + sum(netpaid) paid +FROM ssales +WHERE i_color = 'chiffon' +GROUP BY c_last_name, c_first_name, s_store_name +HAVING sum(netpaid) > (SELECT 0.05 * avg(netpaid) +FROM ssales) diff --git a/sql/core/src/test/resources/tpcds/q25.sql b/sql/core/src/test/resources/tpcds/q25.sql new file mode 100755 index 0000000000000..a4d78a3c56adc --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q25.sql @@ -0,0 +1,33 @@ +SELECT + i_item_id, + i_item_desc, + s_store_id, + s_store_name, + sum(ss_net_profit) AS store_sales_profit, + sum(sr_net_loss) AS store_returns_loss, + sum(cs_net_profit) AS catalog_sales_profit +FROM + store_sales, store_returns, catalog_sales, date_dim d1, date_dim d2, date_dim d3, + store, item +WHERE + d1.d_moy = 4 + AND d1.d_year = 2001 + AND d1.d_date_sk = ss_sold_date_sk + AND i_item_sk = ss_item_sk + AND s_store_sk = ss_store_sk + AND ss_customer_sk = sr_customer_sk + AND ss_item_sk = sr_item_sk + AND ss_ticket_number = sr_ticket_number + AND sr_returned_date_sk = d2.d_date_sk + AND d2.d_moy BETWEEN 4 AND 10 + AND d2.d_year = 2001 + AND sr_customer_sk = cs_bill_customer_sk + AND sr_item_sk = cs_item_sk + AND cs_sold_date_sk = d3.d_date_sk + AND d3.d_moy BETWEEN 4 AND 10 + AND d3.d_year = 2001 +GROUP BY + i_item_id, i_item_desc, s_store_id, s_store_name +ORDER BY + i_item_id, i_item_desc, s_store_id, s_store_name +LIMIT 100 \ No newline at end of file diff --git a/sql/core/src/test/resources/tpcds/q26.sql b/sql/core/src/test/resources/tpcds/q26.sql new file mode 100755 index 0000000000000..6d395a1d791dd --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q26.sql @@ -0,0 +1,19 @@ +SELECT + i_item_id, + avg(cs_quantity) agg1, + avg(cs_list_price) agg2, + avg(cs_coupon_amt) agg3, + avg(cs_sales_price) agg4 +FROM catalog_sales, customer_demographics, date_dim, item, promotion +WHERE cs_sold_date_sk = d_date_sk AND + cs_item_sk = i_item_sk AND + cs_bill_cdemo_sk = cd_demo_sk AND + cs_promo_sk = p_promo_sk AND + cd_gender = 'M' AND + cd_marital_status = 'S' AND + cd_education_status = 'College' AND + (p_channel_email = 'N' OR p_channel_event = 'N') AND + d_year = 2000 +GROUP BY i_item_id +ORDER BY i_item_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q27.sql b/sql/core/src/test/resources/tpcds/q27.sql new file mode 100755 index 0000000000000..b0e2fd95fd159 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q27.sql @@ -0,0 +1,21 @@ +SELECT + i_item_id, + s_state, + grouping(s_state) g_state, + avg(ss_quantity) agg1, + avg(ss_list_price) agg2, + avg(ss_coupon_amt) agg3, + avg(ss_sales_price) agg4 +FROM store_sales, customer_demographics, date_dim, store, item +WHERE ss_sold_date_sk = d_date_sk AND + ss_item_sk = i_item_sk AND + ss_store_sk = s_store_sk AND + ss_cdemo_sk = cd_demo_sk AND + cd_gender = 'M' AND + cd_marital_status = 'S' AND + cd_education_status = 'College' AND + d_year = 2002 AND + s_state IN ('TN', 'TN', 'TN', 'TN', 'TN', 'TN') +GROUP BY ROLLUP (i_item_id, s_state) +ORDER BY i_item_id, s_state +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q28.sql b/sql/core/src/test/resources/tpcds/q28.sql new file mode 100755 index 0000000000000..f34c2bb0e34e1 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q28.sql @@ -0,0 +1,56 @@ +SELECT * +FROM (SELECT + avg(ss_list_price) B1_LP, + count(ss_list_price) B1_CNT, + count(DISTINCT ss_list_price) B1_CNTD +FROM store_sales +WHERE ss_quantity BETWEEN 0 AND 5 + AND (ss_list_price BETWEEN 8 AND 8 + 10 + OR ss_coupon_amt BETWEEN 459 AND 459 + 1000 + OR ss_wholesale_cost BETWEEN 57 AND 57 + 20)) B1, + (SELECT + avg(ss_list_price) B2_LP, + count(ss_list_price) B2_CNT, + count(DISTINCT ss_list_price) B2_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 6 AND 10 + AND (ss_list_price BETWEEN 90 AND 90 + 10 + OR ss_coupon_amt BETWEEN 2323 AND 2323 + 1000 + OR ss_wholesale_cost BETWEEN 31 AND 31 + 20)) B2, + (SELECT + avg(ss_list_price) B3_LP, + count(ss_list_price) B3_CNT, + count(DISTINCT ss_list_price) B3_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 11 AND 15 + AND (ss_list_price BETWEEN 142 AND 142 + 10 + OR ss_coupon_amt BETWEEN 12214 AND 12214 + 1000 + OR ss_wholesale_cost BETWEEN 79 AND 79 + 20)) B3, + (SELECT + avg(ss_list_price) B4_LP, + count(ss_list_price) B4_CNT, + count(DISTINCT ss_list_price) B4_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 16 AND 20 + AND (ss_list_price BETWEEN 135 AND 135 + 10 + OR ss_coupon_amt BETWEEN 6071 AND 6071 + 1000 + OR ss_wholesale_cost BETWEEN 38 AND 38 + 20)) B4, + (SELECT + avg(ss_list_price) B5_LP, + count(ss_list_price) B5_CNT, + count(DISTINCT ss_list_price) B5_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 21 AND 25 + AND (ss_list_price BETWEEN 122 AND 122 + 10 + OR ss_coupon_amt BETWEEN 836 AND 836 + 1000 + OR ss_wholesale_cost BETWEEN 17 AND 17 + 20)) B5, + (SELECT + avg(ss_list_price) B6_LP, + count(ss_list_price) B6_CNT, + count(DISTINCT ss_list_price) B6_CNTD + FROM store_sales + WHERE ss_quantity BETWEEN 26 AND 30 + AND (ss_list_price BETWEEN 154 AND 154 + 10 + OR ss_coupon_amt BETWEEN 7326 AND 7326 + 1000 + OR ss_wholesale_cost BETWEEN 7 AND 7 + 20)) B6 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q29.sql b/sql/core/src/test/resources/tpcds/q29.sql new file mode 100755 index 0000000000000..3f1fd553f6da8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q29.sql @@ -0,0 +1,32 @@ +SELECT + i_item_id, + i_item_desc, + s_store_id, + s_store_name, + sum(ss_quantity) AS store_sales_quantity, + sum(sr_return_quantity) AS store_returns_quantity, + sum(cs_quantity) AS catalog_sales_quantity +FROM + store_sales, store_returns, catalog_sales, date_dim d1, date_dim d2, + date_dim d3, store, item +WHERE + d1.d_moy = 9 + AND d1.d_year = 1999 + AND d1.d_date_sk = ss_sold_date_sk + AND i_item_sk = ss_item_sk + AND s_store_sk = ss_store_sk + AND ss_customer_sk = sr_customer_sk + AND ss_item_sk = sr_item_sk + AND ss_ticket_number = sr_ticket_number + AND sr_returned_date_sk = d2.d_date_sk + AND d2.d_moy BETWEEN 9 AND 9 + 3 + AND d2.d_year = 1999 + AND sr_customer_sk = cs_bill_customer_sk + AND sr_item_sk = cs_item_sk + AND cs_sold_date_sk = d3.d_date_sk + AND d3.d_year IN (1999, 1999 + 1, 1999 + 2) +GROUP BY + i_item_id, i_item_desc, s_store_id, s_store_name +ORDER BY + i_item_id, i_item_desc, s_store_id, s_store_name +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q3.sql b/sql/core/src/test/resources/tpcds/q3.sql new file mode 100755 index 0000000000000..181509df9deb7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q3.sql @@ -0,0 +1,13 @@ +SELECT + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + SUM(ss_ext_sales_price) sum_agg +FROM date_dim dt, store_sales, item +WHERE dt.d_date_sk = store_sales.ss_sold_date_sk + AND store_sales.ss_item_sk = item.i_item_sk + AND item.i_manufact_id = 128 + AND dt.d_moy = 11 +GROUP BY dt.d_year, item.i_brand, item.i_brand_id +ORDER BY dt.d_year, sum_agg DESC, brand_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q30.sql b/sql/core/src/test/resources/tpcds/q30.sql new file mode 100755 index 0000000000000..986bef566d2c8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q30.sql @@ -0,0 +1,35 @@ +WITH customer_total_return AS +(SELECT + wr_returning_customer_sk AS ctr_customer_sk, + ca_state AS ctr_state, + sum(wr_return_amt) AS ctr_total_return + FROM web_returns, date_dim, customer_address + WHERE wr_returned_date_sk = d_date_sk + AND d_year = 2002 + AND wr_returning_addr_sk = ca_address_sk + GROUP BY wr_returning_customer_sk, ca_state) +SELECT + c_customer_id, + c_salutation, + c_first_name, + c_last_name, + c_preferred_cust_flag, + c_birth_day, + c_birth_month, + c_birth_year, + c_birth_country, + c_login, + c_email_address, + c_last_review_date, + ctr_total_return +FROM customer_total_return ctr1, customer_address, customer +WHERE ctr1.ctr_total_return > (SELECT avg(ctr_total_return) * 1.2 +FROM customer_total_return ctr2 +WHERE ctr1.ctr_state = ctr2.ctr_state) + AND ca_address_sk = c_current_addr_sk + AND ca_state = 'GA' + AND ctr1.ctr_customer_sk = c_customer_sk +ORDER BY c_customer_id, c_salutation, c_first_name, c_last_name, c_preferred_cust_flag + , c_birth_day, c_birth_month, c_birth_year, c_birth_country, c_login, c_email_address + , c_last_review_date, ctr_total_return +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q31.sql b/sql/core/src/test/resources/tpcds/q31.sql new file mode 100755 index 0000000000000..3e543d5436407 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q31.sql @@ -0,0 +1,60 @@ +WITH ss AS +(SELECT + ca_county, + d_qoy, + d_year, + sum(ss_ext_sales_price) AS store_sales + FROM store_sales, date_dim, customer_address + WHERE ss_sold_date_sk = d_date_sk + AND ss_addr_sk = ca_address_sk + GROUP BY ca_county, d_qoy, d_year), + ws AS + (SELECT + ca_county, + d_qoy, + d_year, + sum(ws_ext_sales_price) AS web_sales + FROM web_sales, date_dim, customer_address + WHERE ws_sold_date_sk = d_date_sk + AND ws_bill_addr_sk = ca_address_sk + GROUP BY ca_county, d_qoy, d_year) +SELECT + ss1.ca_county, + ss1.d_year, + ws2.web_sales / ws1.web_sales web_q1_q2_increase, + ss2.store_sales / ss1.store_sales store_q1_q2_increase, + ws3.web_sales / ws2.web_sales web_q2_q3_increase, + ss3.store_sales / ss2.store_sales store_q2_q3_increase +FROM + ss ss1, ss ss2, ss ss3, ws ws1, ws ws2, ws ws3 +WHERE + ss1.d_qoy = 1 + AND ss1.d_year = 2000 + AND ss1.ca_county = ss2.ca_county + AND ss2.d_qoy = 2 + AND ss2.d_year = 2000 + AND ss2.ca_county = ss3.ca_county + AND ss3.d_qoy = 3 + AND ss3.d_year = 2000 + AND ss1.ca_county = ws1.ca_county + AND ws1.d_qoy = 1 + AND ws1.d_year = 2000 + AND ws1.ca_county = ws2.ca_county + AND ws2.d_qoy = 2 + AND ws2.d_year = 2000 + AND ws1.ca_county = ws3.ca_county + AND ws3.d_qoy = 3 + AND ws3.d_year = 2000 + AND CASE WHEN ws1.web_sales > 0 + THEN ws2.web_sales / ws1.web_sales + ELSE NULL END + > CASE WHEN ss1.store_sales > 0 + THEN ss2.store_sales / ss1.store_sales + ELSE NULL END + AND CASE WHEN ws2.web_sales > 0 + THEN ws3.web_sales / ws2.web_sales + ELSE NULL END + > CASE WHEN ss2.store_sales > 0 + THEN ss3.store_sales / ss2.store_sales + ELSE NULL END +ORDER BY ss1.ca_county diff --git a/sql/core/src/test/resources/tpcds/q32.sql b/sql/core/src/test/resources/tpcds/q32.sql new file mode 100755 index 0000000000000..1a907961e74bb --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q32.sql @@ -0,0 +1,15 @@ +SELECT 1 AS `excess discount amount ` +FROM + catalog_sales, item, date_dim +WHERE + i_manufact_id = 977 + AND i_item_sk = cs_item_sk + AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + interval 90 days) + AND d_date_sk = cs_sold_date_sk + AND cs_ext_discount_amt > ( + SELECT 1.3 * avg(cs_ext_discount_amt) + FROM catalog_sales, date_dim + WHERE cs_item_sk = i_item_sk + AND d_date BETWEEN '2000-01-27]' AND (cast('2000-01-27' AS DATE) + interval 90 days) + AND d_date_sk = cs_sold_date_sk) +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q33.sql b/sql/core/src/test/resources/tpcds/q33.sql new file mode 100755 index 0000000000000..d24856aa5c1eb --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q33.sql @@ -0,0 +1,65 @@ +WITH ss AS ( + SELECT + i_manufact_id, + sum(ss_ext_sales_price) total_sales + FROM + store_sales, date_dim, customer_address, item + WHERE + i_manufact_id IN (SELECT i_manufact_id + FROM item + WHERE i_category IN ('Electronics')) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 5 + AND ss_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_manufact_id), cs AS +(SELECT + i_manufact_id, + sum(cs_ext_sales_price) total_sales + FROM catalog_sales, date_dim, customer_address, item + WHERE + i_manufact_id IN ( + SELECT i_manufact_id + FROM item + WHERE + i_category IN ('Electronics')) + AND cs_item_sk = i_item_sk + AND cs_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 5 + AND cs_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_manufact_id), + ws AS ( + SELECT + i_manufact_id, + sum(ws_ext_sales_price) total_sales + FROM + web_sales, date_dim, customer_address, item + WHERE + i_manufact_id IN (SELECT i_manufact_id + FROM item + WHERE i_category IN ('Electronics')) + AND ws_item_sk = i_item_sk + AND ws_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 5 + AND ws_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_manufact_id) +SELECT + i_manufact_id, + sum(total_sales) total_sales +FROM (SELECT * + FROM ss + UNION ALL + SELECT * + FROM cs + UNION ALL + SELECT * + FROM ws) tmp1 +GROUP BY i_manufact_id +ORDER BY total_sales +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q34.sql b/sql/core/src/test/resources/tpcds/q34.sql new file mode 100755 index 0000000000000..33396bf16e574 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q34.sql @@ -0,0 +1,32 @@ +SELECT + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + count(*) cnt + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND (date_dim.d_dom BETWEEN 1 AND 3 OR date_dim.d_dom BETWEEN 25 AND 28) + AND (household_demographics.hd_buy_potential = '>10000' OR + household_demographics.hd_buy_potential = 'unknown') + AND household_demographics.hd_vehicle_count > 0 + AND (CASE WHEN household_demographics.hd_vehicle_count > 0 + THEN household_demographics.hd_dep_count / household_demographics.hd_vehicle_count + ELSE NULL + END) > 1.2 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_county IN + ('Williamson County', 'Williamson County', 'Williamson County', 'Williamson County', + 'Williamson County', 'Williamson County', 'Williamson County', 'Williamson County') + GROUP BY ss_ticket_number, ss_customer_sk) dn, customer +WHERE ss_customer_sk = c_customer_sk + AND cnt BETWEEN 15 AND 20 +ORDER BY c_last_name, c_first_name, c_salutation, c_preferred_cust_flag DESC diff --git a/sql/core/src/test/resources/tpcds/q35.sql b/sql/core/src/test/resources/tpcds/q35.sql new file mode 100755 index 0000000000000..cfe4342d8be86 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q35.sql @@ -0,0 +1,46 @@ +SELECT + ca_state, + cd_gender, + cd_marital_status, + count(*) cnt1, + min(cd_dep_count), + max(cd_dep_count), + avg(cd_dep_count), + cd_dep_employed_count, + count(*) cnt2, + min(cd_dep_employed_count), + max(cd_dep_employed_count), + avg(cd_dep_employed_count), + cd_dep_college_count, + count(*) cnt3, + min(cd_dep_college_count), + max(cd_dep_college_count), + avg(cd_dep_college_count) +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) AND + (exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4) OR + exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2002 AND + d_qoy < 4)) +GROUP BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +ORDER BY ca_state, cd_gender, cd_marital_status, cd_dep_count, + cd_dep_employed_count, cd_dep_college_count +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q36.sql b/sql/core/src/test/resources/tpcds/q36.sql new file mode 100755 index 0000000000000..a8f93df76a34b --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q36.sql @@ -0,0 +1,26 @@ +SELECT + sum(ss_net_profit) / sum(ss_ext_sales_price) AS gross_margin, + i_category, + i_class, + grouping(i_category) + grouping(i_class) AS lochierarchy, + rank() + OVER ( + PARTITION BY grouping(i_category) + grouping(i_class), + CASE WHEN grouping(i_class) = 0 + THEN i_category END + ORDER BY sum(ss_net_profit) / sum(ss_ext_sales_price) ASC) AS rank_within_parent +FROM + store_sales, date_dim d1, item, store +WHERE + d1.d_year = 2001 + AND d1.d_date_sk = ss_sold_date_sk + AND i_item_sk = ss_item_sk + AND s_store_sk = ss_store_sk + AND s_state IN ('TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN', 'TN') +GROUP BY ROLLUP (i_category, i_class) +ORDER BY + lochierarchy DESC + , CASE WHEN lochierarchy = 0 + THEN i_category END + , rank_within_parent +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q37.sql b/sql/core/src/test/resources/tpcds/q37.sql new file mode 100755 index 0000000000000..11b3821fa48b8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q37.sql @@ -0,0 +1,15 @@ +SELECT + i_item_id, + i_item_desc, + i_current_price +FROM item, inventory, date_dim, catalog_sales +WHERE i_current_price BETWEEN 68 AND 68 + 30 + AND inv_item_sk = i_item_sk + AND d_date_sk = inv_date_sk + AND d_date BETWEEN cast('2000-02-01' AS DATE) AND (cast('2000-02-01' AS DATE) + INTERVAL 60 days) + AND i_manufact_id IN (677, 940, 694, 808) + AND inv_quantity_on_hand BETWEEN 100 AND 500 + AND cs_item_sk = i_item_sk +GROUP BY i_item_id, i_item_desc, i_current_price +ORDER BY i_item_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q38.sql b/sql/core/src/test/resources/tpcds/q38.sql new file mode 100755 index 0000000000000..1c8d53ee2bbfc --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q38.sql @@ -0,0 +1,30 @@ +SELECT count(*) +FROM ( + SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM store_sales, date_dim, customer + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + INTERSECT + SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM catalog_sales, date_dim, customer + WHERE catalog_sales.cs_sold_date_sk = date_dim.d_date_sk + AND catalog_sales.cs_bill_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + INTERSECT + SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM web_sales, date_dim, customer + WHERE web_sales.ws_sold_date_sk = date_dim.d_date_sk + AND web_sales.ws_bill_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + ) hot_cust +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q39a.sql b/sql/core/src/test/resources/tpcds/q39a.sql new file mode 100755 index 0000000000000..9fc4c1701cf21 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q39a.sql @@ -0,0 +1,47 @@ +WITH inv AS +(SELECT + w_warehouse_name, + w_warehouse_sk, + i_item_sk, + d_moy, + stdev, + mean, + CASE mean + WHEN 0 + THEN NULL + ELSE stdev / mean END cov + FROM (SELECT + w_warehouse_name, + w_warehouse_sk, + i_item_sk, + d_moy, + stddev_samp(inv_quantity_on_hand) stdev, + avg(inv_quantity_on_hand) mean + FROM inventory, item, warehouse, date_dim + WHERE inv_item_sk = i_item_sk + AND inv_warehouse_sk = w_warehouse_sk + AND inv_date_sk = d_date_sk + AND d_year = 2001 + GROUP BY w_warehouse_name, w_warehouse_sk, i_item_sk, d_moy) foo + WHERE CASE mean + WHEN 0 + THEN 0 + ELSE stdev / mean END > 1) +SELECT + inv1.w_warehouse_sk, + inv1.i_item_sk, + inv1.d_moy, + inv1.mean, + inv1.cov, + inv2.w_warehouse_sk, + inv2.i_item_sk, + inv2.d_moy, + inv2.mean, + inv2.cov +FROM inv inv1, inv inv2 +WHERE inv1.i_item_sk = inv2.i_item_sk + AND inv1.w_warehouse_sk = inv2.w_warehouse_sk + AND inv1.d_moy = 1 + AND inv2.d_moy = 1 + 1 +ORDER BY inv1.w_warehouse_sk, inv1.i_item_sk, inv1.d_moy, inv1.mean, inv1.cov + , inv2.d_moy, inv2.mean, inv2.cov diff --git a/sql/core/src/test/resources/tpcds/q39b.sql b/sql/core/src/test/resources/tpcds/q39b.sql new file mode 100755 index 0000000000000..6f8493029fab4 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q39b.sql @@ -0,0 +1,48 @@ +WITH inv AS +(SELECT + w_warehouse_name, + w_warehouse_sk, + i_item_sk, + d_moy, + stdev, + mean, + CASE mean + WHEN 0 + THEN NULL + ELSE stdev / mean END cov + FROM (SELECT + w_warehouse_name, + w_warehouse_sk, + i_item_sk, + d_moy, + stddev_samp(inv_quantity_on_hand) stdev, + avg(inv_quantity_on_hand) mean + FROM inventory, item, warehouse, date_dim + WHERE inv_item_sk = i_item_sk + AND inv_warehouse_sk = w_warehouse_sk + AND inv_date_sk = d_date_sk + AND d_year = 2001 + GROUP BY w_warehouse_name, w_warehouse_sk, i_item_sk, d_moy) foo + WHERE CASE mean + WHEN 0 + THEN 0 + ELSE stdev / mean END > 1) +SELECT + inv1.w_warehouse_sk, + inv1.i_item_sk, + inv1.d_moy, + inv1.mean, + inv1.cov, + inv2.w_warehouse_sk, + inv2.i_item_sk, + inv2.d_moy, + inv2.mean, + inv2.cov +FROM inv inv1, inv inv2 +WHERE inv1.i_item_sk = inv2.i_item_sk + AND inv1.w_warehouse_sk = inv2.w_warehouse_sk + AND inv1.d_moy = 1 + AND inv2.d_moy = 1 + 1 + AND inv1.cov > 1.5 +ORDER BY inv1.w_warehouse_sk, inv1.i_item_sk, inv1.d_moy, inv1.mean, inv1.cov + , inv2.d_moy, inv2.mean, inv2.cov diff --git a/sql/core/src/test/resources/tpcds/q4.sql b/sql/core/src/test/resources/tpcds/q4.sql new file mode 100755 index 0000000000000..b9f27fbc9a4a6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q4.sql @@ -0,0 +1,120 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum(((ss_ext_list_price - ss_ext_wholesale_cost - ss_ext_discount_amt) + + ss_ext_sales_price) / 2) year_total, + 's' sale_type + FROM customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk AND ss_sold_date_sk = d_date_sk + GROUP BY c_customer_id, + c_first_name, + c_last_name, + c_preferred_cust_flag, + c_birth_country, + c_login, + c_email_address, + d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum((((cs_ext_list_price - cs_ext_wholesale_cost - cs_ext_discount_amt) + + cs_ext_sales_price) / 2)) year_total, + 'c' sale_type + FROM customer, catalog_sales, date_dim + WHERE c_customer_sk = cs_bill_customer_sk AND cs_sold_date_sk = d_date_sk + GROUP BY c_customer_id, + c_first_name, + c_last_name, + c_preferred_cust_flag, + c_birth_country, + c_login, + c_email_address, + d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + c_preferred_cust_flag customer_preferred_cust_flag, + c_birth_country customer_birth_country, + c_login customer_login, + c_email_address customer_email_address, + d_year dyear, + sum((((ws_ext_list_price - ws_ext_wholesale_cost - ws_ext_discount_amt) + ws_ext_sales_price) / + 2)) year_total, + 'w' sale_type + FROM customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk AND ws_sold_date_sk = d_date_sk + GROUP BY c_customer_id, + c_first_name, + c_last_name, + c_preferred_cust_flag, + c_birth_country, + c_login, + c_email_address, + d_year) +SELECT + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_preferred_cust_flag, + t_s_secyear.customer_birth_country, + t_s_secyear.customer_login, + t_s_secyear.customer_email_address +FROM year_total t_s_firstyear, year_total t_s_secyear, year_total t_c_firstyear, + year_total t_c_secyear, year_total t_w_firstyear, year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_c_secyear.customer_id + AND t_s_firstyear.customer_id = t_c_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_c_firstyear.sale_type = 'c' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_c_secyear.sale_type = 'c' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.dyear = 2001 + AND t_s_secyear.dyear = 2001 + 1 + AND t_c_firstyear.dyear = 2001 + AND t_c_secyear.dyear = 2001 + 1 + AND t_w_firstyear.dyear = 2001 + AND t_w_secyear.dyear = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_c_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_c_firstyear.year_total > 0 + THEN t_c_secyear.year_total / t_c_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END + AND CASE WHEN t_c_firstyear.year_total > 0 + THEN t_c_secyear.year_total / t_c_firstyear.year_total + ELSE NULL END + > CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END +ORDER BY + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name, + t_s_secyear.customer_preferred_cust_flag, + t_s_secyear.customer_birth_country, + t_s_secyear.customer_login, + t_s_secyear.customer_email_address +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q40.sql b/sql/core/src/test/resources/tpcds/q40.sql new file mode 100755 index 0000000000000..66d8b73ac1c15 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q40.sql @@ -0,0 +1,25 @@ +SELECT + w_state, + i_item_id, + sum(CASE WHEN (cast(d_date AS DATE) < cast('2000-03-11' AS DATE)) + THEN cs_sales_price - coalesce(cr_refunded_cash, 0) + ELSE 0 END) AS sales_before, + sum(CASE WHEN (cast(d_date AS DATE) >= cast('2000-03-11' AS DATE)) + THEN cs_sales_price - coalesce(cr_refunded_cash, 0) + ELSE 0 END) AS sales_after +FROM + catalog_sales + LEFT OUTER JOIN catalog_returns ON + (cs_order_number = cr_order_number + AND cs_item_sk = cr_item_sk) + , warehouse, item, date_dim +WHERE + i_current_price BETWEEN 0.99 AND 1.49 + AND i_item_sk = cs_item_sk + AND cs_warehouse_sk = w_warehouse_sk + AND cs_sold_date_sk = d_date_sk + AND d_date BETWEEN (cast('2000-03-11' AS DATE) - INTERVAL 30 days) + AND (cast('2000-03-11' AS DATE) + INTERVAL 30 days) +GROUP BY w_state, i_item_id +ORDER BY w_state, i_item_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q41.sql b/sql/core/src/test/resources/tpcds/q41.sql new file mode 100755 index 0000000000000..25e317e0e201a --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q41.sql @@ -0,0 +1,49 @@ +SELECT DISTINCT (i_product_name) +FROM item i1 +WHERE i_manufact_id BETWEEN 738 AND 738 + 40 + AND (SELECT count(*) AS item_cnt +FROM item +WHERE (i_manufact = i1.i_manufact AND + ((i_category = 'Women' AND + (i_color = 'powder' OR i_color = 'khaki') AND + (i_units = 'Ounce' OR i_units = 'Oz') AND + (i_size = 'medium' OR i_size = 'extra large') + ) OR + (i_category = 'Women' AND + (i_color = 'brown' OR i_color = 'honeydew') AND + (i_units = 'Bunch' OR i_units = 'Ton') AND + (i_size = 'N/A' OR i_size = 'small') + ) OR + (i_category = 'Men' AND + (i_color = 'floral' OR i_color = 'deep') AND + (i_units = 'N/A' OR i_units = 'Dozen') AND + (i_size = 'petite' OR i_size = 'large') + ) OR + (i_category = 'Men' AND + (i_color = 'light' OR i_color = 'cornflower') AND + (i_units = 'Box' OR i_units = 'Pound') AND + (i_size = 'medium' OR i_size = 'extra large') + ))) OR + (i_manufact = i1.i_manufact AND + ((i_category = 'Women' AND + (i_color = 'midnight' OR i_color = 'snow') AND + (i_units = 'Pallet' OR i_units = 'Gross') AND + (i_size = 'medium' OR i_size = 'extra large') + ) OR + (i_category = 'Women' AND + (i_color = 'cyan' OR i_color = 'papaya') AND + (i_units = 'Cup' OR i_units = 'Dram') AND + (i_size = 'N/A' OR i_size = 'small') + ) OR + (i_category = 'Men' AND + (i_color = 'orange' OR i_color = 'frosted') AND + (i_units = 'Each' OR i_units = 'Tbl') AND + (i_size = 'petite' OR i_size = 'large') + ) OR + (i_category = 'Men' AND + (i_color = 'forest' OR i_color = 'ghost') AND + (i_units = 'Lb' OR i_units = 'Bundle') AND + (i_size = 'medium' OR i_size = 'extra large') + )))) > 0 +ORDER BY i_product_name +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q42.sql b/sql/core/src/test/resources/tpcds/q42.sql new file mode 100755 index 0000000000000..4d2e71760d870 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q42.sql @@ -0,0 +1,18 @@ +SELECT + dt.d_year, + item.i_category_id, + item.i_category, + sum(ss_ext_sales_price) +FROM date_dim dt, store_sales, item +WHERE dt.d_date_sk = store_sales.ss_sold_date_sk + AND store_sales.ss_item_sk = item.i_item_sk + AND item.i_manager_id = 1 + AND dt.d_moy = 11 + AND dt.d_year = 2000 +GROUP BY dt.d_year + , item.i_category_id + , item.i_category +ORDER BY sum(ss_ext_sales_price) DESC, dt.d_year + , item.i_category_id + , item.i_category +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q43.sql b/sql/core/src/test/resources/tpcds/q43.sql new file mode 100755 index 0000000000000..45411772c1b54 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q43.sql @@ -0,0 +1,33 @@ +SELECT + s_store_name, + s_store_id, + sum(CASE WHEN (d_day_name = 'Sunday') + THEN ss_sales_price + ELSE NULL END) sun_sales, + sum(CASE WHEN (d_day_name = 'Monday') + THEN ss_sales_price + ELSE NULL END) mon_sales, + sum(CASE WHEN (d_day_name = 'Tuesday') + THEN ss_sales_price + ELSE NULL END) tue_sales, + sum(CASE WHEN (d_day_name = 'Wednesday') + THEN ss_sales_price + ELSE NULL END) wed_sales, + sum(CASE WHEN (d_day_name = 'Thursday') + THEN ss_sales_price + ELSE NULL END) thu_sales, + sum(CASE WHEN (d_day_name = 'Friday') + THEN ss_sales_price + ELSE NULL END) fri_sales, + sum(CASE WHEN (d_day_name = 'Saturday') + THEN ss_sales_price + ELSE NULL END) sat_sales +FROM date_dim, store_sales, store +WHERE d_date_sk = ss_sold_date_sk AND + s_store_sk = ss_store_sk AND + s_gmt_offset = -5 AND + d_year = 2000 +GROUP BY s_store_name, s_store_id +ORDER BY s_store_name, s_store_id, sun_sales, mon_sales, tue_sales, wed_sales, + thu_sales, fri_sales, sat_sales +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q44.sql b/sql/core/src/test/resources/tpcds/q44.sql new file mode 100755 index 0000000000000..379e604788625 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q44.sql @@ -0,0 +1,46 @@ +SELECT + asceding.rnk, + i1.i_product_name best_performing, + i2.i_product_name worst_performing +FROM (SELECT * +FROM (SELECT + item_sk, + rank() + OVER ( + ORDER BY rank_col ASC) rnk +FROM (SELECT + ss_item_sk item_sk, + avg(ss_net_profit) rank_col +FROM store_sales ss1 +WHERE ss_store_sk = 4 +GROUP BY ss_item_sk +HAVING avg(ss_net_profit) > 0.9 * (SELECT avg(ss_net_profit) rank_col +FROM store_sales +WHERE ss_store_sk = 4 + AND ss_addr_sk IS NULL +GROUP BY ss_store_sk)) V1) V11 +WHERE rnk < 11) asceding, + (SELECT * + FROM (SELECT + item_sk, + rank() + OVER ( + ORDER BY rank_col DESC) rnk + FROM (SELECT + ss_item_sk item_sk, + avg(ss_net_profit) rank_col + FROM store_sales ss1 + WHERE ss_store_sk = 4 + GROUP BY ss_item_sk + HAVING avg(ss_net_profit) > 0.9 * (SELECT avg(ss_net_profit) rank_col + FROM store_sales + WHERE ss_store_sk = 4 + AND ss_addr_sk IS NULL + GROUP BY ss_store_sk)) V2) V21 + WHERE rnk < 11) descending, + item i1, item i2 +WHERE asceding.rnk = descending.rnk + AND i1.i_item_sk = asceding.item_sk + AND i2.i_item_sk = descending.item_sk +ORDER BY asceding.rnk +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q45.sql b/sql/core/src/test/resources/tpcds/q45.sql new file mode 100755 index 0000000000000..907438f196c4c --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q45.sql @@ -0,0 +1,21 @@ +SELECT + ca_zip, + ca_city, + sum(ws_sales_price) +FROM web_sales, customer, customer_address, date_dim, item +WHERE ws_bill_customer_sk = c_customer_sk + AND c_current_addr_sk = ca_address_sk + AND ws_item_sk = i_item_sk + AND (substr(ca_zip, 1, 5) IN + ('85669', '86197', '88274', '83405', '86475', '85392', '85460', '80348', '81792') + OR + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_item_sk IN (2, 3, 5, 7, 11, 13, 17, 19, 23, 29) + ) +) + AND ws_sold_date_sk = d_date_sk + AND d_qoy = 2 AND d_year = 2001 +GROUP BY ca_zip, ca_city +ORDER BY ca_zip, ca_city +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q46.sql b/sql/core/src/test/resources/tpcds/q46.sql new file mode 100755 index 0000000000000..0911677dff206 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q46.sql @@ -0,0 +1,32 @@ +SELECT + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + amt, + profit +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + FROM store_sales, date_dim, store, household_demographics, customer_address + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND store_sales.ss_addr_sk = customer_address.ca_address_sk + AND (household_demographics.hd_dep_count = 4 OR + household_demographics.hd_vehicle_count = 3) + AND date_dim.d_dow IN (6, 0) + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_city IN ('Fairview', 'Midway', 'Fairview', 'Fairview', 'Fairview') + GROUP BY ss_ticket_number, ss_customer_sk, ss_addr_sk, ca_city) dn, customer, + customer_address current_addr +WHERE ss_customer_sk = c_customer_sk + AND customer.c_current_addr_sk = current_addr.ca_address_sk + AND current_addr.ca_city <> bought_city +ORDER BY c_last_name, c_first_name, ca_city, bought_city, ss_ticket_number +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q47.sql b/sql/core/src/test/resources/tpcds/q47.sql new file mode 100755 index 0000000000000..cfc37a4cece66 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q47.sql @@ -0,0 +1,63 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + s_store_name, + s_company_name, + d_year, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, + s_store_name, s_company_name + ORDER BY d_year, d_moy) rn + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + s_store_name, s_company_name, + d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + v1.i_brand, + v1.s_store_name, + v1.s_company_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.s_store_name = v1_lag.s_store_name AND + v1.s_store_name = v1_lead.s_store_name AND + v1.s_company_name = v1_lag.s_company_name AND + v1.s_company_name = v1_lead.s_company_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q48.sql b/sql/core/src/test/resources/tpcds/q48.sql new file mode 100755 index 0000000000000..fdb9f38e294f7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q48.sql @@ -0,0 +1,63 @@ +SELECT sum(ss_quantity) +FROM store_sales, store, customer_demographics, customer_address, date_dim +WHERE s_store_sk = ss_store_sk + AND ss_sold_date_sk = d_date_sk AND d_year = 2001 + AND + ( + ( + cd_demo_sk = ss_cdemo_sk + AND + cd_marital_status = 'M' + AND + cd_education_status = '4 yr Degree' + AND + ss_sales_price BETWEEN 100.00 AND 150.00 + ) + OR + ( + cd_demo_sk = ss_cdemo_sk + AND + cd_marital_status = 'D' + AND + cd_education_status = '2 yr Degree' + AND + ss_sales_price BETWEEN 50.00 AND 100.00 + ) + OR + ( + cd_demo_sk = ss_cdemo_sk + AND + cd_marital_status = 'S' + AND + cd_education_status = 'College' + AND + ss_sales_price BETWEEN 150.00 AND 200.00 + ) + ) + AND + ( + ( + ss_addr_sk = ca_address_sk + AND + ca_country = 'United States' + AND + ca_state IN ('CO', 'OH', 'TX') + AND ss_net_profit BETWEEN 0 AND 2000 + ) + OR + (ss_addr_sk = ca_address_sk + AND + ca_country = 'United States' + AND + ca_state IN ('OR', 'MN', 'KY') + AND ss_net_profit BETWEEN 150 AND 3000 + ) + OR + (ss_addr_sk = ca_address_sk + AND + ca_country = 'United States' + AND + ca_state IN ('VA', 'CA', 'MS') + AND ss_net_profit BETWEEN 50 AND 25000 + ) + ) diff --git a/sql/core/src/test/resources/tpcds/q49.sql b/sql/core/src/test/resources/tpcds/q49.sql new file mode 100755 index 0000000000000..9568d8b92d10a --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q49.sql @@ -0,0 +1,126 @@ +SELECT + 'web' AS channel, + web.item, + web.return_ratio, + web.return_rank, + web.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + ws.ws_item_sk AS item, + (cast(sum(coalesce(wr.wr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(ws.ws_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(wr.wr_return_amt, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(ws.ws_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + web_sales ws LEFT OUTER JOIN web_returns wr + ON (ws.ws_order_number = wr.wr_order_number AND + ws.ws_item_sk = wr.wr_item_sk) + , date_dim + WHERE + wr.wr_return_amt > 10000 + AND ws.ws_net_profit > 1 + AND ws.ws_net_paid > 0 + AND ws.ws_quantity > 0 + AND ws_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY ws.ws_item_sk + ) in_web + ) web +WHERE (web.return_rank <= 10 OR web.currency_rank <= 10) +UNION +SELECT + 'catalog' AS channel, + catalog.item, + catalog.return_ratio, + catalog.return_rank, + catalog.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + cs.cs_item_sk AS item, + (cast(sum(coalesce(cr.cr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(cr.cr_return_amount, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(cs.cs_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + catalog_sales cs LEFT OUTER JOIN catalog_returns cr + ON (cs.cs_order_number = cr.cr_order_number AND + cs.cs_item_sk = cr.cr_item_sk) + , date_dim + WHERE + cr.cr_return_amount > 10000 + AND cs.cs_net_profit > 1 + AND cs.cs_net_paid > 0 + AND cs.cs_quantity > 0 + AND cs_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY cs.cs_item_sk + ) in_cat + ) catalog +WHERE (catalog.return_rank <= 10 OR catalog.currency_rank <= 10) +UNION +SELECT + 'store' AS channel, + store.item, + store.return_ratio, + store.return_rank, + store.currency_rank +FROM ( + SELECT + item, + return_ratio, + currency_ratio, + rank() + OVER ( + ORDER BY return_ratio) AS return_rank, + rank() + OVER ( + ORDER BY currency_ratio) AS currency_rank + FROM + (SELECT + sts.ss_item_sk AS item, + (cast(sum(coalesce(sr.sr_return_quantity, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_quantity, 0)) AS DECIMAL(15, 4))) AS return_ratio, + (cast(sum(coalesce(sr.sr_return_amt, 0)) AS DECIMAL(15, 4)) / + cast(sum(coalesce(sts.ss_net_paid, 0)) AS DECIMAL(15, 4))) AS currency_ratio + FROM + store_sales sts LEFT OUTER JOIN store_returns sr + ON (sts.ss_ticket_number = sr.sr_ticket_number AND sts.ss_item_sk = sr.sr_item_sk) + , date_dim + WHERE + sr.sr_return_amt > 10000 + AND sts.ss_net_profit > 1 + AND sts.ss_net_paid > 0 + AND sts.ss_quantity > 0 + AND ss_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 12 + GROUP BY sts.ss_item_sk + ) in_store + ) store +WHERE (store.return_rank <= 10 OR store.currency_rank <= 10) +ORDER BY 1, 4, 5 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q5.sql b/sql/core/src/test/resources/tpcds/q5.sql new file mode 100755 index 0000000000000..b87cf3a44827b --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q5.sql @@ -0,0 +1,131 @@ +WITH ssr AS +( SELECT + s_store_id, + sum(sales_price) AS sales, + sum(profit) AS profit, + sum(return_amt) AS RETURNS, + sum(net_loss) AS profit_loss + FROM + (SELECT + ss_store_sk AS store_sk, + ss_sold_date_sk AS date_sk, + ss_ext_sales_price AS sales_price, + ss_net_profit AS profit, + cast(0 AS DECIMAL(7, 2)) AS return_amt, + cast(0 AS DECIMAL(7, 2)) AS net_loss + FROM store_sales + UNION ALL + SELECT + sr_store_sk AS store_sk, + sr_returned_date_sk AS date_sk, + cast(0 AS DECIMAL(7, 2)) AS sales_price, + cast(0 AS DECIMAL(7, 2)) AS profit, + sr_return_amt AS return_amt, + sr_net_loss AS net_loss + FROM store_returns) + salesreturns, date_dim, store + WHERE date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND ((cast('2000-08-23' AS DATE) + INTERVAL 14 days)) + AND store_sk = s_store_sk + GROUP BY s_store_id), + csr AS + ( SELECT + cp_catalog_page_id, + sum(sales_price) AS sales, + sum(profit) AS profit, + sum(return_amt) AS RETURNS, + sum(net_loss) AS profit_loss + FROM + (SELECT + cs_catalog_page_sk AS page_sk, + cs_sold_date_sk AS date_sk, + cs_ext_sales_price AS sales_price, + cs_net_profit AS profit, + cast(0 AS DECIMAL(7, 2)) AS return_amt, + cast(0 AS DECIMAL(7, 2)) AS net_loss + FROM catalog_sales + UNION ALL + SELECT + cr_catalog_page_sk AS page_sk, + cr_returned_date_sk AS date_sk, + cast(0 AS DECIMAL(7, 2)) AS sales_price, + cast(0 AS DECIMAL(7, 2)) AS profit, + cr_return_amount AS return_amt, + cr_net_loss AS net_loss + FROM catalog_returns + ) salesreturns, date_dim, catalog_page + WHERE date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND ((cast('2000-08-23' AS DATE) + INTERVAL 14 days)) + AND page_sk = cp_catalog_page_sk + GROUP BY cp_catalog_page_id) + , + wsr AS + ( SELECT + web_site_id, + sum(sales_price) AS sales, + sum(profit) AS profit, + sum(return_amt) AS RETURNS, + sum(net_loss) AS profit_loss + FROM + (SELECT + ws_web_site_sk AS wsr_web_site_sk, + ws_sold_date_sk AS date_sk, + ws_ext_sales_price AS sales_price, + ws_net_profit AS profit, + cast(0 AS DECIMAL(7, 2)) AS return_amt, + cast(0 AS DECIMAL(7, 2)) AS net_loss + FROM web_sales + UNION ALL + SELECT + ws_web_site_sk AS wsr_web_site_sk, + wr_returned_date_sk AS date_sk, + cast(0 AS DECIMAL(7, 2)) AS sales_price, + cast(0 AS DECIMAL(7, 2)) AS profit, + wr_return_amt AS return_amt, + wr_net_loss AS net_loss + FROM web_returns + LEFT OUTER JOIN web_sales ON + (wr_item_sk = ws_item_sk + AND wr_order_number = ws_order_number) + ) salesreturns, date_dim, web_site + WHERE date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND ((cast('2000-08-23' AS DATE) + INTERVAL 14 days)) + AND wsr_web_site_sk = web_site_sk + GROUP BY web_site_id) +SELECT + channel, + id, + sum(sales) AS sales, + sum(returns) AS returns, + sum(profit) AS profit +FROM + (SELECT + 'store channel' AS channel, + concat('store', s_store_id) AS id, + sales, + returns, + (profit - profit_loss) AS profit + FROM ssr + UNION ALL + SELECT + 'catalog channel' AS channel, + concat('catalog_page', cp_catalog_page_id) AS id, + sales, + returns, + (profit - profit_loss) AS profit + FROM csr + UNION ALL + SELECT + 'web channel' AS channel, + concat('web_site', web_site_id) AS id, + sales, + returns, + (profit - profit_loss) AS profit + FROM wsr + ) x +GROUP BY ROLLUP (channel, id) +ORDER BY channel, id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q50.sql b/sql/core/src/test/resources/tpcds/q50.sql new file mode 100755 index 0000000000000..f1d4b15449edd --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q50.sql @@ -0,0 +1,47 @@ +SELECT + s_store_name, + s_company_id, + s_street_number, + s_street_name, + s_street_type, + s_suite_number, + s_city, + s_county, + s_state, + s_zip, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk <= 30) + THEN 1 + ELSE 0 END) AS `30 days `, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk > 30) AND + (sr_returned_date_sk - ss_sold_date_sk <= 60) + THEN 1 + ELSE 0 END) AS `31 - 60 days `, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk > 60) AND + (sr_returned_date_sk - ss_sold_date_sk <= 90) + THEN 1 + ELSE 0 END) AS `61 - 90 days `, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk > 90) AND + (sr_returned_date_sk - ss_sold_date_sk <= 120) + THEN 1 + ELSE 0 END) AS `91 - 120 days `, + sum(CASE WHEN (sr_returned_date_sk - ss_sold_date_sk > 120) + THEN 1 + ELSE 0 END) AS `>120 days ` +FROM + store_sales, store_returns, store, date_dim d1, date_dim d2 +WHERE + d2.d_year = 2001 + AND d2.d_moy = 8 + AND ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk + AND ss_sold_date_sk = d1.d_date_sk + AND sr_returned_date_sk = d2.d_date_sk + AND ss_customer_sk = sr_customer_sk + AND ss_store_sk = s_store_sk +GROUP BY + s_store_name, s_company_id, s_street_number, s_street_name, s_street_type, + s_suite_number, s_city, s_county, s_state, s_zip +ORDER BY + s_store_name, s_company_id, s_street_number, s_street_name, s_street_type, + s_suite_number, s_city, s_county, s_state, s_zip +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q51.sql b/sql/core/src/test/resources/tpcds/q51.sql new file mode 100755 index 0000000000000..62b003eb67b9b --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q51.sql @@ -0,0 +1,55 @@ +WITH web_v1 AS ( + SELECT + ws_item_sk item_sk, + d_date, + sum(sum(ws_sales_price)) + OVER (PARTITION BY ws_item_sk + ORDER BY d_date + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) cume_sales + FROM web_sales, date_dim + WHERE ws_sold_date_sk = d_date_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + AND ws_item_sk IS NOT NULL + GROUP BY ws_item_sk, d_date), + store_v1 AS ( + SELECT + ss_item_sk item_sk, + d_date, + sum(sum(ss_sales_price)) + OVER (PARTITION BY ss_item_sk + ORDER BY d_date + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) cume_sales + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + AND ss_item_sk IS NOT NULL + GROUP BY ss_item_sk, d_date) +SELECT * +FROM (SELECT + item_sk, + d_date, + web_sales, + store_sales, + max(web_sales) + OVER (PARTITION BY item_sk + ORDER BY d_date + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) web_cumulative, + max(store_sales) + OVER (PARTITION BY item_sk + ORDER BY d_date + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) store_cumulative +FROM (SELECT + CASE WHEN web.item_sk IS NOT NULL + THEN web.item_sk + ELSE store.item_sk END item_sk, + CASE WHEN web.d_date IS NOT NULL + THEN web.d_date + ELSE store.d_date END d_date, + web.cume_sales web_sales, + store.cume_sales store_sales +FROM web_v1 web FULL OUTER JOIN store_v1 store ON (web.item_sk = store.item_sk + AND web.d_date = store.d_date) + ) x) y +WHERE web_cumulative > store_cumulative +ORDER BY item_sk, d_date +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q52.sql b/sql/core/src/test/resources/tpcds/q52.sql new file mode 100755 index 0000000000000..467d1ae050455 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q52.sql @@ -0,0 +1,14 @@ +SELECT + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_ext_sales_price) ext_price +FROM date_dim dt, store_sales, item +WHERE dt.d_date_sk = store_sales.ss_sold_date_sk + AND store_sales.ss_item_sk = item.i_item_sk + AND item.i_manager_id = 1 + AND dt.d_moy = 11 + AND dt.d_year = 2000 +GROUP BY dt.d_year, item.i_brand, item.i_brand_id +ORDER BY dt.d_year, ext_price DESC, brand_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q53.sql b/sql/core/src/test/resources/tpcds/q53.sql new file mode 100755 index 0000000000000..b42c68dcf871b --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q53.sql @@ -0,0 +1,30 @@ +SELECT * +FROM + (SELECT + i_manufact_id, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER (PARTITION BY i_manufact_id) avg_quarterly_sales + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + d_month_seq IN (1200, 1200 + 1, 1200 + 2, 1200 + 3, 1200 + 4, 1200 + 5, 1200 + 6, + 1200 + 7, 1200 + 8, 1200 + 9, 1200 + 10, 1200 + 11) AND + ((i_category IN ('Books', 'Children', 'Electronics') AND + i_class IN ('personal', 'portable', 'reference', 'self-help') AND + i_brand IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', + 'exportiunivamalg #9', 'scholaramalgamalg #9')) + OR + (i_category IN ('Women', 'Music', 'Men') AND + i_class IN ('accessories', 'classical', 'fragrances', 'pants') AND + i_brand IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', + 'importoamalg #1'))) + GROUP BY i_manufact_id, d_qoy) tmp1 +WHERE CASE WHEN avg_quarterly_sales > 0 + THEN abs(sum_sales - avg_quarterly_sales) / avg_quarterly_sales + ELSE NULL END > 0.1 +ORDER BY avg_quarterly_sales, + sum_sales, + i_manufact_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q54.sql b/sql/core/src/test/resources/tpcds/q54.sql new file mode 100755 index 0000000000000..897237fb6e10b --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q54.sql @@ -0,0 +1,61 @@ +WITH my_customers AS ( + SELECT DISTINCT + c_customer_sk, + c_current_addr_sk + FROM + (SELECT + cs_sold_date_sk sold_date_sk, + cs_bill_customer_sk customer_sk, + cs_item_sk item_sk + FROM catalog_sales + UNION ALL + SELECT + ws_sold_date_sk sold_date_sk, + ws_bill_customer_sk customer_sk, + ws_item_sk item_sk + FROM web_sales + ) cs_or_ws_sales, + item, + date_dim, + customer + WHERE sold_date_sk = d_date_sk + AND item_sk = i_item_sk + AND i_category = 'Women' + AND i_class = 'maternity' + AND c_customer_sk = cs_or_ws_sales.customer_sk + AND d_moy = 12 + AND d_year = 1998 +) + , my_revenue AS ( + SELECT + c_customer_sk, + sum(ss_ext_sales_price) AS revenue + FROM my_customers, + store_sales, + customer_address, + store, + date_dim + WHERE c_current_addr_sk = ca_address_sk + AND ca_county = s_county + AND ca_state = s_state + AND ss_sold_date_sk = d_date_sk + AND c_customer_sk = ss_customer_sk + AND d_month_seq BETWEEN (SELECT DISTINCT d_month_seq + 1 + FROM date_dim + WHERE d_year = 1998 AND d_moy = 12) + AND (SELECT DISTINCT d_month_seq + 3 + FROM date_dim + WHERE d_year = 1998 AND d_moy = 12) + GROUP BY c_customer_sk +) + , segments AS +(SELECT cast((revenue / 50) AS INT) AS segment + FROM my_revenue) +SELECT + segment, + count(*) AS num_customers, + segment * 50 AS segment_base +FROM segments +GROUP BY segment +ORDER BY segment, num_customers +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q55.sql b/sql/core/src/test/resources/tpcds/q55.sql new file mode 100755 index 0000000000000..bc5d888c9ac58 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q55.sql @@ -0,0 +1,13 @@ +SELECT + i_brand_id brand_id, + i_brand brand, + sum(ss_ext_sales_price) ext_price +FROM date_dim, store_sales, item +WHERE d_date_sk = ss_sold_date_sk + AND ss_item_sk = i_item_sk + AND i_manager_id = 28 + AND d_moy = 11 + AND d_year = 1999 +GROUP BY i_brand, i_brand_id +ORDER BY ext_price DESC, brand_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q56.sql b/sql/core/src/test/resources/tpcds/q56.sql new file mode 100755 index 0000000000000..2fa1738dcfee6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q56.sql @@ -0,0 +1,65 @@ +WITH ss AS ( + SELECT + i_item_id, + sum(ss_ext_sales_price) total_sales + FROM + store_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_color IN ('slate', 'blanched', 'burnished')) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 2 + AND ss_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id), + cs AS ( + SELECT + i_item_id, + sum(cs_ext_sales_price) total_sales + FROM + catalog_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_color IN ('slate', 'blanched', 'burnished')) + AND cs_item_sk = i_item_sk + AND cs_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 2 + AND cs_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id), + ws AS ( + SELECT + i_item_id, + sum(ws_ext_sales_price) total_sales + FROM + web_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_color IN ('slate', 'blanched', 'burnished')) + AND ws_item_sk = i_item_sk + AND ws_sold_date_sk = d_date_sk + AND d_year = 2001 + AND d_moy = 2 + AND ws_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id) +SELECT + i_item_id, + sum(total_sales) total_sales +FROM (SELECT * + FROM ss + UNION ALL + SELECT * + FROM cs + UNION ALL + SELECT * + FROM ws) tmp1 +GROUP BY i_item_id +ORDER BY total_sales +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q57.sql b/sql/core/src/test/resources/tpcds/q57.sql new file mode 100755 index 0000000000000..cf70d4b905b55 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q57.sql @@ -0,0 +1,56 @@ +WITH v1 AS ( + SELECT + i_category, + i_brand, + cc_name, + d_year, + d_moy, + sum(cs_sales_price) sum_sales, + avg(sum(cs_sales_price)) + OVER + (PARTITION BY i_category, i_brand, cc_name, d_year) + avg_monthly_sales, + rank() + OVER + (PARTITION BY i_category, i_brand, cc_name + ORDER BY d_year, d_moy) rn + FROM item, catalog_sales, date_dim, call_center + WHERE cs_item_sk = i_item_sk AND + cs_sold_date_sk = d_date_sk AND + cc_call_center_sk = cs_call_center_sk AND + ( + d_year = 1999 OR + (d_year = 1999 - 1 AND d_moy = 12) OR + (d_year = 1999 + 1 AND d_moy = 1) + ) + GROUP BY i_category, i_brand, + cc_name, d_year, d_moy), + v2 AS ( + SELECT + v1.i_category, + v1.i_brand, + v1.cc_name, + v1.d_year, + v1.d_moy, + v1.avg_monthly_sales, + v1.sum_sales, + v1_lag.sum_sales psum, + v1_lead.sum_sales nsum + FROM v1, v1 v1_lag, v1 v1_lead + WHERE v1.i_category = v1_lag.i_category AND + v1.i_category = v1_lead.i_category AND + v1.i_brand = v1_lag.i_brand AND + v1.i_brand = v1_lead.i_brand AND + v1.cc_name = v1_lag.cc_name AND + v1.cc_name = v1_lead.cc_name AND + v1.rn = v1_lag.rn + 1 AND + v1.rn = v1_lead.rn - 1) +SELECT * +FROM v2 +WHERE d_year = 1999 AND + avg_monthly_sales > 0 AND + CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, 3 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q58.sql b/sql/core/src/test/resources/tpcds/q58.sql new file mode 100755 index 0000000000000..5f63f33dc927c --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q58.sql @@ -0,0 +1,59 @@ +WITH ss_items AS +(SELECT + i_item_id item_id, + sum(ss_ext_sales_price) ss_item_rev + FROM store_sales, item, date_dim + WHERE ss_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_date = '2000-01-03')) + AND ss_sold_date_sk = d_date_sk + GROUP BY i_item_id), + cs_items AS + (SELECT + i_item_id item_id, + sum(cs_ext_sales_price) cs_item_rev + FROM catalog_sales, item, date_dim + WHERE cs_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_date = '2000-01-03')) + AND cs_sold_date_sk = d_date_sk + GROUP BY i_item_id), + ws_items AS + (SELECT + i_item_id item_id, + sum(ws_ext_sales_price) ws_item_rev + FROM web_sales, item, date_dim + WHERE ws_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq = (SELECT d_week_seq + FROM date_dim + WHERE d_date = '2000-01-03')) + AND ws_sold_date_sk = d_date_sk + GROUP BY i_item_id) +SELECT + ss_items.item_id, + ss_item_rev, + ss_item_rev / (ss_item_rev + cs_item_rev + ws_item_rev) / 3 * 100 ss_dev, + cs_item_rev, + cs_item_rev / (ss_item_rev + cs_item_rev + ws_item_rev) / 3 * 100 cs_dev, + ws_item_rev, + ws_item_rev / (ss_item_rev + cs_item_rev + ws_item_rev) / 3 * 100 ws_dev, + (ss_item_rev + cs_item_rev + ws_item_rev) / 3 average +FROM ss_items, cs_items, ws_items +WHERE ss_items.item_id = cs_items.item_id + AND ss_items.item_id = ws_items.item_id + AND ss_item_rev BETWEEN 0.9 * cs_item_rev AND 1.1 * cs_item_rev + AND ss_item_rev BETWEEN 0.9 * ws_item_rev AND 1.1 * ws_item_rev + AND cs_item_rev BETWEEN 0.9 * ss_item_rev AND 1.1 * ss_item_rev + AND cs_item_rev BETWEEN 0.9 * ws_item_rev AND 1.1 * ws_item_rev + AND ws_item_rev BETWEEN 0.9 * ss_item_rev AND 1.1 * ss_item_rev + AND ws_item_rev BETWEEN 0.9 * cs_item_rev AND 1.1 * cs_item_rev +ORDER BY item_id, ss_item_rev +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q59.sql b/sql/core/src/test/resources/tpcds/q59.sql new file mode 100755 index 0000000000000..3cef2027680b0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q59.sql @@ -0,0 +1,75 @@ +WITH wss AS +(SELECT + d_week_seq, + ss_store_sk, + sum(CASE WHEN (d_day_name = 'Sunday') + THEN ss_sales_price + ELSE NULL END) sun_sales, + sum(CASE WHEN (d_day_name = 'Monday') + THEN ss_sales_price + ELSE NULL END) mon_sales, + sum(CASE WHEN (d_day_name = 'Tuesday') + THEN ss_sales_price + ELSE NULL END) tue_sales, + sum(CASE WHEN (d_day_name = 'Wednesday') + THEN ss_sales_price + ELSE NULL END) wed_sales, + sum(CASE WHEN (d_day_name = 'Thursday') + THEN ss_sales_price + ELSE NULL END) thu_sales, + sum(CASE WHEN (d_day_name = 'Friday') + THEN ss_sales_price + ELSE NULL END) fri_sales, + sum(CASE WHEN (d_day_name = 'Saturday') + THEN ss_sales_price + ELSE NULL END) sat_sales + FROM store_sales, date_dim + WHERE d_date_sk = ss_sold_date_sk + GROUP BY d_week_seq, ss_store_sk +) +SELECT + s_store_name1, + s_store_id1, + d_week_seq1, + sun_sales1 / sun_sales2, + mon_sales1 / mon_sales2, + tue_sales1 / tue_sales2, + wed_sales1 / wed_sales2, + thu_sales1 / thu_sales2, + fri_sales1 / fri_sales2, + sat_sales1 / sat_sales2 +FROM + (SELECT + s_store_name s_store_name1, + wss.d_week_seq d_week_seq1, + s_store_id s_store_id1, + sun_sales sun_sales1, + mon_sales mon_sales1, + tue_sales tue_sales1, + wed_sales wed_sales1, + thu_sales thu_sales1, + fri_sales fri_sales1, + sat_sales sat_sales1 + FROM wss, store, date_dim d + WHERE d.d_week_seq = wss.d_week_seq AND + ss_store_sk = s_store_sk AND + d_month_seq BETWEEN 1212 AND 1212 + 11) y, + (SELECT + s_store_name s_store_name2, + wss.d_week_seq d_week_seq2, + s_store_id s_store_id2, + sun_sales sun_sales2, + mon_sales mon_sales2, + tue_sales tue_sales2, + wed_sales wed_sales2, + thu_sales thu_sales2, + fri_sales fri_sales2, + sat_sales sat_sales2 + FROM wss, store, date_dim d + WHERE d.d_week_seq = wss.d_week_seq AND + ss_store_sk = s_store_sk AND + d_month_seq BETWEEN 1212 + 12 AND 1212 + 23) x +WHERE s_store_id1 = s_store_id2 + AND d_week_seq1 = d_week_seq2 - 52 +ORDER BY s_store_name1, s_store_id1, d_week_seq1 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q6.sql b/sql/core/src/test/resources/tpcds/q6.sql new file mode 100755 index 0000000000000..f0f5cf05aebda --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q6.sql @@ -0,0 +1,21 @@ +SELECT + a.ca_state state, + count(*) cnt +FROM + customer_address a, customer c, store_sales s, date_dim d, item i +WHERE a.ca_address_sk = c.c_current_addr_sk + AND c.c_customer_sk = s.ss_customer_sk + AND s.ss_sold_date_sk = d.d_date_sk + AND s.ss_item_sk = i.i_item_sk + AND d.d_month_seq = + (SELECT DISTINCT (d_month_seq) + FROM date_dim + WHERE d_year = 2000 AND d_moy = 1) + AND i.i_current_price > 1.2 * + (SELECT avg(j.i_current_price) + FROM item j + WHERE j.i_category = i.i_category) +GROUP BY a.ca_state +HAVING count(*) >= 10 +ORDER BY cnt +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q60.sql b/sql/core/src/test/resources/tpcds/q60.sql new file mode 100755 index 0000000000000..41b963f44ba13 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q60.sql @@ -0,0 +1,62 @@ +WITH ss AS ( + SELECT + i_item_id, + sum(ss_ext_sales_price) total_sales + FROM store_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_category IN ('Music')) + AND ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 9 + AND ss_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id), + cs AS ( + SELECT + i_item_id, + sum(cs_ext_sales_price) total_sales + FROM catalog_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_category IN ('Music')) + AND cs_item_sk = i_item_sk + AND cs_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 9 + AND cs_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id), + ws AS ( + SELECT + i_item_id, + sum(ws_ext_sales_price) total_sales + FROM web_sales, date_dim, customer_address, item + WHERE + i_item_id IN (SELECT i_item_id + FROM item + WHERE i_category IN ('Music')) + AND ws_item_sk = i_item_sk + AND ws_sold_date_sk = d_date_sk + AND d_year = 1998 + AND d_moy = 9 + AND ws_bill_addr_sk = ca_address_sk + AND ca_gmt_offset = -5 + GROUP BY i_item_id) +SELECT + i_item_id, + sum(total_sales) total_sales +FROM (SELECT * + FROM ss + UNION ALL + SELECT * + FROM cs + UNION ALL + SELECT * + FROM ws) tmp1 +GROUP BY i_item_id +ORDER BY i_item_id, total_sales +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q61.sql b/sql/core/src/test/resources/tpcds/q61.sql new file mode 100755 index 0000000000000..b0a872b4b80e1 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q61.sql @@ -0,0 +1,33 @@ +SELECT + promotions, + total, + cast(promotions AS DECIMAL(15, 4)) / cast(total AS DECIMAL(15, 4)) * 100 +FROM + (SELECT sum(ss_ext_sales_price) promotions + FROM store_sales, store, promotion, date_dim, customer, customer_address, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_store_sk = s_store_sk + AND ss_promo_sk = p_promo_sk + AND ss_customer_sk = c_customer_sk + AND ca_address_sk = c_current_addr_sk + AND ss_item_sk = i_item_sk + AND ca_gmt_offset = -5 + AND i_category = 'Jewelry' + AND (p_channel_dmail = 'Y' OR p_channel_email = 'Y' OR p_channel_tv = 'Y') + AND s_gmt_offset = -5 + AND d_year = 1998 + AND d_moy = 11) promotional_sales, + (SELECT sum(ss_ext_sales_price) total + FROM store_sales, store, date_dim, customer, customer_address, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_store_sk = s_store_sk + AND ss_customer_sk = c_customer_sk + AND ca_address_sk = c_current_addr_sk + AND ss_item_sk = i_item_sk + AND ca_gmt_offset = -5 + AND i_category = 'Jewelry' + AND s_gmt_offset = -5 + AND d_year = 1998 + AND d_moy = 11) all_sales +ORDER BY promotions, total +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q62.sql b/sql/core/src/test/resources/tpcds/q62.sql new file mode 100755 index 0000000000000..8a414f154bdc8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q62.sql @@ -0,0 +1,35 @@ +SELECT + substr(w_warehouse_name, 1, 20), + sm_type, + web_name, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk <= 30) + THEN 1 + ELSE 0 END) AS `30 days `, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk > 30) AND + (ws_ship_date_sk - ws_sold_date_sk <= 60) + THEN 1 + ELSE 0 END) AS `31 - 60 days `, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk > 60) AND + (ws_ship_date_sk - ws_sold_date_sk <= 90) + THEN 1 + ELSE 0 END) AS `61 - 90 days `, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk > 90) AND + (ws_ship_date_sk - ws_sold_date_sk <= 120) + THEN 1 + ELSE 0 END) AS `91 - 120 days `, + sum(CASE WHEN (ws_ship_date_sk - ws_sold_date_sk > 120) + THEN 1 + ELSE 0 END) AS `>120 days ` +FROM + web_sales, warehouse, ship_mode, web_site, date_dim +WHERE + d_month_seq BETWEEN 1200 AND 1200 + 11 + AND ws_ship_date_sk = d_date_sk + AND ws_warehouse_sk = w_warehouse_sk + AND ws_ship_mode_sk = sm_ship_mode_sk + AND ws_web_site_sk = web_site_sk +GROUP BY + substr(w_warehouse_name, 1, 20), sm_type, web_name +ORDER BY + substr(w_warehouse_name, 1, 20), sm_type, web_name +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q63.sql b/sql/core/src/test/resources/tpcds/q63.sql new file mode 100755 index 0000000000000..ef6867e0a9451 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q63.sql @@ -0,0 +1,31 @@ +SELECT * +FROM (SELECT + i_manager_id, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER (PARTITION BY i_manager_id) avg_monthly_sales +FROM item + , store_sales + , date_dim + , store +WHERE ss_item_sk = i_item_sk + AND ss_sold_date_sk = d_date_sk + AND ss_store_sk = s_store_sk + AND d_month_seq IN (1200, 1200 + 1, 1200 + 2, 1200 + 3, 1200 + 4, 1200 + 5, 1200 + 6, 1200 + 7, + 1200 + 8, 1200 + 9, 1200 + 10, 1200 + 11) + AND ((i_category IN ('Books', 'Children', 'Electronics') + AND i_class IN ('personal', 'portable', 'refernece', 'self-help') + AND i_brand IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', + 'exportiunivamalg #9', 'scholaramalgamalg #9')) + OR (i_category IN ('Women', 'Music', 'Men') + AND i_class IN ('accessories', 'classical', 'fragrances', 'pants') + AND i_brand IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', + 'importoamalg #1'))) +GROUP BY i_manager_id, d_moy) tmp1 +WHERE CASE WHEN avg_monthly_sales > 0 + THEN abs(sum_sales - avg_monthly_sales) / avg_monthly_sales + ELSE NULL END > 0.1 +ORDER BY i_manager_id + , avg_monthly_sales + , sum_sales +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q64.sql b/sql/core/src/test/resources/tpcds/q64.sql new file mode 100755 index 0000000000000..8ec1d31b61afe --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q64.sql @@ -0,0 +1,92 @@ +WITH cs_ui AS +(SELECT + cs_item_sk, + sum(cs_ext_list_price) AS sale, + sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit) AS refund + FROM catalog_sales + , catalog_returns + WHERE cs_item_sk = cr_item_sk + AND cs_order_number = cr_order_number + GROUP BY cs_item_sk + HAVING sum(cs_ext_list_price) > 2 * sum(cr_refunded_cash + cr_reversed_charge + cr_store_credit)), + cross_sales AS + (SELECT + i_product_name product_name, + i_item_sk item_sk, + s_store_name store_name, + s_zip store_zip, + ad1.ca_street_number b_street_number, + ad1.ca_street_name b_streen_name, + ad1.ca_city b_city, + ad1.ca_zip b_zip, + ad2.ca_street_number c_street_number, + ad2.ca_street_name c_street_name, + ad2.ca_city c_city, + ad2.ca_zip c_zip, + d1.d_year AS syear, + d2.d_year AS fsyear, + d3.d_year s2year, + count(*) cnt, + sum(ss_wholesale_cost) s1, + sum(ss_list_price) s2, + sum(ss_coupon_amt) s3 + FROM store_sales, store_returns, cs_ui, date_dim d1, date_dim d2, date_dim d3, + store, customer, customer_demographics cd1, customer_demographics cd2, + promotion, household_demographics hd1, household_demographics hd2, + customer_address ad1, customer_address ad2, income_band ib1, income_band ib2, item + WHERE ss_store_sk = s_store_sk AND + ss_sold_date_sk = d1.d_date_sk AND + ss_customer_sk = c_customer_sk AND + ss_cdemo_sk = cd1.cd_demo_sk AND + ss_hdemo_sk = hd1.hd_demo_sk AND + ss_addr_sk = ad1.ca_address_sk AND + ss_item_sk = i_item_sk AND + ss_item_sk = sr_item_sk AND + ss_ticket_number = sr_ticket_number AND + ss_item_sk = cs_ui.cs_item_sk AND + c_current_cdemo_sk = cd2.cd_demo_sk AND + c_current_hdemo_sk = hd2.hd_demo_sk AND + c_current_addr_sk = ad2.ca_address_sk AND + c_first_sales_date_sk = d2.d_date_sk AND + c_first_shipto_date_sk = d3.d_date_sk AND + ss_promo_sk = p_promo_sk AND + hd1.hd_income_band_sk = ib1.ib_income_band_sk AND + hd2.hd_income_band_sk = ib2.ib_income_band_sk AND + cd1.cd_marital_status <> cd2.cd_marital_status AND + i_color IN ('purple', 'burlywood', 'indian', 'spring', 'floral', 'medium') AND + i_current_price BETWEEN 64 AND 64 + 10 AND + i_current_price BETWEEN 64 + 1 AND 64 + 15 + GROUP BY i_product_name, i_item_sk, s_store_name, s_zip, ad1.ca_street_number, + ad1.ca_street_name, ad1.ca_city, ad1.ca_zip, ad2.ca_street_number, + ad2.ca_street_name, ad2.ca_city, ad2.ca_zip, d1.d_year, d2.d_year, d3.d_year + ) +SELECT + cs1.product_name, + cs1.store_name, + cs1.store_zip, + cs1.b_street_number, + cs1.b_streen_name, + cs1.b_city, + cs1.b_zip, + cs1.c_street_number, + cs1.c_street_name, + cs1.c_city, + cs1.c_zip, + cs1.syear, + cs1.cnt, + cs1.s1, + cs1.s2, + cs1.s3, + cs2.s1, + cs2.s2, + cs2.s3, + cs2.syear, + cs2.cnt +FROM cross_sales cs1, cross_sales cs2 +WHERE cs1.item_sk = cs2.item_sk AND + cs1.syear = 1999 AND + cs2.syear = 1999 + 1 AND + cs2.cnt <= cs1.cnt AND + cs1.store_name = cs2.store_name AND + cs1.store_zip = cs2.store_zip +ORDER BY cs1.product_name, cs1.store_name, cs2.cnt diff --git a/sql/core/src/test/resources/tpcds/q65.sql b/sql/core/src/test/resources/tpcds/q65.sql new file mode 100755 index 0000000000000..aad04be1bcdf0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q65.sql @@ -0,0 +1,33 @@ +SELECT + s_store_name, + i_item_desc, + sc.revenue, + i_current_price, + i_wholesale_cost, + i_brand +FROM store, item, + (SELECT + ss_store_sk, + avg(revenue) AS ave + FROM + (SELECT + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) AS revenue + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk AND d_month_seq BETWEEN 1176 AND 1176 + 11 + GROUP BY ss_store_sk, ss_item_sk) sa + GROUP BY ss_store_sk) sb, + (SELECT + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) AS revenue + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk AND d_month_seq BETWEEN 1176 AND 1176 + 11 + GROUP BY ss_store_sk, ss_item_sk) sc +WHERE sb.ss_store_sk = sc.ss_store_sk AND + sc.revenue <= 0.1 * sb.ave AND + s_store_sk = sc.ss_store_sk AND + i_item_sk = sc.ss_item_sk +ORDER BY s_store_name, i_item_desc +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q66.sql b/sql/core/src/test/resources/tpcds/q66.sql new file mode 100755 index 0000000000000..f826b4164372a --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q66.sql @@ -0,0 +1,240 @@ +SELECT + w_warehouse_name, + w_warehouse_sq_ft, + w_city, + w_county, + w_state, + w_country, + ship_carriers, + year, + sum(jan_sales) AS jan_sales, + sum(feb_sales) AS feb_sales, + sum(mar_sales) AS mar_sales, + sum(apr_sales) AS apr_sales, + sum(may_sales) AS may_sales, + sum(jun_sales) AS jun_sales, + sum(jul_sales) AS jul_sales, + sum(aug_sales) AS aug_sales, + sum(sep_sales) AS sep_sales, + sum(oct_sales) AS oct_sales, + sum(nov_sales) AS nov_sales, + sum(dec_sales) AS dec_sales, + sum(jan_sales / w_warehouse_sq_ft) AS jan_sales_per_sq_foot, + sum(feb_sales / w_warehouse_sq_ft) AS feb_sales_per_sq_foot, + sum(mar_sales / w_warehouse_sq_ft) AS mar_sales_per_sq_foot, + sum(apr_sales / w_warehouse_sq_ft) AS apr_sales_per_sq_foot, + sum(may_sales / w_warehouse_sq_ft) AS may_sales_per_sq_foot, + sum(jun_sales / w_warehouse_sq_ft) AS jun_sales_per_sq_foot, + sum(jul_sales / w_warehouse_sq_ft) AS jul_sales_per_sq_foot, + sum(aug_sales / w_warehouse_sq_ft) AS aug_sales_per_sq_foot, + sum(sep_sales / w_warehouse_sq_ft) AS sep_sales_per_sq_foot, + sum(oct_sales / w_warehouse_sq_ft) AS oct_sales_per_sq_foot, + sum(nov_sales / w_warehouse_sq_ft) AS nov_sales_per_sq_foot, + sum(dec_sales / w_warehouse_sq_ft) AS dec_sales_per_sq_foot, + sum(jan_net) AS jan_net, + sum(feb_net) AS feb_net, + sum(mar_net) AS mar_net, + sum(apr_net) AS apr_net, + sum(may_net) AS may_net, + sum(jun_net) AS jun_net, + sum(jul_net) AS jul_net, + sum(aug_net) AS aug_net, + sum(sep_net) AS sep_net, + sum(oct_net) AS oct_net, + sum(nov_net) AS nov_net, + sum(dec_net) AS dec_net +FROM ( + (SELECT + w_warehouse_name, + w_warehouse_sq_ft, + w_city, + w_county, + w_state, + w_country, + concat('DHL', ',', 'BARIAN') AS ship_carriers, + d_year AS year, + sum(CASE WHEN d_moy = 1 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS jan_sales, + sum(CASE WHEN d_moy = 2 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS feb_sales, + sum(CASE WHEN d_moy = 3 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS mar_sales, + sum(CASE WHEN d_moy = 4 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS apr_sales, + sum(CASE WHEN d_moy = 5 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS may_sales, + sum(CASE WHEN d_moy = 6 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS jun_sales, + sum(CASE WHEN d_moy = 7 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS jul_sales, + sum(CASE WHEN d_moy = 8 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS aug_sales, + sum(CASE WHEN d_moy = 9 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS sep_sales, + sum(CASE WHEN d_moy = 10 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS oct_sales, + sum(CASE WHEN d_moy = 11 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS nov_sales, + sum(CASE WHEN d_moy = 12 + THEN ws_ext_sales_price * ws_quantity + ELSE 0 END) AS dec_sales, + sum(CASE WHEN d_moy = 1 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS jan_net, + sum(CASE WHEN d_moy = 2 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS feb_net, + sum(CASE WHEN d_moy = 3 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS mar_net, + sum(CASE WHEN d_moy = 4 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS apr_net, + sum(CASE WHEN d_moy = 5 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS may_net, + sum(CASE WHEN d_moy = 6 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS jun_net, + sum(CASE WHEN d_moy = 7 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS jul_net, + sum(CASE WHEN d_moy = 8 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS aug_net, + sum(CASE WHEN d_moy = 9 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS sep_net, + sum(CASE WHEN d_moy = 10 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS oct_net, + sum(CASE WHEN d_moy = 11 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS nov_net, + sum(CASE WHEN d_moy = 12 + THEN ws_net_paid * ws_quantity + ELSE 0 END) AS dec_net + FROM + web_sales, warehouse, date_dim, time_dim, ship_mode + WHERE + ws_warehouse_sk = w_warehouse_sk + AND ws_sold_date_sk = d_date_sk + AND ws_sold_time_sk = t_time_sk + AND ws_ship_mode_sk = sm_ship_mode_sk + AND d_year = 2001 + AND t_time BETWEEN 30838 AND 30838 + 28800 + AND sm_carrier IN ('DHL', 'BARIAN') + GROUP BY + w_warehouse_name, w_warehouse_sq_ft, w_city, w_county, w_state, w_country, d_year) + UNION ALL + (SELECT + w_warehouse_name, + w_warehouse_sq_ft, + w_city, + w_county, + w_state, + w_country, + concat('DHL', ',', 'BARIAN') AS ship_carriers, + d_year AS year, + sum(CASE WHEN d_moy = 1 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS jan_sales, + sum(CASE WHEN d_moy = 2 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS feb_sales, + sum(CASE WHEN d_moy = 3 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS mar_sales, + sum(CASE WHEN d_moy = 4 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS apr_sales, + sum(CASE WHEN d_moy = 5 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS may_sales, + sum(CASE WHEN d_moy = 6 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS jun_sales, + sum(CASE WHEN d_moy = 7 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS jul_sales, + sum(CASE WHEN d_moy = 8 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS aug_sales, + sum(CASE WHEN d_moy = 9 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS sep_sales, + sum(CASE WHEN d_moy = 10 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS oct_sales, + sum(CASE WHEN d_moy = 11 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS nov_sales, + sum(CASE WHEN d_moy = 12 + THEN cs_sales_price * cs_quantity + ELSE 0 END) AS dec_sales, + sum(CASE WHEN d_moy = 1 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS jan_net, + sum(CASE WHEN d_moy = 2 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS feb_net, + sum(CASE WHEN d_moy = 3 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS mar_net, + sum(CASE WHEN d_moy = 4 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS apr_net, + sum(CASE WHEN d_moy = 5 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS may_net, + sum(CASE WHEN d_moy = 6 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS jun_net, + sum(CASE WHEN d_moy = 7 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS jul_net, + sum(CASE WHEN d_moy = 8 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS aug_net, + sum(CASE WHEN d_moy = 9 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS sep_net, + sum(CASE WHEN d_moy = 10 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS oct_net, + sum(CASE WHEN d_moy = 11 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS nov_net, + sum(CASE WHEN d_moy = 12 + THEN cs_net_paid_inc_tax * cs_quantity + ELSE 0 END) AS dec_net + FROM + catalog_sales, warehouse, date_dim, time_dim, ship_mode + WHERE + cs_warehouse_sk = w_warehouse_sk + AND cs_sold_date_sk = d_date_sk + AND cs_sold_time_sk = t_time_sk + AND cs_ship_mode_sk = sm_ship_mode_sk + AND d_year = 2001 + AND t_time BETWEEN 30838 AND 30838 + 28800 + AND sm_carrier IN ('DHL', 'BARIAN') + GROUP BY + w_warehouse_name, w_warehouse_sq_ft, w_city, w_county, w_state, w_country, d_year + ) + ) x +GROUP BY + w_warehouse_name, w_warehouse_sq_ft, w_city, w_county, w_state, w_country, + ship_carriers, year +ORDER BY w_warehouse_name +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q67.sql b/sql/core/src/test/resources/tpcds/q67.sql new file mode 100755 index 0000000000000..f66e2252bdbd4 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q67.sql @@ -0,0 +1,38 @@ +SELECT * +FROM + (SELECT + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sumsales, + rank() + OVER (PARTITION BY i_category + ORDER BY sumsales DESC) rk + FROM + (SELECT + i_category, + i_class, + i_brand, + i_product_name, + d_year, + d_qoy, + d_moy, + s_store_id, + sum(coalesce(ss_sales_price * ss_quantity, 0)) sumsales + FROM store_sales, date_dim, store, item + WHERE ss_sold_date_sk = d_date_sk + AND ss_item_sk = i_item_sk + AND ss_store_sk = s_store_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + GROUP BY ROLLUP (i_category, i_class, i_brand, i_product_name, d_year, d_qoy, + d_moy, s_store_id)) dw1) dw2 +WHERE rk <= 100 +ORDER BY + i_category, i_class, i_brand, i_product_name, d_year, + d_qoy, d_moy, s_store_id, sumsales, rk +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q68.sql b/sql/core/src/test/resources/tpcds/q68.sql new file mode 100755 index 0000000000000..adb8a7189dad7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q68.sql @@ -0,0 +1,34 @@ +SELECT + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + extended_price, + extended_tax, + list_price +FROM (SELECT + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_ext_sales_price) extended_price, + sum(ss_ext_list_price) list_price, + sum(ss_ext_tax) extended_tax +FROM store_sales, date_dim, store, household_demographics, customer_address +WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND store_sales.ss_addr_sk = customer_address.ca_address_sk + AND date_dim.d_dom BETWEEN 1 AND 2 + AND (household_demographics.hd_dep_count = 4 OR + household_demographics.hd_vehicle_count = 3) + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_city IN ('Midway', 'Fairview') +GROUP BY ss_ticket_number, ss_customer_sk, ss_addr_sk, ca_city) dn, + customer, + customer_address current_addr +WHERE ss_customer_sk = c_customer_sk + AND customer.c_current_addr_sk = current_addr.ca_address_sk + AND current_addr.ca_city <> bought_city +ORDER BY c_last_name, ss_ticket_number +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q69.sql b/sql/core/src/test/resources/tpcds/q69.sql new file mode 100755 index 0000000000000..1f0ee64f565a3 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q69.sql @@ -0,0 +1,38 @@ +SELECT + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3 +FROM + customer c, customer_address ca, customer_demographics +WHERE + c.c_current_addr_sk = ca.ca_address_sk AND + ca_state IN ('KY', 'GA', 'NM') AND + cd_demo_sk = c.c_current_cdemo_sk AND + exists(SELECT * + FROM store_sales, date_dim + WHERE c.c_customer_sk = ss_customer_sk AND + ss_sold_date_sk = d_date_sk AND + d_year = 2001 AND + d_moy BETWEEN 4 AND 4 + 2) AND + (NOT exists(SELECT * + FROM web_sales, date_dim + WHERE c.c_customer_sk = ws_bill_customer_sk AND + ws_sold_date_sk = d_date_sk AND + d_year = 2001 AND + d_moy BETWEEN 4 AND 4 + 2) AND + NOT exists(SELECT * + FROM catalog_sales, date_dim + WHERE c.c_customer_sk = cs_ship_customer_sk AND + cs_sold_date_sk = d_date_sk AND + d_year = 2001 AND + d_moy BETWEEN 4 AND 4 + 2)) +GROUP BY cd_gender, cd_marital_status, cd_education_status, + cd_purchase_estimate, cd_credit_rating +ORDER BY cd_gender, cd_marital_status, cd_education_status, + cd_purchase_estimate, cd_credit_rating +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q7.sql b/sql/core/src/test/resources/tpcds/q7.sql new file mode 100755 index 0000000000000..6630a00548403 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q7.sql @@ -0,0 +1,19 @@ +SELECT + i_item_id, + avg(ss_quantity) agg1, + avg(ss_list_price) agg2, + avg(ss_coupon_amt) agg3, + avg(ss_sales_price) agg4 +FROM store_sales, customer_demographics, date_dim, item, promotion +WHERE ss_sold_date_sk = d_date_sk AND + ss_item_sk = i_item_sk AND + ss_cdemo_sk = cd_demo_sk AND + ss_promo_sk = p_promo_sk AND + cd_gender = 'M' AND + cd_marital_status = 'S' AND + cd_education_status = 'College' AND + (p_channel_email = 'N' OR p_channel_event = 'N') AND + d_year = 2000 +GROUP BY i_item_id +ORDER BY i_item_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q70.sql b/sql/core/src/test/resources/tpcds/q70.sql new file mode 100755 index 0000000000000..625011b212fe0 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q70.sql @@ -0,0 +1,38 @@ +SELECT + sum(ss_net_profit) AS total_sum, + s_state, + s_county, + grouping(s_state) + grouping(s_county) AS lochierarchy, + rank() + OVER ( + PARTITION BY grouping(s_state) + grouping(s_county), + CASE WHEN grouping(s_county) = 0 + THEN s_state END + ORDER BY sum(ss_net_profit) DESC) AS rank_within_parent +FROM + store_sales, date_dim d1, store +WHERE + d1.d_month_seq BETWEEN 1200 AND 1200 + 11 + AND d1.d_date_sk = ss_sold_date_sk + AND s_store_sk = ss_store_sk + AND s_state IN + (SELECT s_state + FROM + (SELECT + s_state AS s_state, + rank() + OVER (PARTITION BY s_state + ORDER BY sum(ss_net_profit) DESC) AS ranking + FROM store_sales, store, date_dim + WHERE d_month_seq BETWEEN 1200 AND 1200 + 11 + AND d_date_sk = ss_sold_date_sk + AND s_store_sk = ss_store_sk + GROUP BY s_state) tmp1 + WHERE ranking <= 5) +GROUP BY ROLLUP (s_state, s_county) +ORDER BY + lochierarchy DESC + , CASE WHEN lochierarchy = 0 + THEN s_state END + , rank_within_parent +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q71.sql b/sql/core/src/test/resources/tpcds/q71.sql new file mode 100755 index 0000000000000..8d724b9244e11 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q71.sql @@ -0,0 +1,44 @@ +SELECT + i_brand_id brand_id, + i_brand brand, + t_hour, + t_minute, + sum(ext_price) ext_price +FROM item, + (SELECT + ws_ext_sales_price AS ext_price, + ws_sold_date_sk AS sold_date_sk, + ws_item_sk AS sold_item_sk, + ws_sold_time_sk AS time_sk + FROM web_sales, date_dim + WHERE d_date_sk = ws_sold_date_sk + AND d_moy = 11 + AND d_year = 1999 + UNION ALL + SELECT + cs_ext_sales_price AS ext_price, + cs_sold_date_sk AS sold_date_sk, + cs_item_sk AS sold_item_sk, + cs_sold_time_sk AS time_sk + FROM catalog_sales, date_dim + WHERE d_date_sk = cs_sold_date_sk + AND d_moy = 11 + AND d_year = 1999 + UNION ALL + SELECT + ss_ext_sales_price AS ext_price, + ss_sold_date_sk AS sold_date_sk, + ss_item_sk AS sold_item_sk, + ss_sold_time_sk AS time_sk + FROM store_sales, date_dim + WHERE d_date_sk = ss_sold_date_sk + AND d_moy = 11 + AND d_year = 1999 + ) AS tmp, time_dim +WHERE + sold_item_sk = i_item_sk + AND i_manager_id = 1 + AND time_sk = t_time_sk + AND (t_meal_time = 'breakfast' OR t_meal_time = 'dinner') +GROUP BY i_brand, i_brand_id, t_hour, t_minute +ORDER BY ext_price DESC, brand_id diff --git a/sql/core/src/test/resources/tpcds/q72.sql b/sql/core/src/test/resources/tpcds/q72.sql new file mode 100755 index 0000000000000..99b3eee54aa1a --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q72.sql @@ -0,0 +1,33 @@ +SELECT + i_item_desc, + w_warehouse_name, + d1.d_week_seq, + count(CASE WHEN p_promo_sk IS NULL + THEN 1 + ELSE 0 END) no_promo, + count(CASE WHEN p_promo_sk IS NOT NULL + THEN 1 + ELSE 0 END) promo, + count(*) total_cnt +FROM catalog_sales + JOIN inventory ON (cs_item_sk = inv_item_sk) + JOIN warehouse ON (w_warehouse_sk = inv_warehouse_sk) + JOIN item ON (i_item_sk = cs_item_sk) + JOIN customer_demographics ON (cs_bill_cdemo_sk = cd_demo_sk) + JOIN household_demographics ON (cs_bill_hdemo_sk = hd_demo_sk) + JOIN date_dim d1 ON (cs_sold_date_sk = d1.d_date_sk) + JOIN date_dim d2 ON (inv_date_sk = d2.d_date_sk) + JOIN date_dim d3 ON (cs_ship_date_sk = d3.d_date_sk) + LEFT OUTER JOIN promotion ON (cs_promo_sk = p_promo_sk) + LEFT OUTER JOIN catalog_returns ON (cr_item_sk = cs_item_sk AND cr_order_number = cs_order_number) +WHERE d1.d_week_seq = d2.d_week_seq + AND inv_quantity_on_hand < cs_quantity + AND d3.d_date > (cast(d1.d_date AS DATE) + interval 5 days) + AND hd_buy_potential = '>10000' + AND d1.d_year = 1999 + AND hd_buy_potential = '>10000' + AND cd_marital_status = 'D' + AND d1.d_year = 1999 +GROUP BY i_item_desc, w_warehouse_name, d1.d_week_seq +ORDER BY total_cnt DESC, i_item_desc, w_warehouse_name, d_week_seq +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q73.sql b/sql/core/src/test/resources/tpcds/q73.sql new file mode 100755 index 0000000000000..881be2e9024d7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q73.sql @@ -0,0 +1,30 @@ +SELECT + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + count(*) cnt + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND date_dim.d_dom BETWEEN 1 AND 2 + AND (household_demographics.hd_buy_potential = '>10000' OR + household_demographics.hd_buy_potential = 'unknown') + AND household_demographics.hd_vehicle_count > 0 + AND CASE WHEN household_demographics.hd_vehicle_count > 0 + THEN + household_demographics.hd_dep_count / household_demographics.hd_vehicle_count + ELSE NULL END > 1 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_county IN ('Williamson County', 'Franklin Parish', 'Bronx County', 'Orange County') + GROUP BY ss_ticket_number, ss_customer_sk) dj, customer +WHERE ss_customer_sk = c_customer_sk + AND cnt BETWEEN 1 AND 5 +ORDER BY cnt DESC diff --git a/sql/core/src/test/resources/tpcds/q74.sql b/sql/core/src/test/resources/tpcds/q74.sql new file mode 100755 index 0000000000000..154b26d6802a3 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q74.sql @@ -0,0 +1,58 @@ +WITH year_total AS ( + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ss_net_paid) year_total, + 's' sale_type + FROM + customer, store_sales, date_dim + WHERE c_customer_sk = ss_customer_sk + AND ss_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year + UNION ALL + SELECT + c_customer_id customer_id, + c_first_name customer_first_name, + c_last_name customer_last_name, + d_year AS year, + sum(ws_net_paid) year_total, + 'w' sale_type + FROM + customer, web_sales, date_dim + WHERE c_customer_sk = ws_bill_customer_sk + AND ws_sold_date_sk = d_date_sk + AND d_year IN (2001, 2001 + 1) + GROUP BY + c_customer_id, c_first_name, c_last_name, d_year) +SELECT + t_s_secyear.customer_id, + t_s_secyear.customer_first_name, + t_s_secyear.customer_last_name +FROM + year_total t_s_firstyear, year_total t_s_secyear, + year_total t_w_firstyear, year_total t_w_secyear +WHERE t_s_secyear.customer_id = t_s_firstyear.customer_id + AND t_s_firstyear.customer_id = t_w_secyear.customer_id + AND t_s_firstyear.customer_id = t_w_firstyear.customer_id + AND t_s_firstyear.sale_type = 's' + AND t_w_firstyear.sale_type = 'w' + AND t_s_secyear.sale_type = 's' + AND t_w_secyear.sale_type = 'w' + AND t_s_firstyear.year = 2001 + AND t_s_secyear.year = 2001 + 1 + AND t_w_firstyear.year = 2001 + AND t_w_secyear.year = 2001 + 1 + AND t_s_firstyear.year_total > 0 + AND t_w_firstyear.year_total > 0 + AND CASE WHEN t_w_firstyear.year_total > 0 + THEN t_w_secyear.year_total / t_w_firstyear.year_total + ELSE NULL END + > CASE WHEN t_s_firstyear.year_total > 0 + THEN t_s_secyear.year_total / t_s_firstyear.year_total + ELSE NULL END +ORDER BY 1, 1, 1 +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q75.sql b/sql/core/src/test/resources/tpcds/q75.sql new file mode 100755 index 0000000000000..2a143232b5196 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q75.sql @@ -0,0 +1,76 @@ +WITH all_sales AS ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + SUM(sales_cnt) AS sales_cnt, + SUM(sales_amt) AS sales_amt + FROM ( + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + cs_quantity - COALESCE(cr_return_quantity, 0) AS sales_cnt, + cs_ext_sales_price - COALESCE(cr_return_amount, 0.0) AS sales_amt + FROM catalog_sales + JOIN item ON i_item_sk = cs_item_sk + JOIN date_dim ON d_date_sk = cs_sold_date_sk + LEFT JOIN catalog_returns ON (cs_order_number = cr_order_number + AND cs_item_sk = cr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ss_quantity - COALESCE(sr_return_quantity, 0) AS sales_cnt, + ss_ext_sales_price - COALESCE(sr_return_amt, 0.0) AS sales_amt + FROM store_sales + JOIN item ON i_item_sk = ss_item_sk + JOIN date_dim ON d_date_sk = ss_sold_date_sk + LEFT JOIN store_returns ON (ss_ticket_number = sr_ticket_number + AND ss_item_sk = sr_item_sk) + WHERE i_category = 'Books' + UNION + SELECT + d_year, + i_brand_id, + i_class_id, + i_category_id, + i_manufact_id, + ws_quantity - COALESCE(wr_return_quantity, 0) AS sales_cnt, + ws_ext_sales_price - COALESCE(wr_return_amt, 0.0) AS sales_amt + FROM web_sales + JOIN item ON i_item_sk = ws_item_sk + JOIN date_dim ON d_date_sk = ws_sold_date_sk + LEFT JOIN web_returns ON (ws_order_number = wr_order_number + AND ws_item_sk = wr_item_sk) + WHERE i_category = 'Books') sales_detail + GROUP BY d_year, i_brand_id, i_class_id, i_category_id, i_manufact_id) +SELECT + prev_yr.d_year AS prev_year, + curr_yr.d_year AS year, + curr_yr.i_brand_id, + curr_yr.i_class_id, + curr_yr.i_category_id, + curr_yr.i_manufact_id, + prev_yr.sales_cnt AS prev_yr_cnt, + curr_yr.sales_cnt AS curr_yr_cnt, + curr_yr.sales_cnt - prev_yr.sales_cnt AS sales_cnt_diff, + curr_yr.sales_amt - prev_yr.sales_amt AS sales_amt_diff +FROM all_sales curr_yr, all_sales prev_yr +WHERE curr_yr.i_brand_id = prev_yr.i_brand_id + AND curr_yr.i_class_id = prev_yr.i_class_id + AND curr_yr.i_category_id = prev_yr.i_category_id + AND curr_yr.i_manufact_id = prev_yr.i_manufact_id + AND curr_yr.d_year = 2002 + AND prev_yr.d_year = 2002 - 1 + AND CAST(curr_yr.sales_cnt AS DECIMAL(17, 2)) / CAST(prev_yr.sales_cnt AS DECIMAL(17, 2)) < 0.9 +ORDER BY sales_cnt_diff +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q76.sql b/sql/core/src/test/resources/tpcds/q76.sql new file mode 100755 index 0000000000000..815fa922be19d --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q76.sql @@ -0,0 +1,47 @@ +SELECT + channel, + col_name, + d_year, + d_qoy, + i_category, + COUNT(*) sales_cnt, + SUM(ext_sales_price) sales_amt +FROM ( + SELECT + 'store' AS channel, + ss_store_sk col_name, + d_year, + d_qoy, + i_category, + ss_ext_sales_price ext_sales_price + FROM store_sales, item, date_dim + WHERE ss_store_sk IS NULL + AND ss_sold_date_sk = d_date_sk + AND ss_item_sk = i_item_sk + UNION ALL + SELECT + 'web' AS channel, + ws_ship_customer_sk col_name, + d_year, + d_qoy, + i_category, + ws_ext_sales_price ext_sales_price + FROM web_sales, item, date_dim + WHERE ws_ship_customer_sk IS NULL + AND ws_sold_date_sk = d_date_sk + AND ws_item_sk = i_item_sk + UNION ALL + SELECT + 'catalog' AS channel, + cs_ship_addr_sk col_name, + d_year, + d_qoy, + i_category, + cs_ext_sales_price ext_sales_price + FROM catalog_sales, item, date_dim + WHERE cs_ship_addr_sk IS NULL + AND cs_sold_date_sk = d_date_sk + AND cs_item_sk = i_item_sk) foo +GROUP BY channel, col_name, d_year, d_qoy, i_category +ORDER BY channel, col_name, d_year, d_qoy, i_category +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q77.sql b/sql/core/src/test/resources/tpcds/q77.sql new file mode 100755 index 0000000000000..7830f96e76515 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q77.sql @@ -0,0 +1,100 @@ +WITH ss AS +(SELECT + s_store_sk, + sum(ss_ext_sales_price) AS sales, + sum(ss_net_profit) AS profit + FROM store_sales, date_dim, store + WHERE ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + AND ss_store_sk = s_store_sk + GROUP BY s_store_sk), + sr AS + (SELECT + s_store_sk, + sum(sr_return_amt) AS returns, + sum(sr_net_loss) AS profit_loss + FROM store_returns, date_dim, store + WHERE sr_returned_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + AND sr_store_sk = s_store_sk + GROUP BY s_store_sk), + cs AS + (SELECT + cs_call_center_sk, + sum(cs_ext_sales_price) AS sales, + sum(cs_net_profit) AS profit + FROM catalog_sales, date_dim + WHERE cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + GROUP BY cs_call_center_sk), + cr AS + (SELECT + sum(cr_return_amount) AS returns, + sum(cr_net_loss) AS profit_loss + FROM catalog_returns, date_dim + WHERE cr_returned_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03]' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days)), + ws AS + (SELECT + wp_web_page_sk, + sum(ws_ext_sales_price) AS sales, + sum(ws_net_profit) AS profit + FROM web_sales, date_dim, web_page + WHERE ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + AND ws_web_page_sk = wp_web_page_sk + GROUP BY wp_web_page_sk), + wr AS + (SELECT + wp_web_page_sk, + sum(wr_return_amt) AS returns, + sum(wr_net_loss) AS profit_loss + FROM web_returns, date_dim, web_page + WHERE wr_returned_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-03' AS DATE) AND + (cast('2000-08-03' AS DATE) + INTERVAL 30 days) + AND wr_web_page_sk = wp_web_page_sk + GROUP BY wp_web_page_sk) +SELECT + channel, + id, + sum(sales) AS sales, + sum(returns) AS returns, + sum(profit) AS profit +FROM + (SELECT + 'store channel' AS channel, + ss.s_store_sk AS id, + sales, + coalesce(returns, 0) AS returns, + (profit - coalesce(profit_loss, 0)) AS profit + FROM ss + LEFT JOIN sr + ON ss.s_store_sk = sr.s_store_sk + UNION ALL + SELECT + 'catalog channel' AS channel, + cs_call_center_sk AS id, + sales, + returns, + (profit - profit_loss) AS profit + FROM cs, cr + UNION ALL + SELECT + 'web channel' AS channel, + ws.wp_web_page_sk AS id, + sales, + coalesce(returns, 0) returns, + (profit - coalesce(profit_loss, 0)) AS profit + FROM ws + LEFT JOIN wr + ON ws.wp_web_page_sk = wr.wp_web_page_sk + ) x +GROUP BY ROLLUP (channel, id) +ORDER BY channel, id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q78.sql b/sql/core/src/test/resources/tpcds/q78.sql new file mode 100755 index 0000000000000..07b0940e26882 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q78.sql @@ -0,0 +1,64 @@ +WITH ws AS +(SELECT + d_year AS ws_sold_year, + ws_item_sk, + ws_bill_customer_sk ws_customer_sk, + sum(ws_quantity) ws_qty, + sum(ws_wholesale_cost) ws_wc, + sum(ws_sales_price) ws_sp + FROM web_sales + LEFT JOIN web_returns ON wr_order_number = ws_order_number AND ws_item_sk = wr_item_sk + JOIN date_dim ON ws_sold_date_sk = d_date_sk + WHERE wr_order_number IS NULL + GROUP BY d_year, ws_item_sk, ws_bill_customer_sk +), + cs AS + (SELECT + d_year AS cs_sold_year, + cs_item_sk, + cs_bill_customer_sk cs_customer_sk, + sum(cs_quantity) cs_qty, + sum(cs_wholesale_cost) cs_wc, + sum(cs_sales_price) cs_sp + FROM catalog_sales + LEFT JOIN catalog_returns ON cr_order_number = cs_order_number AND cs_item_sk = cr_item_sk + JOIN date_dim ON cs_sold_date_sk = d_date_sk + WHERE cr_order_number IS NULL + GROUP BY d_year, cs_item_sk, cs_bill_customer_sk + ), + ss AS + (SELECT + d_year AS ss_sold_year, + ss_item_sk, + ss_customer_sk, + sum(ss_quantity) ss_qty, + sum(ss_wholesale_cost) ss_wc, + sum(ss_sales_price) ss_sp + FROM store_sales + LEFT JOIN store_returns ON sr_ticket_number = ss_ticket_number AND ss_item_sk = sr_item_sk + JOIN date_dim ON ss_sold_date_sk = d_date_sk + WHERE sr_ticket_number IS NULL + GROUP BY d_year, ss_item_sk, ss_customer_sk + ) +SELECT + round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) ratio, + ss_qty store_qty, + ss_wc store_wholesale_cost, + ss_sp store_sales_price, + coalesce(ws_qty, 0) + coalesce(cs_qty, 0) other_chan_qty, + coalesce(ws_wc, 0) + coalesce(cs_wc, 0) other_chan_wholesale_cost, + coalesce(ws_sp, 0) + coalesce(cs_sp, 0) other_chan_sales_price +FROM ss + LEFT JOIN ws + ON (ws_sold_year = ss_sold_year AND ws_item_sk = ss_item_sk AND ws_customer_sk = ss_customer_sk) + LEFT JOIN cs + ON (cs_sold_year = ss_sold_year AND cs_item_sk = ss_item_sk AND cs_customer_sk = ss_customer_sk) +WHERE coalesce(ws_qty, 0) > 0 AND coalesce(cs_qty, 0) > 0 AND ss_sold_year = 2000 +ORDER BY + ratio, + ss_qty DESC, ss_wc DESC, ss_sp DESC, + other_chan_qty, + other_chan_wholesale_cost, + other_chan_sales_price, + round(ss_qty / (coalesce(ws_qty + cs_qty, 1)), 2) +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q79.sql b/sql/core/src/test/resources/tpcds/q79.sql new file mode 100755 index 0000000000000..08f86dc2032aa --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q79.sql @@ -0,0 +1,27 @@ +SELECT + c_last_name, + c_first_name, + substr(s_city, 1, 30), + ss_ticket_number, + amt, + profit +FROM + (SELECT + ss_ticket_number, + ss_customer_sk, + store.s_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + FROM store_sales, date_dim, store, household_demographics + WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_store_sk = store.s_store_sk + AND store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + AND (household_demographics.hd_dep_count = 6 OR + household_demographics.hd_vehicle_count > 2) + AND date_dim.d_dow = 1 + AND date_dim.d_year IN (1999, 1999 + 1, 1999 + 2) + AND store.s_number_employees BETWEEN 200 AND 295 + GROUP BY ss_ticket_number, ss_customer_sk, ss_addr_sk, store.s_city) ms, customer +WHERE ss_customer_sk = c_customer_sk +ORDER BY c_last_name, c_first_name, substr(s_city, 1, 30), profit +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q8.sql b/sql/core/src/test/resources/tpcds/q8.sql new file mode 100755 index 0000000000000..497725111f4f3 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q8.sql @@ -0,0 +1,87 @@ +SELECT + s_store_name, + sum(ss_net_profit) +FROM store_sales, date_dim, store, + (SELECT ca_zip + FROM ( + (SELECT substr(ca_zip, 1, 5) ca_zip + FROM customer_address + WHERE substr(ca_zip, 1, 5) IN ( + '24128','76232','65084','87816','83926','77556','20548', + '26231','43848','15126','91137','61265','98294','25782', + '17920','18426','98235','40081','84093','28577','55565', + '17183','54601','67897','22752','86284','18376','38607', + '45200','21756','29741','96765','23932','89360','29839', + '25989','28898','91068','72550','10390','18845','47770', + '82636','41367','76638','86198','81312','37126','39192', + '88424','72175','81426','53672','10445','42666','66864', + '66708','41248','48583','82276','18842','78890','49448', + '14089','38122','34425','79077','19849','43285','39861', + '66162','77610','13695','99543','83444','83041','12305', + '57665','68341','25003','57834','62878','49130','81096', + '18840','27700','23470','50412','21195','16021','76107', + '71954','68309','18119','98359','64544','10336','86379', + '27068','39736','98569','28915','24206','56529','57647', + '54917','42961','91110','63981','14922','36420','23006', + '67467','32754','30903','20260','31671','51798','72325', + '85816','68621','13955','36446','41766','68806','16725', + '15146','22744','35850','88086','51649','18270','52867', + '39972','96976','63792','11376','94898','13595','10516', + '90225','58943','39371','94945','28587','96576','57855', + '28488','26105','83933','25858','34322','44438','73171', + '30122','34102','22685','71256','78451','54364','13354', + '45375','40558','56458','28286','45266','47305','69399', + '83921','26233','11101','15371','69913','35942','15882', + '25631','24610','44165','99076','33786','70738','26653', + '14328','72305','62496','22152','10144','64147','48425', + '14663','21076','18799','30450','63089','81019','68893', + '24996','51200','51211','45692','92712','70466','79994', + '22437','25280','38935','71791','73134','56571','14060', + '19505','72425','56575','74351','68786','51650','20004', + '18383','76614','11634','18906','15765','41368','73241', + '76698','78567','97189','28545','76231','75691','22246', + '51061','90578','56691','68014','51103','94167','57047', + '14867','73520','15734','63435','25733','35474','24676', + '94627','53535','17879','15559','53268','59166','11928', + '59402','33282','45721','43933','68101','33515','36634', + '71286','19736','58058','55253','67473','41918','19515', + '36495','19430','22351','77191','91393','49156','50298', + '87501','18652','53179','18767','63193','23968','65164', + '68880','21286','72823','58470','67301','13394','31016', + '70372','67030','40604','24317','45748','39127','26065', + '77721','31029','31880','60576','24671','45549','13376', + '50016','33123','19769','22927','97789','46081','72151', + '15723','46136','51949','68100','96888','64528','14171', + '79777','28709','11489','25103','32213','78668','22245', + '15798','27156','37930','62971','21337','51622','67853', + '10567','38415','15455','58263','42029','60279','37125', + '56240','88190','50308','26859','64457','89091','82136', + '62377','36233','63837','58078','17043','30010','60099', + '28810','98025','29178','87343','73273','30469','64034', + '39516','86057','21309','90257','67875','40162','11356', + '73650','61810','72013','30431','22461','19512','13375', + '55307','30625','83849','68908','26689','96451','38193', + '46820','88885','84935','69035','83144','47537','56616', + '94983','48033','69952','25486','61547','27385','61860', + '58048','56910','16807','17871','35258','31387','35458', + '35576')) + INTERSECT + (SELECT ca_zip + FROM + (SELECT + substr(ca_zip, 1, 5) ca_zip, + count(*) cnt + FROM customer_address, customer + WHERE ca_address_sk = c_current_addr_sk AND + c_preferred_cust_flag = 'Y' + GROUP BY ca_zip + HAVING count(*) > 10) A1) + ) A2 + ) V1 +WHERE ss_store_sk = s_store_sk + AND ss_sold_date_sk = d_date_sk + AND d_qoy = 2 AND d_year = 1998 + AND (substr(s_zip, 1, 2) = substr(V1.ca_zip, 1, 2)) +GROUP BY s_store_name +ORDER BY s_store_name +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q80.sql b/sql/core/src/test/resources/tpcds/q80.sql new file mode 100755 index 0000000000000..433db87d2a858 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q80.sql @@ -0,0 +1,94 @@ +WITH ssr AS +(SELECT + s_store_id AS store_id, + sum(ss_ext_sales_price) AS sales, + sum(coalesce(sr_return_amt, 0)) AS returns, + sum(ss_net_profit - coalesce(sr_net_loss, 0)) AS profit + FROM store_sales + LEFT OUTER JOIN store_returns ON + (ss_item_sk = sr_item_sk AND + ss_ticket_number = sr_ticket_number) + , + date_dim, store, item, promotion + WHERE ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND (cast('2000-08-23' AS DATE) + INTERVAL 30 days) + AND ss_store_sk = s_store_sk + AND ss_item_sk = i_item_sk + AND i_current_price > 50 + AND ss_promo_sk = p_promo_sk + AND p_channel_tv = 'N' + GROUP BY s_store_id), + csr AS + (SELECT + cp_catalog_page_id AS catalog_page_id, + sum(cs_ext_sales_price) AS sales, + sum(coalesce(cr_return_amount, 0)) AS returns, + sum(cs_net_profit - coalesce(cr_net_loss, 0)) AS profit + FROM catalog_sales + LEFT OUTER JOIN catalog_returns ON + (cs_item_sk = cr_item_sk AND + cs_order_number = cr_order_number) + , + date_dim, catalog_page, item, promotion + WHERE cs_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND (cast('2000-08-23' AS DATE) + INTERVAL 30 days) + AND cs_catalog_page_sk = cp_catalog_page_sk + AND cs_item_sk = i_item_sk + AND i_current_price > 50 + AND cs_promo_sk = p_promo_sk + AND p_channel_tv = 'N' + GROUP BY cp_catalog_page_id), + wsr AS + (SELECT + web_site_id, + sum(ws_ext_sales_price) AS sales, + sum(coalesce(wr_return_amt, 0)) AS returns, + sum(ws_net_profit - coalesce(wr_net_loss, 0)) AS profit + FROM web_sales + LEFT OUTER JOIN web_returns ON + (ws_item_sk = wr_item_sk AND ws_order_number = wr_order_number) + , + date_dim, web_site, item, promotion + WHERE ws_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('2000-08-23' AS DATE) + AND (cast('2000-08-23' AS DATE) + INTERVAL 30 days) + AND ws_web_site_sk = web_site_sk + AND ws_item_sk = i_item_sk + AND i_current_price > 50 + AND ws_promo_sk = p_promo_sk + AND p_channel_tv = 'N' + GROUP BY web_site_id) +SELECT + channel, + id, + sum(sales) AS sales, + sum(returns) AS returns, + sum(profit) AS profit +FROM (SELECT + 'store channel' AS channel, + concat('store', store_id) AS id, + sales, + returns, + profit + FROM ssr + UNION ALL + SELECT + 'catalog channel' AS channel, + concat('catalog_page', catalog_page_id) AS id, + sales, + returns, + profit + FROM csr + UNION ALL + SELECT + 'web channel' AS channel, + concat('web_site', web_site_id) AS id, + sales, + returns, + profit + FROM wsr) x +GROUP BY ROLLUP (channel, id) +ORDER BY channel, id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q81.sql b/sql/core/src/test/resources/tpcds/q81.sql new file mode 100755 index 0000000000000..18f0ffa7e8f4c --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q81.sql @@ -0,0 +1,38 @@ +WITH customer_total_return AS +(SELECT + cr_returning_customer_sk AS ctr_customer_sk, + ca_state AS ctr_state, + sum(cr_return_amt_inc_tax) AS ctr_total_return + FROM catalog_returns, date_dim, customer_address + WHERE cr_returned_date_sk = d_date_sk + AND d_year = 2000 + AND cr_returning_addr_sk = ca_address_sk + GROUP BY cr_returning_customer_sk, ca_state ) +SELECT + c_customer_id, + c_salutation, + c_first_name, + c_last_name, + ca_street_number, + ca_street_name, + ca_street_type, + ca_suite_number, + ca_city, + ca_county, + ca_state, + ca_zip, + ca_country, + ca_gmt_offset, + ca_location_type, + ctr_total_return +FROM customer_total_return ctr1, customer_address, customer +WHERE ctr1.ctr_total_return > (SELECT avg(ctr_total_return) * 1.2 +FROM customer_total_return ctr2 +WHERE ctr1.ctr_state = ctr2.ctr_state) + AND ca_address_sk = c_current_addr_sk + AND ca_state = 'GA' + AND ctr1.ctr_customer_sk = c_customer_sk +ORDER BY c_customer_id, c_salutation, c_first_name, c_last_name, ca_street_number, ca_street_name + , ca_street_type, ca_suite_number, ca_city, ca_county, ca_state, ca_zip, ca_country, ca_gmt_offset + , ca_location_type, ctr_total_return +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q82.sql b/sql/core/src/test/resources/tpcds/q82.sql new file mode 100755 index 0000000000000..20942cfeb0787 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q82.sql @@ -0,0 +1,15 @@ +SELECT + i_item_id, + i_item_desc, + i_current_price +FROM item, inventory, date_dim, store_sales +WHERE i_current_price BETWEEN 62 AND 62 + 30 + AND inv_item_sk = i_item_sk + AND d_date_sk = inv_date_sk + AND d_date BETWEEN cast('2000-05-25' AS DATE) AND (cast('2000-05-25' AS DATE) + INTERVAL 60 days) + AND i_manufact_id IN (129, 270, 821, 423) + AND inv_quantity_on_hand BETWEEN 100 AND 500 + AND ss_item_sk = i_item_sk +GROUP BY i_item_id, i_item_desc, i_current_price +ORDER BY i_item_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q83.sql b/sql/core/src/test/resources/tpcds/q83.sql new file mode 100755 index 0000000000000..53c10c7ded6c1 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q83.sql @@ -0,0 +1,56 @@ +WITH sr_items AS +(SELECT + i_item_id item_id, + sum(sr_return_quantity) sr_item_qty + FROM store_returns, item, date_dim + WHERE sr_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq IN + (SELECT d_week_seq + FROM date_dim + WHERE d_date IN ('2000-06-30', '2000-09-27', '2000-11-17'))) + AND sr_returned_date_sk = d_date_sk + GROUP BY i_item_id), + cr_items AS + (SELECT + i_item_id item_id, + sum(cr_return_quantity) cr_item_qty + FROM catalog_returns, item, date_dim + WHERE cr_item_sk = i_item_sk + AND d_date IN (SELECT d_date + FROM date_dim + WHERE d_week_seq IN + (SELECT d_week_seq + FROM date_dim + WHERE d_date IN ('2000-06-30', '2000-09-27', '2000-11-17'))) + AND cr_returned_date_sk = d_date_sk + GROUP BY i_item_id), + wr_items AS + (SELECT + i_item_id item_id, + sum(wr_return_quantity) wr_item_qty + FROM web_returns, item, date_dim + WHERE wr_item_sk = i_item_sk AND d_date IN + (SELECT d_date + FROM date_dim + WHERE d_week_seq IN + (SELECT d_week_seq + FROM date_dim + WHERE d_date IN ('2000-06-30', '2000-09-27', '2000-11-17'))) + AND wr_returned_date_sk = d_date_sk + GROUP BY i_item_id) +SELECT + sr_items.item_id, + sr_item_qty, + sr_item_qty / (sr_item_qty + cr_item_qty + wr_item_qty) / 3.0 * 100 sr_dev, + cr_item_qty, + cr_item_qty / (sr_item_qty + cr_item_qty + wr_item_qty) / 3.0 * 100 cr_dev, + wr_item_qty, + wr_item_qty / (sr_item_qty + cr_item_qty + wr_item_qty) / 3.0 * 100 wr_dev, + (sr_item_qty + cr_item_qty + wr_item_qty) / 3.0 average +FROM sr_items, cr_items, wr_items +WHERE sr_items.item_id = cr_items.item_id + AND sr_items.item_id = wr_items.item_id +ORDER BY sr_items.item_id, sr_item_qty +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q84.sql b/sql/core/src/test/resources/tpcds/q84.sql new file mode 100755 index 0000000000000..a1076b57ced5c --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q84.sql @@ -0,0 +1,19 @@ +SELECT + c_customer_id AS customer_id, + concat(c_last_name, ', ', c_first_name) AS customername +FROM customer + , customer_address + , customer_demographics + , household_demographics + , income_band + , store_returns +WHERE ca_city = 'Edgewood' + AND c_current_addr_sk = ca_address_sk + AND ib_lower_bound >= 38128 + AND ib_upper_bound <= 38128 + 50000 + AND ib_income_band_sk = hd_income_band_sk + AND cd_demo_sk = c_current_cdemo_sk + AND hd_demo_sk = c_current_hdemo_sk + AND sr_cdemo_sk = cd_demo_sk +ORDER BY c_customer_id +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q85.sql b/sql/core/src/test/resources/tpcds/q85.sql new file mode 100755 index 0000000000000..cf718b0f8adec --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q85.sql @@ -0,0 +1,82 @@ +SELECT + substr(r_reason_desc, 1, 20), + avg(ws_quantity), + avg(wr_refunded_cash), + avg(wr_fee) +FROM web_sales, web_returns, web_page, customer_demographics cd1, + customer_demographics cd2, customer_address, date_dim, reason +WHERE ws_web_page_sk = wp_web_page_sk + AND ws_item_sk = wr_item_sk + AND ws_order_number = wr_order_number + AND ws_sold_date_sk = d_date_sk AND d_year = 2000 + AND cd1.cd_demo_sk = wr_refunded_cdemo_sk + AND cd2.cd_demo_sk = wr_returning_cdemo_sk + AND ca_address_sk = wr_refunded_addr_sk + AND r_reason_sk = wr_reason_sk + AND + ( + ( + cd1.cd_marital_status = 'M' + AND + cd1.cd_marital_status = cd2.cd_marital_status + AND + cd1.cd_education_status = 'Advanced Degree' + AND + cd1.cd_education_status = cd2.cd_education_status + AND + ws_sales_price BETWEEN 100.00 AND 150.00 + ) + OR + ( + cd1.cd_marital_status = 'S' + AND + cd1.cd_marital_status = cd2.cd_marital_status + AND + cd1.cd_education_status = 'College' + AND + cd1.cd_education_status = cd2.cd_education_status + AND + ws_sales_price BETWEEN 50.00 AND 100.00 + ) + OR + ( + cd1.cd_marital_status = 'W' + AND + cd1.cd_marital_status = cd2.cd_marital_status + AND + cd1.cd_education_status = '2 yr Degree' + AND + cd1.cd_education_status = cd2.cd_education_status + AND + ws_sales_price BETWEEN 150.00 AND 200.00 + ) + ) + AND + ( + ( + ca_country = 'United States' + AND + ca_state IN ('IN', 'OH', 'NJ') + AND ws_net_profit BETWEEN 100 AND 200 + ) + OR + ( + ca_country = 'United States' + AND + ca_state IN ('WI', 'CT', 'KY') + AND ws_net_profit BETWEEN 150 AND 300 + ) + OR + ( + ca_country = 'United States' + AND + ca_state IN ('LA', 'IA', 'AR') + AND ws_net_profit BETWEEN 50 AND 250 + ) + ) +GROUP BY r_reason_desc +ORDER BY substr(r_reason_desc, 1, 20) + , avg(ws_quantity) + , avg(wr_refunded_cash) + , avg(wr_fee) +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q86.sql b/sql/core/src/test/resources/tpcds/q86.sql new file mode 100755 index 0000000000000..789a4abf7b5f7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q86.sql @@ -0,0 +1,24 @@ +SELECT + sum(ws_net_paid) AS total_sum, + i_category, + i_class, + grouping(i_category) + grouping(i_class) AS lochierarchy, + rank() + OVER ( + PARTITION BY grouping(i_category) + grouping(i_class), + CASE WHEN grouping(i_class) = 0 + THEN i_category END + ORDER BY sum(ws_net_paid) DESC) AS rank_within_parent +FROM + web_sales, date_dim d1, item +WHERE + d1.d_month_seq BETWEEN 1200 AND 1200 + 11 + AND d1.d_date_sk = ws_sold_date_sk + AND i_item_sk = ws_item_sk +GROUP BY ROLLUP (i_category, i_class) +ORDER BY + lochierarchy DESC, + CASE WHEN lochierarchy = 0 + THEN i_category END, + rank_within_parent +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q87.sql b/sql/core/src/test/resources/tpcds/q87.sql new file mode 100755 index 0000000000000..4aaa9f39dce9e --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q87.sql @@ -0,0 +1,28 @@ +SELECT count(*) +FROM ((SELECT DISTINCT + c_last_name, + c_first_name, + d_date +FROM store_sales, date_dim, customer +WHERE store_sales.ss_sold_date_sk = date_dim.d_date_sk + AND store_sales.ss_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11) + EXCEPT + (SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM catalog_sales, date_dim, customer + WHERE catalog_sales.cs_sold_date_sk = date_dim.d_date_sk + AND catalog_sales.cs_bill_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11) + EXCEPT + (SELECT DISTINCT + c_last_name, + c_first_name, + d_date + FROM web_sales, date_dim, customer + WHERE web_sales.ws_sold_date_sk = date_dim.d_date_sk + AND web_sales.ws_bill_customer_sk = customer.c_customer_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11) + ) cool_cust diff --git a/sql/core/src/test/resources/tpcds/q88.sql b/sql/core/src/test/resources/tpcds/q88.sql new file mode 100755 index 0000000000000..25bcd90f41ab6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q88.sql @@ -0,0 +1,122 @@ +SELECT * +FROM + (SELECT count(*) h8_30_to_9 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 8 + AND time_dim.t_minute >= 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s1, + (SELECT count(*) h9_to_9_30 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 9 + AND time_dim.t_minute < 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s2, + (SELECT count(*) h9_30_to_10 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 9 + AND time_dim.t_minute >= 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s3, + (SELECT count(*) h10_to_10_30 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 10 + AND time_dim.t_minute < 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s4, + (SELECT count(*) h10_30_to_11 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 10 + AND time_dim.t_minute >= 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s5, + (SELECT count(*) h11_to_11_30 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 11 + AND time_dim.t_minute < 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s6, + (SELECT count(*) h11_30_to_12 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 11 + AND time_dim.t_minute >= 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s7, + (SELECT count(*) h12_to_12_30 + FROM store_sales, household_demographics, time_dim, store + WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 12 + AND time_dim.t_minute < 30 + AND ( + (household_demographics.hd_dep_count = 4 AND household_demographics.hd_vehicle_count <= 4 + 2) + OR + (household_demographics.hd_dep_count = 2 AND household_demographics.hd_vehicle_count <= 2 + 2) + OR + (household_demographics.hd_dep_count = 0 AND + household_demographics.hd_vehicle_count <= 0 + 2)) + AND store.s_store_name = 'ese') s8 diff --git a/sql/core/src/test/resources/tpcds/q89.sql b/sql/core/src/test/resources/tpcds/q89.sql new file mode 100755 index 0000000000000..75408cb0323f8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q89.sql @@ -0,0 +1,30 @@ +SELECT * +FROM ( + SELECT + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) + OVER + (PARTITION BY i_category, i_brand, s_store_name, s_company_name) + avg_monthly_sales + FROM item, store_sales, date_dim, store + WHERE ss_item_sk = i_item_sk AND + ss_sold_date_sk = d_date_sk AND + ss_store_sk = s_store_sk AND + d_year IN (1999) AND + ((i_category IN ('Books', 'Electronics', 'Sports') AND + i_class IN ('computers', 'stereo', 'football')) + OR (i_category IN ('Men', 'Jewelry', 'Women') AND + i_class IN ('shirts', 'birdal', 'dresses'))) + GROUP BY i_category, i_class, i_brand, + s_store_name, s_company_name, d_moy) tmp1 +WHERE CASE WHEN (avg_monthly_sales <> 0) + THEN (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) + ELSE NULL END > 0.1 +ORDER BY sum_sales - avg_monthly_sales, s_store_name +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q9.sql b/sql/core/src/test/resources/tpcds/q9.sql new file mode 100755 index 0000000000000..de3db9d988f1e --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q9.sql @@ -0,0 +1,48 @@ +SELECT + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 1 AND 20) > 62316685 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 1 AND 20) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 1 AND 20) END bucket1, + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 21 AND 40) > 19045798 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 21 AND 40) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 21 AND 40) END bucket2, + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 41 AND 60) > 365541424 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 41 AND 60) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 41 AND 60) END bucket3, + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 61 AND 80) > 216357808 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 61 AND 80) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 61 AND 80) END bucket4, + CASE WHEN (SELECT count(*) + FROM store_sales + WHERE ss_quantity BETWEEN 81 AND 100) > 184483884 + THEN (SELECT avg(ss_ext_discount_amt) + FROM store_sales + WHERE ss_quantity BETWEEN 81 AND 100) + ELSE (SELECT avg(ss_net_paid) + FROM store_sales + WHERE ss_quantity BETWEEN 81 AND 100) END bucket5 +FROM reason +WHERE r_reason_sk = 1 diff --git a/sql/core/src/test/resources/tpcds/q90.sql b/sql/core/src/test/resources/tpcds/q90.sql new file mode 100755 index 0000000000000..85e35bf8bf8ec --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q90.sql @@ -0,0 +1,19 @@ +SELECT cast(amc AS DECIMAL(15, 4)) / cast(pmc AS DECIMAL(15, 4)) am_pm_ratio +FROM (SELECT count(*) amc +FROM web_sales, household_demographics, time_dim, web_page +WHERE ws_sold_time_sk = time_dim.t_time_sk + AND ws_ship_hdemo_sk = household_demographics.hd_demo_sk + AND ws_web_page_sk = web_page.wp_web_page_sk + AND time_dim.t_hour BETWEEN 8 AND 8 + 1 + AND household_demographics.hd_dep_count = 6 + AND web_page.wp_char_count BETWEEN 5000 AND 5200) at, + (SELECT count(*) pmc + FROM web_sales, household_demographics, time_dim, web_page + WHERE ws_sold_time_sk = time_dim.t_time_sk + AND ws_ship_hdemo_sk = household_demographics.hd_demo_sk + AND ws_web_page_sk = web_page.wp_web_page_sk + AND time_dim.t_hour BETWEEN 19 AND 19 + 1 + AND household_demographics.hd_dep_count = 6 + AND web_page.wp_char_count BETWEEN 5000 AND 5200) pt +ORDER BY am_pm_ratio +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q91.sql b/sql/core/src/test/resources/tpcds/q91.sql new file mode 100755 index 0000000000000..9ca7ce00ac775 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q91.sql @@ -0,0 +1,23 @@ +SELECT + cc_call_center_id Call_Center, + cc_name Call_Center_Name, + cc_manager Manager, + sum(cr_net_loss) Returns_Loss +FROM + call_center, catalog_returns, date_dim, customer, customer_address, + customer_demographics, household_demographics +WHERE + cr_call_center_sk = cc_call_center_sk + AND cr_returned_date_sk = d_date_sk + AND cr_returning_customer_sk = c_customer_sk + AND cd_demo_sk = c_current_cdemo_sk + AND hd_demo_sk = c_current_hdemo_sk + AND ca_address_sk = c_current_addr_sk + AND d_year = 1998 + AND d_moy = 11 + AND ((cd_marital_status = 'M' AND cd_education_status = 'Unknown') + OR (cd_marital_status = 'W' AND cd_education_status = 'Advanced Degree')) + AND hd_buy_potential LIKE 'Unknown%' + AND ca_gmt_offset = -7 +GROUP BY cc_call_center_id, cc_name, cc_manager, cd_marital_status, cd_education_status +ORDER BY sum(cr_net_loss) DESC diff --git a/sql/core/src/test/resources/tpcds/q92.sql b/sql/core/src/test/resources/tpcds/q92.sql new file mode 100755 index 0000000000000..99129c3bd9e5b --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q92.sql @@ -0,0 +1,16 @@ +SELECT sum(ws_ext_discount_amt) AS `Excess Discount Amount ` +FROM web_sales, item, date_dim +WHERE i_manufact_id = 350 + AND i_item_sk = ws_item_sk + AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + INTERVAL 90 days) + AND d_date_sk = ws_sold_date_sk + AND ws_ext_discount_amt > + ( + SELECT 1.3 * avg(ws_ext_discount_amt) + FROM web_sales, date_dim + WHERE ws_item_sk = i_item_sk + AND d_date BETWEEN '2000-01-27' AND (cast('2000-01-27' AS DATE) + INTERVAL 90 days) + AND d_date_sk = ws_sold_date_sk + ) +ORDER BY sum(ws_ext_discount_amt) +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q93.sql b/sql/core/src/test/resources/tpcds/q93.sql new file mode 100755 index 0000000000000..222dc31c1f561 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q93.sql @@ -0,0 +1,19 @@ +SELECT + ss_customer_sk, + sum(act_sales) sumsales +FROM (SELECT + ss_item_sk, + ss_ticket_number, + ss_customer_sk, + CASE WHEN sr_return_quantity IS NOT NULL + THEN (ss_quantity - sr_return_quantity) * ss_sales_price + ELSE (ss_quantity * ss_sales_price) END act_sales +FROM store_sales + LEFT OUTER JOIN store_returns + ON (sr_item_sk = ss_item_sk AND sr_ticket_number = ss_ticket_number) + , + reason +WHERE sr_reason_sk = r_reason_sk AND r_reason_desc = 'reason 28') t +GROUP BY ss_customer_sk +ORDER BY sumsales, ss_customer_sk +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q94.sql b/sql/core/src/test/resources/tpcds/q94.sql new file mode 100755 index 0000000000000..d6de3d75b82d2 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q94.sql @@ -0,0 +1,23 @@ +SELECT + count(DISTINCT ws_order_number) AS `order count `, + sum(ws_ext_ship_cost) AS `total shipping cost `, + sum(ws_net_profit) AS `total net profit ` +FROM + web_sales ws1, date_dim, customer_address, web_site +WHERE + d_date BETWEEN '1999-02-01' AND + (CAST('1999-02-01' AS DATE) + INTERVAL 60 days) + AND ws1.ws_ship_date_sk = d_date_sk + AND ws1.ws_ship_addr_sk = ca_address_sk + AND ca_state = 'IL' + AND ws1.ws_web_site_sk = web_site_sk + AND web_company_name = 'pri' + AND EXISTS(SELECT * + FROM web_sales ws2 + WHERE ws1.ws_order_number = ws2.ws_order_number + AND ws1.ws_warehouse_sk <> ws2.ws_warehouse_sk) + AND NOT EXISTS(SELECT * + FROM web_returns wr1 + WHERE ws1.ws_order_number = wr1.wr_order_number) +ORDER BY count(DISTINCT ws_order_number) +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q95.sql b/sql/core/src/test/resources/tpcds/q95.sql new file mode 100755 index 0000000000000..df71f00bd6c0b --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q95.sql @@ -0,0 +1,29 @@ +WITH ws_wh AS +(SELECT + ws1.ws_order_number, + ws1.ws_warehouse_sk wh1, + ws2.ws_warehouse_sk wh2 + FROM web_sales ws1, web_sales ws2 + WHERE ws1.ws_order_number = ws2.ws_order_number + AND ws1.ws_warehouse_sk <> ws2.ws_warehouse_sk) +SELECT + count(DISTINCT ws_order_number) AS `order count `, + sum(ws_ext_ship_cost) AS `total shipping cost `, + sum(ws_net_profit) AS `total net profit ` +FROM + web_sales ws1, date_dim, customer_address, web_site +WHERE + d_date BETWEEN '1999-02-01' AND + (CAST('1999-02-01' AS DATE) + INTERVAL 60 DAY) + AND ws1.ws_ship_date_sk = d_date_sk + AND ws1.ws_ship_addr_sk = ca_address_sk + AND ca_state = 'IL' + AND ws1.ws_web_site_sk = web_site_sk + AND web_company_name = 'pri' + AND ws1.ws_order_number IN (SELECT ws_order_number + FROM ws_wh) + AND ws1.ws_order_number IN (SELECT wr_order_number + FROM web_returns, ws_wh + WHERE wr_order_number = ws_wh.ws_order_number) +ORDER BY count(DISTINCT ws_order_number) +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q96.sql b/sql/core/src/test/resources/tpcds/q96.sql new file mode 100755 index 0000000000000..7ab17e7bc4597 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q96.sql @@ -0,0 +1,11 @@ +SELECT count(*) +FROM store_sales, household_demographics, time_dim, store +WHERE ss_sold_time_sk = time_dim.t_time_sk + AND ss_hdemo_sk = household_demographics.hd_demo_sk + AND ss_store_sk = s_store_sk + AND time_dim.t_hour = 20 + AND time_dim.t_minute >= 30 + AND household_demographics.hd_dep_count = 7 + AND store.s_store_name = 'ese' +ORDER BY count(*) +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q97.sql b/sql/core/src/test/resources/tpcds/q97.sql new file mode 100755 index 0000000000000..e7e0b1a05259d --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q97.sql @@ -0,0 +1,30 @@ +WITH ssci AS ( + SELECT + ss_customer_sk customer_sk, + ss_item_sk item_sk + FROM store_sales, date_dim + WHERE ss_sold_date_sk = d_date_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + GROUP BY ss_customer_sk, ss_item_sk), + csci AS ( + SELECT + cs_bill_customer_sk customer_sk, + cs_item_sk item_sk + FROM catalog_sales, date_dim + WHERE cs_sold_date_sk = d_date_sk + AND d_month_seq BETWEEN 1200 AND 1200 + 11 + GROUP BY cs_bill_customer_sk, cs_item_sk) +SELECT + sum(CASE WHEN ssci.customer_sk IS NOT NULL AND csci.customer_sk IS NULL + THEN 1 + ELSE 0 END) store_only, + sum(CASE WHEN ssci.customer_sk IS NULL AND csci.customer_sk IS NOT NULL + THEN 1 + ELSE 0 END) catalog_only, + sum(CASE WHEN ssci.customer_sk IS NOT NULL AND csci.customer_sk IS NOT NULL + THEN 1 + ELSE 0 END) store_and_catalog +FROM ssci + FULL OUTER JOIN csci ON (ssci.customer_sk = csci.customer_sk + AND ssci.item_sk = csci.item_sk) +LIMIT 100 diff --git a/sql/core/src/test/resources/tpcds/q98.sql b/sql/core/src/test/resources/tpcds/q98.sql new file mode 100755 index 0000000000000..bb10d4bf8da23 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q98.sql @@ -0,0 +1,21 @@ +SELECT + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) AS itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) + OVER + (PARTITION BY i_class) AS revenueratio +FROM + store_sales, item, date_dim +WHERE + ss_item_sk = i_item_sk + AND i_category IN ('Sports', 'Books', 'Home') + AND ss_sold_date_sk = d_date_sk + AND d_date BETWEEN cast('1999-02-22' AS DATE) + AND (cast('1999-02-22' AS DATE) + INTERVAL 30 days) +GROUP BY + i_item_id, i_item_desc, i_category, i_class, i_current_price +ORDER BY + i_category, i_class, i_item_id, i_item_desc, revenueratio diff --git a/sql/core/src/test/resources/tpcds/q99.sql b/sql/core/src/test/resources/tpcds/q99.sql new file mode 100755 index 0000000000000..f1a3d4d2b7fe9 --- /dev/null +++ b/sql/core/src/test/resources/tpcds/q99.sql @@ -0,0 +1,34 @@ +SELECT + substr(w_warehouse_name, 1, 20), + sm_type, + cc_name, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk <= 30) + THEN 1 + ELSE 0 END) AS `30 days `, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk > 30) AND + (cs_ship_date_sk - cs_sold_date_sk <= 60) + THEN 1 + ELSE 0 END) AS `31 - 60 days `, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk > 60) AND + (cs_ship_date_sk - cs_sold_date_sk <= 90) + THEN 1 + ELSE 0 END) AS `61 - 90 days `, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk > 90) AND + (cs_ship_date_sk - cs_sold_date_sk <= 120) + THEN 1 + ELSE 0 END) AS `91 - 120 days `, + sum(CASE WHEN (cs_ship_date_sk - cs_sold_date_sk > 120) + THEN 1 + ELSE 0 END) AS `>120 days ` +FROM + catalog_sales, warehouse, ship_mode, call_center, date_dim +WHERE + d_month_seq BETWEEN 1200 AND 1200 + 11 + AND cs_ship_date_sk = d_date_sk + AND cs_warehouse_sk = w_warehouse_sk + AND cs_ship_mode_sk = sm_ship_mode_sk + AND cs_call_center_sk = cc_call_center_sk +GROUP BY + substr(w_warehouse_name, 1, 20), sm_type, cc_name +ORDER BY substr(w_warehouse_name, 1, 20), sm_type, cc_name +LIMIT 100 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 12dbed89d561b..f42402e1cc7d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql +import scala.collection.mutable.HashSet import scala.concurrent.duration._ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ -import org.apache.spark.AccumulatorContext +import org.apache.spark.CleanerListener import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.util.AccumulatorContext private case class BigData(s: String) @@ -36,7 +38,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext import testImplicits._ def rddIdOf(tableName: String): Int = { - val plan = sqlContext.table(tableName).queryExecution.sparkPlan + val plan = spark.table(tableName).queryExecution.sparkPlan plan.collect { case InMemoryTableScanExec(_, _, relation) => relation.cachedColumnBuffers.id @@ -71,43 +73,47 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } test("cache temp table") { - testData.select('key).registerTempTable("tempTable") - assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - sqlContext.cacheTable("tempTable") - assertCached(sql("SELECT COUNT(*) FROM tempTable")) - sqlContext.uncacheTable("tempTable") + withTempView("tempTable") { + testData.select('key).createOrReplaceTempView("tempTable") + assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) + spark.catalog.cacheTable("tempTable") + assertCached(sql("SELECT COUNT(*) FROM tempTable")) + spark.catalog.uncacheTable("tempTable") + } } test("unpersist an uncached table will not raise exception") { - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.sharedState.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.sharedState.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.sharedState.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != sqlContext.cacheManager.lookupCachedData(testData)) + assert(None != spark.sharedState.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.sharedState.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == sqlContext.cacheManager.lookupCachedData(testData)) + assert(None == spark.sharedState.cacheManager.lookupCachedData(testData)) } test("cache table as select") { - sql("CACHE TABLE tempTable AS SELECT key FROM testData") - assertCached(sql("SELECT COUNT(*) FROM tempTable")) - sqlContext.uncacheTable("tempTable") + withTempView("tempTable") { + sql("CACHE TABLE tempTable AS SELECT key FROM testData") + assertCached(sql("SELECT COUNT(*) FROM tempTable")) + spark.catalog.uncacheTable("tempTable") + } } test("uncaching temp table") { - testData.select('key).registerTempTable("tempTable1") - testData.select('key).registerTempTable("tempTable2") - sqlContext.cacheTable("tempTable1") + testData.select('key).createOrReplaceTempView("tempTable1") + testData.select('key).createOrReplaceTempView("tempTable2") + spark.catalog.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - sqlContext.uncacheTable("tempTable2") + spark.catalog.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -116,102 +122,96 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("too big for memory") { val data = "*" * 1000 sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() - .registerTempTable("bigData") - sqlContext.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(sqlContext.table("bigData").count() === 200000L) - sqlContext.table("bigData").unpersist(blocking = true) + .createOrReplaceTempView("bigData") + spark.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(spark.table("bigData").count() === 200000L) + spark.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - sqlContext.table("testData").cache() - assertCached(sqlContext.table("testData")) - sqlContext.table("testData").unpersist(blocking = true) + spark.table("testData").cache() + assertCached(spark.table("testData")) + spark.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - sqlContext.table("testData").cache() - sqlContext.table("testData").count() - sqlContext.table("testData").unpersist(blocking = true) - assertCached(sqlContext.table("testData"), 0) + spark.table("testData").cache() + spark.table("testData").count() + spark.table("testData").unpersist(blocking = true) + assertCached(spark.table("testData"), 0) } test("isCached") { - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") - assertCached(sqlContext.table("testData")) - assert(sqlContext.table("testData").queryExecution.withCachedData match { + assertCached(spark.table("testData")) + assert(spark.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - sqlContext.uncacheTable("testData") - assert(!sqlContext.isCached("testData")) - assert(sqlContext.table("testData").queryExecution.withCachedData match { + spark.catalog.uncacheTable("testData") + assert(!spark.catalog.isCached("testData")) + assert(spark.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!sqlContext.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - sqlContext.cacheTable("testData") - assertCached(sqlContext.table("testData")) + spark.catalog.cacheTable("testData") + assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryTableScanExec, _) => r }.size } - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } test("read from cached table and uncache") { - sqlContext.cacheTable("testData") - checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) - assertCached(sqlContext.table("testData")) + spark.catalog.cacheTable("testData") + checkAnswer(spark.table("testData"), testData.collect().toSeq) + assertCached(spark.table("testData")) - sqlContext.uncacheTable("testData") - checkAnswer(sqlContext.table("testData"), testData.collect().toSeq) - assertCached(sqlContext.table("testData"), 0) - } - - test("correct error on uncache of non-cached table") { - intercept[IllegalArgumentException] { - sqlContext.uncacheTable("testData") - } + spark.catalog.uncacheTable("testData") + checkAnswer(spark.table("testData"), testData.collect().toSeq) + assertCached(spark.table("testData"), 0) } test("SELECT star from cached table") { - sql("SELECT * FROM testData").registerTempTable("selectStar") - sqlContext.cacheTable("selectStar") + sql("SELECT * FROM testData").createOrReplaceTempView("selectStar") + spark.catalog.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - sqlContext.uncacheTable("selectStar") + spark.catalog.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - sqlContext.cacheTable("testData") + spark.catalog.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(sqlContext.table("testData")) + assertCached(spark.table("testData")) val rddId = rddIdOf("testData") assert( @@ -219,7 +219,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!sqlContext.isCached("testData"), "Table 'testData' should not be cached") + assert(!spark.catalog.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -227,38 +227,42 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { - sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(sqlContext.table("testCacheTable")) - - val rddId = rddIdOf("testCacheTable") - assert( - isMaterialized(rddId), - "Eagerly cached in-memory table should have already been materialized") - - sqlContext.uncacheTable("testCacheTable") - eventually(timeout(10 seconds)) { - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + withTempView("testCacheTable") { + sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") + assertCached(spark.table("testCacheTable")) + + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") + + spark.catalog.uncacheTable("testCacheTable") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } } test("CACHE TABLE tableName AS SELECT ...") { - sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(sqlContext.table("testCacheTable")) - - val rddId = rddIdOf("testCacheTable") - assert( - isMaterialized(rddId), - "Eagerly cached in-memory table should have already been materialized") - - sqlContext.uncacheTable("testCacheTable") - eventually(timeout(10 seconds)) { - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + withTempView("testCacheTable") { + sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") + assertCached(spark.table("testCacheTable")) + + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") + + spark.catalog.uncacheTable("testCacheTable") + eventually(timeout(10 seconds)) { + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } } test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(sqlContext.table("testData")) + assertCached(spark.table("testData")) val rddId = rddIdOf("testData") assert( @@ -270,7 +274,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - sqlContext.uncacheTable("testData") + spark.catalog.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -278,7 +282,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - sqlContext.table("testData").queryExecution.withCachedData.collect { + spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -286,63 +290,85 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext } test("Drops temporary table") { - testData.select('key).registerTempTable("t1") - sqlContext.table("t1") - sqlContext.dropTempTable("t1") - intercept[AnalysisException](sqlContext.table("t1")) + testData.select('key).createOrReplaceTempView("t1") + spark.table("t1") + spark.catalog.dropTempView("t1") + intercept[AnalysisException](spark.table("t1")) } test("Drops cached temporary table") { - testData.select('key).registerTempTable("t1") - testData.select('key).registerTempTable("t2") - sqlContext.cacheTable("t1") + testData.select('key).createOrReplaceTempView("t1") + testData.select('key).createOrReplaceTempView("t2") + spark.catalog.cacheTable("t1") - assert(sqlContext.isCached("t1")) - assert(sqlContext.isCached("t2")) + assert(spark.catalog.isCached("t1")) + assert(spark.catalog.isCached("t2")) - sqlContext.dropTempTable("t1") - intercept[AnalysisException](sqlContext.table("t1")) - assert(!sqlContext.isCached("t2")) + spark.catalog.dropTempView("t1") + intercept[AnalysisException](spark.table("t1")) + assert(!spark.catalog.isCached("t2")) } test("Clear all cache") { - sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") - sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") - sqlContext.clearCache() - assert(sqlContext.cacheManager.isEmpty) - - sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") - sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + sql("SELECT key FROM testData LIMIT 10").createOrReplaceTempView("t1") + sql("SELECT key FROM testData LIMIT 5").createOrReplaceTempView("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") + spark.catalog.clearCache() + assert(spark.sharedState.cacheManager.isEmpty) + + sql("SELECT key FROM testData LIMIT 10").createOrReplaceTempView("t1") + sql("SELECT key FROM testData LIMIT 5").createOrReplaceTempView("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") sql("Clear CACHE") - assert(sqlContext.cacheManager.isEmpty) + assert(spark.sharedState.cacheManager.isEmpty) } - test("Clear accumulators when uncacheTable to prevent memory leaking") { - sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") - sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + test("Ensure accumulators to be cleared after GC when uncacheTable") { + sql("SELECT key FROM testData LIMIT 10").createOrReplaceTempView("t1") + sql("SELECT key FROM testData LIMIT 5").createOrReplaceTempView("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() - val accId1 = sqlContext.table("t1").queryExecution.withCachedData.collect { + val toBeCleanedAccIds = new HashSet[Long] + + val accId1 = spark.table("t1").queryExecution.withCachedData.collect { case i: InMemoryRelation => i.batchStats.id }.head + toBeCleanedAccIds += accId1 - val accId2 = sqlContext.table("t1").queryExecution.withCachedData.collect { + val accId2 = spark.table("t1").queryExecution.withCachedData.collect { case i: InMemoryRelation => i.batchStats.id }.head + toBeCleanedAccIds += accId2 + + val cleanerListener = new CleanerListener { + def rddCleaned(rddId: Int): Unit = {} + def shuffleCleaned(shuffleId: Int): Unit = {} + def broadcastCleaned(broadcastId: Long): Unit = {} + def accumCleaned(accId: Long): Unit = { + toBeCleanedAccIds.synchronized { toBeCleanedAccIds -= accId } + } + def checkpointCleaned(rddId: Long): Unit = {} + } + spark.sparkContext.cleaner.get.attachListener(cleanerListener) + + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + System.gc() + + eventually(timeout(10 seconds)) { + assert(toBeCleanedAccIds.synchronized { toBeCleanedAccIds.isEmpty }, + "batchStats accumulators should be cleared after GC when uncacheTable") + } assert(AccumulatorContext.get(accId1).isEmpty) assert(AccumulatorContext.get(accId2).isEmpty) @@ -350,8 +376,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) - .toDF("key", "value").selectExpr("key", "value", "key+1").registerTempTable("abc") - sqlContext.cacheTable("abc") + .toDF("key", "value").selectExpr("key", "value", "key+1").createOrReplaceTempView("abc") + spark.catalog.cacheTable("abc") val sparkPlan = sql( """select a.key, b.key, c.key from @@ -371,27 +397,27 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { val table3x = testData.union(testData).union(testData) - table3x.registerTempTable("testData3x") + table3x.createOrReplaceTempView("testData3x") - sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable") - sqlContext.cacheTable("orderedTable") - assertCached(sqlContext.table("orderedTable")) + sql("SELECT key, value FROM testData3x ORDER BY key").createOrReplaceTempView("orderedTable") + spark.catalog.cacheTable("orderedTable") + assertCached(spark.table("orderedTable")) // Should not have an exchange as the query is already sorted on the group by key. verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0) checkAnswer( sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) - sqlContext.uncacheTable("orderedTable") - sqlContext.dropTempTable("orderedTable") + spark.catalog.uncacheTable("orderedTable") + spark.catalog.dropTempView("orderedTable") // Set up two tables distributed in the same way. Try this with the data distributed into // different number of partitions. for (numPartitions <- 1 until 10 by 4) { - withTempTable("t1", "t2") { - testData.repartition(numPartitions, $"key").registerTempTable("t1") - testData2.repartition(numPartitions, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + withTempView("t1", "t2") { + testData.repartition(numPartitions, $"key").createOrReplaceTempView("t1") + testData2.repartition(numPartitions, $"a").createOrReplaceTempView("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") // Joining them should result in no exchanges. verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) @@ -403,17 +429,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), sql("SELECT count(*) FROM testData GROUP BY key")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } } // Distribute the tables into non-matching number of partitions. Need to shuffle one side. - withTempTable("t1", "t2") { - testData.repartition(6, $"key").registerTempTable("t1") - testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + withTempView("t1", "t2") { + testData.repartition(6, $"key").createOrReplaceTempView("t1") + testData2.repartition(3, $"a").createOrReplaceTempView("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -421,16 +447,16 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // One side of join is not partitioned in the desired way. Need to shuffle one side. - withTempTable("t1", "t2") { - testData.repartition(6, $"value").registerTempTable("t1") - testData2.repartition(6, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + withTempView("t1", "t2") { + testData.repartition(6, $"value").createOrReplaceTempView("t1") + testData2.repartition(6, $"a").createOrReplaceTempView("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -438,15 +464,15 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } - withTempTable("t1", "t2") { - testData.repartition(6, $"value").registerTempTable("t1") - testData2.repartition(12, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + withTempView("t1", "t2") { + testData.repartition(6, $"value").createOrReplaceTempView("t1") + testData2.repartition(12, $"a").createOrReplaceTempView("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 1) @@ -454,53 +480,53 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // One side of join is not partitioned in the desired way. Since the number of partitions of // the side that has already partitioned is smaller than the side that is not partitioned, // we shuffle both side. - withTempTable("t1", "t2") { - testData.repartition(6, $"value").registerTempTable("t1") - testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + withTempView("t1", "t2") { + testData.repartition(6, $"value").createOrReplaceTempView("t1") + testData2.repartition(3, $"a").createOrReplaceTempView("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") verifyNumExchanges(query, 2) checkAnswer( query, testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } // repartition's column ordering is different from group by column ordering. // But they use the same set of columns. - withTempTable("t1") { - testData.repartition(6, $"value", $"key").registerTempTable("t1") - sqlContext.cacheTable("t1") + withTempView("t1") { + testData.repartition(6, $"value", $"key").createOrReplaceTempView("t1") + spark.catalog.cacheTable("t1") val query = sql("SELECT value, key from t1 group by key, value") verifyNumExchanges(query, 0) checkAnswer( query, testData.distinct().select($"value", $"key")) - sqlContext.uncacheTable("t1") + spark.catalog.uncacheTable("t1") } // repartition's column ordering is different from join condition's column ordering. // We will still shuffle because hashcodes of a row depend on the column ordering. // If we do not shuffle, we may actually partition two tables in totally two different way. // See PartitioningSuite for more details. - withTempTable("t1", "t2") { + withTempView("t1", "t2") { val df1 = testData - df1.repartition(6, $"value", $"key").registerTempTable("t1") + df1.repartition(6, $"value", $"key").createOrReplaceTempView("t1") val df2 = testData2.select($"a", $"b".cast("string")) - df2.repartition(6, $"a", $"b").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + df2.repartition(6, $"a", $"b").createOrReplaceTempView("t2") + spark.catalog.cacheTable("t1") + spark.catalog.cacheTable("t2") val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b") @@ -509,8 +535,34 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext checkAnswer( query, df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b")) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") + spark.catalog.uncacheTable("t1") + spark.catalog.uncacheTable("t2") } } + + test("SPARK-15870 DataFrame can't execute after uncacheTable") { + val selectStar = sql("SELECT * FROM testData WHERE key = 1") + selectStar.createOrReplaceTempView("selectStar") + + spark.catalog.cacheTable("selectStar") + checkAnswer( + selectStar, + Seq(Row(1, "1"))) + + spark.catalog.uncacheTable("selectStar") + checkAnswer( + selectStar, + Seq(Row(1, "1"))) + } + + test("SPARK-15915 Logical plans should use canonicalized plan when override sameResult") { + val localRelation = Seq(1, 2, 3).toDF() + localRelation.createOrReplaceTempView("localRelation") + + spark.catalog.cacheTable("localRelation") + assert( + localRelation.queryExecution.withCachedData.collect { + case i: InMemoryRelation => i + }.size == 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 19fe29a202a64..26e1a9f75da13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.expressions.NamedExpression @@ -29,7 +31,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { import testImplicits._ private lazy val booleanData = { - sqlContext.createDataFrame(sparkContext.parallelize( + spark.createDataFrame(sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -120,66 +122,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value") } - test("single explode") { - val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") - checkAnswer( - df.select(explode('intList)), - Row(1) :: Row(2) :: Row(3) :: Nil) - } - - test("explode and other columns") { - val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") - - checkAnswer( - df.select($"a", explode('intList)), - Row(1, 1) :: - Row(1, 2) :: - Row(1, 3) :: Nil) - - checkAnswer( - df.select($"*", explode('intList)), - Row(1, Seq(1, 2, 3), 1) :: - Row(1, Seq(1, 2, 3), 2) :: - Row(1, Seq(1, 2, 3), 3) :: Nil) - } - - test("aliased explode") { - val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") - - checkAnswer( - df.select(explode('intList).as('int)).select('int), - Row(1) :: Row(2) :: Row(3) :: Nil) - - checkAnswer( - df.select(explode('intList).as('int)).select(sum('int)), - Row(6) :: Nil) - } - - test("explode on map") { - val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") - - checkAnswer( - df.select(explode('map)), - Row("a", "b")) - } - - test("explode on map with aliases") { - val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") - - checkAnswer( - df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), - Row("a", "b")) - } - - test("self join explode") { - val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") - val exploded = df.select(explode('intList).as('i)) - - checkAnswer( - exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), - Row(3) :: Nil) - } - test("collect on column produced by a binary operator") { val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) @@ -287,7 +229,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("isNaN") { - val testData = sqlContext.createDataFrame(sparkContext.parallelize( + val testData = spark.createDataFrame(sparkContext.parallelize( Row(Double.NaN, Float.NaN) :: Row(math.log(-1), math.log(-3).toFloat) :: Row(null, null) :: @@ -308,7 +250,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("nanvl") { - val testData = sqlContext.createDataFrame(sparkContext.parallelize( + val testData = spark.createDataFrame(sparkContext.parallelize( Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), StructField("c", DoubleType), StructField("d", DoubleType), @@ -321,7 +263,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { nanvl($"b", $"e"), nanvl($"e", $"f")), Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) ) - testData.registerTempTable("t") + testData.createOrReplaceTempView("t") checkAnswer( sql( "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + @@ -351,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("=!=") { - val nullData = sqlContext.createDataFrame(sparkContext.parallelize( + val nullData = spark.createDataFrame(sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -370,7 +312,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { nullData.filter($"a" <=> $"b"), Row(1, 1) :: Row(null, null) :: Nil) - val nullData2 = sqlContext.createDataFrame(sparkContext.parallelize( + val nullData2 = spark.createDataFrame(sparkContext.parallelize( Row("abc") :: Row(null) :: Row("xyz") :: Nil), @@ -566,18 +508,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { Row("ab", "cde")) } - test("monotonicallyIncreasingId") { + test("monotonically_increasing_id") { // Make sure we have 2 partitions, each with 2 records. val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( - df.select(monotonicallyIncreasingId()), - Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil - ) - checkAnswer( - df.select(expr("monotonically_increasing_id()")), - Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + df.select(monotonically_increasing_id(), expr("monotonically_increasing_id()")), + Row(0L, 0L) :: + Row(1L, 1L) :: + Row((1L << 33) + 0L, (1L << 33) + 0L) :: + Row((1L << 33) + 1L, (1L << 33) + 1L) :: Nil ) } @@ -592,11 +533,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } - test("input_file_name") { + test("input_file_name - FileScanRDD") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name()) + val answer = spark.read.parquet(dir.getCanonicalPath).select(input_file_name()) .head.getString(0) assert(answer.contains(dir.getCanonicalPath)) @@ -604,6 +545,35 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } } + test("input_file_name - HadoopRDD") { + withTempPath { dir => + val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF() + data.write.text(dir.getCanonicalPath) + val df = spark.sparkContext.textFile(dir.getCanonicalPath).toDF() + val answer = df.select(input_file_name()).head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(input_file_name()).limit(1), Row("")) + } + } + + test("input_file_name - NewHadoopRDD") { + withTempPath { dir => + val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF() + data.write.text(dir.getCanonicalPath) + val rdd = spark.sparkContext.newAPIHadoopFile( + dir.getCanonicalPath, + classOf[NewTextInputFormat], + classOf[LongWritable], + classOf[Text]) + val df = rdd.map(pair => pair._2.toString).toDF() + val answer = df.select(input_file_name()).head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(input_file_name()).limit(1), Row("")) + } + } + test("columns can be compared") { assert('key.desc == 'key.desc) assert('key.desc != 'key.asc) @@ -707,5 +677,4 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)), testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39))) } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 63f4b759a00ae..3454caff6b825 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -70,7 +70,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0))) ) - val decimalDataWithNulls = sqlContext.sparkContext.parallelize( + val decimalDataWithNulls = spark.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, null) :: DecimalData(2, 1) :: @@ -87,6 +87,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-17124 agg should be ordering preserving") { + val df = spark.range(2) + val ret = df.groupBy("id").agg("id" -> "sum", "id" -> "count", "id" -> "min") + assert(ret.schema.map(_.name) == Seq("id", "sum(id)", "count(id)", "min(id)")) + checkAnswer( + ret, + Row(0, 0, 1, 0) :: Row(1, 1, 1, 1) :: Nil + ) + } + test("rollup") { checkAnswer( courseSales.rollup("course", "year").sum("earnings"), @@ -114,7 +124,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, 113000.0) :: Nil ) - val df0 = sqlContext.sparkContext.parallelize(Seq( + val df0 = spark.sparkContext.parallelize(Seq( Fact(20151123, 18, 35, "room1", 18.6), Fact(20151123, 18, 35, "room2", 22.4), Fact(20151123, 18, 36, "room1", 17.4), @@ -207,12 +217,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) + spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) + spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, true) } test("agg without groups") { @@ -431,12 +441,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, null, null, null)) } + test("collect functions") { + val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) + ) + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 4))) + ) + } + + test("collect functions structs") { + val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1)) + .toDF("a", "x", "y") + .select($"a", struct($"x", $"y").as("b")) + checkAnswer( + df.select(collect_list($"a"), sort_array(collect_list($"b"))), + Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(2, 2), Row(4, 1)))) + ) + checkAnswer( + df.select(collect_set($"a"), sort_array(collect_set($"b"))), + Seq(Row(Seq(1, 2, 3), Seq(Row(2, 2), Row(4, 1)))) + ) + } + + test("collect_set functions cannot have maps") { + val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1)) + .toDF("a", "x", "y") + .select($"a", map($"x", $"y").as("b")) + val error = intercept[AnalysisException] { + df.select(collect_set($"a"), collect_set($"b")) + } + assert(error.message.contains("collect_set() cannot have map type data")) + } + + test("SPARK-17641: collect functions should not collect null values") { + val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b") + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq("1", "1"), Seq(2, 2, 4))) + ) + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq("1"), Seq(2, 4))) + ) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( - sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), + spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) checkAnswer( - sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), + spark.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } + + test("SPARK-17616: distinct aggregate combined with a non-partial aggregate") { + val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d")) + .toDF("x", "y", "z") + checkAnswer( + df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))), + Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 72f676e6225ee..1230b921aa279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.DefinedByConstructorParams import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -58,4 +59,43 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val nullIntRow = df.selectExpr("i[1]").collect()(0) assert(nullIntRow == org.apache.spark.sql.Row(null)) } + + test("SPARK-15285 Generated SpecificSafeProjection.apply method grows beyond 64KB") { + val ds100_5 = Seq(S100_5()).toDS() + ds100_5.rdd.count + } } + +class S100( + val s1: String = "1", val s2: String = "2", val s3: String = "3", val s4: String = "4", + val s5: String = "5", val s6: String = "6", val s7: String = "7", val s8: String = "8", + val s9: String = "9", val s10: String = "10", val s11: String = "11", val s12: String = "12", + val s13: String = "13", val s14: String = "14", val s15: String = "15", val s16: String = "16", + val s17: String = "17", val s18: String = "18", val s19: String = "19", val s20: String = "20", + val s21: String = "21", val s22: String = "22", val s23: String = "23", val s24: String = "24", + val s25: String = "25", val s26: String = "26", val s27: String = "27", val s28: String = "28", + val s29: String = "29", val s30: String = "30", val s31: String = "31", val s32: String = "32", + val s33: String = "33", val s34: String = "34", val s35: String = "35", val s36: String = "36", + val s37: String = "37", val s38: String = "38", val s39: String = "39", val s40: String = "40", + val s41: String = "41", val s42: String = "42", val s43: String = "43", val s44: String = "44", + val s45: String = "45", val s46: String = "46", val s47: String = "47", val s48: String = "48", + val s49: String = "49", val s50: String = "50", val s51: String = "51", val s52: String = "52", + val s53: String = "53", val s54: String = "54", val s55: String = "55", val s56: String = "56", + val s57: String = "57", val s58: String = "58", val s59: String = "59", val s60: String = "60", + val s61: String = "61", val s62: String = "62", val s63: String = "63", val s64: String = "64", + val s65: String = "65", val s66: String = "66", val s67: String = "67", val s68: String = "68", + val s69: String = "69", val s70: String = "70", val s71: String = "71", val s72: String = "72", + val s73: String = "73", val s74: String = "74", val s75: String = "75", val s76: String = "76", + val s77: String = "77", val s78: String = "78", val s79: String = "79", val s80: String = "80", + val s81: String = "81", val s82: String = "82", val s83: String = "83", val s84: String = "84", + val s85: String = "85", val s86: String = "86", val s87: String = "87", val s88: String = "88", + val s89: String = "89", val s90: String = "90", val s91: String = "91", val s92: String = "92", + val s93: String = "93", val s94: String = "94", val s95: String = "95", val s96: String = "96", + val s97: String = "97", val s98: String = "98", val s99: String = "99", val s100: String = "100") +extends DefinedByConstructorParams + +case class S100_5( + s1: S100 = new S100(), s2: S100 = new S100(), s3: S100 = new S100(), + s4: S100 = new S100(), s5: S100 = new S100()) + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 746e25a0c3ec5..0f6c49e759590 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -152,12 +152,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Row("one", "not_one")) } - test("nvl function") { - checkAnswer( - sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), - Row("x", "y", null)) - } - test("misc md5 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( @@ -358,6 +352,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_keys/map_values function") { + val df = Seq( + (Map[Int, Int](1 -> 100, 2 -> 200), "x"), + (Map[Int, Int](), "y"), + (Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), "z") + ).toDF("a", "b") + checkAnswer( + df.selectExpr("map_keys(a)"), + Seq(Row(Seq(1, 2)), Row(Seq.empty), Row(Seq(1, 2, 3))) + ) + checkAnswer( + df.selectExpr("map_values(a)"), + Seq(Row(Seq(100, 200)), Row(Seq.empty), Row(Seq(100, 200, 300))) + ) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 0414fa1c91efa..4abf5e42b9c34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -154,7 +154,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { // SPARK-12275: no physical plan for BroadcastHint in some condition withTempPath { path => df1.write.parquet(path.getCanonicalPath) - val pf1 = sqlContext.read.parquet(path.getCanonicalPath) + val pf1 = spark.read.parquet(path.getCanonicalPath) assert(df1.join(broadcast(pf1)).count() === 4) } } @@ -204,4 +204,33 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { leftJoin2Inner, Row(1, 2, "1", 1, 3, "1") :: Nil) } + + test("process outer join results using the non-nullable columns in the join input") { + // Filter data using a non-nullable column from a right table + val df1 = Seq((0, 0), (1, 0), (2, 0), (3, 0), (4, 0)).toDF("id", "count") + val df2 = Seq(Tuple1(0), Tuple1(1)).toDF("id").groupBy("id").count + checkAnswer( + df1.join(df2, df1("id") === df2("id"), "left_outer").filter(df2("count").isNull), + Row(2, 0, null, null) :: + Row(3, 0, null, null) :: + Row(4, 0, null, null) :: Nil + ) + + // Coalesce data using non-nullable columns in input tables + val df3 = Seq((1, 1)).toDF("a", "b") + val df4 = Seq((2, 2)).toDF("a", "b") + checkAnswer( + df3.join(df4, df3("a") === df4("a"), "outer") + .select(coalesce(df3("a"), df3("b")), coalesce(df4("a"), df4("b"))), + Row(1, null) :: Row(null, 2) :: Nil + ) + } + + test("SPARK-16991: Full outer join followed by inner join produces wrong results") { + val a = Seq((1, 2), (2, 3)).toDF("a", "b") + val b = Seq((2, 5), (3, 4)).toDF("a", "c") + val c = Seq((3, 1)).toDF("a", "d") + val ab = a.join(b, Seq("a"), "fullouter") + checkAnswer(ab.join(c, "a"), Row(3, null, 4, 1) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 368aa5cd141f0..d5cb5e15688e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ import testImplicits._ - test("pivot courses with literals") { + test("pivot courses") { checkAnswer( courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java")) .agg(sum($"earnings")), @@ -32,14 +34,14 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ ) } - test("pivot year with literals") { + test("pivot year") { checkAnswer( courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")), Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil ) } - test("pivot courses with literals and multiple aggregations") { + test("pivot courses with multiple aggregations") { checkAnswer( courseSales.groupBy($"year") .pivot("course", Seq("dotNET", "Java")) @@ -79,11 +81,11 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ } test("pivot max values enforced") { - sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1) + spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, 1) intercept[AnalysisException]( courseSales.groupBy("year").pivot("course") ) - sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, + spark.conf.set(SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key, SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get) } @@ -94,4 +96,105 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil ) } + + // Tests for optimized pivot (with PivotFirst) below + + test("optimized pivot planned") { + val df = courseSales.groupBy("year") + // pivot with extra columns to trigger optimization + .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) + .agg(sum($"earnings")) + val queryExecution = spark.sessionState.executePlan(df.queryExecution.logical) + assert(queryExecution.simpleString.contains("pivotfirst")) + } + + + test("optimized pivot courses with literals") { + checkAnswer( + courseSales.groupBy("year") + // pivot with extra columns to trigger optimization + .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) + .agg(sum($"earnings")) + .select("year", "dotNET", "Java"), + Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil + ) + } + + test("optimized pivot year with literals") { + checkAnswer( + courseSales.groupBy($"course") + // pivot with extra columns to trigger optimization + .pivot("year", Seq(2012, 2013) ++ (1 to 10)) + .agg(sum($"earnings")) + .select("course", "2012", "2013"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("optimized pivot year with string values (cast)") { + checkAnswer( + courseSales.groupBy("course") + // pivot with extra columns to trigger optimization + .pivot("year", Seq("2012", "2013") ++ (1 to 10).map(_.toString)) + .sum("earnings") + .select("course", "2012", "2013"), + Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil + ) + } + + test("optimized pivot DecimalType") { + val df = courseSales.select($"course", $"year", $"earnings".cast(DecimalType(10, 2))) + .groupBy("year") + // pivot with extra columns to trigger optimization + .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) + .agg(sum($"earnings")) + .select("year", "dotNET", "Java") + + assertResult(IntegerType)(df.schema("year").dataType) + assertResult(DecimalType(20, 2))(df.schema("Java").dataType) + assertResult(DecimalType(20, 2))(df.schema("dotNET").dataType) + + checkAnswer(df, Row(2012, BigDecimal(1500000, 2), BigDecimal(2000000, 2)) :: + Row(2013, BigDecimal(4800000, 2), BigDecimal(3000000, 2)) :: Nil) + } + + test("PivotFirst supported datatypes") { + val supportedDataTypes: Seq[DataType] = DoubleType :: IntegerType :: LongType :: FloatType :: + BooleanType :: ShortType :: ByteType :: Nil + for (datatype <- supportedDataTypes) { + assertResult(true)(PivotFirst.supportsDataType(datatype)) + } + assertResult(true)(PivotFirst.supportsDataType(DecimalType(10, 1))) + assertResult(false)(PivotFirst.supportsDataType(null)) + assertResult(false)(PivotFirst.supportsDataType(ArrayType(IntegerType))) + } + + test("optimized pivot with multiple aggregations") { + checkAnswer( + courseSales.groupBy($"year") + // pivot with extra columns to trigger optimization + .pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString)) + .agg(sum($"earnings"), avg($"earnings")), + Row(Seq(2012, 15000.0, 7500.0, 20000.0, 20000.0) ++ Seq.fill(20)(null): _*) :: + Row(Seq(2013, 48000.0, 48000.0, 30000.0, 30000.0) ++ Seq.fill(20)(null): _*) :: Nil + ) + } + + test("pivot with datatype not supported by PivotFirst") { + checkAnswer( + complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), + Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil + ) + } + + test("pivot with datatype not supported by PivotFirst 2") { + checkAnswer( + courseSales.withColumn("e", expr("array(earnings, 7.0d)")) + .groupBy("year") + .pivot("course", Seq("dotNET", "Java")) + .agg(min($"e")), + Row(2012, Seq(5000.0, 7.0), Seq(20000.0, 7.0)) :: + Row(2013, Seq(48000.0, 7.0), Seq(30000.0, 7.0)) :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0ea7727e45029..73026c749db45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -235,8 +235,19 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(items.length === 1) } + test("SPARK-15709: Prevent `UnsupportedOperationException: empty.min` in `freqItems`") { + val ds = spark.createDataset(Seq(1, 2, 2, 3, 3, 3)) + + intercept[IllegalArgumentException] { + ds.stat.freqItems(Seq("value"), 0) + } + intercept[IllegalArgumentException] { + ds.stat.freqItems(Seq("value"), 2) + } + } + test("sampleBy") { - val df = sqlContext.range(0, 100).select((col("id") % 3).as("key")) + val df = spark.range(0, 100).select((col("id") % 3).as("key")) val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), @@ -247,7 +258,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // `CountMinSketch`es that meet required specs. Test cases for `CountMinSketch` can be found in // `CountMinSketchSuite` in project spark-sketch. test("countMinSketch") { - val df = sqlContext.range(1000) + val df = spark.range(1000) val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42) assert(sketch1.totalCount() === 1000) @@ -279,7 +290,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // This test only verifies some basic requirements, more correctness tests can be found in // `BloomFilterSuite` in project spark-sketch. test("Bloom filter") { - val df = sqlContext.range(1000) + val df = spark.range(1000) val filter1 = df.stat.bloomFilter("id", 1000, 0.03) assert(filter1.expectedFpp() - 0.03 < 1e-3) @@ -304,7 +315,7 @@ class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Loggin // Turn on this test if you want to test the performance of approximate quantiles. ignore("computing quantiles should not take much longer than describe()") { - val df = sqlContext.range(5000000L).toDF("col1").cache() + val df = spark.range(5000000L).toDF("col1").cache() def seconds(f: => Any): Double = { // Do some warmup logDebug("warmup...") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 80a93ee6d4f3e..4478a9afe8585 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.io.File import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.util.UUID import scala.language.postfixOps import scala.util.Random @@ -26,15 +28,18 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} import org.apache.spark.sql.test.SQLTestData.TestData2 import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class DataFrameSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -99,8 +104,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0)))) val schema2 = StructType(Array(StructField("label", IntegerType, false), StructField("point", new ExamplePointUDT(), false))) - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) - val df2 = sqlContext.createDataFrame(rowRDD2, schema2) + val df1 = spark.createDataFrame(rowRDD1, schema1) + val df2 = spark.createDataFrame(rowRDD2, schema2) checkAnswer( df1.union(df2).orderBy("label"), @@ -109,8 +114,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("empty data frame") { - assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(sqlContext.emptyDataFrame.count() === 0) + assert(spark.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(spark.emptyDataFrame.count() === 0) } test("head and take") { @@ -259,12 +264,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("repartition") { + intercept[IllegalArgumentException] { + testData.select('key).repartition(0) + } + checkAnswer( testData.select('key).repartition(10).select('key), testData.select('key).collect().toSeq) } test("coalesce") { + intercept[IllegalArgumentException] { + testData.select('key).coalesce(0) + } + assert(testData.select('key).coalesce(1).rdd.partitions.size === 1) checkAnswer( @@ -369,7 +382,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake checkAnswer( - sqlContext.range(2).toDF().limit(2147483638), + spark.range(2).toDF().limit(2147483638), Row(0) :: Row(1) :: Nil ) } @@ -507,7 +520,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("callUDF in SQLContext") { + test("callUDF without Hive Support") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") df.sparkSession.udf.register("simpleUDF", (v: Int) => v * v) checkAnswer( @@ -601,6 +614,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df("id") == person("id")) } + test("drop top level columns that contains dot") { + val df1 = Seq((1, 2)).toDF("a.b", "a.c") + checkAnswer(df1.drop("a.b"), Row(2)) + + // Creates data set: {"a.b": 1, "a": {"b": 3}} + val df2 = Seq((1)).toDF("a.b").withColumn("a", struct(lit(3) as "b")) + // Not like select(), drop() parses the column name "a.b" literally without interpreting "." + checkAnswer(df2.drop("a.b").select("a.b"), Row(3)) + + // "`" is treated as a normal char here with no interpreting, "`a`b" is a valid column name. + assert(df2.drop("`a.b`").columns.size == 2) + } + + test("drop(name: String) search and drop all top level columns that matchs the name") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((3, 4)).toDF("a", "b") + checkAnswer(df1.join(df2), Row(1, 2, 3, 4)) + // Finds and drops all columns that match the name (case insensitive). + checkAnswer(df1.join(df2).drop("A"), Row(2, 4)) + } + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") @@ -672,12 +706,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val parquetDir = new File(dir, "parquet").getCanonicalPath df.write.parquet(parquetDir) - val parquetDF = sqlContext.read.parquet(parquetDir) + val parquetDF = spark.read.parquet(parquetDir) assert(parquetDF.inputFiles.nonEmpty) val jsonDir = new File(dir, "json").getCanonicalPath df.write.json(jsonDir) - val jsonDF = sqlContext.read.json(jsonDir) + val jsonDF = spark.read.json(jsonDir) assert(parquetDF.inputFiles.nonEmpty) val unioned = jsonDF.union(parquetDF).inputFiles.sorted @@ -801,7 +835,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = sqlContext.createDataFrame(rowRDD, schema) + val df = spark.createDataFrame(rowRDD, schema) df.rdd.collect() } @@ -818,14 +852,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = sqlContext.read.json(sparkContext.makeRDD( + val df = spark.read.json(sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = sqlContext.read.json(sparkContext.makeRDD( + val df2 = spark.read.json(sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -877,57 +911,61 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( testData.dropDuplicates(Seq("value2")), Seq(Row(2, 1, 2), Row(1, 1, 1))) + + checkAnswer( + testData.dropDuplicates("key", "value1"), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) } test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = sqlContext.range(0, 10, 1, 15).select("id") + val res1 = spark.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = sqlContext.range(3, 15, 3, 2).select("id") + val res2 = spark.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = sqlContext.range(1, -2).select("id") + val res3 = spark.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = sqlContext.range(1, -2, -2, 6).select("id") + val res4 = spark.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = sqlContext.range(-3, -8, -2, 1).select("id") + val res5 = spark.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = sqlContext.range(-8, -4, 2, 1).select("id") + val res6 = spark.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = sqlContext.range(-10, -9, -20, 1).select("id") + val res7 = spark.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = sqlContext.range(10).select("id") + val res10 = spark.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = sqlContext.range(-1).select("id") + val res11 = spark.range(-1).select("id") assert(res11.count == 0) // using the default slice number - val res12 = sqlContext.range(3, 15, 3).select("id") + val res12 = spark.range(3, 15, 3).select("id") assert(res12.count == 4) assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) } @@ -993,18 +1031,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // pass case: parquet table (HadoopFsRelation) df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) - pdf.registerTempTable("parquet_base") + val pdf = spark.read.parquet(tempParquetFile.getCanonicalPath) + pdf.createOrReplaceTempView("parquet_base") + insertion.write.insertInto("parquet_base") // pass case: json table (InsertableRelation) df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) - jdf.registerTempTable("json_base") + val jdf = spark.read.json(tempJsonFile.getCanonicalPath) + jdf.createOrReplaceTempView("json_base") insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") // error cases: insert into an RDD - df.registerTempTable("rdd_base") + df.createOrReplaceTempView("rdd_base") val e1 = intercept[AnalysisException] { insertion.write.insertInto("rdd_base") } @@ -1012,14 +1051,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // error case: insert into a logical plan that is not a LeafNode val indirectDS = pdf.select("_1").filter($"_1" > 5) - indirectDS.registerTempTable("indirect_ds") + indirectDS.createOrReplaceTempView("indirect_ds") val e2 = intercept[AnalysisException] { insertion.write.insertInto("indirect_ds") } assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - Dataset.ofRows(sqlContext.sparkSession, OneRowRelation).registerTempTable("one_row") + Dataset.ofRows(spark, OneRowRelation).createOrReplaceTempView("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } @@ -1062,7 +1101,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-9323: DataFrame.orderBy should support nested column name") { - val df = sqlContext.read.json(sparkContext.makeRDD( + val df = spark.read.json(sparkContext.makeRDD( """{"a": {"b": 1}}""" :: Nil)) checkAnswer(df.orderBy("a.b"), Row(Row(1))) } @@ -1091,10 +1130,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val dir2 = new File(dir, "dir2").getCanonicalPath df2.write.format("json").save(dir2) - checkAnswer(sqlContext.read.format("json").load(dir1, dir2), + checkAnswer(spark.read.format("json").load(dir1, dir2), Row(1, 22) :: Row(2, 23) :: Nil) - checkAnswer(sqlContext.read.format("json").load(dir1), + checkAnswer(spark.read.format("json").load(dir1), Row(1, 22) :: Nil) } } @@ -1116,7 +1155,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") { - val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD( + val input = spark.read.json(spark.sparkContext.makeRDD( (1 to 10).map(i => s"""{"id": $i}"""))) val df = input.select($"id", rand(0).as('r)) @@ -1185,7 +1224,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { withTempPath { path => Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath) - val df = sqlContext.read.parquet(path.getAbsolutePath) + val df = spark.read.parquet(path.getAbsolutePath) checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a")) } } @@ -1197,7 +1236,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyNonExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => + case agg: HashAggregateExec => atFirstAgg = !atFirstAgg case _ => if (atFirstAgg) { @@ -1212,7 +1251,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => + case agg: HashAggregateExec => if (atFirstAgg) { fail("Should not have back to back Aggregates") } @@ -1244,7 +1283,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { verifyExchangingAgg(testData.repartition($"key", $"value") .groupBy("key").count()) - val data = sqlContext.sparkContext.parallelize( + val data = spark.sparkContext.parallelize( (1 to 100).map(i => TestData2(i % 10, i))).toDF() // Distribute and order by. @@ -1308,7 +1347,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { withTempPath { path => val p = path.getAbsolutePath Seq(2012 -> "a").toDF("year", "val").write.partitionBy("yEAr").parquet(p) - checkAnswer(sqlContext.read.parquet(p).select("YeaR"), Row(2012)) + checkAnswer(spark.read.parquet(p).select("YeaR"), Row(2012)) } } } @@ -1317,7 +1356,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-11633: LogicalRDD throws TreeNode Exception: Failed to Copy Node") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { val rdd = sparkContext.makeRDD(Seq(Row(1, 3), Row(2, 1))) - val df = sqlContext.createDataFrame( + val df = spark.createDataFrame( rdd, new StructType().add("f1", IntegerType).add("f2", IntegerType), needsConversion = false).select($"F1", $"f2".as("f2")) @@ -1344,7 +1383,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) - sqlContext.udf.register("boxedUDF", + spark.udf.register("boxedUDF", (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) @@ -1393,7 +1432,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("reuse exchange") { withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") { - val df = sqlContext.range(100).toDF() + val df = spark.range(100).toDF() val join = df.join(df, "id") val plan = join.queryExecution.executedPlan checkAnswer(join, df) @@ -1415,14 +1454,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } test("sameResult() on aggregate") { - val df = sqlContext.range(100) + val df = spark.range(100) val agg1 = df.groupBy().count() val agg2 = df.groupBy().count() // two aggregates with different ExprId within them should have same result assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan)) val agg3 = df.groupBy().sum() assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan)) - val df2 = sqlContext.range(101) + val df2 = spark.range(101) val agg4 = df2.groupBy().count() assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan)) } @@ -1443,36 +1482,132 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("SPARK-12982: Add table name validation in temp table registration") { val df = Seq("foo", "bar").map(Tuple1.apply).toDF("col") // invalid table name test as below - intercept[AnalysisException](df.registerTempTable("t~")) + intercept[AnalysisException](df.createOrReplaceTempView("t~")) // valid table name test as below - df.registerTempTable("table1") + df.createOrReplaceTempView("table1") // another invalid table name test as below - intercept[AnalysisException](df.registerTempTable("#$@sum")) + intercept[AnalysisException](df.createOrReplaceTempView("#$@sum")) // another invalid table name test as below - intercept[AnalysisException](df.registerTempTable("table!#")) + intercept[AnalysisException](df.createOrReplaceTempView("table!#")) } test("assertAnalyzed shouldn't replace original stack trace") { val e = intercept[AnalysisException] { - sqlContext.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) + spark.range(1).select('id as 'a, 'id as 'b).groupBy('a).agg('b) } assert(e.getStackTrace.head.getClassName != classOf[QueryExecution].getName) } test("SPARK-13774: Check error message for non existent path without globbed paths") { - val e = intercept[AnalysisException] (sqlContext.read.format("csv"). - load("/xyz/file2", "/xyz/file21", "/abc/files555", "a")).getMessage() - assert(e.startsWith("Path does not exist")) + val uuid = UUID.randomUUID().toString + val baseDir = Utils.createTempDir() + try { + val e = intercept[AnalysisException] { + spark.read.format("csv").load( + new File(baseDir, "file").getAbsolutePath, + new File(baseDir, "file2").getAbsolutePath, + new File(uuid, "file3").getAbsolutePath, + uuid).rdd + } + assert(e.getMessage.startsWith("Path does not exist")) + } finally { + + } + } test("SPARK-13774: Check error message for not existent globbed paths") { - val e = intercept[AnalysisException] (sqlContext.read.format("text"). - load( "/xyz/*")).getMessage() - assert(e.startsWith("Path does not exist")) + // Non-existent initial path component: + val nonExistentBasePath = "/" + UUID.randomUUID().toString + assert(!new File(nonExistentBasePath).exists()) + val e = intercept[AnalysisException] { + spark.read.format("text").load(s"$nonExistentBasePath/*") + } + assert(e.getMessage.startsWith("Path does not exist")) + + // Existent initial path component, but no matching files: + val baseDir = Utils.createTempDir() + val childDir = Utils.createTempDir(baseDir.getAbsolutePath) + assert(childDir.exists()) + try { + val e1 = intercept[AnalysisException] { + spark.read.json(s"${baseDir.getAbsolutePath}/*/*-xyz.json").rdd + } + assert(e1.getMessage.startsWith("Path does not exist")) + } finally { + Utils.deleteRecursively(baseDir) + } + } + + test("SPARK-15230: distinct() does not handle column name with dot properly") { + val df = Seq(1, 1, 2).toDF("column.with.dot") + checkAnswer(df.distinct(), Row(1) :: Row(2) :: Nil) + } + + test("SPARK-16181: outer join with isNull filter") { + val left = Seq("x").toDF("col") + val right = Seq("y").toDF("col").withColumn("new", lit(true)) + val joined = left.join(right, left("col") === right("col"), "left_outer") + + checkAnswer(joined, Row("x", null, null)) + checkAnswer(joined.filter($"new".isNull), Row("x", null, null)) + } + + test("SPARK-16664: persist with more than 200 columns") { + val size = 201L + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(Seq.range(0, size)))) + val schemas = List.range(0, size).map(a => StructField("name" + a, LongType, true)) + val df = spark.createDataFrame(rdd, StructType(schemas), false) + assert(df.persist.take(1).apply(0).toSeq(100).asInstanceOf[Long] == 100) + } + + test("copy results for sampling with replacement") { + val df = Seq((1, 0), (2, 0), (3, 0)).toDF("a", "b") + val sampleDf = df.sample(true, 2.00) + val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect + assert(d.size == d.distinct.size) + } + + test("SPARK-17409: Do Not Optimize Query in CTAS (Data source tables) More Than Once") { + withTable("bar") { + withTempView("foo") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { + sql("select 0 as id").createOrReplaceTempView("foo") + val df = sql("select * from foo group by id") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + df.write.mode("overwrite").saveAsTable("bar") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(tableMetadata.properties(DATASOURCE_PROVIDER) == "json", + "the expected table is a data source table using json") + } + } + } + } + + test("SPARK-17123: Performing set operations that combine non-scala native types") { + val dates = Seq( + (BigDecimal.valueOf(1), new Timestamp(2)), + (BigDecimal.valueOf(4), new Timestamp(5)) + ).toDF("decimal", "timestamp") + + val widenTypedRows = Seq( + (10.5D, "string") + ).toDF("decimal", "timestamp") + + dates.union(widenTypedRows).collect() + dates.except(widenTypedRows).collect() + dates.intersect(widenTypedRows).collect() + } - val e1 = intercept[AnalysisException] (sqlContext.read.json("/mnt/*/*-xyz.json").rdd). - getMessage() - assert(e1.startsWith("Path does not exist")) + test("SPARK-18070 binary operator should not consider nullability when comparing input types") { + val rows = Seq(Row(Seq(1), Seq(1))) + val schema = new StructType() + .add("array1", ArrayType(IntegerType)) + .add("array2", ArrayType(IntegerType, containsNull = false)) + val df = spark.createDataFrame(spark.sparkContext.makeRDD(rows), schema) + assert(df.filter($"array1" === $"array2").count() == 1) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 06584ec21e2f8..4296ec543e275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -29,16 +29,6 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B import testImplicits._ - override def beforeEach(): Unit = { - super.beforeEach() - TimeZone.setDefault(TimeZone.getTimeZone("UTC")) - } - - override def afterEach(): Unit = { - super.beforeEach() - TimeZone.setDefault(null) - } - test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), @@ -245,18 +235,18 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B Seq( ("2016-03-27 19:39:34", 1), ("2016-03-27 19:39:56", 2), - ("2016-03-27 19:39:27", 4)).toDF("time", "value").registerTempTable(tableName) + ("2016-03-27 19:39:27", 4)).toDF("time", "value").createOrReplaceTempView(tableName) try { f(tableName) } finally { - sqlContext.dropTempTable(tableName) + spark.catalog.dropTempView(tableName) } } test("time window in SQL with single string expression") { withTempTable { table => checkAnswer( - sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""") + spark.sql(s"""select window(time, "10 seconds"), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), @@ -270,7 +260,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("time window in SQL with with two expressions") { withTempTable { table => checkAnswer( - sqlContext.sql( + spark.sql( s"""select window(time, "10 seconds", 10000000), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( @@ -285,7 +275,7 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("time window in SQL with with three expressions") { withTempTable { table => checkAnswer( - sqlContext.sql( + spark.sql( s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""") .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index 68e99d6a6b816..fe6ba83b4cbfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -48,7 +48,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { .add("b3", FloatType) .add("b4", DoubleType)) - val df = sqlContext.createDataFrame(data, schema) + val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(struct)) } @@ -70,7 +70,7 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { .add("b5b", StringType)) .add("b6", StringType)) - val df = sqlContext.createDataFrame(data, schema) + val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(outerStruct)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index 91095af0ddae7..c6f8c3ad3fc93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -49,7 +49,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("lead") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") checkAnswer( df.select( @@ -59,7 +59,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("lag") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") checkAnswer( df.select( @@ -70,7 +70,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("lead with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") checkAnswer( df.select( lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), @@ -80,7 +80,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("lag with default value") { val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") checkAnswer( df.select( lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), @@ -89,7 +89,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("rank functions in unspecific window") { val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") checkAnswer( df.select( $"key", @@ -110,9 +110,17 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } + test("window function should fail if order by clause is not specified") { + val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") + val e = intercept[AnalysisException]( + // Here we missed .orderBy("key")! + df.select(row_number().over(Window.partitionBy("value"))).collect()) + assert(e.message.contains("requires window to be ordered")) + } + test("aggregation and rows between") { val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), @@ -121,7 +129,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("aggregation and range between") { val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), @@ -131,7 +139,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") checkAnswer( df.select( $"key", @@ -146,7 +154,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("aggregation and range between with unbounded") { val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") checkAnswer( df.select( $"key", @@ -237,6 +245,18 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { Seq(Row("a", 6, 9), Row("b", 9, 6))) } + test("SPARK-16195 empty over spec") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("b", 2)). + toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select($"key", $"value", sum($"value").over(), avg($"value").over()), + Seq(Row("a", 1, 6, 1.5), Row("a", 1, 6, 1.5), Row("a", 2, 6, 1.5), Row("b", 2, 6, 1.5))) + checkAnswer( + sql("select key, value, sum(value) over(), avg(value) over() from window_table"), + Seq(Row("a", 1, 6, 1.5), Row("a", 1, 6, 1.5), Row("a", 2, 6, 1.5), Row("b", 2, 6, 1.5))) + } + test("window function with udaf") { val udaf = new UserDefinedAggregateFunction { def inputSchema: StructType = new StructType() @@ -357,7 +377,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("aggregation and rows between with unbounded + predicate pushdown") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") val selectList = Seq($"key", $"value", last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), @@ -372,7 +392,7 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext { test("aggregation and range between with unbounded + predicate pushdown") { val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") - df.registerTempTable("window_table") + df.createOrReplaceTempView("window_table") val selectList = Seq($"key", $"value", last("value").over( Window.partitionBy($"value").orderBy($"key").rangeBetween(-2, -1)).equalTo("2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 6eae3ed7ad6c0..32fcf84b02f92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -21,9 +21,10 @@ import scala.language.postfixOps import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator -import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StringType object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -52,6 +53,16 @@ object ClassInputAgg extends Aggregator[AggData, Int, Int] { } +object ClassBufferAggregator extends Aggregator[AggData, AggData, Int] { + override def zero: AggData = AggData(0, "") + override def reduce(b: AggData, a: AggData): AggData = AggData(b.a + a.a, "") + override def finish(reduction: AggData): Int = reduction.a + override def merge(b1: AggData, b2: AggData): AggData = AggData(b1.a + b2.a, "") + override def bufferEncoder: Encoder[AggData] = Encoders.product[AggData] + override def outputEncoder: Encoder[Int] = Encoders.scalaInt +} + + object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { override def zero: (Int, AggData) = 0 -> AggData(0, "0") override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) @@ -63,6 +74,16 @@ object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { } +object MapTypeBufferAgg extends Aggregator[Int, Map[Int, Int], Int] { + override def zero: Map[Int, Int] = Map.empty + override def reduce(b: Map[Int, Int], a: Int): Map[Int, Int] = b + override def finish(reduction: Map[Int, Int]): Int = 1 + override def merge(b1: Map[Int, Int], b2: Map[Int, Int]): Map[Int, Int] = b1 + override def bufferEncoder: Encoder[Map[Int, Int]] = ExpressionEncoder() + override def outputEncoder: Encoder[Int] = ExpressionEncoder() +} + + object NameAgg extends Aggregator[AggData, String, String] { def zero: String = "" def reduce(b: String, a: AggData): String = a.b + b @@ -104,11 +125,23 @@ object RowAgg extends Aggregator[Row, Int, Int] { override def outputEncoder: Encoder[Int] = Encoders.scalaInt } +object NullResultAgg extends Aggregator[AggData, AggData, AggData] { + override def zero: AggData = AggData(0, "") + override def reduce(b: AggData, a: AggData): AggData = AggData(b.a + a.a, b.b + a.b) + override def finish(reduction: AggData): AggData = { + if (reduction.a % 2 == 0) null else reduction + } + override def merge(b1: AggData, b2: AggData): AggData = AggData(b1.a + b2.a, b1.b + b2.b) + override def bufferEncoder: Encoder[AggData] = Encoders.product[AggData] + override def outputEncoder: Encoder[AggData] = Encoders.product[AggData] +} -class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { +class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ + private implicit val ordering = Ordering.by((c: AggData) => c.a -> c.b) + test("typed aggregation: TypedAggregator") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() @@ -173,6 +206,14 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ("one", 1)) } + test("Typed aggregation using aggregator") { + // based on Dataset complex Aggregator test of DatasetBenchmark + val ds = Seq(AggData(1, "x"), AggData(2, "y"), AggData(3, "z")).toDS() + checkDataset( + ds.select(ClassBufferAggregator.toColumn), + 6) + } + test("typed aggregation: complex input") { val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() @@ -185,7 +226,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn), (1.5, 2)) - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn), ("one", 1), ("two", 1)) } @@ -232,4 +273,36 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { "a" -> Seq(1, 2) ) } + + test("spark-15051 alias of aggregator in DataFrame/Dataset[Row]") { + val df1 = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + checkAnswer(df1.agg(RowAgg.toColumn as "b"), Row(6) :: Nil) + + val df2 = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + checkAnswer(df2.agg(RowAgg.toColumn as "b").select("b"), Row(6) :: Nil) + } + + test("spark-15114 shorter system generated alias names") { + val ds = Seq(1, 3, 2, 5).toDS() + assert(ds.select(typed.sum((i: Int) => i)).columns.head === "TypedSumDouble(int)") + val ds2 = ds.select(typed.sum((i: Int) => i), typed.avg((i: Int) => i)) + assert(ds2.columns.head === "TypedSumDouble(int)") + assert(ds2.columns.last === "TypedAverage(int)") + val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + assert(df.groupBy($"j").agg(RowAgg.toColumn).columns.last == + "RowAgg(org.apache.spark.sql.Row)") + assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1") + } + + test("SPARK-15814 Aggregator can return null result") { + val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() + checkDatasetUnorderly( + ds.groupByKey(_.a).agg(NullResultAgg.toColumn), + 1 -> AggData(1, "one"), 2 -> null) + } + + test("SPARK-16100: use Map as the buffer type of Aggregator") { + val ds = Seq(1, 2, 3).toDS() + checkDataset(ds.select(MapTypeBufferAgg.toColumn), 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index ae9fb80c68f42..4101e5c75b937 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.expressions.Aggregator -import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StringType import org.apache.spark.util.Benchmark @@ -31,14 +31,14 @@ object DatasetBenchmark { case class Data(l: Long, s: String) - def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { - import sqlContext.implicits._ + def backToBackMap(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back map", numRows) val func = (d: Data) => Data(d.l + 1, d.s) - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -72,17 +72,17 @@ object DatasetBenchmark { benchmark } - def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { - import sqlContext.implicits._ + def backToBackFilter(spark: SparkSession, numRows: Long, numChains: Int): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back filter", numRows) val func = (d: Data, i: Int) => d.l % (100L + i) == 0L val funcs = 0.until(numChains).map { i => (d: Data) => func(d, i) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -130,13 +130,13 @@ object DatasetBenchmark { override def outputEncoder: Encoder[Long] = Encoders.scalaLong } - def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = { - import sqlContext.implicits._ + def aggregate(spark: SparkSession, numRows: Long): Benchmark = { + import spark.implicits._ - val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val df = spark.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("aggregate", numRows) - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = spark.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD sum") { iter => rdd.aggregate(0L)(_ + _.l, _ + _) } @@ -157,15 +157,17 @@ object DatasetBenchmark { } def main(args: Array[String]): Unit = { - val sparkContext = new SparkContext("local[*]", "Dataset benchmark") - val sqlContext = new SQLContext(sparkContext) + val spark = SparkSession.builder + .master("local[*]") + .appName("Dataset benchmark") + .getOrCreate() val numRows = 100000000 val numChains = 10 - val benchmark = backToBackMap(sqlContext, numRows, numChains) - val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) - val benchmark3 = aggregate(sqlContext, numRows) + val benchmark = backToBackMap(spark, numRows, numChains) + val benchmark2 = backToBackFilter(spark, numRows, numChains) + val benchmark3 = aggregate(spark, numRows) /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 942cc09b6d58e..ac9f6c2f38537 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -39,7 +39,8 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { 2, 3, 4) // Drop the cache. cached.unpersist() - assert(!sqlContext.isCached(cached), "The Dataset should not be cached.") + assert(spark.sharedState.cacheManager.lookupCachedData(cached).isEmpty, + "The Dataset should not be cached.") } test("persist and then rebind right encoder when join 2 datasets") { @@ -56,9 +57,11 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(joined, 2) ds1.unpersist() - assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.") + assert(spark.sharedState.cacheManager.lookupCachedData(ds1).isEmpty, + "The Dataset ds1 should not be cached.") ds2.unpersist() - assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.") + assert(spark.sharedState.cacheManager.lookupCachedData(ds2).isEmpty, + "The Dataset ds2 should not be cached.") } test("persist and then groupBy columns asKey, map") { @@ -73,8 +76,10 @@ class DatasetCacheSuite extends QueryTest with SharedSQLContext { assertCached(agged.filter(_._1 == "b")) ds.unpersist() - assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.") + assert(spark.sharedState.cacheManager.lookupCachedData(ds).isEmpty, + "The Dataset ds should not be cached.") agged.unpersist() - assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.") + assert(spark.sharedState.cacheManager.lookupCachedData(agged).isEmpty, + "The Dataset agged should not be cached.") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index ff022b2dc45ee..6aa3d3fe808b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -62,15 +62,15 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("foreach") { val ds = Seq(1, 2, 3).toDS() - val acc = sparkContext.accumulator(0) - ds.foreach(acc += _) + val acc = sparkContext.longAccumulator + ds.foreach(acc.add(_)) assert(acc.value == 6) } test("foreachPartition") { val ds = Seq(1, 2, 3).toDS() - val acc = sparkContext.accumulator(0) - ds.foreachPartition(_.foreach(acc +=)) + val acc = sparkContext.longAccumulator + ds.foreachPartition(_.foreach(acc.add(_))) assert(acc.value == 6) } @@ -82,7 +82,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("groupBy function, keys") { val ds = Seq(1, 2, 3, 4, 5).toDS() val grouped = ds.groupByKey(_ % 2) - checkDataset( + checkDatasetUnorderly( grouped.keys, 0, 1) } @@ -95,7 +95,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { (name, iter.size) } - checkDataset( + checkDatasetUnorderly( agged, ("even", 5), ("odd", 6)) } @@ -105,7 +105,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { val grouped = ds.groupByKey(_.length) val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) } - checkDataset( + checkDatasetUnorderly( agged, "1", "abc", "3", "xyz", "5", "hello") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala new file mode 100644 index 0000000000000..0f3d0cefe3bb5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import com.esotericsoftware.kryo.{Kryo, Serializer} +import com.esotericsoftware.kryo.io.{Input, Output} + +import org.apache.spark.serializer.KryoRegistrator +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.TestSparkSession + +/** + * Test suite to test Kryo custom registrators. + */ +class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + /** + * Initialize the [[TestSparkSession]] with a [[KryoRegistrator]]. + */ + protected override def beforeAll(): Unit = { + sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) + super.beforeAll() + } + + test("Kryo registrator") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + assert(ds.collect().toSet == Set(KryoData(0), KryoData(0))) + } + +} + +/** Used to test user provided registrator. */ +class TestRegistrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo): Unit = + kryo.register(classOf[KryoData], new ZeroKryoDataSerializer()) +} + +object TestRegistrator { + def apply(): TestRegistrator = new TestRegistrator() +} + +/** A [[Serializer]] that takes a [[KryoData]] and serializes it as KryoData(0). */ +class ZeroKryoDataSerializer extends Serializer[KryoData] { + override def write(kryo: Kryo, output: Output, t: KryoData): Unit = { + output.writeInt(0) + } + + override def read(kryo: Kryo, input: Input, aClass: Class[KryoData]): KryoData = { + KryoData(input.readInt()) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index eee21acf7510b..f897cfb26d3ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,7 +22,8 @@ import java.sql.{Date, Timestamp} import scala.language.postfixOps -import org.apache.spark.sql.catalyst.encoders.OuterScopes +import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} +import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -31,6 +32,15 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ + private implicit val ordering = Ordering.by((c: ClassData) => c.a -> c.b) + + test("checkAnswer should compare map correctly") { + val data = Seq((1, "2", Map(1 -> 2, 2 -> 1))) + checkAnswer( + data.toDF(), + Seq(Row(1, "2", Map(2 -> 1, 1 -> 2)))) + } + test("toDS") { val data = Seq(("a", 1), ("b", 2), ("c", 3)) checkDataset( @@ -45,13 +55,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } + test("emptyDataset") { + val ds = spark.emptyDataset[Int] + assert(ds.count() == 0L) + assert(ds.collect() sameElements Array.empty[Int]) + } + test("range") { - assert(sqlContext.range(10).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) - assert(sqlContext.range(0, 10).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) - assert(sqlContext.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) - assert(sqlContext.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(0, 10).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(0, 10).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) + assert(spark.range(0, 10, 1, 2).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(0, 10, 1, 2).map{ case i: java.lang.Long => i + 1 }.reduce(_ + _) == 55) } test("SPARK-12404: Datatype Helper Serializability") { @@ -79,13 +95,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val data = (1 to 100).map(i => ClassData(i.toString, i)) val ds = data.toDS() + intercept[IllegalArgumentException] { + ds.coalesce(0) + } + + intercept[IllegalArgumentException] { + ds.repartition(0) + } + assert(ds.repartition(10).rdd.partitions.length == 10) - checkDataset( + checkDatasetUnorderly( ds.repartition(10), data: _*) assert(ds.coalesce(1).rdd.partitions.length == 1) - checkDataset( + checkDatasetUnorderly( ds.coalesce(1), data: _*) } @@ -148,7 +172,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { .map(c => ClassData(c.a, c.b + 1)) .groupByKey(p => p).count() - checkDataset( + checkDatasetUnorderly( ds, (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } @@ -189,7 +213,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("select 2, primitive and class, fields reordered") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - checkDecoding( + checkDataset( ds.select( expr("_1").as[String], expr("named_struct('b', _2, 'a', _1)").as[ClassData]), @@ -203,17 +227,30 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("b", 2)) } + test("filter and then select") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( + ds.filter(_._1 == "b").select(expr("_1").as[String]), + "b") + } + + test("SPARK-15632: typed filter should preserve the underlying logical schema") { + val ds = spark.range(10) + val ds2 = ds.filter(_ > 3) + assert(ds.schema.equals(ds2.schema)) + } + test("foreach") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - val acc = sparkContext.accumulator(0) - ds.foreach(v => acc += v._2) + val acc = sparkContext.longAccumulator + ds.foreach(v => acc.add(v._2)) assert(acc.value == 6) } test("foreachPartition") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() - val acc = sparkContext.accumulator(0) - ds.foreachPartition(_.foreach(v => acc += v._2)) + val acc = sparkContext.longAccumulator + ds.foreachPartition(_.foreach(v => acc.add(v._2))) assert(acc.value == 6) } @@ -231,21 +268,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (1, 1), (2, 2)) } - test("joinWith, expression condition, outer join") { - val nullInteger = null.asInstanceOf[Integer] - val nullString = null.asInstanceOf[String] - val ds1 = Seq(ClassNullableData("a", 1), - ClassNullableData("c", 3)).toDS() - val ds2 = Seq(("a", new Integer(1)), - ("b", new Integer(2))).toDS() - - checkDataset( - ds1.joinWith(ds2, $"_1" === $"a", "outer"), - (ClassNullableData("a", 1), ("a", new Integer(1))), - (ClassNullableData("c", 3), (nullString, nullInteger)), - (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) - } - test("joinWith tuple with primitive, expression") { val ds1 = Seq(1, 1, 2).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() @@ -278,7 +300,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupByKey(v => (1, v._2)) - checkDataset( + checkDatasetUnorderly( grouped.keys, (1, 1)) } @@ -288,7 +310,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val grouped = ds.groupByKey(v => (v._1, "word")) val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) } - checkDataset( + checkDatasetUnorderly( agged, ("a", 30), ("b", 3), ("c", 1)) } @@ -300,7 +322,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Iterator(g._1, iter.map(_._2).sum.toString) } - checkDataset( + checkDatasetUnorderly( agged, "a", "30", "b", "3", "c", "1") } @@ -309,7 +331,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq("abc", "xyz", "hello").toDS() val agged = ds.groupByKey(_.length).reduceGroups(_ + _) - checkDataset( + checkDatasetUnorderly( agged, 3 -> "abcxyz", 5 -> "hello") } @@ -327,7 +349,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long]), ("a", 30L), ("b", 3L), ("c", 1L)) } @@ -335,7 +357,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L)) } @@ -343,7 +365,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L)) } @@ -351,7 +373,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("typed aggregation: expr, expr, expr, expr") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - checkDataset( + checkDatasetUnorderly( ds.groupByKey(_._1).agg( sum("_2").as[Long], sum($"_2" + 1).as[Long], @@ -367,7 +389,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString)) } - checkDataset( + checkDatasetUnorderly( cogrouped, 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") } @@ -379,7 +401,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString)) } - checkDataset( + checkDatasetUnorderly( cogrouped, 1 -> "a", 2 -> "bc", 3 -> "d") } @@ -400,6 +422,31 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3, 17, 27, 58, 62) } + test("SPARK-16686: Dataset.sample with seed results shouldn't depend on downstream usage") { + val simpleUdf = udf((n: Int) => { + require(n != 1, "simpleUdf shouldn't see id=1!") + 1 + }) + + val df = Seq( + (0, "string0"), + (1, "string1"), + (2, "string2"), + (3, "string3"), + (4, "string4"), + (5, "string5"), + (6, "string6"), + (7, "string7"), + (8, "string8"), + (9, "string9") + ).toDF("id", "stringData") + val sampleDF = df.sample(false, 0.7, 50) + // After sampling, sampleDF doesn't contain id=1. + assert(!sampleDF.select("id").collect.contains(1)) + // simpleUdf should not encounter id=1. + checkAnswer(sampleDF.select(simpleUdf($"id")), List.fill(sampleDF.count.toInt)(Row(1))) + } + test("SPARK-11436: we should rebind right encoder when join 2 datasets") { val ds1 = Seq("1", "2").toDS().as("a") val ds2 = Seq(2, 3).toDS().as("b") @@ -419,20 +466,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.toString == "[_1: int, _2: int]") } - test("showString: Kryo encoder") { - implicit val kryoEncoder = Encoders.kryo[KryoData] - val ds = Seq(KryoData(1), KryoData(2)).toDS() - - val expectedAnswer = """+-----------+ - || value| - |+-----------+ - ||KryoData(1)| - ||KryoData(2)| - |+-----------+ - |""".stripMargin - assert(ds.showString(10) === expectedAnswer) - } - test("Kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() @@ -452,6 +485,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (KryoData(2), KryoData(2)))) } + test("Kryo encoder: check the schema mismatch when converting DataFrame to Dataset") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val df = Seq((1)).toDF("a") + val e = intercept[AnalysisException] { + df.as[KryoData] + }.message + assert(e.contains("cannot cast IntegerType to BinaryType")) + } + test("Java encoder") { implicit val kryoEncoder = Encoders.javaSerialization[JavaData] val ds = Seq(JavaData(1), JavaData(2)).toDS() @@ -472,7 +514,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-14696: implicit encoders for boxed types") { - assert(sqlContext.range(1).map { i => i : java.lang.Long }.head == 0L) + assert(spark.range(1).map { i => i : java.lang.Long }.head == 0L) } test("SPARK-11894: Incorrect results are returned when using null") { @@ -483,8 +525,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset( ds1.joinWith(ds2, lit(true)), ((nullInt, "1"), (nullInt, "1")), - ((new java.lang.Integer(22), "2"), (nullInt, "1")), ((nullInt, "1"), (new java.lang.Integer(22), "2")), + ((new java.lang.Integer(22), "2"), (nullInt, "1")), ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2"))) } @@ -505,13 +547,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val schema = StructType(Seq( StructField("f", StructType(Seq( StructField("a", StringType, nullable = true), - StructField("b", IntegerType, nullable = false) + StructField("b", IntegerType, nullable = true) )), nullable = true) )) def buildDataset(rows: Row*): Dataset[NestedStruct] = { - val rowRDD = sqlContext.sparkContext.parallelize(rows) - sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + val rowRDD = spark.sparkContext.parallelize(rows) + spark.createDataFrame(rowRDD, schema).as[NestedStruct] } checkDataset( @@ -573,18 +615,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { }.message assert(message == "Try to map struct to Tuple3, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct\n" + - " - Target schema: struct<_1:string,_2:int,_3:bigint>") + "but failed as the number of fields does not line up.") val message2 = intercept[AnalysisException] { ds.as[Tuple1[String]] }.message assert(message2 == "Try to map struct to Tuple1, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct\n" + - " - Target schema: struct<_1:string>") + "but failed as the number of fields does not line up.") } test("SPARK-13440: Resolving option fields") { @@ -626,7 +664,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } test("SPARK-14554: Dataset.map may generate wrong java code for wide table") { - val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + val wideDF = spark.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) // Make sure the generated code for this plan can compile and execute. checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*) } @@ -653,8 +691,201 @@ class DatasetSuite extends QueryTest with SharedSQLContext { dataset.join(actual, dataset("user") === actual("id")).collect() } + + test("SPARK-15097: implicits on dataset's spark can be imported") { + val dataset = Seq(1, 2, 3).toDS() + checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4) + } + + test("dataset.rdd with generic case class") { + val ds = Seq(Generic(1, 1.0), Generic(2, 2.0)).toDS() + val ds2 = ds.map(g => Generic(g.id, g.value)) + assert(ds.rdd.map(r => r.id).count === 2) + assert(ds2.rdd.map(r => r.id).count === 2) + + val ds3 = ds.map(g => new java.lang.Long(g.id)) + assert(ds3.rdd.map(r => r).count === 2) + } + + test("runtime null check for RowEncoder") { + val schema = new StructType().add("i", IntegerType, nullable = false) + val df = spark.range(10).map(l => { + if (l % 5 == 0) { + Row(null) + } else { + Row(l) + } + })(RowEncoder(schema)) + + val message = intercept[Exception] { + df.collect() + }.getMessage + assert(message.contains("The 0th field 'i' of input row cannot be null")) + } + + test("row nullability mismatch") { + val schema = new StructType().add("a", StringType, true).add("b", StringType, false) + val rdd = spark.sparkContext.parallelize(Row(null, "123") :: Row("234", null) :: Nil) + val message = intercept[Exception] { + spark.createDataFrame(rdd, schema).collect() + }.getMessage + assert(message.contains("The 1th field 'b' of input row cannot be null")) + } + + test("createTempView") { + val dataset = Seq(1, 2, 3).toDS() + dataset.createOrReplaceTempView("tempView") + + // Overrides the existing temporary view with same name + // No exception should be thrown here. + dataset.createOrReplaceTempView("tempView") + + // Throws AnalysisException if temp view with same name already exists + val e = intercept[AnalysisException]( + dataset.createTempView("tempView")) + intercept[AnalysisException](dataset.createTempView("tempView")) + assert(e.message.contains("already exists")) + dataset.sparkSession.catalog.dropTempView("tempView") + } + + test("SPARK-15381: physical object operator should define `reference` correctly") { + val df = Seq(1 -> 2).toDF("a", "b") + checkAnswer(df.map(row => row)(RowEncoder(df.schema)).select("b", "a"), Row(2, 1)) + } + + private def checkShowString[T](ds: Dataset[T], expected: String): Unit = { + val numRows = expected.split("\n").length - 4 + val actual = ds.showString(numRows, truncate = true) + + if (expected != actual) { + fail( + "Dataset.showString() gives wrong result:\n\n" + sideBySide( + "== Expected ==\n" + expected, + "== Actual ==\n" + actual + ).mkString("\n") + ) + } + } + + test("SPARK-15550 Dataset.show() should show contents of the underlying logical plan") { + val df = Seq((1, "foo", "extra"), (2, "bar", "extra")).toDF("b", "a", "c") + val ds = df.as[ClassData] + val expected = + """+---+---+-----+ + || b| a| c| + |+---+---+-----+ + || 1|foo|extra| + || 2|bar|extra| + |+---+---+-----+ + |""".stripMargin + + checkShowString(ds, expected) + } + + test("SPARK-15550 Dataset.show() should show inner nested products as rows") { + val ds = Seq( + NestedStruct(ClassData("foo", 1)), + NestedStruct(ClassData("bar", 2)) + ).toDS() + + val expected = + """+-------+ + || f| + |+-------+ + ||[foo,1]| + ||[bar,2]| + |+-------+ + |""".stripMargin + + checkShowString(ds, expected) + } + + test( + "SPARK-15112: EmbedDeserializerInFilter should not optimize plan fragment that changes schema" + ) { + val ds = Seq(1 -> "foo", 2 -> "bar").toDF("b", "a").as[ClassData] + + assertResult(Seq(ClassData("foo", 1), ClassData("bar", 2))) { + ds.collect().toSeq + } + + assertResult(Seq(ClassData("bar", 2))) { + ds.filter(_.b > 1).collect().toSeq + } + } + + test("mapped dataset should resolve duplicated attributes for self join") { + val ds = Seq(1, 2, 3).toDS().map(_ + 1) + val ds1 = ds.as("d1") + val ds2 = ds.as("d2") + + checkDatasetUnorderly(ds1.joinWith(ds2, $"d1.value" === $"d2.value"), (2, 2), (3, 3), (4, 4)) + checkDatasetUnorderly(ds1.intersect(ds2), 2, 3, 4) + checkDatasetUnorderly(ds1.except(ds1)) + } + + test("SPARK-15441: Dataset outer join") { + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left") + val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right") + val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + val result = joined.collect().toSet + assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2))) + } + + test("better error message when use java reserved keyword as field name") { + val e = intercept[UnsupportedOperationException] { + Seq(InvalidInJava(1)).toDS() + } + assert(e.getMessage.contains( + "`abstract` is a reserved keyword and cannot be used as field name")) + } + + test("Dataset should support flat input object to be null") { + checkDataset(Seq("a", null).toDS(), "a", null) + } + + test("Dataset should throw RuntimeException if non-flat input object is null") { + val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + assert(e.getMessage.contains("top level non-flat input object")) + } + + test("dropDuplicates") { + val ds = Seq(("a", 1), ("a", 2), ("b", 1), ("a", 1)).toDS() + checkDataset( + ds.dropDuplicates("_1"), + ("a", 1), ("b", 1)) + checkDataset( + ds.dropDuplicates("_2"), + ("a", 1), ("a", 2)) + checkDataset( + ds.dropDuplicates("_1", "_2"), + ("a", 1), ("a", 2), ("b", 1)) + } + + test("SPARK-16097: Encoders.tuple should handle null object correctly") { + val enc = Encoders.tuple(Encoders.tuple(Encoders.STRING, Encoders.STRING), Encoders.STRING) + val data = Seq((("a", "b"), "c"), (null, "d")) + val ds = spark.createDataset(data)(enc) + checkDataset(ds, (("a", "b"), "c"), (null, "d")) + } + + test("SPARK-16995: flat mapping on Dataset containing a column created with lit/expr") { + val df = Seq("1").toDF("a") + + import df.sparkSession.implicits._ + + checkDataset( + df.withColumn("b", lit(0)).as[ClassData] + .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) + checkDataset( + df.withColumn("b", expr("0")).as[ClassData] + .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) + } } +case class Generic[T](id: T, value: Double) + case class OtherTuple(_1: String, _2: Int) case class TupleClass(data: (Int, String)) @@ -674,6 +905,8 @@ case class ClassNullableData(a: String, b: Integer) case class NestedStruct(f: ClassData) case class DeepNestedStruct(f: NestedStruct) +case class InvalidInJava(`abstract`: Int) + /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. @@ -713,3 +946,11 @@ class JavaData(val a: Int) extends Serializable { object JavaData { def apply(a: Int): JavaData = new JavaData(a) } + +/** Used to test importing dataset.spark.implicits._ */ +object DatasetTransform { + def addOne(ds: Dataset[Int]): Dataset[Int] = { + import ds.sparkSession.implicits._ + ds.map(_ + 1) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index b1987c690811d..a41b465548622 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -51,7 +51,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { test("insert an extraStrategy") { try { - sqlContext.experimental.extraStrategies = TestStrategy :: Nil + spark.experimental.extraStrategies = TestStrategy :: Nil val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b") checkAnswer( @@ -62,7 +62,7 @@ class ExtraStrategiesSuite extends QueryTest with SharedSQLContext { df.select("a", "b"), Row("so slow", 1)) } finally { - sqlContext.experimental.extraStrategies = Nil + spark.experimental.extraStrategies = Nil } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala new file mode 100644 index 0000000000000..aedc0a8d6f70b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("stack") { + val df = spark.range(1) + + // Empty DataFrame suppress the result generation + checkAnswer(spark.emptyDataFrame.selectExpr("stack(1, 1, 2, 3)"), Nil) + + // Rows & columns + checkAnswer(df.selectExpr("stack(1, 1, 2, 3)"), Row(1, 2, 3) :: Nil) + checkAnswer(df.selectExpr("stack(2, 1, 2, 3)"), Row(1, 2) :: Row(3, null) :: Nil) + checkAnswer(df.selectExpr("stack(3, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) :: Nil) + checkAnswer(df.selectExpr("stack(4, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + + // Various column types + checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"), + Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil) + + // Repeat generation at every input row + checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"), + Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil) + + // The first argument must be a positive constant integer. + val m = intercept[AnalysisException] { + df.selectExpr("stack(1.1, 1, 2, 3)") + }.getMessage + assert(m.contains("The number of rows must be a positive constant integer.")) + val m2 = intercept[AnalysisException] { + df.selectExpr("stack(-1, 1, 2, 3)") + }.getMessage + assert(m2.contains("The number of rows must be a positive constant integer.")) + + // The data for the same column should have the same type. + val m3 = intercept[AnalysisException] { + df.selectExpr("stack(2, 1, '2.2')") + }.getMessage + assert(m3.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (StringType)")) + + // stack on column data + val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c") + checkAnswer(df2.selectExpr("stack(2, a, b, c)"), Row(1, 2) :: Row(3, null) :: Nil) + + val m4 = intercept[AnalysisException] { + df2.selectExpr("stack(n, a, b, c)") + }.getMessage + assert(m4.contains("The number of rows must be a positive constant integer.")) + + val df3 = Seq((2, 1, 2.0)).toDF("n", "a", "b") + val m5 = intercept[AnalysisException] { + df3.selectExpr("stack(2, a, b)") + }.getMessage + assert(m5.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (DoubleType)")) + + } + + test("single explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + checkAnswer( + df.select(explode('intList)), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + + test("single posexplode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + checkAnswer( + df.select(posexplode('intList)), + Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil) + } + + test("explode and other columns") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + + checkAnswer( + df.select($"a", explode('intList)), + Row(1, 1) :: + Row(1, 2) :: + Row(1, 3) :: Nil) + + checkAnswer( + df.select($"*", explode('intList)), + Row(1, Seq(1, 2, 3), 1) :: + Row(1, Seq(1, 2, 3), 2) :: + Row(1, Seq(1, 2, 3), 3) :: Nil) + } + + test("aliased explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + + checkAnswer( + df.select(explode('intList).as('int)).select('int), + Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + df.select(explode('intList).as('int)).select(sum('int)), + Row(6) :: Nil) + } + + test("explode on map") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map)), + Row("a", "b")) + } + + test("explode on map with aliases") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), + Row("a", "b")) + } + + test("self join explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + val exploded = df.select(explode('intList).as('i)) + + checkAnswer( + exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), + Row(3) :: Nil) + } + + test("inline raises exception on array of null type") { + val m = intercept[AnalysisException] { + spark.range(2).selectExpr("inline(array())") + }.getMessage + assert(m.contains("data type mismatch")) + } + + test("inline with empty table") { + checkAnswer( + spark.range(0).selectExpr("inline(array(struct(10, 100)))"), + Nil) + } + + test("inline on literal") { + checkAnswer( + spark.range(2).selectExpr("inline(array(struct(10, 100), struct(20, 200), struct(30, 300)))"), + Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: + Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: Nil) + } + + test("inline on column") { + val df = Seq((1, 2)).toDF("a", "b") + + checkAnswer( + df.selectExpr("inline(array(struct(a), struct(a)))"), + Row(1) :: Row(1) :: Nil) + + checkAnswer( + df.selectExpr("inline(array(struct(a, b), struct(a, b)))"), + Row(1, 2) :: Row(1, 2) :: Nil) + + // Spark think [struct, struct] is heterogeneous due to name difference. + val m = intercept[AnalysisException] { + df.selectExpr("inline(array(struct(a), struct(b)))") + }.getMessage + assert(m.contains("data type mismatch")) + + checkAnswer( + df.selectExpr("inline(array(struct(a), named_struct('a', b)))"), + Row(1) :: Row(2) :: Nil) + + // Spark think [struct, struct] is heterogeneous due to name difference. + val m2 = intercept[AnalysisException] { + df.selectExpr("inline(array(struct(a), struct(2)))") + }.getMessage + assert(m2.contains("data type mismatch")) + + checkAnswer( + df.selectExpr("inline(array(struct(a), named_struct('a', 2)))"), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + df.selectExpr("struct(a)").selectExpr("inline(array(*))"), + Row(1) :: Nil) + + checkAnswer( + df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"), + Row(1) :: Row(2) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8cbad04e23f18..44889d92ee306 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.language.existentials + import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ @@ -37,7 +39,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.JoinSelection(join) + val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -60,9 +62,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("join operator selection") { - sqlContext.cacheManager.clearCache() + spark.sharedState.cacheManager.clearCache() - withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") { + withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[SortMergeJoinExec]), @@ -112,7 +115,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { // } test("broadcasted hash join operator selection") { - sqlContext.cacheManager.clearCache() + spark.sharedState.cacheManager.clearCache() sql("CACHE TABLE testData") Seq( ("SELECT * FROM testData join testData2 ON key = a", @@ -126,7 +129,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted hash outer join operator selection") { - sqlContext.cacheManager.clearCache() + spark.sharedState.cacheManager.clearCache() sql("CACHE TABLE testData") sql("CACHE TABLE testData2") Seq( @@ -144,7 +147,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = sqlContext.sessionState.planner.JoinSelection(join) + val planned = spark.sessionState.planner.JoinSelection(join) assert(planned.size === 1) } @@ -204,13 +207,27 @@ class JoinSuite extends QueryTest with SharedSQLContext { testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } - test("cartisian product join") { - checkAnswer( - testData3.join(testData3), - Row(1, null, 1, null) :: - Row(1, null, 2, 2) :: - Row(2, 2, 1, null) :: - Row(2, 2, 2, 2) :: Nil) + test("cartesian product join") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + checkAnswer( + testData3.join(testData3), + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) + } + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") { + val e = intercept[Exception] { + checkAnswer( + testData3.join(testData3), + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) + } + assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " + + "disabled by default")) + } } test("left outer join") { @@ -344,8 +361,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("`left`") - upperCaseData.where('N >= 3).registerTempTable("`right`") + upperCaseData.where('N <= 4).createOrReplaceTempView("`left`") + upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") val left = UnresolvedRelation(TableIdentifier("left"), None) val right = UnresolvedRelation(TableIdentifier("right"), None) @@ -435,10 +452,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("broadcasted existence join operator selection") { - sqlContext.cacheManager.clearCache() + spark.sharedState.cacheManager.clearCache() sql("CACHE TABLE testData") - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastHashJoinExec]), @@ -461,17 +478,17 @@ class JoinSuite extends QueryTest with SharedSQLContext { test("cross join with broadcast") { sql("CACHE TABLE testData") - val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData")) + val sizeInByteOfTestData = statisticSizeInByte(spark.table("testData")) // we set the threshold is greater than statistic of the cached table testData withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) { - assert(statisticSizeInByte(sqlContext.table("testData2")) > - sqlContext.conf.autoBroadcastJoinThreshold) + assert(statisticSizeInByte(spark.table("testData2")) > + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) - assert(statisticSizeInByte(sqlContext.table("testData")) < - sqlContext.conf.autoBroadcastJoinThreshold) + assert(statisticSizeInByte(spark.table("testData")) < + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala deleted file mode 100644 index 9f6c86a5752ac..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import org.scalatest.BeforeAndAfter - -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} - -class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { - import testImplicits._ - - private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") - - before { - df.registerTempTable("listtablessuitetable") - } - - after { - sqlContext.sessionState.catalog.dropTable( - TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - } - - test("get all tables") { - checkAnswer( - sqlContext.tables().filter("tableName = 'listtablessuitetable'"), - Row("listtablessuitetable", true)) - - checkAnswer( - sql("SHOW tables").filter("tableName = 'listtablessuitetable'"), - Row("listtablessuitetable", true)) - - sqlContext.sessionState.catalog.dropTable( - TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) - } - - test("getting all tables with a database name has no impact on returned table names") { - checkAnswer( - sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"), - Row("listtablessuitetable", true)) - - checkAnswer( - sql("show TABLES in default").filter("tableName = 'listtablessuitetable'"), - Row("listtablessuitetable", true)) - - sqlContext.sessionState.catalog.dropTable( - TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) - } - - test("query the returned DataFrame of tables") { - val expectedSchema = StructType( - StructField("tableName", StringType, false) :: - StructField("isTemporary", BooleanType, false) :: Nil) - - Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach { - case tableDF => - assert(expectedSchema === tableDF.schema) - - tableDF.registerTempTable("tables") - checkAnswer( - sql( - "SELECT isTemporary, tableName from tables WHERE tableName = 'listtablessuitetable'"), - Row(true, "listtablessuitetable") - ) - checkAnswer( - sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), - Row("tables", true)) - sqlContext.dropTempTable("tables") - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala new file mode 100644 index 0000000000000..1732977ee51a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/LocalSparkSession.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.BeforeAndAfterEach +import org.scalatest.Suite + +/** Manages a local `spark` {@link SparkSession} variable, correctly stopping it after each test. */ +trait LocalSparkSession extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => + + @transient var spark: SparkSession = _ + + override def beforeAll() { + super.beforeAll() + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) + } + + override def afterEach() { + try { + resetSparkContext() + } finally { + super.afterEach() + } + } + + def resetSparkContext(): Unit = { + LocalSparkSession.stop(spark) + spark = null + } + +} + +object LocalSparkSession { + def stop(spark: SparkSession) { + if (spark != null) { + spark.stop() + } + // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + } + + /** Runs `f` by passing in `sc` and ensures that `sc` is stopped. */ + def withSparkSession[T](sc: SparkSession)(f: SparkSession => T): T = { + try { + f(sc) + } finally { + stop(sc) + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala new file mode 100644 index 0000000000000..98aa447fc0560 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import org.apache.spark.SparkException +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +/** + * Test suite to handle metadata cache related. + */ +class MetadataCacheSuite extends QueryTest with SharedSQLContext { + + /** Removes one data file in the given directory. */ + private def deleteOneFileInDirectory(dir: File): Unit = { + assert(dir.isDirectory) + val oneFile = dir.listFiles().find { file => + !file.getName.startsWith("_") && !file.getName.startsWith(".") + } + assert(oneFile.isDefined) + oneFile.foreach(_.delete()) + } + + test("SPARK-16336 Suggest doing table refresh when encountering FileNotFoundException") { + withTempPath { (location: File) => + // Create a Parquet directory + spark.range(start = 0, end = 100, step = 1, numPartitions = 3) + .write.parquet(location.getAbsolutePath) + + // Read the directory in + val df = spark.read.parquet(location.getAbsolutePath) + assert(df.count() == 100) + + // Delete a file + deleteOneFileInDirectory(location) + + // Read it again and now we should see a FileNotFoundException + val e = intercept[SparkException] { + df.count() + } + assert(e.getMessage.contains("FileNotFoundException")) + assert(e.getMessage.contains("REFRESH")) + } + } + + test("SPARK-16337 temporary view refresh") { + withTempView("view_refresh") { withTempPath { (location: File) => + // Create a Parquet directory + spark.range(start = 0, end = 100, step = 1, numPartitions = 3) + .write.parquet(location.getAbsolutePath) + + // Read the directory in + spark.read.parquet(location.getAbsolutePath).createOrReplaceTempView("view_refresh") + assert(sql("select count(*) from view_refresh").first().getLong(0) == 100) + + // Delete a file + deleteOneFileInDirectory(location) + + // Read it again and now we should see a FileNotFoundException + val e = intercept[SparkException] { + sql("select count(*) from view_refresh").first() + } + assert(e.getMessage.contains("FileNotFoundException")) + assert(e.getMessage.contains("REFRESH")) + + // Refresh and we should be able to read it again. + spark.catalog.refreshTable("view_refresh") + val newCount = sql("select count(*) from view_refresh").first().getLong(0) + assert(newCount > 0 && newCount < 100) + }} + } + + test("case sensitivity support in temporary view refresh") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempView("view_refresh") { + withTempPath { (location: File) => + // Create a Parquet directory + spark.range(start = 0, end = 100, step = 1, numPartitions = 3) + .write.parquet(location.getAbsolutePath) + + // Read the directory in + spark.read.parquet(location.getAbsolutePath).createOrReplaceTempView("view_refresh") + + // Delete a file + deleteOneFileInDirectory(location) + intercept[SparkException](sql("select count(*) from view_refresh").first()) + + // Refresh and we should be able to read it again. + spark.catalog.refreshTable("vIeW_reFrEsH") + val newCount = sql("select count(*) from view_refresh").first().getLong(0) + assert(newCount > 0 && newCount < 100) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala new file mode 100644 index 0000000000000..a5b08f717767f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MiscFunctionsSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class MiscFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("reflect and java_method") { + val df = Seq((1, "one")).toDF("a", "b") + val className = ReflectClass.getClass.getName.stripSuffix("$") + checkAnswer( + df.selectExpr( + s"reflect('$className', 'method1', a, b)", + s"java_method('$className', 'method1', a, b)"), + Row("m1one", "m1one")) + } +} + +object ReflectClass { + def method1(v1: Int, v2: String): String = "m" + v1 + v2 +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala deleted file mode 100644 index 0b5a92c256e57..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark._ -import org.apache.spark.sql.internal.SQLConf - -class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { - - private var originalActiveSQLContext: Option[SQLContext] = _ - private var originalInstantiatedSQLContext: Option[SQLContext] = _ - private var sparkConf: SparkConf = _ - - override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActive() - originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() - - SQLContext.clearActive() - SQLContext.clearInstantiatedContext() - sparkConf = - new SparkConf(false) - .setMaster("local[*]") - .setAppName("test") - .set("spark.ui.enabled", "false") - .set("spark.driver.allowMultipleContexts", "true") - } - - override protected def afterAll(): Unit = { - // Set these states back. - originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx)) - originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx)) - } - - def testNewSession(rootSQLContext: SQLContext): Unit = { - // Make sure we can successfully create new Session. - rootSQLContext.newSession() - - // Reset the state. It is always safe to clear the active context. - SQLContext.clearActive() - } - - def testCreatingNewSQLContext(allowsMultipleContexts: Boolean): Unit = { - val conf = - sparkConf - .clone - .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowsMultipleContexts.toString) - val sparkContext = new SparkContext(conf) - - try { - if (allowsMultipleContexts) { - new SQLContext(sparkContext) - SQLContext.clearActive() - } else { - // If allowsMultipleContexts is false, make sure we can get the error. - val message = intercept[SparkException] { - new SQLContext(sparkContext) - }.getMessage - assert(message.contains("Only one SQLContext/HiveContext may be running")) - } - } finally { - sparkContext.stop() - } - } - - test("test the flag to disallow creating multiple root SQLContext") { - Seq(false, true).foreach { allowMultipleSQLContexts => - val conf = - sparkConf - .clone - .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowMultipleSQLContexts.toString) - val sc = new SparkContext(conf) - try { - val rootSQLContext = new SQLContext(sc) - testNewSession(rootSQLContext) - testNewSession(rootSQLContext) - testCreatingNewSQLContext(allowMultipleSQLContexts) - } finally { - sc.stop() - SQLContext.clearInstantiatedContext() - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala index 0d18a645f6790..52c200796ce41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.streaming.ProcessingTime class ProcessingTimeSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index df8b3b7d8719b..f96bd8cc8b5e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -33,12 +33,12 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.types.{Metadata, ObjectType} abstract class QueryTest extends PlanTest { - protected def sqlContext: SQLContext + protected def spark: SparkSession // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -68,55 +68,82 @@ abstract class QueryTest extends PlanTest { /** * Evaluates a dataset to make sure that the result of calling collect matches the given * expected answer. - * - Special handling is done based on whether the query plan should be expected to return - * the results in sorted order. - * - This function also checks to make sure that the schema for serializing the expected answer - * matches that produced by the dataset (i.e. does manual construction of object match - * the constructed encoder for cases like joins, etc). Note that this means that it will fail - * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead - * which performs a subset of the checks done by this function. */ protected def checkDataset[T]( - ds: Dataset[T], + ds: => Dataset[T], expectedAnswer: T*): Unit = { - checkAnswer( - ds.toDF(), - sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) + val result = getResult(ds) - checkDecoding(ds, expectedAnswer: _*) + if (!compare(result.toSeq, expectedAnswer)) { + fail( + s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + |${ds.exprEnc.deserializer.treeString} + """.stripMargin) + } } - protected def checkDecoding[T]( + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer, after sort. + */ + protected def checkDatasetUnorderly[T : Ordering]( ds: => Dataset[T], expectedAnswer: T*): Unit = { - val decoded = try ds.collect().toSet catch { + val result = getResult(ds) + + if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) { + fail( + s""" + |Decoded objects do not match expected objects: + |expected: $expectedAnswer + |actual: ${result.toSeq} + |${ds.exprEnc.deserializer.treeString} + """.stripMargin) + } + } + + private def getResult[T](ds: => Dataset[T]): Array[T] = { + val analyzedDS = try ds catch { + case ae: AnalysisException => + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${stackTraceToString(ae)} + """.stripMargin) + } else { + throw ae + } + } + checkJsonFormat(analyzedDS) + assertEmptyMissingInput(analyzedDS) + + try ds.collect() catch { case e: Exception => fail( s""" |Exception collecting dataset as objects - |${ds.resolvedTEncoder} - |${ds.resolvedTEncoder.deserializer.treeString} + |${ds.exprEnc} + |${ds.exprEnc.deserializer.treeString} |${ds.queryExecution} """.stripMargin, e) } + } - // Handle the case where the return type is an array - val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false) - def normalEquality = decoded == expectedAnswer.toSet - def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet - def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq) - - if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) { - val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted - val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted - - val comparison = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") - fail( - s"""Decoded objects do not match expected objects: - |$comparison - |${ds.resolvedTEncoder.deserializer.treeString} - """.stripMargin) - } + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a, b) => a == b } /** @@ -143,7 +170,7 @@ abstract class QueryTest extends PlanTest { checkJsonFormat(analyzedDF) - assertEmptyMissingInput(df) + assertEmptyMissingInput(analyzedDF) QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) @@ -201,10 +228,10 @@ abstract class QueryTest extends PlanTest { planWithCaching) } - private def checkJsonFormat(df: DataFrame): Unit = { + private def checkJsonFormat(ds: Dataset[_]): Unit = { // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that // RDD and Data resolution does not break. - val logicalPlan = df.queryExecution.analyzed + val logicalPlan = ds.queryExecution.analyzed // bypass some cases that we can't handle currently. logicalPlan.transform { @@ -215,9 +242,10 @@ abstract class QueryTest extends PlanTest { case p if p.getClass.getSimpleName == "MetastoreRelation" => return case _: MemoryPlan => return }.transformAllExpressions { - case a: ImperativeAggregate => return + case _: ImperativeAggregate => return case _: TypedAggregateExpression => return case Literal(_, _: ObjectType) => return + case _: UserDefinedGenerator => return } // bypass hive tests before we fix all corner cases in hive module. @@ -239,6 +267,14 @@ abstract class QueryTest extends PlanTest { val normalized1 = logicalPlan.transformAllExpressions { case udf: ScalaUDF => udf.copy(function = null) case gen: UserDefinedGenerator => gen.copy(function = null) + // After SPARK-17356: the JSON representation no longer has the Metadata. We need to remove + // the Metadata from the normalized plan so that we can compare this plan with the + // JSON-deserialzed plan. + case a @ Alias(child, name) if a.explicitMetadata.isDefined => + Alias(child, name)(a.exprId, a.qualifier, Some(Metadata.empty), a.isGenerated) + case a: AttributeReference if a.metadata != Metadata.empty => + AttributeReference(a.name, a.dataType, a.nullable, Metadata.empty)(a.exprId, a.qualifier, + a.isGenerated) } // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains @@ -267,7 +303,7 @@ abstract class QueryTest extends PlanTest { val jsonBackPlan = try { - TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) + TreeNode.fromJSON[LogicalPlan](jsonString, spark.sparkContext) } catch { case NonFatal(e) => fail( @@ -282,7 +318,7 @@ abstract class QueryTest extends PlanTest { def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { case l: LogicalRDD => val origin = logicalRDDs.pop() - LogicalRDD(l.output, origin.rdd)(sqlContext.sparkSession) + LogicalRDD(l.output, origin.rdd)(spark) case l: LocalRelation => val origin = localRelations.pop() l.copy(data = origin.data) @@ -296,8 +332,7 @@ abstract class QueryTest extends PlanTest { origin.child, l.tableName)( origin.cachedColumnBuffers, - l._statistics, - origin._batchStats) + origin.batchStats) case p => p.transformExpressions { case s: SubqueryExpression => @@ -341,10 +376,16 @@ object QueryTest { * * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. */ - def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = { + def checkAnswer( + df: DataFrame, + expectedAnswer: Seq[Row], + checkToRDD: Boolean = true): Option[String] = { val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty - + if (checkToRDD) { + df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] + } val sparkAnswer = try df.collect().toSeq catch { case e: Exception => @@ -362,6 +403,9 @@ object QueryTest { sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => s""" |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env("TZ")} + | |${df.queryExecution} |== Results == |$results diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 4552eb6ce00a5..34936b38fb5d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} -import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -55,15 +54,6 @@ class RowSuite extends SparkFunSuite with SharedSQLContext { assert(row.isNullAt(0)) } - test("serialize w/ kryo") { - val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(sparkContext.getConf) - val instance = serializer.newInstance() - val ser = instance.serialize(row) - val de = instance.deserialize(ser).asInstanceOf[Row] - assert(de === row) - } - test("get values by field name on Row created via .toDF") { val row = Seq((1, Seq(1))).toDF("a", "b").first() assert(row.getAs[Int]("a") === 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala similarity index 85% rename from sql/core/src/test/scala/org/apache/spark/sql/internal/RuntimeConfigSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index a629b73ac046d..cfe2e9f2dbc44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.internal +package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.RuntimeConfig class RuntimeConfigSuite extends SparkFunSuite { @@ -26,10 +25,9 @@ class RuntimeConfigSuite extends SparkFunSuite { test("set and get") { val conf = newConf() - conf - .set("k1", "v1") - .set("k2", 2) - .set("k3", value = false) + conf.set("k1", "v1") + conf.set("k2", 2) + conf.set("k3", value = false) assert(conf.get("k1") == "v1") assert(conf.get("k2") == "2") @@ -41,13 +39,15 @@ class RuntimeConfigSuite extends SparkFunSuite { } test("getOption") { - val conf = newConf().set("k1", "v1") + val conf = newConf() + conf.set("k1", "v1") assert(conf.getOption("k1") == Some("v1")) assert(conf.getOption("notset") == None) } test("unset") { - val conf = newConf().set("k1", "v1") + val conf = newConf() + conf.set("k1", "v1") assert(conf.get("k1") == "v1") conf.unset("k1") intercept[NoSuchElementException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala new file mode 100644 index 0000000000000..27b60e0d9def8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.math.BigDecimal +import java.sql.Timestamp + +import org.apache.spark.sql.test.SharedSQLContext + +/** + * A test suite for functions added for compatibility with other databases such as Oracle, MSSQL. + * + * These functions are typically implemented using the trait + * [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]]. + */ +class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext { + + test("ifnull") { + checkAnswer( + sql("SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null)"), + Row("x", "y", null)) + + // Type coercion + checkAnswer( + sql("SELECT ifnull(1, 2.1d), ifnull(null, 2.1d)"), + Row(1.0, 2.1)) + } + + test("nullif") { + checkAnswer( + sql("SELECT nullif('x', 'x'), nullif('x', 'y')"), + Row(null, "x")) + + // Type coercion + checkAnswer( + sql("SELECT nullif(1, 2.1d), nullif(1, 1.0d)"), + Row(1.0, null)) + } + + test("nvl") { + checkAnswer( + sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + Row("x", "y", null)) + + // Type coercion + checkAnswer( + sql("SELECT nvl(1, 2.1d), nvl(null, 2.1d)"), + Row(1.0, 2.1)) + } + + test("nvl2") { + checkAnswer( + sql("SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null)"), + Row("y", "x", null)) + + // Type coercion + checkAnswer( + sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"), + Row(2.1, 1.0)) + } + + test("SPARK-16730 cast alias functions for Hive compatibility") { + checkAnswer( + sql("SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1)"), + Row(true, 1.toByte, 1.toShort, 1, 1L)) + + checkAnswer( + sql("SELECT float(1), double(1), decimal(1)"), + Row(1.toFloat, 1.0, new BigDecimal(1))) + + checkAnswer( + sql("SELECT date(\"2014-04-04\"), timestamp(date(\"2014-04-04\"))"), + Row(new java.util.Date(114, 3, 4), new Timestamp(114, 3, 4, 0, 0, 0, 0))) + + checkAnswer( + sql("SELECT string(1)"), + Row("1")) + + // Error handling: only one argument + val errorMsg = intercept[AnalysisException](sql("SELECT string(1, 2)")).getMessage + assert(errorMsg.contains("Function string accepts only one argument")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 1d5fc570c65d8..e57c1716a5bfd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -20,8 +20,11 @@ package org.apache.spark.sql import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} +@deprecated("This suite is deprecated to silent compiler deprecation warnings", "2.0.0") class SQLContextSuite extends SparkFunSuite with SharedSparkContext { object DummyRule extends Rule[LogicalPlan] { @@ -40,7 +43,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { val newSession = sqlContext.newSession() assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") - SQLContext.setActive(newSession) + SparkSession.setActiveSession(newSession.sparkSession) assert(SQLContext.getOrCreate(sc).eq(newSession), "SQLContext.getOrCreate after explicitly setActive() did not return the active context") } @@ -60,7 +63,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { // temporary table should not be shared val df = session1.range(10) - df.registerTempTable("test1") + df.createOrReplaceTempView("test1") assert(session1.tableNames().contains("test1")) assert(!session2.tableNames().contains("test1")) @@ -79,4 +82,63 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { assert(sqlContext.sessionState.optimizer.batches.flatMap(_.rules).contains(DummyRule)) } + test("get all tables") { + val sqlContext = SQLContext.getOrCreate(sc) + val df = sqlContext.range(10) + df.createOrReplaceTempView("listtablessuitetable") + assert( + sqlContext.tables().filter("tableName = 'listtablessuitetable'").collect().toSeq == + Row("listtablessuitetable", true) :: Nil) + + assert( + sqlContext.sql("SHOW tables").filter("tableName = 'listtablessuitetable'").collect().toSeq == + Row("listtablessuitetable", true) :: Nil) + + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) + assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + } + + test("getting all tables with a database name has no impact on returned table names") { + val sqlContext = SQLContext.getOrCreate(sc) + val df = sqlContext.range(10) + df.createOrReplaceTempView("listtablessuitetable") + assert( + sqlContext.tables("default").filter("tableName = 'listtablessuitetable'").collect().toSeq == + Row("listtablessuitetable", true) :: Nil) + + assert( + sqlContext.sql("show TABLES in default").filter("tableName = 'listtablessuitetable'") + .collect().toSeq == Row("listtablessuitetable", true) :: Nil) + + sqlContext.sessionState.catalog.dropTable( + TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) + assert(sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + } + + test("query the returned DataFrame of tables") { + val sqlContext = SQLContext.getOrCreate(sc) + val df = sqlContext.range(10) + df.createOrReplaceTempView("listtablessuitetable") + + val expectedSchema = StructType( + StructField("tableName", StringType, false) :: + StructField("isTemporary", BooleanType, false) :: Nil) + + Seq(sqlContext.tables(), sqlContext.sql("SHOW TABLes")).foreach { + case tableDF => + assert(expectedSchema === tableDF.schema) + + tableDF.createOrReplaceTempView("tables") + assert( + sqlContext.sql( + "SELECT isTemporary, tableName from tables WHERE tableName = 'listtablessuitetable'") + .collect().toSeq == Row(true, "listtablessuitetable") :: Nil) + assert( + sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary") + .collect().toSeq == Row("tables", true) :: Nil) + sqlContext.dropTempTable("tables") + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ec5163b658c14..8ca470873c9bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql +import java.io.File import java.math.MathContext import java.sql.Timestamp -import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} @@ -38,16 +36,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { setupTestData() - test("having clause") { - Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav") - checkAnswer( - sql("SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2"), - Row("one", 6) :: Row("three", 3) :: Nil) - } - test("SPARK-8010: promote numeric to string") { val df = Seq((1, 1)).toDF("key", "value") - df.registerTempTable("src") + df.createOrReplaceTempView("src") val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ") val queryCoalesce = sql("select coalesce(null, 1, '1') from src ") @@ -57,15 +48,37 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { - StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern) + StringUtils.filterPattern( + spark.sessionState.catalog.listFunctions("default").map(_._1.funcName), pattern) .map(Row(_)) } + + def createFunction(names: Seq[String]): Unit = { + names.foreach { name => + spark.udf.register(name, (arg1: Int, arg2: String) => arg2 + arg1) + } + } + + def dropFunction(names: Seq[String]): Unit = { + names.foreach { name => + spark.sessionState.catalog.dropTempFunction(name, false) + } + } + + val functions = Array("ilog", "logi", "logii", "logiii", "crc32i", "cubei", "cume_disti", + "isize", "ispace", "to_datei", "date_addi", "current_datei") + + createFunction(functions) + checkAnswer(sql("SHOW functions"), getFunctions("*")) + assert(sql("SHOW functions").collect().size > 200) + Seq("^c*", "*e$", "log*", "*date*").foreach { pattern => // For the pattern part, only '*' and '|' are allowed as wildcards. // For '*', we need to replace it to '.*'. checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) } + dropFunction(functions) } test("describe functions") { @@ -88,9 +101,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-14415: All functions should have own descriptions") { - for (f <- sqlContext.sessionState.functionRegistry.listFunction()) { + for (f <- spark.sessionState.functionRegistry.listFunction()) { if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { - checkKeywordsNotExist(sql(s"describe function `$f`"), "To be added.") + checkKeywordsNotExist(sql(s"describe function `$f`"), "N/A.") } } } @@ -100,16 +113,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { (83, 0, 38), (26, 0, 79), (43, 81, 24) - ).toDF("a", "b", "c").registerTempTable("cachedData") + ).toDF("a", "b", "c").createOrReplaceTempView("cachedData") - sqlContext.cacheTable("cachedData") - checkAnswer( - sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), - Row(0) :: Row(81) :: Nil) + spark.catalog.cacheTable("cachedData") + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + checkAnswer( + sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), + Row(0) :: Row(81) :: Nil) + } } test("self join with aliases") { - Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") + Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").createOrReplaceTempView("df") checkAnswer( sql( @@ -137,7 +152,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .toDF("int", "str") .groupBy("str") .agg($"str", count("str").as("strCount")) - .registerTempTable("df") + .createOrReplaceTempView("df") checkAnswer( sql( @@ -193,9 +208,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("grouping on nested fields") { - sqlContext.read.json(sparkContext.parallelize( + spark.read.json(sparkContext.parallelize( """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) - .registerTempTable("rows") + .createOrReplaceTempView("rows") checkAnswer( sql( @@ -211,10 +226,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6201 IN type conversion") { - sqlContext.read.json( + spark.read.json( sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) - .registerTempTable("d") + .createOrReplaceTempView("d") checkAnswer( sql("select * from d where d.a in (1,2)"), @@ -222,10 +237,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-11226 Skip empty line in json file") { - sqlContext.read.json( + spark.read.json( sparkContext.parallelize( Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}", ""))) - .registerTempTable("d") + .createOrReplaceTempView("d") checkAnswer( sql("select count(1) from d"), @@ -243,12 +258,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val df = sql(sqlText) // First, check if we have GeneratedAggregate. val hasGeneratedAgg = df.queryExecution.sparkPlan - .collect { case _: aggregate.TungstenAggregate => true } + .collect { case _: aggregate.HashAggregateExec => true } .nonEmpty if (!hasGeneratedAgg) { fail( s""" - |Codegen is enabled, but query $sqlText does not have TungstenAggregate in the plan. + |Codegen is enabled, but query $sqlText does not have HashAggregate in the plan. |${df.queryExecution.simpleString} """.stripMargin) } @@ -258,10 +273,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("aggregation with codegen") { // Prepare a table that we can group some rows. - sqlContext.table("testData") - .union(sqlContext.table("testData")) - .union(sqlContext.table("testData")) - .registerTempTable("testData3x") + spark.table("testData") + .union(spark.table("testData")) + .union(spark.table("testData")) + .createOrReplaceTempView("testData3x") try { // Just to group rows. @@ -333,7 +348,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "SELECT sum('a'), avg('a'), count(null) FROM testData", Row(null, null, 0) :: Nil) } finally { - sqlContext.dropTempTable("testData3x") + spark.catalog.dropTempView("testData3x") } } @@ -391,7 +406,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3173 Timestamp support in the parser") { - (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time").registerTempTable("timestamps") + (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time").createOrReplaceTempView("timestamps") checkAnswer(sql( "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), @@ -427,17 +442,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Nil) } - test("index into array") { - checkAnswer( - sql("SELECT data, data[0], data[0] + data[1], data[0 + 1] FROM arrayData"), - arrayData.map(d => Row(d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect()) - } - test("left semi greater than predicate") { - checkAnswer( - sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq(Row(3, 1), Row(3, 2)) - ) + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), + Seq(Row(3, 1), Row(3, 2)) + ) + } } test("left semi greater than predicate and equal operator") { @@ -452,119 +463,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } - test("index into array of arrays") { - checkAnswer( - sql( - "SELECT nestedData, nestedData[0][0], nestedData[0][0] + nestedData[0][1] FROM arrayData"), - arrayData.map(d => - Row(d.nestedData, - d.nestedData(0)(0), - d.nestedData(0)(0) + d.nestedData(0)(1))).collect().toSeq) - } - test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } - test("Group By Ordinal - basic") { - checkAnswer( - sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"), - sql("SELECT a, sum(b) FROM testData2 GROUP BY a")) - - // duplicate group-by columns - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - } - - test("Group By Ordinal - non aggregate expressions") { - checkAnswer( - sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"), - sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) - - checkAnswer( - sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"), - sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) - } - - test("Group By Ordinal - non-foldable constant expression") { - checkAnswer( - sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"), - sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) - - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - } - - test("Group By Ordinal - alias") { - checkAnswer( - sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"), - sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2")) - - checkAnswer( - sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b")) - } - - test("Group By Ordinal - constants") { - checkAnswer( - sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) - } - - test("Group By Ordinal - negative cases") { - intercept[UnresolvedException[Aggregate]] { - sql("SELECT a, b FROM testData2 GROUP BY -1") - } - - intercept[UnresolvedException[Aggregate]] { - sql("SELECT a, b FROM testData2 GROUP BY 3") - } - - var e = intercept[UnresolvedException[Aggregate]]( - sql("SELECT SUM(a) FROM testData2 GROUP BY 1")) - assert(e.getMessage contains - "Invalid call to Group by position: the '1'th column in the select contains " + - "an aggregate function") - - e = intercept[UnresolvedException[Aggregate]]( - sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1")) - assert(e.getMessage contains - "Invalid call to Group by position: the '1'th column in the select contains " + - "an aggregate function") - - var ae = intercept[AnalysisException]( - sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2")) - assert(ae.getMessage contains - "nondeterministic expression rand(0) should not appear in grouping expression") - - ae = intercept[AnalysisException]( - sql("SELECT * FROM testData2 GROUP BY a, b, 1")) - assert(ae.getMessage contains - "Group by position: star is not allowed to use in the select list " + - "when using ordinals in group by") - } - - test("Group By Ordinal: spark.sql.groupByOrdinal=false") { - withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") { - // If spark.sql.groupByOrdinal=false, ignore the position number. - intercept[AnalysisException] { - sql("SELECT a, sum(b) FROM testData2 GROUP BY 1") - } - // '*' is not allowed to use in the select list when users specify ordinals in group by - checkAnswer( - sql("SELECT * FROM testData2 GROUP BY a, b, 1"), - sql("SELECT * FROM testData2 GROUP BY a, b")) - } - } - test("aggregates with nulls") { checkAnswer( sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," + @@ -631,18 +535,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sortTest() } - test("limit") { - checkAnswer( - sql("SELECT * FROM testData LIMIT 10"), - testData.take(10).toSeq) - - checkAnswer( - sql("SELECT * FROM arrayData LIMIT 1"), - arrayData.collect().take(1).map(Row.fromTuple).toSeq) - - checkAnswer( - sql("SELECT * FROM mapData LIMIT 1"), - mapData.collect().take(1).map(Row.fromTuple).toSeq) + test("negative in LIMIT or TABLESAMPLE") { + val expected = "The limit expression must be equal to or greater than 0, but got -1" + var e = intercept[AnalysisException] { + sql("SELECT * FROM testData TABLESAMPLE (-1 rows)") + }.getMessage + assert(e.contains(expected)) } test("CTE feature") { @@ -745,8 +643,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("count of empty table") { - withTempTable("t") { - Seq.empty[(Int, Int)].toDF("a", "b").registerTempTable("t") + withTempView("t") { + Seq.empty[(Int, Int)].toDF("a", "b").createOrReplaceTempView("t") checkAnswer( sql("select count(a) from t"), Row(0)) @@ -823,12 +721,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("cartesian product join") { - checkAnswer( - testData3.join(testData3), - Row(1, null, 1, null) :: - Row(1, null, 2, 2) :: - Row(2, 2, 1, null) :: - Row(2, 2, 2, 2) :: Nil) + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + checkAnswer( + testData3.join(testData3), + Row(1, null, 1, null) :: + Row(1, null, 2, 2) :: + Row(2, 2, 1, null) :: + Row(2, 2, 2, 2) :: Nil) + } } test("left outer join") { @@ -891,10 +791,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-3349 partitioning after limit") { sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC") .limit(2) - .registerTempTable("subset1") + .createOrReplaceTempView("subset1") sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n ASC") .limit(2) - .registerTempTable("subset2") + .createOrReplaceTempView("subset2") checkAnswer( sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), Row(3, "c", 3) :: @@ -1041,7 +941,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SET commands semantics using sql()") { - sqlContext.conf.clear() + spark.sessionState.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -1082,17 +982,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql(s"SET $nonexistentKey"), Row(nonexistentKey, "") ) - sqlContext.conf.clear() + spark.sessionState.conf.clear() } test("SET commands with illegal or inappropriate argument") { - sqlContext.conf.clear() + spark.sessionState.conf.clear() // Set negative mapred.reduce.tasks for automatically determining // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - sqlContext.conf.clear() + spark.sessionState.conf.clear() } test("apply schema") { @@ -1110,8 +1010,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) - df1.registerTempTable("applySchema1") + val df1 = spark.createDataFrame(rowRDD1, schema1) + df1.createOrReplaceTempView("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), Row(1, "A1", true, null) :: @@ -1140,8 +1040,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlContext.createDataFrame(rowRDD2, schema2) - df2.registerTempTable("applySchema2") + val df2 = spark.createDataFrame(rowRDD2, schema2) + df2.createOrReplaceTempView("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), Row(Row(1, true), Map("A1" -> null)) :: @@ -1165,8 +1065,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlContext.createDataFrame(rowRDD3, schema2) - df3.registerTempTable("applySchema3") + val df3 = spark.createDataFrame(rowRDD3, schema2) + df3.createOrReplaceTempView("applySchema3") checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), @@ -1193,6 +1093,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-17863: SELECT distinct does not work correctly if order by missing attribute") { + checkAnswer( + sql("""select distinct struct.a, struct.b + |from ( + | select named_struct('a', 1, 'b', 2, 'c', 3) as struct + | union all + | select named_struct('a', 1, 'b', 2, 'c', 4) as struct) tmp + |order by a, b + |""".stripMargin), + Row(1, 2) :: Nil) + + val error = intercept[AnalysisException] { + sql("""select distinct struct.a, struct.b + |from ( + | select named_struct('a', 1, 'b', 2, 'c', 3) as struct + | union all + | select named_struct('a', 1, 'b', 2, 'c', 4) as struct) tmp + |order by struct.a, struct.b + |""".stripMargin) + } + assert(error.message contains "cannot resolve '`struct.a`' given input columns: [a, b]") + + } + test("cast boolean to string") { // TODO Ensure true/false string letter casing is consistent with Hive in all cases. checkAnswer( @@ -1210,11 +1134,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = spark.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } - personWithMeta.registerTempTable("personWithMeta") + personWithMeta.createOrReplaceTempView("personWithMeta") validateMetadata(personWithMeta.select($"name")) validateMetadata(personWithMeta.select($"name")) validateMetadata(personWithMeta.select($"id", $"name")) @@ -1226,7 +1150,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-3371 Renaming a function expression with group by gives error") { - sqlContext.udf.register("len", (s: String) => s.length) + spark.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1265,134 +1189,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) } - test("Test to check we can use Long.MinValue") { - checkAnswer( - sql(s"SELECT ${Long.MinValue} FROM testData ORDER BY key LIMIT 1"), Row(Long.MinValue) - ) - - checkAnswer( - sql(s"SELECT key FROM testData WHERE key > ${Long.MinValue}"), - (1 to 100).map(Row(_)).toSeq - ) - } - - test("Floating point number format") { - checkAnswer( - sql("SELECT 0.3"), Row(BigDecimal(0.3)) - ) - - checkAnswer( - sql("SELECT -0.8"), Row(BigDecimal(-0.8)) - ) - - checkAnswer( - sql("SELECT .5"), Row(BigDecimal(0.5)) - ) - - checkAnswer( - sql("SELECT -.18"), Row(BigDecimal(-0.18)) - ) - } - - test("Auto cast integer type") { - checkAnswer( - sql(s"SELECT ${Int.MaxValue + 1L}"), Row(Int.MaxValue + 1L) - ) - - checkAnswer( - sql(s"SELECT ${Int.MinValue - 1L}"), Row(Int.MinValue - 1L) - ) - - checkAnswer( - sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808")) - ) - - checkAnswer( - sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809")) - ) - } - - test("Test to check we can apply sign to expression") { - - checkAnswer( - sql("SELECT -100"), Row(-100) - ) - - checkAnswer( - sql("SELECT +230"), Row(230) - ) - - checkAnswer( - sql("SELECT -5.2"), Row(BigDecimal(-5.2)) - ) - - checkAnswer( - sql("SELECT +6.8e0"), Row(6.8d) - ) - - checkAnswer( - sql("SELECT -key FROM testData WHERE key = 2"), Row(-2) - ) - - checkAnswer( - sql("SELECT +key FROM testData WHERE key = 3"), Row(3) - ) - - checkAnswer( - sql("SELECT -(key + 1) FROM testData WHERE key = 1"), Row(-2) - ) - - checkAnswer( - sql("SELECT - key + 1 FROM testData WHERE key = 10"), Row(-9) - ) - - checkAnswer( - sql("SELECT +(key + 5) FROM testData WHERE key = 5"), Row(10) - ) - - checkAnswer( - sql("SELECT -MAX(key) FROM testData"), Row(-100) - ) - - checkAnswer( - sql("SELECT +MAX(key) FROM testData"), Row(100) - ) - - checkAnswer( - sql("SELECT - (-10)"), Row(10) - ) - - checkAnswer( - sql("SELECT + (-key) FROM testData WHERE key = 32"), Row(-32) - ) - - checkAnswer( - sql("SELECT - (+Max(key)) FROM testData"), Row(-100) - ) - - checkAnswer( - sql("SELECT - - 3"), Row(3) - ) - - checkAnswer( - sql("SELECT - + 20"), Row(-20) - ) - - checkAnswer( - sql("SELEcT - + 45"), Row(-45) - ) - - checkAnswer( - sql("SELECT + + 100"), Row(100) - ) - - checkAnswer( - sql("SELECT - - Max(key) FROM testData"), Row(100) - ) - - checkAnswer( - sql("SELECT + - key FROM testData WHERE key = 33"), Row(-33) - ) + testQuietly( + "SPARK-16748: SparkExceptions during planning should not wrapped in TreeNodeException") { + intercept[SparkException] { + val df = spark.range(0, 5).map(x => (1 / x).toString).toDF("a").orderBy("a") + df.queryExecution.toRdd // force physical planning, but not execution of the plan + } } test("Multiple join") { @@ -1409,7 +1211,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-3483 Special chars in column names") { val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - sqlContext.read.json(data).registerTempTable("records") + spark.read.json(data).createOrReplaceTempView("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1450,15 +1252,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-4322 Grouping field with struct field as sub expression") { - sqlContext.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) - .registerTempTable("data") + spark.read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + .createOrReplaceTempView("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - sqlContext.dropTempTable("data") + spark.catalog.dropTempView("data") - sqlContext.read.json( - sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + spark.read.json( + sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).createOrReplaceTempView("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - sqlContext.dropTempTable("data") + spark.catalog.dropTempView("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { @@ -1478,10 +1280,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) - rdd1.toDF().registerTempTable("nulldata1") + rdd1.toDF().createOrReplaceTempView("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) - rdd2.toDF().registerTempTable("nulldata2") + rdd2.toDF().createOrReplaceTempView("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), (1 to 2).map(i => Row(i))) @@ -1490,7 +1292,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) - rdd.toDF().registerTempTable("distinctData") + rdd.toDF().createOrReplaceTempView("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } @@ -1498,15 +1300,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) - rdd.toDF().registerTempTable("testTable1") + rdd.toDF().createOrReplaceTempView("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) } } test("SPARK-6145: ORDER BY test for nested fields") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) - .registerTempTable("nestedOrder") + .createOrReplaceTempView("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1517,23 +1319,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-6145: special cases") { - sqlContext.read.json(sparkContext.makeRDD( - """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + spark.read + .json(sparkContext.makeRDD("""{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)) + .createOrReplaceTempView("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - sqlContext.read.json(sparkContext.makeRDD( + spark.read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) - .registerTempTable("t") + .createOrReplaceTempView("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } test("SPARK-6583 order by aggregated function") { Seq("1" -> 3, "1" -> 4, "2" -> 7, "2" -> 8, "3" -> 5, "3" -> 6, "4" -> 1, "4" -> 2) - .toDF("a", "b").registerTempTable("orderByData") + .toDF("a", "b").createOrReplaceTempView("orderByData") checkAnswer( sql( @@ -1607,7 +1411,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-7952: fix the equality check between boolean and numeric types") { - withTempTable("t") { + withTempView("t") { // numeric field i, boolean field j, result of i = j, result of i <=> j Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( (1, true, true, true), @@ -1619,7 +1423,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { (0, null, null, false), (1, null, null, false), (null, null, null, true) - ).toDF("i", "b", "r1", "r2").registerTempTable("t") + ).toDF("i", "b", "r1", "r2").createOrReplaceTempView("t") checkAnswer(sql("select i = b from t"), sql("select r1 from t")) checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) @@ -1627,25 +1431,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-7067: order by queries for complex ExtractValue chain") { - withTempTable("t") { - sqlContext.read.json(sparkContext.makeRDD( - """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t") + withTempView("t") { + spark.read.json(sparkContext.makeRDD( + """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).createOrReplaceTempView("t") checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1)))) } } test("SPARK-8782: ORDER BY NULL") { - withTempTable("t") { - Seq((1, 2), (1, 2)).toDF("a", "b").registerTempTable("t") + withTempView("t") { + Seq((1, 2), (1, 2)).toDF("a", "b").createOrReplaceTempView("t") checkAnswer(sql("SELECT * FROM t ORDER BY NULL"), Seq(Row(1, 2), Row(1, 2))) } } test("SPARK-8837: use keyword in column name") { - withTempTable("t") { + withTempView("t") { val df = Seq(1 -> "a").toDF("count", "sort") checkAnswer(df.filter("count > 0"), Row(1, "a")) - df.registerTempTable("t") + df.createOrReplaceTempView("t") checkAnswer(sql("select count, sort from t"), Row(1, "a")) } } @@ -1756,10 +1560,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-9511: error with table starting with number") { - withTempTable("1one") { + withTempView("1one") { sparkContext.parallelize(1 to 10).map(i => (i, i.toString)) .toDF("num", "str") - .registerTempTable("1one") + .createOrReplaceTempView("1one") checkAnswer(sql("select count(num) from 1one"), Row(10)) } } @@ -1776,7 +1580,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // We don't support creating a temporary table while specifying a database intercept[AnalysisException] { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE db.t |USING parquet @@ -1787,7 +1591,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { }.getMessage // If you use backticks to quote the name then it's OK. - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE `db.t` |USING parquet @@ -1795,21 +1599,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | path '$path' |) """.stripMargin) - checkAnswer(sqlContext.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) } } test("SPARK-10130 type coercion for IF should have children resolved first") { - withTempTable("src") { - Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + withTempView("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").createOrReplaceTempView("src") checkAnswer( sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0))) } } test("SPARK-10389: order by non-attribute grouping expression on Aggregate") { - withTempTable("src") { - Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src") + withTempView("src") { + Seq((1, 1), (-1, 1)).toDF("key", "value").createOrReplaceTempView("src") checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"), Seq(Row(1), Row(1))) checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"), @@ -1818,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("run sql directly on files") { - val df = sqlContext.range(100).toDF() + val df = spark.range(100).toDF() withTempPath(f => { df.write.json(f.getCanonicalPath) checkAnswer(sql(s"select id from json.`${f.getCanonicalPath}`"), @@ -1829,20 +1633,58 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df) }) - val e1 = intercept[AnalysisException] { + var e = intercept[AnalysisException] { sql("select * from in_valid_table") } - assert(e1.message.contains("Table or view not found")) + assert(e.message.contains("Table or view not found")) - val e2 = intercept[AnalysisException] { + e = intercept[AnalysisException] { sql("select * from no_db.no_table").show() } - assert(e2.message.contains("Table or view not found")) + assert(e.message.contains("Table or view not found")) - val e3 = intercept[AnalysisException] { + e = intercept[AnalysisException] { sql("select * from json.invalid_file") } - assert(e3.message.contains("Path does not exist")) + assert(e.message.contains("Path does not exist")) + + e = intercept[AnalysisException] { + sql(s"select id from `org.apache.spark.sql.hive.orc`.`file_path`") + } + assert(e.message.contains("The ORC data source must be used with Hive support enabled")) + + e = intercept[AnalysisException] { + sql(s"select id from `com.databricks.spark.avro`.`file_path`") + } + assert(e.message.contains("Failed to find data source: com.databricks.spark.avro.")) + + // data source type is case insensitive + e = intercept[AnalysisException] { + sql(s"select id from Avro.`file_path`") + } + assert(e.message.contains("Failed to find data source: avro.")) + + e = intercept[AnalysisException] { + sql(s"select id from avro.`file_path`") + } + assert(e.message.contains("Failed to find data source: avro.")) + + e = intercept[AnalysisException] { + sql(s"select id from `org.apache.spark.sql.sources.HadoopFsRelationProvider`.`file_path`") + } + assert(e.message.contains("Table or view not found: " + + "`org.apache.spark.sql.sources.HadoopFsRelationProvider`.`file_path`")) + + e = intercept[AnalysisException] { + sql(s"select id from `Jdbc`.`file_path`") + } + assert(e.message.contains("Unsupported data source type for direct query on files: Jdbc")) + + e = intercept[AnalysisException] { + sql(s"select id from `org.apache.spark.sql.execution.datasources.jdbc`.`file_path`") + } + assert(e.message.contains("Unsupported data source type for direct query on files: " + + "org.apache.spark.sql.execution.datasources.jdbc")) } test("SortMergeJoin returns wrong results when using UnsafeRows") { @@ -1870,17 +1712,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("SPARK-11032: resolve having correctly") { - withTempTable("src") { - Seq(1 -> "a").toDF("i", "j").registerTempTable("src") - checkAnswer( - sql("SELECT MIN(t.i) FROM (SELECT * FROM src WHERE i > 0) t HAVING(COUNT(1) > 0)"), - Row(1)) - } - } - test("SPARK-11303: filter should not be pushed down into sample") { - val df = sqlContext.range(100) + val df = spark.range(100) List(true, false).foreach { withReplacement => val sampled = df.sample(withReplacement, 0.1, 1) val sampledOdd = sampled.filter("id % 2 != 0") @@ -1910,8 +1743,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1, 1, 1, 1) :: Row(1, 2, 2, 1) :: Row(2, 1, 1, 2) :: Row(2, 2, 2, 2) :: Row(3, 1, 1, 3) :: Row(3, 2, 2, 3) :: Nil) - // Try with a registered table. - sql("select struct(a, b) as record from testData2").registerTempTable("structTable") + // Try with a temporary view + sql("select struct(a, b) as record from testData2").createOrReplaceTempView("structTable") checkAnswer( sql("SELECT record.* FROM structTable"), Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) @@ -1975,9 +1808,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { nestedStructData.select($"record.r1.*"), Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) - // Try with a registered table - withTempTable("nestedStructTable") { - nestedStructData.registerTempTable("nestedStructTable") + // Try with a temporary view + withTempView("nestedStructTable") { + nestedStructData.createOrReplaceTempView("nestedStructTable") checkAnswer( sql("SELECT record.* FROM nestedStructTable"), nestedStructData.select($"record.*")) @@ -1999,8 +1832,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp """.stripMargin) - withTempTable("specialCharacterTable") { - specialCharacterPath.registerTempTable("specialCharacterTable") + withTempView("specialCharacterTable") { + specialCharacterPath.createOrReplaceTempView("specialCharacterTable") checkAnswer( specialCharacterPath.select($"`r&&b.c`.*"), nestedStructData.select($"record.*")) @@ -2023,8 +1856,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("Struct Star Expansion - Name conflict") { // Create a data set that contains a naming conflict val nameConflict = sql("SELECT struct(a, b) as nameConflict, a as a FROM testData2") - withTempTable("nameConflict") { - nameConflict.registerTempTable("nameConflict") + withTempView("nameConflict") { + nameConflict.createOrReplaceTempView("nameConflict") // Unqualified should resolve to table. checkAnswer(sql("SELECT nameConflict.* FROM nameConflict"), Row(Row(1, 1), 1) :: Row(Row(1, 2), 1) :: Row(Row(2, 1), 2) :: Row(Row(2, 2), 2) :: @@ -2043,6 +1876,37 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("Star Expansion - table with zero column") { + withTempView("temp_table_no_cols") { + val rddNoCols = sparkContext.parallelize(1 to 10).map(_ => Row.empty) + val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty)) + dfNoCols.createTempView("temp_table_no_cols") + + // ResolvedStar + checkAnswer( + dfNoCols, + dfNoCols.select(dfNoCols.col("*"))) + + // UnresolvedStar + checkAnswer( + dfNoCols, + sql("SELECT * FROM temp_table_no_cols")) + checkAnswer( + dfNoCols, + dfNoCols.select($"*")) + + var e = intercept[AnalysisException] { + sql("SELECT a.* FROM temp_table_no_cols a") + }.getMessage + assert(e.contains("cannot resolve 'a.*' give input columns ''")) + + e = intercept[AnalysisException] { + dfNoCols.select($"b.*") + }.getMessage + assert(e.contains("cannot resolve 'b.*' give input columns ''")) + } + } + test("Common subexpression elimination") { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { @@ -2058,9 +1922,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) // Identity udf that tracks the number of times it is called. - val countAcc = sparkContext.accumulator(0, "CallCount") - sqlContext.udf.register("testUdf", (x: Int) => { - countAcc.++=(1) + val countAcc = sparkContext.longAccumulator("CallCount") + spark.udf.register("testUdf", (x: Int) => { + countAcc.add(1) x }) @@ -2068,7 +1932,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // is correct. def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { countAcc.setValue(0) - checkAnswer(df, expectedResult) + QueryTest.checkAnswer( + df, Seq(expectedResult), checkToRDD = false /* avoid duplicate exec */) assert(countAcc.value == expectedCount) } @@ -2083,7 +1948,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) val testUdf = functions.udf((x: Int) => { - countAcc.++=(1) + countAcc.add(1) x }) verifyCallCount( @@ -2093,9 +1958,9 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 1) // Try disabling it via configuration. - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + spark.conf.set("spark.sql.subexpressionElimination.enabled", "false") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) - sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + spark.conf.set("spark.sql.subexpressionElimination.enabled", "true") verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) } } @@ -2327,8 +2192,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-13056: Null in map value causes NPE") { val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") - withTempTable("maptest") { - df.registerTempTable("maptest") + withTempView("maptest") { + df.createOrReplaceTempView("maptest") // local optimization will by pass codegen code, so we should keep the filter `key=1` checkAnswer(sql("SELECT value['abc'] FROM maptest where key = 1"), Row("somestring")) checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null)) @@ -2337,8 +2202,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("hash function") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - withTempTable("tbl") { - df.registerTempTable("tbl") + withTempView("tbl") { + df.createOrReplaceTempView("tbl") checkAnswer( df.select(hash($"i", $"j")), sql("SELECT hash(i, j) from tbl") @@ -2346,70 +2211,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("order by ordinal number") { - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 DESC"), - sql("SELECT * FROM testData2 ORDER BY a DESC")) - // If the position is not an integer, ignore it. - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"), - sql("SELECT * FROM testData2 ORDER BY b ASC")) - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), - sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC")) - checkAnswer( - sql("SELECT * FROM testData2 SORT BY 1 DESC, 2"), - sql("SELECT * FROM testData2 SORT BY a DESC, b ASC")) - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 ASC, b ASC"), - Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) - } - - test("order by ordinal number - negative cases") { - intercept[UnresolvedException[SortOrder]] { - sql("SELECT * FROM testData2 ORDER BY 0") - } - intercept[UnresolvedException[SortOrder]] { - sql("SELECT * FROM testData2 ORDER BY -1 DESC, b ASC") - } - intercept[UnresolvedException[SortOrder]] { - sql("SELECT * FROM testData2 ORDER BY 3 DESC, b ASC") - } - } - - test("order by ordinal number with conf spark.sql.orderByOrdinal=false") { - withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "false") { - // If spark.sql.orderByOrdinal=false, ignore the position number. - checkAnswer( - sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"), - sql("SELECT * FROM testData2 ORDER BY b ASC")) - } - } - - test("natural join") { - val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1") - val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2") - withTempTable("nt1", "nt2") { - df1.registerTempTable("nt1") - df2.registerTempTable("nt2") - checkAnswer( - sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""), - Row("one", 1, 1) :: Row("one", 1, 5) :: Nil) - - checkAnswer( - sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"), - Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil) - - checkAnswer( - sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"), - Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil) - - checkAnswer( - sql("SELECT count(*) FROM nt1 natural full outer join nt2"), - Row(4) :: Nil) - } - } - test("join with using clause") { val df1 = Seq(("r1c1", "r1c2", "t1r1c3"), ("r2c1", "r2c2", "t1r2c3"), ("r3c1x", "r3c2", "t1r3c3")).toDF("c1", "c2", "c3") @@ -2417,10 +2218,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ("r2c1", "r2c2", "t2r2c3"), ("r3c1y", "r3c2", "t2r3c3")).toDF("c1", "c2", "c3") val df3 = Seq((null, "r1c2", "t3r1c3"), ("r2c1", "r2c2", "t3r2c3"), ("r3c1y", "r3c2", "t3r3c3")).toDF("c1", "c2", "c3") - withTempTable("t1", "t2", "t3") { - df1.registerTempTable("t1") - df2.registerTempTable("t2") - df3.registerTempTable("t3") + withTempView("t1", "t2", "t3") { + df1.createOrReplaceTempView("t1") + df2.createOrReplaceTempView("t2") + df3.createOrReplaceTempView("t3") // inner join with one using column checkAnswer( sql("SELECT * FROM t1 join t2 using (c1)"), @@ -2473,4 +2274,383 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row("r3c1x", "r3c2", "t1r3c3", "r3c2", "t1r3c3") :: Nil) } } + + test("SPARK-15327: fail to compile generated code with complex data structure") { + withTempDir{ dir => + val json = + """ + |{"h": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}], + |"b": [{"e": "test", "count": 1}]}}, "d": {"b": {"c": [{"e": "adfgd"}], + |"a": [{"e": "testing", "count": 3}], "b": [{"e": "test", "count": 1}]}}, + |"c": {"b": {"c": [{"e": "adfgd"}], "a": [{"count": 3}], + |"b": [{"e": "test", "count": 1}]}}, "a": {"b": {"c": [{"e": "adfgd"}], + |"a": [{"count": 3}], "b": [{"e": "test", "count": 1}]}}, + |"e": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}], + |"b": [{"e": "test", "count": 1}]}}, "g": {"b": {"c": [{"e": "adfgd"}], + |"a": [{"e": "testing", "count": 3}], "b": [{"e": "test", "count": 1}]}}, + |"f": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}], + |"b": [{"e": "test", "count": 1}]}}, "b": {"b": {"c": [{"e": "adfgd"}], + |"a": [{"count": 3}], "b": [{"e": "test", "count": 1}]}}}' + | + """.stripMargin + val rdd = sparkContext.parallelize(Array(json)) + spark.read.json(rdd).write.mode("overwrite").parquet(dir.toString) + spark.read.parquet(dir.toString).collect() + } + } + + test("SPARK-14986: Outer lateral view with empty generate expression") { + checkAnswer( + sql("select nil from (select 1 as x ) x lateral view outer explode(array()) n as nil"), + Row(null) :: Nil + ) + } + + test("data source table created in InMemoryCatalog should be able to read/write") { + withTable("tbl") { + sql("CREATE TABLE tbl(i INT, j STRING) USING parquet") + checkAnswer(sql("SELECT i, j FROM tbl"), Nil) + + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.mode("overwrite").insertInto("tbl") + checkAnswer(sql("SELECT i, j FROM tbl"), Row(1, "a") :: Row(2, "b") :: Nil) + + Seq(3 -> "c", 4 -> "d").toDF("i", "j").write.mode("append").saveAsTable("tbl") + checkAnswer( + sql("SELECT i, j FROM tbl"), + Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) + } + } + + test("Eliminate noop ordinal ORDER BY") { + withSQLConf(SQLConf.ORDER_BY_ORDINAL.key -> "true") { + val plan1 = sql("SELECT 1.0, 'abc', year(current_date()) ORDER BY 1, 2, 3") + val plan2 = sql("SELECT 1.0, 'abc', year(current_date())") + comparePlans(plan1.queryExecution.optimizedPlan, plan2.queryExecution.optimizedPlan) + } + } + + test("check code injection is prevented") { + // The end of comment (*/) should be escaped. + var literal = + """|*/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + var expected = + """|*/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + // `\u002A` is `*` and `\u002F` is `/` + // so if the end of comment consists of those characters in queries, we need to escape them. + literal = + """|\\u002A/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + s"""|${"\\u002A/"} + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\\\u002A/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + """|\\u002A/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\u002a/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + s"""|${"\\u002a/"} + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\\\u002a/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + """|\\u002a/ + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|*\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + s"""|${"*\\u002F"} + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|*\\\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + """|*\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|*\\u002f + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + s"""|${"*\\u002f"} + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|*\\\\u002f + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + """|*\\u002f + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\u002A\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + s"""|${"\\u002A\\u002F"} + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\\\u002A\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + s"""|${"\\\\u002A\\u002F"} + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\u002A\\\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + s"""|${"\\u002A\\\\u002F"} + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + + literal = + """|\\\\u002A\\\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + expected = + """|\\u002A\\u002F + |{ + | new Object() { + | void f() { throw new RuntimeException("This exception is injected."); } + | }.f(); + |} + |/*""".stripMargin + checkAnswer( + sql(s"SELECT '$literal' AS DUMMY"), + Row(s"$expected") :: Nil) + } + + test("SPARK-16975: Column-partition path starting '_' should be handled correctly") { + withTempDir { dir => + val parquetDir = new File(dir, "parquet").getCanonicalPath + spark.range(10).withColumn("_col", $"id").write.partitionBy("_col").save(parquetDir) + spark.read.parquet(parquetDir) + } + } + + test("SPARK-16644: Aggregate should not put aggregate expressions to constraints") { + withTable("tbl") { + sql("CREATE TABLE tbl(a INT, b INT) USING parquet") + checkAnswer(sql( + """ + |SELECT + | a, + | MAX(b) AS c1, + | b AS c2 + |FROM tbl + |WHERE a = b + |GROUP BY a, b + |HAVING c1 = 1 + """.stripMargin), Nil) + } + } + + test("SPARK-16674: field names containing dots for both fields and partitioned fields") { + withTempPath { path => + val data = (1 to 10).map(i => (i, s"data-$i", i % 2, if ((i % 2) == 0) "a" else "b")) + .toDF("col.1", "col.2", "part.col1", "part.col2") + data.write + .format("parquet") + .partitionBy("part.col1", "part.col2") + .save(path.getCanonicalPath) + val readBack = spark.read.format("parquet").load(path.getCanonicalPath) + checkAnswer( + readBack.selectExpr("`part.col1`", "`col.1`"), + data.selectExpr("`part.col1`", "`col.1`")) + } + } + + test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") { + val numRecordsRead = spark.sparkContext.longAccumulator + spark.range(1, 100, 1, numPartitions = 10).map { x => + numRecordsRead.add(1) + x + }.limit(1).queryExecution.toRdd.count() + assert(numRecordsRead.value === 10) + } + + test("CREATE TABLE USING should not fail if a same-name temp view exists") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + sql("CREATE TABLE same_name(i int) USING json") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + assert(spark.table("default.same_name").collect().isEmpty) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala new file mode 100644 index 0000000000000..02841d7bb03ff --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File +import java.util.{Locale, TimeZone} + +import scala.util.control.NonFatal + +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +/** + * End-to-end test cases for SQL queries. + * + * Each case is loaded from a file in "spark/sql/core/src/test/resources/sql-tests/inputs". + * Each case has a golden result file in "spark/sql/core/src/test/resources/sql-tests/results". + * + * To run the entire test suite: + * {{{ + * build/sbt "sql/test-only *SQLQueryTestSuite" + * }}} + * + * To run a single test file upon change: + * {{{ + * build/sbt "~sql/test-only *SQLQueryTestSuite -- -z inline-table.sql" + * }}} + * + * To re-generate golden files, run: + * {{{ + * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/test-only *SQLQueryTestSuite" + * }}} + * + * The format for input files is simple: + * 1. A list of SQL queries separated by semicolon. + * 2. Lines starting with -- are treated as comments and ignored. + * + * For example: + * {{{ + * -- this is a comment + * select 1, -1; + * select current_date; + * }}} + * + * The format for golden result files look roughly like: + * {{{ + * -- some header information + * + * -- !query 0 + * select 1, -1 + * -- !query 0 schema + * struct<...schema...> + * -- !query 0 output + * ... data row 1 ... + * ... data row 2 ... + * ... + * + * -- !query 1 + * ... + * }}} + */ +class SQLQueryTestSuite extends QueryTest with SharedSQLContext { + + private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" + + private val baseResourcePath = { + // If regenerateGoldenFiles is true, we must be running this in SBT and we use hard-coded + // relative path. Otherwise, we use classloader's getResource to find the location. + if (regenerateGoldenFiles) { + java.nio.file.Paths.get("src", "test", "resources", "sql-tests").toFile + } else { + val res = getClass.getClassLoader.getResource("sql-tests") + new File(res.getFile) + } + } + + private val inputFilePath = new File(baseResourcePath, "inputs").getAbsolutePath + private val goldenFilePath = new File(baseResourcePath, "results").getAbsolutePath + + /** List of test cases to ignore, in lower cases. */ + private val blackList = Set( + "blacklist.sql" // Do NOT remove this one. It is here to test the blacklist functionality. + ) + + // Create all the test cases. + listTestCases().foreach(createScalaTestCase) + + /** A test case. */ + private case class TestCase(name: String, inputFile: String, resultFile: String) + + /** A single SQL query's output. */ + private case class QueryOutput(sql: String, schema: String, output: String) { + def toString(queryIndex: Int): String = { + // We are explicitly not using multi-line string due to stripMargin removing "|" in output. + s"-- !query $queryIndex\n" + + sql + "\n" + + s"-- !query $queryIndex schema\n" + + schema + "\n" + + s"-- !query $queryIndex output\n" + + output + } + } + + private def createScalaTestCase(testCase: TestCase): Unit = { + if (blackList.contains(testCase.name.toLowerCase)) { + // Create a test case to ignore this case. + ignore(testCase.name) { /* Do nothing */ } + } else { + // Create a test case to run this case. + test(testCase.name) { runTest(testCase) } + } + } + + /** Run a test case. */ + private def runTest(testCase: TestCase): Unit = { + val input = fileToString(new File(testCase.inputFile)) + + // List of SQL queries to run + val queries: Seq[String] = { + val cleaned = input.split("\n").filterNot(_.startsWith("--")).mkString("\n") + // note: this is not a robust way to split queries using semicolon, but works for now. + cleaned.split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq + } + + // Create a local SparkSession to have stronger isolation between different test cases. + // This does not isolate catalog changes. + val localSparkSession = spark.newSession() + loadTestData(localSparkSession) + + // Run the SQL queries preparing them for comparison. + val outputs: Seq[QueryOutput] = queries.map { sql => + val (schema, output) = getNormalizedResult(localSparkSession, sql) + // We might need to do some query canonicalization in the future. + QueryOutput( + sql = sql, + schema = schema.catalogString, + output = output.mkString("\n").trim) + } + + if (regenerateGoldenFiles) { + // Again, we are explicitly not using multi-line string due to stripMargin removing "|". + val goldenOutput = { + s"-- Automatically generated by ${getClass.getSimpleName}\n" + + s"-- Number of queries: ${outputs.size}\n\n\n" + + outputs.zipWithIndex.map{case (qr, i) => qr.toString(i)}.mkString("\n\n\n") + "\n" + } + stringToFile(new File(testCase.resultFile), goldenOutput) + } + + // Read back the golden file. + val expectedOutputs: Seq[QueryOutput] = { + val goldenOutput = fileToString(new File(testCase.resultFile)) + val segments = goldenOutput.split("-- !query.+\n") + + // each query has 3 segments, plus the header + assert(segments.size == outputs.size * 3 + 1, + s"Expected ${outputs.size * 3 + 1} blocks in result file but got ${segments.size}. " + + s"Try regenerate the result files.") + Seq.tabulate(outputs.size) { i => + QueryOutput( + sql = segments(i * 3 + 1).trim, + schema = segments(i * 3 + 2).trim, + output = segments(i * 3 + 3).trim + ) + } + } + + // Compare results. + assertResult(expectedOutputs.size, s"Number of queries should be ${expectedOutputs.size}") { + outputs.size + } + + outputs.zip(expectedOutputs).zipWithIndex.foreach { case ((output, expected), i) => + assertResult(expected.sql, s"SQL query did not match for query #$i\n${expected.sql}") { + output.sql + } + assertResult(expected.schema, s"Schema did not match for query #$i\n${expected.sql}") { + output.schema + } + assertResult(expected.output, s"Result dit not match for query #$i\n${expected.sql}") { + output.output + } + } + } + + /** Executes a query and returns the result as (schema of the output, normalized output). */ + private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = { + // Returns true if the plan is supposed to be sorted. + def isSorted(plan: LogicalPlan): Boolean = plan match { + case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false + case PhysicalOperation(_, _, Sort(_, true, _)) => true + case _ => plan.children.iterator.exists(isSorted) + } + + try { + val df = session.sql(sql) + val schema = df.schema + val answer = df.queryExecution.hiveResultString() + + // If the output is not pre-sorted, sort it. + if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) + + } catch { + case a: AnalysisException if a.plan.nonEmpty => + // Do not output the logical plan tree which contains expression IDs. + (StructType(Seq.empty), Seq(a.getClass.getName, a.getSimpleMessage)) + case NonFatal(e) => + // If there is an exception, put the exception class followed by the message. + (StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage)) + } + } + + private def listTestCases(): Seq[TestCase] = { + listFilesRecursively(new File(inputFilePath)).map { file => + val resultFile = file.getAbsolutePath.replace(inputFilePath, goldenFilePath) + ".out" + TestCase(file.getName, file.getAbsolutePath, resultFile) + } + } + + /** Returns all the files (not directories) in a directory, recursively. */ + private def listFilesRecursively(path: File): Seq[File] = { + val (dirs, files) = path.listFiles().partition(_.isDirectory) + files ++ dirs.flatMap(listFilesRecursively) + } + + /** Load built-in test tables into the SparkSession. */ + private def loadTestData(session: SparkSession): Unit = { + import session.implicits._ + + (1 to 100).map(i => (i, i.toString)).toDF("key", "value").createOrReplaceTempView("testdata") + + ((Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: (Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) + .toDF("arraycol", "nestedarraycol") + .createOrReplaceTempView("arraydata") + + (Tuple1(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: + Tuple1(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: + Tuple1(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: + Tuple1(Map(1 -> "a4", 2 -> "b4")) :: + Tuple1(Map(1 -> "a5")) :: Nil) + .toDF("mapcol") + .createOrReplaceTempView("mapdata") + } + + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + + override def beforeAll(): Unit = { + super.beforeAll() + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + RuleExecutor.resetTime() + } + + override def afterAll(): Unit = { + try { + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + + // For debugging dump some statistics about how much time was spent in various optimizer rules + logWarning(RuleExecutor.dumpTimeSpent()) + } finally { + super.afterAll() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 295f02f9a7b5d..c9bd05d0e4e36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -34,7 +34,9 @@ case class ReflectData( decimalField: java.math.BigDecimal, date: Date, timestampField: Timestamp, - seqInt: Seq[Int]) + seqInt: Seq[Int], + javaBigInt: java.math.BigInteger, + scalaBigInt: scala.math.BigInt) case class NullReflectData( intField: java.lang.Integer, @@ -77,18 +79,20 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) - Seq(data).toDF().registerTempTable("reflectData") + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3), + new java.math.BigInteger("1"), scala.math.BigInt(1)) + Seq(data).toDF().createOrReplaceTempView("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), - new Timestamp(12345), Seq(1, 2, 3))) + new Timestamp(12345), Seq(1, 2, 3), new java.math.BigDecimal(1), + new java.math.BigDecimal(1))) } test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) - Seq(data).toDF().registerTempTable("reflectNullData") + Seq(data).toDF().createOrReplaceTempView("reflectNullData") assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) @@ -96,7 +100,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) - Seq(data).toDF().registerTempTable("reflectOptionalData") + Seq(data).toDF().createOrReplaceTempView("reflectOptionalData") assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) @@ -104,7 +108,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { // Equality is broken for Arrays, so we test that separately. test("query binary data") { - Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") + Seq(ReflectBinary(Array[Byte](1))).toDF().createOrReplaceTempView("reflectBinary") val result = sql("SELECT data FROM reflectBinary") .collect().head(0).asInstanceOf[Array[Byte]] @@ -124,7 +128,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) - Seq(data).toDF().registerTempTable("reflectComplexData") + Seq(data).toDF().createOrReplaceTempView("reflectComplexData") assert(sql("SELECT * FROM reflectComplexData").collect().head === Row( Seq(1, 2, 3), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index ddab918629645..cd6b2647e0be6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.test.SharedSQLContext class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { - val _sqlContext = new SQLContext(sparkContext) - new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext) + val spark = SparkSession.builder.getOrCreate() + new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala new file mode 100644 index 0000000000000..386d13d07a95f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} + +/** + * Test cases for the builder pattern of [[SparkSession]]. + */ +class SparkSessionBuilderSuite extends SparkFunSuite { + + private var initialSession: SparkSession = _ + + private lazy val sparkContext: SparkContext = { + initialSession = SparkSession.builder() + .master("local") + .config("spark.ui.enabled", value = false) + .config("some-config", "v2") + .getOrCreate() + initialSession.sparkContext + } + + test("create with config options and propagate them to SparkContext and SparkSession") { + // Creating a new session with config - this works by just calling the lazy val + sparkContext + assert(initialSession.sparkContext.conf.get("some-config") == "v2") + assert(initialSession.conf.get("some-config") == "v2") + SparkSession.clearDefaultSession() + } + + test("use global default session") { + val session = SparkSession.builder().getOrCreate() + assert(SparkSession.builder().getOrCreate() == session) + SparkSession.clearDefaultSession() + } + + test("config options are propagated to existing SparkSession") { + val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate() + assert(session1.conf.get("spark-config1") == "a") + val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate() + assert(session1 == session2) + assert(session1.conf.get("spark-config1") == "b") + SparkSession.clearDefaultSession() + } + + test("use session from active thread session and propagate config options") { + val defaultSession = SparkSession.builder().getOrCreate() + val activeSession = defaultSession.newSession() + SparkSession.setActiveSession(activeSession) + val session = SparkSession.builder().config("spark-config2", "a").getOrCreate() + + assert(activeSession != defaultSession) + assert(session == activeSession) + assert(session.conf.get("spark-config2") == "a") + SparkSession.clearActiveSession() + + assert(SparkSession.builder().getOrCreate() == defaultSession) + SparkSession.clearDefaultSession() + } + + test("create a new session if the default session has been stopped") { + val defaultSession = SparkSession.builder().getOrCreate() + SparkSession.setDefaultSession(defaultSession) + defaultSession.stop() + val newSession = SparkSession.builder().master("local").getOrCreate() + assert(newSession != defaultSession) + newSession.stop() + } + + test("create a new session if the active thread session has been stopped") { + val activeSession = SparkSession.builder().master("local").getOrCreate() + SparkSession.setActiveSession(activeSession) + activeSession.stop() + val newSession = SparkSession.builder().master("local").getOrCreate() + assert(newSession != activeSession) + newSession.stop() + } + + test("create SparkContext first then SparkSession") { + sparkContext.stop() + val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") + val sparkContext2 = new SparkContext(conf) + val session = SparkSession.builder().config("key2", "value2").getOrCreate() + assert(session.conf.get("key1") == "value1") + assert(session.conf.get("key2") == "value2") + assert(session.sparkContext.conf.get("key1") == "value1") + assert(session.sparkContext.conf.get("key2") == "value2") + assert(session.sparkContext.conf.get("spark.app.name") == "test") + session.stop() + } + + test("SPARK-15887: hive-site.xml should be loaded") { + val session = SparkSession.builder().master("local").getOrCreate() + assert(session.sessionState.newHadoopConf().get("hive.in.test") == "true") + assert(session.sparkContext.hadoopConfiguration.get("hive.in.test") == "true") + session.stop() + } + + test("SPARK-15991: Set global Hadoop conf") { + val session = SparkSession.builder().master("local").getOrCreate() + val mySpecialKey = "my.special.key.15991" + val mySpecialValue = "msv" + try { + session.sparkContext.hadoopConfiguration.set(mySpecialKey, mySpecialValue) + assert(session.sessionState.newHadoopConf().get(mySpecialKey) == mySpecialValue) + } finally { + session.sparkContext.hadoopConfiguration.unset(mySpecialKey) + session.stop() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala new file mode 100644 index 0000000000000..2c81cbf15f088 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +class StatisticsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { + val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) + val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) + assert(df.queryExecution.analyzed.statistics.sizeInBytes > + spark.sessionState.conf.autoBroadcastJoinThreshold) + assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes > + spark.sessionState.conf.autoBroadcastJoinThreshold) + } + + test("estimates the size of limit") { + withTempView("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => + val df = sql(s"""SELECT * FROM test limit $limit""") + + val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => + g.statistics.sizeInBytes + } + assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesGlobalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") + + val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => + l.statistics.sizeInBytes + } + assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesLocalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") + } + } + } + + test("estimates the size of a limit 0 on outer join") { + withTempView("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + val df1 = spark.table("test") + val df2 = spark.table("test").limit(0) + val df = df1.join(df2, Seq("k"), "left") + + val sizes = df.queryExecution.analyzed.collect { case g: Join => + g.statistics.sizeInBytes + } + + assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") + assert(sizes.head === BigInt(96), + s"expected exact size 96 for table 'test', got: ${sizes.head}") + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 6809f26968836..9be2de9c7d719 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -48,6 +48,20 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("a||b")) } + test("string elt") { + val df = Seq[(String, String, String, Int)](("hello", "world", null, 15)) + .toDF("a", "b", "c", "d") + + checkAnswer( + df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"), + Row(null, "hello", null)) + + // check implicit type cast + checkAnswer( + df.selectExpr("elt(4, a, b, c, d)", "elt('2', a, b, c, d)"), + Row("15", "world")) + } + test("string Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1))) @@ -78,6 +92,18 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) } + test("non-matching optional group") { + val df = Seq(Tuple1("aaaac")).toDF("s") + checkAnswer( + df.select(regexp_extract($"s", "(foo)", 1)), + Row("") + ) + checkAnswer( + df.select(regexp_extract($"s", "(a+)(b)?(c)", 2)), + Row("") + ) + } + test("string ascii function") { val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( @@ -189,15 +215,15 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string locate function") { - val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") + val df = Seq(("aaads", "aa", "zz", 2)).toDF("a", "b", "c", "d") checkAnswer( - df.select(locate("aa", $"a"), locate("aa", $"a", 1)), - Row(1, 2)) + df.select(locate("aa", $"a"), locate("aa", $"a", 2), locate("aa", $"a", 0)), + Row(1, 2, 0)) checkAnswer( - df.selectExpr("locate(b, a)", "locate(b, a, d)"), - Row(1, 2)) + df.selectExpr("locate(b, a)", "locate(b, a, d)", "locate(b, a, 3)"), + Row(1, 2, 0)) } test("string padding functions") { @@ -212,6 +238,21 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row("???hi", "hi???", "h", "h")) } + test("string parse_url function") { + val df = Seq[String](("http://userinfo@spark.apache.org/path?query=1#Ref")) + .toDF("url") + + checkAnswer( + df.selectExpr( + "parse_url(url, 'HOST')", "parse_url(url, 'PATH')", + "parse_url(url, 'QUERY')", "parse_url(url, 'REF')", + "parse_url(url, 'PROTOCOL')", "parse_url(url, 'FILE')", + "parse_url(url, 'AUTHORITY')", "parse_url(url, 'USERINFO')", + "parse_url(url, 'QUERY', 'query')"), + Row("spark.apache.org", "/path", "query=1", "Ref", + "http", "/path?query=1", "userinfo@spark.apache.org", "userinfo", "1")) + } + test("string repeat function") { val df = Seq(("hi", 2)).toDF("a", "b") @@ -281,7 +322,7 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("number format function") { - val df = sqlContext.range(1) + val df = spark.range(1) checkAnswer( df.select(format_number(lit(5L), 4)), @@ -333,4 +374,47 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { df2.filter("b>0").selectExpr("format_number(a, b)"), Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil) } + + test("string sentences function") { + val df = Seq(("Hi there! The price was $1,234.56.... But, not now.", "en", "US")) + .toDF("str", "language", "country") + + checkAnswer( + df.selectExpr("sentences(str, language, country)"), + Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now")))) + + // Type coercion + checkAnswer( + df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"), + Row(null, Seq(Seq("10")), Seq(Seq("3.14")))) + + // Argument number exception + val m = intercept[AnalysisException] { + df.selectExpr("sentences()") + }.getMessage + assert(m.contains("Invalid number of arguments for function sentences")) + } + + test("str_to_map function") { + val df1 = Seq( + ("a=1,b=2", "y"), + ("a=1,b=2,c=3", "y") + ).toDF("a", "b") + + checkAnswer( + df1.selectExpr("str_to_map(a,',','=')"), + Seq( + Row(Map("a" -> "1", "b" -> "2")), + Row(Map("a" -> "1", "b" -> "2", "c" -> "3")) + ) + ) + + val df2 = Seq(("a:1,b:2,c:3", "y")).toDF("a", "b") + + checkAnswer( + df2.selectExpr("str_to_map(a)"), + Seq(Row(Map("a" -> "1", "b" -> "2", "c" -> "3"))) + ) + + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 0bf4c6f960639..afed342ff8e2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -49,9 +49,13 @@ class SubquerySuite extends QueryTest with SharedSQLContext { protected override def beforeAll(): Unit = { super.beforeAll() - l.registerTempTable("l") - r.registerTempTable("r") - t.registerTempTable("t") + l.createOrReplaceTempView("l") + r.createOrReplaceTempView("r") + t.createOrReplaceTempView("t") + } + + test("rdd deserialization does not crash [SPARK-15791]") { + sql("select (select 1 as b) as b").rdd.count() } test("simple uncorrelated scalar subquery") { @@ -99,7 +103,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("uncorrelated scalar subquery on a DataFrame generated query") { val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value") - df.registerTempTable("subqueryData") + df.createOrReplaceTempView("subqueryData") checkAnswer( sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1"), @@ -123,6 +127,33 @@ class SubquerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-15677: Queries against local relations with scalar subquery in Select list") { + withTempView("t1", "t2") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + + checkAnswer( + sql("SELECT (select 1 as col) from t1"), + Row(1) :: Row(1) :: Nil) + + checkAnswer( + sql("SELECT (select max(c1) from t2) from t1"), + Row(2) :: Row(2) :: Nil) + + checkAnswer( + sql("SELECT 1 + (select 1 as col) from t1"), + Row(2) :: Row(2) :: Nil) + + checkAnswer( + sql("SELECT c1, (select max(c1) from t2) + c2 from t1"), + Row(1, 3) :: Row(2, 4) :: Nil) + + checkAnswer( + sql("SELECT c1, (select max(c1) from t2 where t1.c2 = t2.c2) from t1"), + Row(1, 1) :: Row(2, 2) :: Nil) + } + } + test("SPARK-14791: scalar subquery inside broadcast join") { val df = sql("select a, sum(b) as s from l group by a having a > (select avg(a) from l)") val expected = Row(3, 2.0, 3, 3.0) :: Row(6, null, 6, null) :: Nil @@ -152,6 +183,19 @@ class SubquerySuite extends QueryTest with SharedSQLContext { Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) } + test("EXISTS predicate subquery within OR") { + checkAnswer( + sql("select * from l where exists (select * from r where l.a = r.c)" + + " or exists (select * from r where l.a = r.c)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + checkAnswer( + sql("select * from l where not exists (select * from r where l.a = r.c and l.b < r.d)" + + " or not exists (select * from r where l.a = r.c)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) + } + test("IN predicate subquery") { checkAnswer( sql("select * from l where l.a in (select c from r)"), @@ -187,6 +231,18 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } + test("IN predicate subquery within OR") { + checkAnswer( + sql("select * from l where l.a in (select c from r)" + + " or l.a in (select c from r where l.b < r.d)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + intercept[AnalysisException] { + sql("select * from l where a not in (select c from r)" + + " or a not in (select c from r where c is not null)") + } + } + test("complex IN predicate subquery") { checkAnswer( sql("select * from l where (a, b) not in (select c, d from r)"), @@ -209,4 +265,310 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select a from l group by 1 having exists (select 1 from r where d < min(b))"), Row(null) :: Row(1) :: Row(3) :: Nil) } + + test("SPARK-15832: Test embedded existential predicate sub-queries") { + withTempView("t1", "t2", "t3", "t4", "t5") { + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") + Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") + Seq((1, 1), (2, 2), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t3") + + checkAnswer( + sql( + """ + | select c1 from t1 + | where c2 IN (select c2 from t2) + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where c2 NOT IN (select c2 from t2) + | + """.stripMargin), + Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where EXISTS (select c2 from t2) + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where NOT EXISTS (select c2 from t2) + | + """.stripMargin), + Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where NOT EXISTS (select c2 from t2) and + | c2 IN (select c2 from t3) + | + """.stripMargin), + Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where (case when c2 IN (select 1 as one) then 1 + | else 2 end) = c1 + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where (case when c2 IN (select 1 as one) then 1 + | else 2 end) + | IN (select c2 from t2) + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where (case when c2 IN (select c2 from t2) then 1 + | else 2 end) + | IN (select c2 from t3) + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where (case when c2 IN (select c2 from t2) then 1 + | when c2 IN (select c2 from t3) then 2 + | else 3 end) + | IN (select c2 from t1) + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where (c1, (case when c2 IN (select c2 from t2) then 1 + | when c2 IN (select c2 from t3) then 2 + | else 3 end)) + | IN (select c1, c2 from t1) + | + """.stripMargin), + Row(1) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t3 + | where ((case when c2 IN (select c2 from t2) then 1 else 2 end), + | (case when c2 IN (select c2 from t3) then 2 else 3 end)) + | IN (select c1, c2 from t3) + | + """.stripMargin), + Row(1) :: Row(2) :: Row(1) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where ((case when EXISTS (select c2 from t2) then 1 else 2 end), + | (case when c2 IN (select c2 from t3) then 2 else 3 end)) + | IN (select c1, c2 from t3) + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where (case when c2 IN (select c2 from t2) then 3 + | else 2 end) + | NOT IN (select c2 from t3) + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where ((case when c2 IN (select c2 from t2) then 1 else 2 end), + | (case when NOT EXISTS (select c2 from t3) then 2 + | when EXISTS (select c2 from t2) then 3 + | else 3 end)) + | NOT IN (select c1, c2 from t3) + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql( + """ + | select c1 from t1 + | where (select max(c1) from t2 where c2 IN (select c2 from t3)) + | IN (select c2 from t2) + | + """.stripMargin), + Row(1) :: Row(2) :: Nil) + } + } + + test("correlated scalar subquery in where") { + checkAnswer( + sql("select * from l where b < (select max(d) from r where a = c)"), + Row(2, 1.0) :: Row(2, 1.0) :: Nil) + } + + test("correlated scalar subquery in select") { + checkAnswer( + sql("select a, (select sum(b) from l l2 where l2.a = l1.a) sum_b from l l1"), + Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, null) :: Row(6, null) :: Nil) + } + + test("correlated scalar subquery in select (null safe)") { + checkAnswer( + sql("select a, (select sum(b) from l l2 where l2.a <=> l1.a) sum_b from l l1"), + Row(1, 4.0) :: Row(1, 4.0) :: Row(2, 2.0) :: Row(2, 2.0) :: Row(3, 3.0) :: + Row(null, 5.0) :: Row(null, 5.0) :: Row(6, null) :: Nil) + } + + test("correlated scalar subquery in aggregate") { + checkAnswer( + sql("select a, (select sum(d) from r where a = c) sum_d from l l1 group by 1, 2"), + Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, null) :: Nil) + } + + test("non-aggregated correlated scalar subquery") { + val msg1 = intercept[AnalysisException] { + sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") + } + assert(msg1.getMessage.contains("Correlated scalar subqueries must be Aggregated")) + + val msg2 = intercept[AnalysisException] { + sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") + } + assert(msg2.getMessage.contains( + "The output of a correlated scalar subquery must be aggregated")) + } + + test("non-equal correlated scalar subquery") { + val msg1 = intercept[AnalysisException] { + sql("select a, (select b from l l2 where l2.a < l1.a) sum_b from l l1") + } + assert(msg1.getMessage.contains( + "The correlated scalar subquery can only contain equality predicates")) + } + + test("disjunctive correlated scalar subquery") { + checkAnswer( + sql(""" + |select a + |from l + |where (select count(*) + | from r + | where (a = c and d = 2.0) or (a = c and d = 1.0)) > 0 + """.stripMargin), + Row(3) :: Nil) + } + + test("SPARK-15370: COUNT bug in WHERE clause (Filter)") { + // Case 1: Canonical example of the COUNT bug + checkAnswer( + sql("select l.a from l where (select count(*) from r where l.a = r.c) < l.a"), + Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) + // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses + // a rewrite that is vulnerable to the COUNT bug + checkAnswer( + sql("select l.a from l where (select count(*) from r where l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + // Case 3: COUNT bug without a COUNT aggregate + checkAnswer( + sql("select l.a from l where (select sum(r.d) is null from r where l.a = r.c)"), + Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) + } + + test("SPARK-15370: COUNT bug in SELECT clause (Project)") { + checkAnswer( + sql("select a, (select count(*) from r where l.a = r.c) as cnt from l"), + Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0) + :: Row(null, 0) :: Row(6, 1) :: Nil) + } + + test("SPARK-15370: COUNT bug in HAVING clause (Filter)") { + checkAnswer( + sql("select l.a as grp_a from l group by l.a " + + "having (select count(*) from r where grp_a = r.c) = 0 " + + "order by grp_a"), + Row(null) :: Row(1) :: Nil) + } + + test("SPARK-15370: COUNT bug in Aggregate") { + checkAnswer( + sql("select l.a as aval, sum((select count(*) from r where l.a = r.c)) as cnt " + + "from l group by l.a order by aval"), + Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil) + } + + test("SPARK-15370: COUNT bug negative examples") { + // Case 1: Potential COUNT bug case that was working correctly prior to the fix + checkAnswer( + sql("select l.a from l where (select sum(r.d) from r where l.a = r.c) is null"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil) + // Case 2: COUNT aggregate but no COUNT bug due to > 0 test. + checkAnswer( + sql("select l.a from l where (select count(*) from r where l.a = r.c) > 0"), + Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil) + // Case 3: COUNT inside aggregate expression but no COUNT bug. + checkAnswer( + sql("select l.a from l where (select count(*) + sum(r.d) from r where l.a = r.c) = 0"), + Nil) + } + + test("SPARK-15370: COUNT bug in subquery in subquery in subquery") { + checkAnswer( + sql("""select l.a from l + |where ( + | select cntPlusOne + 1 as cntPlusTwo from ( + | select cnt + 1 as cntPlusOne from ( + | select sum(r.c) s, count(*) cnt from r where l.a = r.c having cnt = 0 + | ) + | ) + |) = 2""".stripMargin), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-15370: COUNT bug with nasty predicate expr") { + checkAnswer( + sql("select l.a from l where " + + "(select case when count(*) = 1 then null else count(*) end as cnt " + + "from r where l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") { + checkAnswer( + sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"), + Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: + Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ec950332c5f63..547d3c1abe858 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -26,7 +26,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("built-in fixed arity expressions") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") } @@ -53,25 +53,25 @@ class UDFSuite extends QueryTest with SharedSQLContext { test("SPARK-8003 spark_partition_id") { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") - df.registerTempTable("tmp_table") + df.createOrReplaceTempView("tmp_table") checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) - sqlContext.dropTempTable("tmp_table") + spark.catalog.dropTempView("tmp_table") } test("SPARK-8005 input_file_name") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10, 2).toDF("id") data.write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("test_table") val answer = sql("select input_file_name() from test_table").head().getString(0) assert(answer.contains(dir.getCanonicalPath)) assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2) - sqlContext.dropTempTable("test_table") + spark.catalog.dropTempView("test_table") } } test("error reporting for incorrect number of arguments") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("substr('abcd', 2, 3, 4)") } @@ -79,7 +79,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("error reporting for undefined functions") { - val df = sqlContext.emptyDataFrame + val df = spark.emptyDataFrame val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } @@ -88,26 +88,26 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("Simple UDF") { - sqlContext.udf.register("strLenScala", (_: String).length) + spark.udf.register("strLenScala", (_: String).length) assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - sqlContext.udf.register("random0", () => { Math.random()}) + spark.udf.register("random0", () => { Math.random()}) assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - sqlContext.udf.register("strLenScala", (_: String).length + (_: Int)) + spark.udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("UDF in a WHERE") { - sqlContext.udf.register("oneArgFilter", (n: Int) => { n > 80 }) + spark.udf.register("oneArgFilter", (n: Int) => { n > 80 }) val df = sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() - df.registerTempTable("integerData") + df.createOrReplaceTempView("integerData") val result = sql("SELECT * FROM integerData WHERE oneArgFilter(key)") @@ -115,11 +115,11 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a HAVING") { - sqlContext.udf.register("havingFilter", (n: Long) => { n > 5 }) + spark.udf.register("havingFilter", (n: Long) => { n > 5 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") - df.registerTempTable("groupData") + df.createOrReplaceTempView("groupData") val result = sql( @@ -134,11 +134,11 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDF in a GROUP BY") { - sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") - df.registerTempTable("groupData") + df.createOrReplaceTempView("groupData") val result = sql( @@ -151,14 +151,14 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("UDFs everywhere") { - sqlContext.udf.register("groupFunction", (n: Int) => { n > 10 }) - sqlContext.udf.register("havingFilter", (n: Long) => { n > 2000 }) - sqlContext.udf.register("whereFilter", (n: Int) => { n < 150 }) - sqlContext.udf.register("timesHundred", (n: Long) => { n * 100 }) + spark.udf.register("groupFunction", (n: Int) => { n > 10 }) + spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) + spark.udf.register("whereFilter", (n: Int) => { n < 150 }) + spark.udf.register("timesHundred", (n: Long) => { n * 100 }) val df = Seq(("red", 1), ("red", 2), ("blue", 10), ("green", 100), ("green", 200)).toDF("g", "v") - df.registerTempTable("groupData") + df.createOrReplaceTempView("groupData") val result = sql( @@ -173,7 +173,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("struct UDF") { - sqlContext.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + spark.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = sql("SELECT returnStruct('test', 'test2') as ret") @@ -182,27 +182,27 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("udf that is transformed") { - sqlContext.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + spark.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } test("type coercion for udf inputs") { - sqlContext.udf.register("intExpected", (x: Int) => x) + spark.udf.register("intExpected", (x: Int) => x) // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } test("udf in different types") { - sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) - sqlContext.udf.register("decimalDataFunc", + spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) + spark.udf.register("decimalDataFunc", (a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) }) - sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) - sqlContext.udf.register("arrayDataFunc", + spark.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) + spark.udf.register("arrayDataFunc", (data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) }) - sqlContext.udf.register("mapDataFunc", + spark.udf.register("mapDataFunc", (data: scala.collection.Map[Int, String]) => { data }) - sqlContext.udf.register("complexDataFunc", + spark.udf.register("complexDataFunc", (m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } ) checkAnswer( @@ -235,7 +235,7 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("SPARK-11716 UDFRegistration does not include the input data type in returned UDF") { - val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) + val myUDF = spark.udf.register("testDataFunc", (n: Int, s: String) => { (n, s.toInt) }) // Without the fix, this will fail because we fail to cast data type of b to string // because myUDF does not know its input data type. With the fix, this query should not diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a49aaa8b73386..474f17ff7afbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -94,8 +94,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT } test("UDTs and UDFs") { - sqlContext.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) - pointsRDD.registerTempTable("points") + spark.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) + pointsRDD.createOrReplaceTempView("points") checkAnswer( sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) @@ -106,7 +106,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val path = dir.getCanonicalPath pointsRDD.write.parquet(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) @@ -118,7 +118,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val path = dir.getCanonicalPath pointsRDD.repartition(1).write.parquet(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) @@ -146,7 +146,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT )) val stringRDD = sparkContext.parallelize(data) - val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) + val jsonRDD = spark.read.schema(schema).json(stringRDD) checkAnswer( jsonRDD, Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: @@ -167,7 +167,7 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT )) val stringRDD = sparkContext.parallelize(data) - val jsonDataset = sqlContext.read.schema(schema).json(stringRDD) + val jsonDataset = spark.read.schema(schema).json(stringRDD) .as[(Int, UDT.MyDenseVector)] checkDataset( jsonDataset, @@ -188,6 +188,10 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt) assert(toCatalystConverter(null) === null) + } + test("SPARK-15658: Analysis exception if Dataset.map returns UDT object") { + // call `collect` to make sure this query can pass analysis. + pointsRDD.as[MyLabeledPoint].map(_.copy(label = 2.0)).collect() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala new file mode 100644 index 0000000000000..1d33e7970be8e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +/** + * End-to-end tests for xpath expressions. + */ +class XPathFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("xpath_boolean") { + val df = Seq("b").toDF("xml") + checkAnswer(df.selectExpr("xpath_boolean(xml, 'a/b')"), Row(true)) + } + + test("xpath_short, xpath_int, xpath_long") { + val df = Seq("12").toDF("xml") + checkAnswer( + df.selectExpr( + "xpath_short(xml, 'sum(a/b)')", + "xpath_int(xml, 'sum(a/b)')", + "xpath_long(xml, 'sum(a/b)')"), + Row(3.toShort, 3, 3L)) + } + + test("xpath_float, xpath_double, xpath_number") { + val df = Seq("1.02.1").toDF("xml") + checkAnswer( + df.selectExpr( + "xpath_float(xml, 'sum(a/b)')", + "xpath_double(xml, 'sum(a/b)')", + "xpath_number(xml, 'sum(a/b)')"), + Row(3.1.toFloat, 3.1, 3.1)) + } + + test("xpath_string") { + val df = Seq("bcc").toDF("xml") + checkAnswer(df.selectExpr("xpath_string(xml, 'a/c')"), Row("cc")) + } + + test("xpath") { + val df = Seq("b1b2b3c1c2").toDF("xml") + checkAnswer(df.selectExpr("xpath(xml, 'a/*/text()')"), Row(Seq("b1", "b2", "b3", "c1", "c2"))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala deleted file mode 100644 index 841263d3dab93..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ /dev/null @@ -1,843 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.HashMap - -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} -import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.vectorized.AggregateHashMap -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{IntegerType, LongType, StructType} -import org.apache.spark.unsafe.Platform -import org.apache.spark.unsafe.hash.Murmur3_x86_32 -import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.util.Benchmark - -/** - * Benchmark to measure whole stage codegen performance. - * To run this: - * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" - */ -class BenchmarkWholeStageCodegen extends SparkFunSuite { - lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") - .set("spark.sql.shuffle.partitions", "1") - .set("spark.sql.autoBroadcastJoinThreshold", "1") - lazy val sc = SparkContext.getOrCreate(conf) - lazy val sqlContext = SQLContext.getOrCreate(sc) - - def runBenchmark(name: String, values: Long)(f: => Unit): Unit = { - val benchmark = new Benchmark(name, values) - - Seq(false, true).foreach { enabled => - benchmark.addCase(s"$name codegen=$enabled") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", enabled.toString) - f - } - } - - benchmark.run() - } - - // These benchmark are skipped in normal build - ignore("range/filter/sum") { - val N = 500L << 20 - runBenchmark("rang/filter/sum", N) { - sqlContext.range(N).filter("(id & 1) = 1").groupBy().sum().collect() - } - /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X - rang/filter/sum codegen=true 897 / 1022 584.6 1.7 16.4X - */ - } - - ignore("range/limit/sum") { - val N = 500L << 20 - runBenchmark("range/limit/sum", N) { - sqlContext.range(N).limit(1000000).groupBy().sum().collect() - } - /* - Westmere E56xx/L56xx/X56xx (Nehalem-C) - range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X - range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X - */ - } - - ignore("range/sample/sum") { - val N = 500 << 20 - runBenchmark("range/sample/sum", N) { - sqlContext.range(N).sample(true, 0.01).groupBy().sum().collect() - } - /* - Westmere E56xx/L56xx/X56xx (Nehalem-C) - range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - range/sample/sum codegen=false 53888 / 56592 9.7 102.8 1.0X - range/sample/sum codegen=true 41614 / 42607 12.6 79.4 1.3X - */ - - runBenchmark("range/sample/sum", N) { - sqlContext.range(N).sample(false, 0.01).groupBy().sum().collect() - } - /* - Westmere E56xx/L56xx/X56xx (Nehalem-C) - range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - range/sample/sum codegen=false 12982 / 13384 40.4 24.8 1.0X - range/sample/sum codegen=true 7074 / 7383 74.1 13.5 1.8X - */ - } - - ignore("stat functions") { - val N = 100L << 20 - - runBenchmark("stddev", N) { - sqlContext.range(N).groupBy().agg("id" -> "stddev").collect() - } - - runBenchmark("kurtosis", N) { - sqlContext.range(N).groupBy().agg("id" -> "kurtosis").collect() - } - - - /** - Using ImperativeAggregate (as implemented in Spark 1.6): - - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------------- - stddev w/o codegen 2019.04 10.39 1.00 X - stddev w codegen 2097.29 10.00 0.96 X - kurtosis w/o codegen 2108.99 9.94 0.96 X - kurtosis w codegen 2090.69 10.03 0.97 X - - Using DeclarativeAggregate: - - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - stddev: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - stddev codegen=false 5630 / 5776 18.0 55.6 1.0X - stddev codegen=true 1259 / 1314 83.0 12.0 4.5X - - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - kurtosis: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - kurtosis codegen=false 14847 / 15084 7.0 142.9 1.0X - kurtosis codegen=true 1652 / 2124 63.0 15.9 9.0X - */ - } - - ignore("aggregate with linear keys") { - val N = 20 << 20 - - val benchmark = new Benchmark("Aggregate w keys", N) - def f(): Unit = sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() - - benchmark.addCase(s"codegen = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - f() - } - - benchmark.addCase(s"codegen = T hashmap = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "0") - f() - } - - benchmark.addCase(s"codegen = T hashmap = T") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "3") - f() - } - - benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - codegen = F 2067 / 2166 10.1 98.6 1.0X - codegen = T hashmap = F 1149 / 1321 18.3 54.8 1.8X - codegen = T hashmap = T 388 / 475 54.0 18.5 5.3X - */ - } - - ignore("aggregate with randomized keys") { - val N = 20 << 20 - - val benchmark = new Benchmark("Aggregate w keys", N) - sqlContext.range(N).selectExpr("id", "floor(rand() * 10000) as k").registerTempTable("test") - - def f(): Unit = sqlContext.sql("select k, k, sum(id) from test group by k, k").collect() - - benchmark.addCase(s"codegen = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - f() - } - - benchmark.addCase(s"codegen = T hashmap = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "0") - f() - } - - benchmark.addCase(s"codegen = T hashmap = T") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "3") - f() - } - - benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - codegen = F 2517 / 2608 8.3 120.0 1.0X - codegen = T hashmap = F 1484 / 1560 14.1 70.8 1.7X - codegen = T hashmap = T 794 / 908 26.4 37.9 3.2X - */ - } - - ignore("aggregate with string key") { - val N = 20 << 20 - - val benchmark = new Benchmark("Aggregate w string key", N) - def f(): Unit = sqlContext.range(N).selectExpr("id", "cast(id & 1023 as string) as k") - .groupBy("k").count().collect() - - benchmark.addCase(s"codegen = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - f() - } - - benchmark.addCase(s"codegen = T hashmap = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "0") - f() - } - - benchmark.addCase(s"codegen = T hashmap = T") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "3") - f() - } - - benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Aggregate w string key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - codegen = F 3307 / 3376 6.3 157.7 1.0X - codegen = T hashmap = F 2364 / 2471 8.9 112.7 1.4X - codegen = T hashmap = T 1740 / 1841 12.0 83.0 1.9X - */ - } - - ignore("aggregate with decimal key") { - val N = 20 << 20 - - val benchmark = new Benchmark("Aggregate w decimal key", N) - def f(): Unit = sqlContext.range(N).selectExpr("id", "cast(id & 65535 as decimal) as k") - .groupBy("k").count().collect() - - benchmark.addCase(s"codegen = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - f() - } - - benchmark.addCase(s"codegen = T hashmap = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "0") - f() - } - - benchmark.addCase(s"codegen = T hashmap = T") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "3") - f() - } - - benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Aggregate w decimal key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - codegen = F 2756 / 2817 7.6 131.4 1.0X - codegen = T hashmap = F 1580 / 1647 13.3 75.4 1.7X - codegen = T hashmap = T 641 / 662 32.7 30.6 4.3X - */ - } - - ignore("aggregate with multiple key types") { - val N = 20 << 20 - - val benchmark = new Benchmark("Aggregate w multiple keys", N) - def f(): Unit = sqlContext.range(N) - .selectExpr( - "id", - "(id & 1023) as k1", - "cast(id & 1023 as string) as k2", - "cast(id & 1023 as int) as k3", - "cast(id & 1023 as double) as k4", - "cast(id & 1023 as float) as k5", - "id > 1023 as k6") - .groupBy("k1", "k2", "k3", "k4", "k5", "k6") - .sum() - .collect() - - benchmark.addCase(s"codegen = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "false") - f() - } - - benchmark.addCase(s"codegen = T hashmap = F") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "0") - f() - } - - benchmark.addCase(s"codegen = T hashmap = T") { iter => - sqlContext.setConf("spark.sql.codegen.wholeStage", "true") - sqlContext.setConf("spark.sql.codegen.aggregate.map.columns.max", "10") - f() - } - - benchmark.run() - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - Aggregate w decimal key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - codegen = F 5885 / 6091 3.6 280.6 1.0X - codegen = T hashmap = F 3625 / 4009 5.8 172.8 1.6X - codegen = T hashmap = T 3204 / 3271 6.5 152.8 1.8X - */ - } - - ignore("broadcast hash join") { - val N = 20 << 20 - val M = 1 << 16 - val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) - - runBenchmark("Join w long", N) { - sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count() - } - - /* - Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X - Join w long codegen=true 321 / 371 65.3 15.3 9.3X - */ - - runBenchmark("Join w long duplicated", N) { - val dim = broadcast(sqlContext.range(M).selectExpr("cast(id/10 as long) as k")) - sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count() - } - - /** - Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X - Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X - */ - - val dim2 = broadcast(sqlContext.range(M) - .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) - - runBenchmark("Join w 2 ints", N) { - sqlContext.range(N).join(dim2, - (col("id") % M).cast(IntegerType) === col("k1") - && (col("id") % M).cast(IntegerType) === col("k2")).count() - } - - /** - Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X - Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X - */ - - val dim3 = broadcast(sqlContext.range(M) - .selectExpr("id as k1", "id as k2", "cast(id as string) as v")) - - runBenchmark("Join w 2 longs", N) { - sqlContext.range(N).join(dim3, - (col("id") % M) === col("k1") && (col("id") % M) === col("k2")) - .count() - } - - /** - Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Join w 2 longs codegen=false 5905 / 6123 3.6 281.6 1.0X - Join w 2 longs codegen=true 2230 / 2529 9.4 106.3 2.6X - */ - - val dim4 = broadcast(sqlContext.range(M) - .selectExpr("cast(id/10 as long) as k1", "cast(id/10 as long) as k2")) - - runBenchmark("Join w 2 longs duplicated", N) { - sqlContext.range(N).join(dim4, - (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) - .count() - } - - /** - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Join w 2 longs duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Join w 2 longs duplicated codegen=false 6420 / 6587 3.3 306.1 1.0X - Join w 2 longs duplicated codegen=true 2080 / 2139 10.1 99.2 3.1X - */ - - runBenchmark("outer join w long", N) { - sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "left").count() - } - - /** - Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X - outer join w long codegen=true 261 / 276 80.5 12.4 11.7X - */ - - runBenchmark("semi join w long", N) { - sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count() - } - - /** - Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X - semi join w long codegen=true 237 / 244 88.3 11.3 8.1X - */ - } - - ignore("sort merge join") { - val N = 2 << 20 - runBenchmark("merge join", N) { - val df1 = sqlContext.range(N).selectExpr(s"id * 2 as k1") - val df2 = sqlContext.range(N).selectExpr(s"id * 3 as k2") - df1.join(df2, col("k1") === col("k2")).count() - } - - /** - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - merge join codegen=false 1588 / 1880 1.3 757.1 1.0X - merge join codegen=true 1477 / 1531 1.4 704.2 1.1X - */ - - runBenchmark("sort merge join", N) { - val df1 = sqlContext.range(N) - .selectExpr(s"(id * 15485863) % ${N*10} as k1") - val df2 = sqlContext.range(N) - .selectExpr(s"(id * 15485867) % ${N*10} as k2") - df1.join(df2, col("k1") === col("k2")).count() - } - - /** - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - sort merge join codegen=false 3626 / 3667 0.6 1728.9 1.0X - sort merge join codegen=true 3405 / 3438 0.6 1623.8 1.1X - */ - } - - ignore("shuffle hash join") { - val N = 4 << 20 - sqlContext.setConf("spark.sql.shuffle.partitions", "2") - sqlContext.setConf("spark.sql.autoBroadcastJoinThreshold", "10000000") - sqlContext.setConf("spark.sql.join.preferSortMergeJoin", "false") - runBenchmark("shuffle hash join", N) { - val df1 = sqlContext.range(N).selectExpr(s"id as k1") - val df2 = sqlContext.range(N / 5).selectExpr(s"id * 3 as k2") - df1.join(df2, col("k1") === col("k2")).count() - } - - /** - Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X - shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X - */ - } - - ignore("cube") { - val N = 5 << 20 - - runBenchmark("cube", N) { - sqlContext.range(N).selectExpr("id", "id % 1000 as k1", "id & 256 as k2") - .cube("k1", "k2").sum("id").collect() - } - - /** - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - cube: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - cube codegen=false 3188 / 3392 1.6 608.2 1.0X - cube codegen=true 1239 / 1394 4.2 236.3 2.6X - */ - } - - ignore("hash and BytesToBytesMap") { - val N = 20 << 20 - - val benchmark = new Benchmark("BytesToBytesMap", N) - - benchmark.addCase("UnsafeRowhash") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var s = 0 - while (i < N) { - key.setInt(0, i % 1000) - val h = Murmur3_x86_32.hashUnsafeWords( - key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42) - s += h - i += 1 - } - } - - benchmark.addCase("murmur3 hash") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var p = 524283 - var s = 0 - while (i < N) { - var h = Murmur3_x86_32.hashLong(i, 42) - key.setInt(0, h) - s += h - i += 1 - } - } - - benchmark.addCase("fast hash") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var p = 524283 - var s = 0 - while (i < N) { - var h = i % p - if (h < 0) { - h += p - } - key.setInt(0, h) - s += h - i += 1 - } - } - - benchmark.addCase("arrayEqual") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - var s = 0 - while (i < N) { - key.setInt(0, i % 1000) - if (key.equals(value)) { - s += 1 - } - i += 1 - } - } - - benchmark.addCase("Java HashMap (Long)") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - val map = new HashMap[Long, UnsafeRow]() - while (i < 65536) { - value.setInt(0, i) - map.put(i.toLong, value) - i += 1 - } - var s = 0 - i = 0 - while (i < N) { - if (map.get(i % 100000) != null) { - s += 1 - } - i += 1 - } - } - - benchmark.addCase("Java HashMap (two ints) ") { iter => - var i = 0 - val valueBytes = new Array[Byte](16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - val map = new HashMap[Long, UnsafeRow]() - while (i < 65536) { - value.setInt(0, i) - val key = (i.toLong << 32) + Integer.rotateRight(i, 15) - map.put(key, value) - i += 1 - } - var s = 0 - i = 0 - while (i < N) { - val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15) - if (map.get(key) != null) { - s += 1 - } - i += 1 - } - } - - benchmark.addCase("Java HashMap (UnsafeRow)") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - val map = new HashMap[UnsafeRow, UnsafeRow]() - while (i < 65536) { - key.setInt(0, i) - value.setInt(0, i) - map.put(key, value.copy()) - i += 1 - } - var s = 0 - i = 0 - while (i < N) { - key.setInt(0, i % 100000) - if (map.get(key) != null) { - s += 1 - } - i += 1 - } - } - - Seq(false, true).foreach { optimized => - benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => - var i = 0 - val valueBytes = new Array[Byte](16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - val taskMemoryManager = new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - val map = new LongToUnsafeRowMap(taskMemoryManager, 64) - while (i < 65536) { - value.setInt(0, i) - val key = i % 100000 - map.append(key, value) - i += 1 - } - if (optimized) { - map.optimize() - } - var s = 0 - i = 0 - while (i < N) { - val key = i % 100000 - if (map.getValue(key, value) != null) { - s += 1 - } - i += 1 - } - } - } - - Seq("off", "on").foreach { heap => - benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => - val taskMemoryManager = new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") - .set("spark.memory.offHeap.size", "102400000"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) - val keyBytes = new Array[Byte](16) - val valueBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var i = 0 - val numKeys = 65536 - while (i < numKeys) { - key.setInt(0, i % 65536) - val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - Murmur3_x86_32.hashLong(i % 65536, 42)) - if (!loc.isDefined) { - loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) - } - i += 1 - } - i = 0 - var s = 0 - while (i < N) { - key.setInt(0, i % 100000) - val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - Murmur3_x86_32.hashLong(i % 100000, 42)) - if (loc.isDefined) { - s += 1 - } - i += 1 - } - } - } - - benchmark.addCase("Aggregate HashMap") { iter => - var i = 0 - val numKeys = 65536 - val schema = new StructType() - .add("key", LongType) - .add("value", LongType) - val map = new AggregateHashMap(schema) - while (i < numKeys) { - val row = map.findOrInsert(i.toLong) - row.setLong(1, row.getLong(1) + 1) - i += 1 - } - var s = 0 - i = 0 - while (i < N) { - if (map.find(i % 100000) != -1) { - s += 1 - } - i += 1 - } - } - - /** - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - UnsafeRow hash 267 / 284 78.4 12.8 1.0X - murmur3 hash 102 / 129 205.5 4.9 2.6X - fast hash 79 / 96 263.8 3.8 3.4X - arrayEqual 164 / 172 128.2 7.8 1.6X - Java HashMap (Long) 321 / 399 65.4 15.3 0.8X - Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X - Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X - LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X - LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X - BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X - BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X - Aggregate HashMap 121 / 131 173.3 5.8 2.2X - */ - benchmark.run() - } - - ignore("collect") { - val N = 1 << 20 - - val benchmark = new Benchmark("collect", N) - benchmark.addCase("collect 1 million") { iter => - sqlContext.range(N).collect() - } - benchmark.addCase("collect 2 millions") { iter => - sqlContext.range(N * 2).collect() - } - benchmark.addCase("collect 4 millions") { iter => - sqlContext.range(N * 4).collect() - } - benchmark.run() - - /** - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - collect 1 million 439 / 654 2.4 418.7 1.0X - collect 2 millions 961 / 1907 1.1 916.4 0.5X - collect 4 millions 3193 / 3895 0.3 3044.7 0.1X - */ - } - - ignore("collect limit") { - val N = 1 << 20 - - val benchmark = new Benchmark("collect limit", N) - benchmark.addCase("collect limit 1 million") { iter => - sqlContext.range(N * 4).limit(N).collect() - } - benchmark.addCase("collect limit 2 millions") { iter => - sqlContext.range(N * 4).limit(N * 2).collect() - } - benchmark.run() - - /** - model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) - collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - collect limit 1 million 833 / 1284 1.3 794.4 1.0X - collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X - */ - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 01d485ce2d713..2803b62462417 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -19,30 +19,29 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.TestSQLContext class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { - private var originalActiveSQLContext: Option[SQLContext] = _ - private var originalInstantiatedSQLContext: Option[SQLContext] = _ + private var originalActiveSparkSession: Option[SparkSession] = _ + private var originalInstantiatedSparkSession: Option[SparkSession] = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActive() - originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() + originalActiveSparkSession = SparkSession.getActiveSession + originalInstantiatedSparkSession = SparkSession.getDefaultSession - SQLContext.clearActive() - SQLContext.clearInstantiatedContext() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } override protected def afterAll(): Unit = { // Set these states back. - originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx)) - originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx)) + originalActiveSparkSession.foreach(ctx => SparkSession.setActiveSession(ctx)) + originalInstantiatedSparkSession.foreach(ctx => SparkSession.setDefaultSession(ctx)) } private def checkEstimation( @@ -250,8 +249,8 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - def withSQLContext( - f: SQLContext => Unit, + def withSparkSession( + f: SparkSession => Unit, targetNumPostShufflePartitions: Int, minNumPostShufflePartitions: Option[Int]): Unit = { val sparkConf = @@ -272,9 +271,11 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => sparkConf.set(SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS.key, "-1") } - val sparkContext = new SparkContext(sparkConf) - val sqlContext = new TestSQLContext(sparkContext) - try f(sqlContext) finally sparkContext.stop() + + val spark = SparkSession.builder + .config(sparkConf) + .getOrCreate() + try f(spark) finally spark.stop() } Seq(Some(3), None).foreach { minNumPostShufflePartitions => @@ -284,9 +285,9 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } test(s"determining the number of reducers: aggregate operator$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 20 as key", "id as value") val agg = df.groupBy("key").count @@ -294,7 +295,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. checkAnswer( agg, - sqlContext.range(0, 20).selectExpr("id", "50 as cnt").collect()) + spark.range(0, 20).selectExpr("id", "50 as cnt").collect()) // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. @@ -321,17 +322,17 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 2000, minNumPostShufflePartitions) + withSparkSession(test, 2000, minNumPostShufflePartitions) } test(s"determining the number of reducers: join operator$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") @@ -339,10 +340,10 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 1000) .selectExpr("id % 500 as key", "id as value") - .union(sqlContext.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) checkAnswer( join, expectedAnswer.collect()) @@ -372,20 +373,20 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 16384, minNumPostShufflePartitions) + withSparkSession(test, 16384, minNumPostShufflePartitions) } test(s"determining the number of reducers: complex query 1$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") .count .toDF("key1", "cnt1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") .groupBy("key2") @@ -396,7 +397,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 500) .selectExpr("id", "2 as cnt") checkAnswer( @@ -424,20 +425,20 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 6644, minNumPostShufflePartitions) + withSparkSession(test, 6644, minNumPostShufflePartitions) } test(s"determining the number of reducers: complex query 2$testNameNote") { - val test = { sqlContext: SQLContext => + val test = { spark: SparkSession => val df1 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key1", "id as value1") .groupBy("key1") .count .toDF("key1", "cnt1") val df2 = - sqlContext + spark .range(0, 1000, 1, numInputPartitions) .selectExpr("id % 500 as key2", "id as value2") @@ -448,7 +449,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Check the answer first. val expectedAnswer = - sqlContext + spark .range(0, 1000) .selectExpr("id % 500 as key", "2 as cnt", "id as value") checkAnswer( @@ -476,7 +477,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 6144, minNumPostShufflePartitions) + withSparkSession(test, 6144, minNumPostShufflePartitions) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index ba16810ceeb50..36cde3233dce8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -50,7 +50,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { } test("BroadcastExchange same result") { - val df = sqlContext.range(10) + val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) @@ -75,7 +75,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { } test("ShuffleExchange same result") { - val df = sqlContext.range(10) + val df = spark.range(10) val plan = df.queryExecution.executedPlan val output = plan.output assert(plan sameResult plan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index 6f10e4b80577a..80340b5552c6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -27,7 +27,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("basic") { val schema = new StructType().add("i", IntegerType).add("s", StringType) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(encoder.toRow), Seq('i.int.at(0)), schema.toAttributes) @@ -45,7 +45,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("group by 2 columns") { val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Seq( Row(1, 2L, "a"), @@ -72,7 +72,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("do nothing to the value iterator") { val schema = new StructType().add("i", IntegerType).add("s", StringType) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(encoder.toRow), Seq('i.int.at(0)), schema.toAttributes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3b2911d056652..375da224aaa7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ @@ -38,7 +38,7 @@ class PlannerSuite extends SharedSQLContext { setupTestData() private def testPartialAggregationPlan(query: LogicalPlan): Unit = { - val planner = sqlContext.sessionState.planner + val planner = spark.sessionState.planner import planner._ val plannedOption = Aggregation(query).headOption val planned = @@ -71,14 +71,14 @@ class PlannerSuite extends SharedSQLContext { test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType]): Unit = { - withTempTable("testLimit") { + withTempView("testLimit") { val fields = fieldTypes.zipWithIndex.map { case (dataType, index) => StructField(s"c${index}", dataType, true) } :+ StructField("key", IntegerType, true) val schema = StructType(fields) val row = Row.fromSeq(Seq.fill(fields.size)(null)) val rowRDD = sparkContext.parallelize(row :: Nil) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("testLimit") + spark.createDataFrame(rowRDD, schema).createOrReplaceTempView("testLimit") val planned = sql( """ @@ -131,12 +131,12 @@ class PlannerSuite extends SharedSQLContext { test("InMemoryRelation statistics propagation") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "81920") { - withTempTable("tiny") { - testData.limit(3).registerTempTable("tiny") + withTempView("tiny") { + testData.limit(3).createOrReplaceTempView("tiny") sql("CACHE TABLE tiny") val a = testData.as("a") - val b = sqlContext.table("tiny").as("b") + val b = spark.table("tiny").as("b") val planned = a.join(b, $"a.key" === $"b.key").queryExecution.sparkPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoinExec => join } @@ -145,7 +145,7 @@ class PlannerSuite extends SharedSQLContext { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(sortMergeJoins.isEmpty, "Should not use shuffled hash join") - sqlContext.clearCache() + spark.catalog.clearCache() } } } @@ -154,10 +154,10 @@ class PlannerSuite extends SharedSQLContext { withTempPath { file => val path = file.getCanonicalPath testData.write.parquet(path) - val df = sqlContext.read.parquet(path) - sqlContext.registerDataFrameAsTable(df, "testPushed") + val df = spark.read.parquet(path) + df.createOrReplaceTempView("testPushed") - withTempTable("testPushed") { + withTempView("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan assert(exp.toString.contains("PushedFilters: [IsNotNull(key), EqualTo(key,15)]")) } @@ -198,10 +198,10 @@ class PlannerSuite extends SharedSQLContext { } test("PartitioningCollection") { - withTempTable("normal", "small", "tiny") { - testData.registerTempTable("normal") - testData.limit(10).registerTempTable("small") - testData.limit(3).registerTempTable("tiny") + withTempView("normal", "small", "tiny") { + testData.createOrReplaceTempView("normal") + testData.limit(10).createOrReplaceTempView("small") + testData.limit(3).createOrReplaceTempView("tiny") // Disable broadcast join withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { @@ -295,7 +295,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -315,7 +315,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) } @@ -333,7 +333,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") @@ -353,7 +353,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(Seq.empty, Seq.empty) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") @@ -376,7 +376,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildDistribution = Seq(distribution, distribution), requiredChildOrdering = Seq(outputOrdering, outputOrdering) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") @@ -392,7 +392,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -408,7 +408,45 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements skips sort when required ordering is semantically equal to " + + "existing ordering") { + val exprId: ExprId = NamedExpression.newExprId + val attribute1 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId, + qualifier = Some("col1_qualifier") + ) + + val attribute2 = + AttributeReference( + name = "col1", + dataType = LongType, + nullable = false + ) (exprId = exprId) + + val orderingA1 = SortOrder(attribute1, Ascending) + val orderingA2 = SortOrder(attribute2, Ascending) + + assert(orderingA1 != orderingA2, s"$orderingA1 should NOT equal to $orderingA2") + assert(orderingA1.semanticEquals(orderingA2), + s"$orderingA1 should be semantically equal to $orderingA2") + + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA1)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingA2)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { fail(s"No sorts should have been added:\n$outputPlan") @@ -425,7 +463,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq(orderingA, orderingB)), requiredChildDistribution = Seq(UnspecifiedDistribution) ) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case s: SortExec => true }.isEmpty) { fail(s"Sort should have been added:\n$outputPlan") @@ -444,7 +482,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") @@ -464,7 +502,7 @@ class PlannerSuite extends SharedSQLContext { requiredChildOrdering = Seq(Seq.empty)), None) - val outputPlan = EnsureRequirements(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") @@ -493,7 +531,7 @@ class PlannerSuite extends SharedSQLContext { shuffle, shuffle) - val outputPlan = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan) + val outputPlan = ReuseExchange(spark.sessionState.conf).apply(inputPlan) if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } @@ -510,7 +548,7 @@ class PlannerSuite extends SharedSQLContext { ShuffleExchange(finalPartitioning, inputPlan), ShuffleExchange(finalPartitioning, inputPlan)) - val outputPlan2 = ReuseExchange(sqlContext.sessionState.conf).apply(inputPlan2) + val outputPlan2 = ReuseExchange(spark.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchangeExec => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index c9f517ca34296..ad41111bec9d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import java.util.Properties import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { @@ -50,16 +50,19 @@ class SQLExecutionSuite extends SparkFunSuite { } test("concurrent query execution with fork-join pool (SPARK-13747)") { - val sc = new SparkContext("local[*]", "test") - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder + .master("local[*]") + .appName("test") + .getOrCreate() + + import spark.implicits._ try { // Should not throw IllegalArgumentException (1 to 100).par.foreach { _ => - sc.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() + spark.sparkContext.parallelize(1 to 5).map { i => (i, i) }.toDF("a", "b").count() } } finally { - sc.stop() + spark.sparkContext.stop() } } @@ -67,8 +70,8 @@ class SQLExecutionSuite extends SparkFunSuite { * Trigger SPARK-10548 by mocking a parent and its child thread executing queries concurrently. */ private def testConcurrentQueryExecution(sc: SparkContext): Unit = { - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder.getOrCreate() + import spark.implicits._ // Initialize local properties. This is necessary for the test to pass. sc.getLocalProperties diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala similarity index 80% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index d0e7552c12e49..d3cfa953a3123 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -15,12 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils - +import org.apache.spark.sql.test.SharedSQLContext case class WindowData(month: Int, area: String, product: Int) @@ -28,8 +26,9 @@ case class WindowData(month: Int, area: String, product: Int) /** * Test suite for SQL window functions. */ -class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - import hiveContext.implicits._ +class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ test("window function: udaf with aggregate expression") { val data = Seq( @@ -40,7 +39,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi WindowData(5, "c", 9), WindowData(6, "c", 10) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") checkAnswer( sql( @@ -112,7 +111,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi WindowData(5, "c", 9), WindowData(6, "c", 10) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") checkAnswer( sql( @@ -139,7 +138,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi WindowData(5, "c", 9), WindowData(6, "c", 10) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") checkAnswer( sql( @@ -182,7 +181,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi WindowData(5, "c", 9), WindowData(6, "c", 10) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") val e = intercept[AnalysisException] { sql( @@ -203,7 +202,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi WindowData(5, "c", 9), WindowData(6, "c", 10) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") checkAnswer( sql( @@ -232,7 +231,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi WindowData(5, "c", 9), WindowData(6, "c", 11) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") checkAnswer( sql("select month, product, sum(product + 1) over() from windowData order by area"), @@ -301,7 +300,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi WindowData(5, "c", 9), WindowData(6, "c", 11) ) - sparkContext.parallelize(data).toDF().registerTempTable("windowData") + sparkContext.parallelize(data).toDF().createOrReplaceTempView("windowData") checkAnswer( sql( @@ -322,7 +321,7 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi test("window function: multiple window expressions in a single expression") { val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") - nums.registerTempTable("nums") + nums.createOrReplaceTempView("nums") val expected = Row(1, 1, 1, 55, 1, 57) :: @@ -353,18 +352,63 @@ class SQLWindowFunctionSuite extends QueryTest with SQLTestUtils with TestHiveSi checkAnswer(actual, expected) - sqlContext.dropTempTable("nums") + spark.catalog.dropTempView("nums") } test("SPARK-7595: Window will cause resolve failed with self join") { - sql("SELECT * FROM src") // Force loading of src table. - checkAnswer(sql( """ |with - | v1 as (select key, count(value) over (partition by key) cnt_val from src), + | v0 as (select 0 as key, 1 as value), + | v1 as (select key, count(value) over (partition by key) cnt_val from v0), | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key) - | select * from v2 order by key limit 1 - """.stripMargin), Row(0, 3)) + | select key, cnt_val from v2 order by key limit 1 + """.stripMargin), Row(0, 1)) + } + + test("SPARK-16633: lead/lag should return the default value if the offset row does not exist") { + checkAnswer(sql( + """ + |SELECT + | lag(123, 100, 321) OVER (ORDER BY id) as lag, + | lead(123, 100, 321) OVER (ORDER BY id) as lead + |FROM (SELECT 1 as id) tmp + """.stripMargin), + Row(321, 321)) + + checkAnswer(sql( + """ + |SELECT + | lag(123, 100, a) OVER (ORDER BY id) as lag, + | lead(123, 100, a) OVER (ORDER BY id) as lead + |FROM (SELECT 1 as id, 2 as a) tmp + """.stripMargin), + Row(2, 2)) + } + + test("lead/lag should respect null values") { + checkAnswer(sql( + """ + |SELECT + | b, + | lag(a, 1, 321) OVER (ORDER BY b) as lag, + | lead(a, 1, 321) OVER (ORDER BY b) as lead + |FROM (SELECT cast(null as int) as a, 1 as b + | UNION ALL + | select cast(null as int) as id, 2 as b) tmp + """.stripMargin), + Row(1, 321, null) :: Row(2, null, 321) :: Nil) + + checkAnswer(sql( + """ + |SELECT + | b, + | lag(a, 1, c) OVER (ORDER BY b) as lag, + | lead(a, 1, c) OVER (ORDER BY b) as lead + |FROM (SELECT cast(null as int) as a, 1 as b, 3 as c + | UNION ALL + | select cast(null as int) as id, 2 as b, 4 as c) tmp + """.stripMargin), + Row(1, 3, null) :: Row(2, null, 4) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index ebeb39b690e69..c3acf29c2d36f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -95,7 +95,7 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = sqlContext.createDataFrame( + val inputDf = spark.createDataFrame( sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 073e0b3f00ffc..b29e822add8bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -21,7 +21,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.test.SQLTestUtils @@ -30,7 +30,7 @@ import org.apache.spark.sql.test.SQLTestUtils * class's test helper methods can be used, see [[SortSuite]]. */ private[sql] abstract class SparkPlanTest extends SparkFunSuite { - protected def sqlContext: SQLContext + protected def spark: SparkSession /** * Runs the plan and makes sure the answer matches the expected result. @@ -90,9 +90,10 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { - case Some(errorMessage) => fail(errorMessage) - case None => + SparkPlanTest + .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.sqlContext) match { + case Some(errorMessage) => fail(errorMessage) + case None => } } @@ -114,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { + input, planFunction, expectedPlanFunction, sortAnswers, spark.sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -141,13 +142,13 @@ object SparkPlanTest { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + spark: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan, sqlContext) + executePlan(expectedOutputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -162,7 +163,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -202,12 +203,12 @@ object SparkPlanTest { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean, - sqlContext: SQLContext): Option[String] = { + spark: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan, sqlContext) + executePlan(outputPlan, spark) } catch { case NonFatal(e) => val errorMessage = @@ -230,8 +231,13 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - val execution = new QueryExecution(sqlContext.sparkSession, null) { + /** + * Runs the plan + * @param outputPlan SparkPlan to be executed + * @param spark SqlContext used for execution of the plan + */ + def executePlan(outputPlan: SparkPlan, spark: SQLContext): Seq[Row] = { + val execution = new QueryExecution(spark.sparkSession, null) { override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala new file mode 100644 index 0000000000000..aecfd3062147c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union} +import org.apache.spark.sql.test.SharedSQLContext + +class SparkPlannerSuite extends SharedSQLContext { + import testImplicits._ + + test("Ensure to go down only the first branch, not any other possible branches") { + + case object NeverPlanned extends LeafNode { + override def output: Seq[Attribute] = Nil + } + + var planned = 0 + object TestStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case ReturnAnswer(child) => + planned += 1 + planLater(child) :: planLater(NeverPlanned) :: Nil + case Union(children) => + planned += 1 + UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil + case LocalRelation(output, data) => + planned += 1 + LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil + case NeverPlanned => + fail("QueryPlanner should not go down to this branch.") + case _ => Nil + } + } + + try { + spark.experimental.extraStrategies = TestStrategy :: Nil + + val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS()) + + assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f")) + assert(planned === 4) + } finally { + spark.experimental.extraStrategies = Nil + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala new file mode 100644 index 0000000000000..8161c08b2cb48 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.command.{DescribeFunctionCommand, ShowFunctionsCommand} +import org.apache.spark.sql.internal.SQLConf + +/** + * Parser test cases for rules defined in [[SparkSqlParser]]. + * + * See [[org.apache.spark.sql.catalyst.parser.PlanParserSuite]] for rules + * defined in the Catalyst module. + */ +class SparkSqlParserSuite extends PlanTest { + + private lazy val parser = new SparkSqlParser(new SQLConf) + + private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + comparePlans(parser.parsePlan(sqlCommand), plan) + } + + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parser.parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("show functions") { + assertEqual("show functions", ShowFunctionsCommand(None, None, true, true)) + assertEqual("show all functions", ShowFunctionsCommand(None, None, true, true)) + assertEqual("show user functions", ShowFunctionsCommand(None, None, true, false)) + assertEqual("show system functions", ShowFunctionsCommand(None, None, false, true)) + intercept("show special functions", "SHOW special FUNCTIONS") + assertEqual("show functions foo", + ShowFunctionsCommand(None, Some("foo"), true, true)) + assertEqual("show functions foo.bar", + ShowFunctionsCommand(Some("foo"), Some("bar"), true, true)) + assertEqual("show functions 'foo\\\\.*'", + ShowFunctionsCommand(None, Some("foo\\.*"), true, true)) + intercept("show functions foo.bar.baz", "Unsupported function name") + } + + test("describe function") { + assertEqual("describe function bar", + DescribeFunctionCommand(FunctionIdentifier("bar", database = None), isExtended = false)) + assertEqual("describe function extended bar", + DescribeFunctionCommand(FunctionIdentifier("bar", database = None), isExtended = true)) + assertEqual("describe function foo.bar", + DescribeFunctionCommand( + FunctionIdentifier("bar", database = Option("foo")), isExtended = false)) + assertEqual("describe function extended f.bar", + DescribeFunctionCommand(FunctionIdentifier("bar", database = Option("f")), isExtended = true)) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index fba04d0cb2653..7e317a4d80265 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -42,7 +42,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { .add("a", IntegerType, nullable = false) .add("b", IntegerType, nullable = false) val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt())) - sqlContext.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) + spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema) } /** @@ -59,7 +59,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { checkThatPlansAgree( generateRandomInputData(), input => - noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, None, input)), + noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, @@ -74,7 +74,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext { generateRandomInputData(), input => noOpFilter( - TakeOrderedAndProjectExec(limit, sortOrder, Some(Seq(input.output.last)), input)), + TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)), input => GlobalLimitExec(limit, LocalLimitExec(limit, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 03d4be8ee528e..3d869c77e9608 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -123,7 +124,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { metricsSystem = null)) val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, pageSize) + keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, + pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 233104ae84fdf..f26e5e7b6990d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec -import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -28,35 +28,35 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructType} class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { test("range/filter should be combined") { - val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") + val df = spark.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined) assert(df.collect() === Array(Row(2))) } test("Aggregate should be included in WholeStageCodegen") { - val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id"))) + val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id"))) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[TungstenAggregate]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) assert(df.collect() === Array(Row(9, 4.5))) } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy("id").count().orderBy("id") val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[TungstenAggregate]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) } test("BroadcastHashJoin should be included in WholeStageCodegen") { - val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) + val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2"))) val schema = new StructType().add("k", IntegerType).add("v", StringType) - val smallDF = sqlContext.createDataFrame(rdd, schema) - val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) + val smallDF = spark.createDataFrame(rdd, schema) + val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id")) assert(df.queryExecution.executedPlan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined) @@ -64,7 +64,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("Sort should be included in WholeStageCodegen") { - val df = sqlContext.range(3, 0, -1).toDF().sort(col("id")) + val df = spark.range(3, 0, -1).toDF().sort(col("id")) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -75,7 +75,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { test("MapElements should be included in WholeStageCodegen") { import testImplicits._ - val ds = sqlContext.range(10).map(_.toString) + val ds = spark.range(10).map(_.toString) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -84,7 +84,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("typed filter should be included in WholeStageCodegen") { - val ds = sqlContext.range(10).filter(_ % 2 == 0) + val ds = spark.range(10).filter(_ % 2 == 0) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && @@ -93,11 +93,11 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } test("back-to-back typed filter should be included in WholeStageCodegen") { - val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) + val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) assert(ds.collect() === Array(0, 6)) } @@ -110,7 +110,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[TungstenAggregate]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala new file mode 100644 index 0000000000000..bf3a39c84b3b2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -0,0 +1,579 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.util.HashMap + +import org.apache.spark.SparkConf +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap +import org.apache.spark.sql.execution.vectorized.AggregateHashMap +import org.apache.spark.sql.types.{LongType, StructType} +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure performance for aggregate primitives. + * To run this: + * build/sbt "sql/test-only *benchmark.AggregateBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class AggregateBenchmark extends BenchmarkBase { + + ignore("aggregate without grouping") { + val N = 500L << 22 + val benchmark = new Benchmark("agg without grouping", N) + runBenchmark("agg w/o group", N) { + sparkSession.range(N).selectExpr("sum(id)").collect() + } + /* + agg w/o group: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + agg w/o group wholestage off 30136 / 31885 69.6 14.4 1.0X + agg w/o group wholestage on 1851 / 1860 1132.9 0.9 16.3X + */ + } + + ignore("stat functions") { + val N = 100L << 20 + + runBenchmark("stddev", N) { + sparkSession.range(N).groupBy().agg("id" -> "stddev").collect() + } + + runBenchmark("kurtosis", N) { + sparkSession.range(N).groupBy().agg("id" -> "kurtosis").collect() + } + + /* + Using ImperativeAggregate (as implemented in Spark 1.6): + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + stddev w/o codegen 2019.04 10.39 1.00 X + stddev w codegen 2097.29 10.00 0.96 X + kurtosis w/o codegen 2108.99 9.94 0.96 X + kurtosis w codegen 2090.69 10.03 0.97 X + + Using DeclarativeAggregate: + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + stddev: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + stddev codegen=false 5630 / 5776 18.0 55.6 1.0X + stddev codegen=true 1259 / 1314 83.0 12.0 4.5X + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + kurtosis: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + kurtosis codegen=false 14847 / 15084 7.0 142.9 1.0X + kurtosis codegen=true 1652 / 2124 63.0 15.9 9.0X + */ + } + + ignore("aggregate with linear keys") { + val N = 20 << 22 + + val benchmark = new Benchmark("Aggregate w keys", N) + def f(): Unit = { + sparkSession.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } + + benchmark.addCase(s"codegen = F", numIters = 2) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0") + f() + } + + benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "3") + f() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + codegen = F 6619 / 6780 12.7 78.9 1.0X + codegen = T hashmap = F 3935 / 4059 21.3 46.9 1.7X + codegen = T hashmap = T 897 / 971 93.5 10.7 7.4X + */ + } + + ignore("aggregate with randomized keys") { + val N = 20 << 22 + + val benchmark = new Benchmark("Aggregate w keys", N) + sparkSession.range(N).selectExpr("id", "floor(rand() * 10000) as k") + .createOrReplaceTempView("test") + + def f(): Unit = sparkSession.sql("select k, k, sum(id) from test group by k, k").collect() + + benchmark.addCase(s"codegen = F", numIters = 2) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false) + f() + } + + benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", 0) + f() + } + + benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", 3) + f() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + codegen = F 7445 / 7517 11.3 88.7 1.0X + codegen = T hashmap = F 4672 / 4703 18.0 55.7 1.6X + codegen = T hashmap = T 1764 / 1958 47.6 21.0 4.2X + */ + } + + ignore("aggregate with string key") { + val N = 20 << 20 + + val benchmark = new Benchmark("Aggregate w string key", N) + def f(): Unit = sparkSession.range(N).selectExpr("id", "cast(id & 1023 as string) as k") + .groupBy("k").count().collect() + + benchmark.addCase(s"codegen = F", numIters = 2) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0") + f() + } + + benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "3") + f() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Aggregate w string key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + codegen = F 3307 / 3376 6.3 157.7 1.0X + codegen = T hashmap = F 2364 / 2471 8.9 112.7 1.4X + codegen = T hashmap = T 1740 / 1841 12.0 83.0 1.9X + */ + } + + ignore("aggregate with decimal key") { + val N = 20 << 20 + + val benchmark = new Benchmark("Aggregate w decimal key", N) + def f(): Unit = sparkSession.range(N).selectExpr("id", "cast(id & 65535 as decimal) as k") + .groupBy("k").count().collect() + + benchmark.addCase(s"codegen = F") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = F") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0") + f() + } + + benchmark.addCase(s"codegen = T hashmap = T") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "3") + f() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Aggregate w decimal key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + codegen = F 2756 / 2817 7.6 131.4 1.0X + codegen = T hashmap = F 1580 / 1647 13.3 75.4 1.7X + codegen = T hashmap = T 641 / 662 32.7 30.6 4.3X + */ + } + + ignore("aggregate with multiple key types") { + val N = 20 << 20 + + val benchmark = new Benchmark("Aggregate w multiple keys", N) + def f(): Unit = sparkSession.range(N) + .selectExpr( + "id", + "(id & 1023) as k1", + "cast(id & 1023 as string) as k2", + "cast(id & 1023 as int) as k3", + "cast(id & 1023 as double) as k4", + "cast(id & 1023 as float) as k5", + "id > 1023 as k6") + .groupBy("k1", "k2", "k3", "k4", "k5", "k6") + .sum() + .collect() + + benchmark.addCase(s"codegen = F") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = F") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0") + f() + } + + benchmark.addCase(s"codegen = T hashmap = T") { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "10") + f() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Aggregate w decimal key: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + codegen = F 5885 / 6091 3.6 280.6 1.0X + codegen = T hashmap = F 3625 / 4009 5.8 172.8 1.6X + codegen = T hashmap = T 3204 / 3271 6.5 152.8 1.8X + */ + } + + + ignore("cube") { + val N = 5 << 20 + + runBenchmark("cube", N) { + sparkSession.range(N).selectExpr("id", "id % 1000 as k1", "id & 256 as k2") + .cube("k1", "k2").sum("id").collect() + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + cube: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + cube codegen=false 3188 / 3392 1.6 608.2 1.0X + cube codegen=true 1239 / 1394 4.2 236.3 2.6X + */ + } + + ignore("hash and BytesToBytesMap") { + val N = 20 << 20 + + val benchmark = new Benchmark("BytesToBytesMap", N) + + benchmark.addCase("UnsafeRowhash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashUnsafeWords( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 42) + s += h + i += 1 + } + } + + benchmark.addCase("murmur3 hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = Murmur3_x86_32.hashLong(i, 42) + key.setInt(0, h) + s += h + i += 1 + } + } + + benchmark.addCase("fast hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = i % p + if (h < 0) { + h += p + } + key.setInt(0, h) + s += h + i += 1 + } + } + + benchmark.addCase("arrayEqual") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + var s = 0 + while (i < N) { + key.setInt(0, i % 1000) + if (key.equals(value)) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (Long)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + map.put(i.toLong, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.get(i % 100000) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (two ints) ") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[Long, UnsafeRow]() + while (i < 65536) { + value.setInt(0, i) + val key = (i.toLong << 32) + Integer.rotateRight(i, 15) + map.put(key, value) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + val key = ((i & 100000).toLong << 32) + Integer.rotateRight(i & 100000, 15) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + + benchmark.addCase("Java HashMap (UnsafeRow)") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val map = new HashMap[UnsafeRow, UnsafeRow]() + while (i < 65536) { + key.setInt(0, i) + value.setInt(0, i) + map.put(key, value.copy()) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + key.setInt(0, i % 100000) + if (map.get(key) != null) { + s += 1 + } + i += 1 + } + } + + Seq(false, true).foreach { optimized => + benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new LongToUnsafeRowMap(taskMemoryManager, 64) + while (i < 65536) { + value.setInt(0, i) + val key = i % 100000 + map.append(key, value) + i += 1 + } + if (optimized) { + map.optimize() + } + var s = 0 + i = 0 + while (i < N) { + val key = i % 100000 + if (map.getValue(key, value) != null) { + s += 1 + } + i += 1 + } + } + } + + Seq("off", "on").foreach { heap => + benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") + .set("spark.memory.offHeap.size", "102400000"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var i = 0 + val numKeys = 65536 + while (i < numKeys) { + key.setInt(0, i % 65536) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 65536, 42)) + if (!loc.isDefined) { + loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + i += 1 + } + i = 0 + var s = 0 + while (i < N) { + key.setInt(0, i % 100000) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 100000, 42)) + if (loc.isDefined) { + s += 1 + } + i += 1 + } + } + } + + benchmark.addCase("Aggregate HashMap") { iter => + var i = 0 + val numKeys = 65536 + val schema = new StructType() + .add("key", LongType) + .add("value", LongType) + val map = new AggregateHashMap(schema) + while (i < numKeys) { + val row = map.findOrInsert(i.toLong) + row.setLong(1, row.getLong(1) + 1) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.find(i % 100000) != -1) { + s += 1 + } + i += 1 + } + } + + /* + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + UnsafeRow hash 267 / 284 78.4 12.8 1.0X + murmur3 hash 102 / 129 205.5 4.9 2.6X + fast hash 79 / 96 263.8 3.8 3.4X + arrayEqual 164 / 172 128.2 7.8 1.6X + Java HashMap (Long) 321 / 399 65.4 15.3 0.8X + Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X + Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X + LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X + LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X + BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X + BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X + Aggregate HashMap 121 / 131 173.3 5.8 2.2X + */ + benchmark.run() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala new file mode 100644 index 0000000000000..c99a5aec1cd6e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkBase.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.util.Benchmark + +/** + * Common base trait for micro benchmarks that are supposed to run standalone (i.e. not together + * with other test suites). + */ +private[benchmark] trait BenchmarkBase extends SparkFunSuite { + + lazy val sparkSession = SparkSession.builder + .master("local[1]") + .appName("microbenchmark") + .config("spark.sql.shuffle.partitions", 1) + .config("spark.sql.autoBroadcastJoinThreshold", 1) + .getOrCreate() + + /** Runs function `f` with whole stage codegen on and off. */ + def runBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, cardinality) + + benchmark.addCase(s"$name wholestage off", numIters = 2) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", value = false) + f + } + + benchmark.addCase(s"$name wholestage on", numIters = 5) { iter => + sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) + f + } + + benchmark.run() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala new file mode 100644 index 0000000000000..9dcaca0ca93ee --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/BenchmarkWideTable.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.util.Benchmark + + +/** + * Benchmark to measure performance for wide table. + * To run this: + * build/sbt "sql/test-only *benchmark.BenchmarkWideTable" + * + * Benchmarks in this file are skipped in normal builds. + */ +class BenchmarkWideTable extends BenchmarkBase { + + ignore("project on wide table") { + val N = 1 << 20 + val df = sparkSession.range(N) + val columns = (0 until 400).map{ i => s"id as id$i"} + val benchmark = new Benchmark("projection on wide table", N) + benchmark.addCase("wide table", numIters = 5) { iter => + df.selectExpr(columns : _*).queryExecution.toRdd.count() + } + benchmark.run() + + /** + * Here are some numbers with different split threshold: + * + * Split threshold methods Rate(M/s) Per Row(ns) + * 10 400 0.4 2279 + * 100 200 0.6 1554 + * 1k 37 0.9 1116 + * 8k 5 0.5 2025 + * 64k 1 0.0 21649 + */ + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala new file mode 100644 index 0000000000000..46db41a8abad9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.IntegerType + +/** + * Benchmark to measure performance for aggregate primitives. + * To run this: + * build/sbt "sql/test-only *benchmark.JoinBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class JoinBenchmark extends BenchmarkBase { + + ignore("broadcast hash join, long key") { + val N = 20 << 20 + val M = 1 << 16 + + val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v")) + runBenchmark("Join w long", N) { + sparkSession.range(N).join(dim, (col("id") % M) === col("k")).count() + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X + Join w long codegen=true 321 / 371 65.3 15.3 9.3X + */ + } + + ignore("broadcast hash join, long key with duplicates") { + val N = 20 << 20 + val M = 1 << 16 + + val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v")) + runBenchmark("Join w long duplicated", N) { + val dim = broadcast(sparkSession.range(M).selectExpr("cast(id/10 as long) as k")) + sparkSession.range(N).join(dim, (col("id") % M) === col("k")).count() + } + + /* + *Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *------------------------------------------------------------------------------------------- + *Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X + *Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X + */ + } + + ignore("broadcast hash join, two int key") { + val N = 20 << 20 + val M = 1 << 16 + val dim2 = broadcast(sparkSession.range(M) + .selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 ints", N) { + sparkSession.range(N).join(dim2, + (col("id") % M).cast(IntegerType) === col("k1") + && (col("id") % M).cast(IntegerType) === col("k2")).count() + } + + /* + *Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *------------------------------------------------------------------------------------------- + *Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X + *Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X + */ + } + + ignore("broadcast hash join, two long key") { + val N = 20 << 20 + val M = 1 << 16 + val dim3 = broadcast(sparkSession.range(M) + .selectExpr("id as k1", "id as k2", "cast(id as string) as v")) + + runBenchmark("Join w 2 longs", N) { + sparkSession.range(N).join(dim3, + (col("id") % M) === col("k1") && (col("id") % M) === col("k2")) + .count() + } + + /* + *Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *------------------------------------------------------------------------------------------- + *Join w 2 longs codegen=false 5905 / 6123 3.6 281.6 1.0X + *Join w 2 longs codegen=true 2230 / 2529 9.4 106.3 2.6X + */ + } + + ignore("broadcast hash join, two long key with duplicates") { + val N = 20 << 20 + val M = 1 << 16 + val dim4 = broadcast(sparkSession.range(M) + .selectExpr("cast(id/10 as long) as k1", "cast(id/10 as long) as k2")) + + runBenchmark("Join w 2 longs duplicated", N) { + sparkSession.range(N).join(dim4, + (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) + .count() + } + + /* + *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *Join w 2 longs duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *------------------------------------------------------------------------------------------- + *Join w 2 longs duplicated codegen=false 6420 / 6587 3.3 306.1 1.0X + *Join w 2 longs duplicated codegen=true 2080 / 2139 10.1 99.2 3.1X + */ + } + + ignore("broadcast hash join, outer join long key") { + val N = 20 << 20 + val M = 1 << 16 + val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v")) + runBenchmark("outer join w long", N) { + sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "left").count() + } + + /* + *Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *------------------------------------------------------------------------------------------- + *outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X + *outer join w long codegen=true 261 / 276 80.5 12.4 11.7X + */ + } + + ignore("broadcast hash join, semi join long key") { + val N = 20 << 20 + val M = 1 << 16 + val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v")) + runBenchmark("semi join w long", N) { + sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count() + } + + /* + *Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *------------------------------------------------------------------------------------------- + *semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X + *semi join w long codegen=true 237 / 244 88.3 11.3 8.1X + */ + } + + ignore("sort merge join") { + val N = 2 << 20 + runBenchmark("merge join", N) { + val df1 = sparkSession.range(N).selectExpr(s"id * 2 as k1") + val df2 = sparkSession.range(N).selectExpr(s"id * 3 as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /* + *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *------------------------------------------------------------------------------------------- + *merge join codegen=false 1588 / 1880 1.3 757.1 1.0X + *merge join codegen=true 1477 / 1531 1.4 704.2 1.1X + */ + } + + ignore("sort merge join with duplicates") { + val N = 2 << 20 + runBenchmark("sort merge join", N) { + val df1 = sparkSession.range(N) + .selectExpr(s"(id * 15485863) % ${N*10} as k1") + val df2 = sparkSession.range(N) + .selectExpr(s"(id * 15485867) % ${N*10} as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /* + *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *------------------------------------------------------------------------------------------- + *sort merge join codegen=false 3626 / 3667 0.6 1728.9 1.0X + *sort merge join codegen=true 3405 / 3438 0.6 1623.8 1.1X + */ + } + + ignore("shuffle hash join") { + val N = 4 << 20 + sparkSession.conf.set("spark.sql.shuffle.partitions", "2") + sparkSession.conf.set("spark.sql.autoBroadcastJoinThreshold", "10000000") + sparkSession.conf.set("spark.sql.join.preferSortMergeJoin", "false") + runBenchmark("shuffle hash join", N) { + val df1 = sparkSession.range(N).selectExpr(s"id as k1") + val df2 = sparkSession.range(N / 5).selectExpr(s"id * 3 as k2") + df1.join(df2, col("k1") === col("k2")).count() + } + + /* + *Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + *shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *------------------------------------------------------------------------------------------- + *shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X + *shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X + */ + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala new file mode 100644 index 0000000000000..470c78120b194 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure whole stage codegen performance. + * To run this: + * build/sbt "sql/test-only *benchmark.MiscBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class MiscBenchmark extends BenchmarkBase { + + ignore("filter & aggregate without group") { + val N = 500L << 22 + runBenchmark("range/filter/sum", N) { + sparkSession.range(N).filter("(id & 1) = 1").groupBy().sum().collect() + } + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + range/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + range/filter/sum codegen=false 30663 / 31216 68.4 14.6 1.0X + range/filter/sum codegen=true 2399 / 2409 874.1 1.1 12.8X + */ + } + + ignore("range/limit/sum") { + val N = 500L << 20 + runBenchmark("range/limit/sum", N) { + sparkSession.range(N).limit(1000000).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X + range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X + */ + } + + ignore("sample") { + val N = 500 << 18 + runBenchmark("sample with replacement", N) { + sparkSession.range(N).sample(withReplacement = true, 0.01).groupBy().sum().collect() + } + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + sample with replacement: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + sample with replacement codegen=false 7073 / 7227 18.5 54.0 1.0X + sample with replacement codegen=true 5199 / 5203 25.2 39.7 1.4X + */ + + runBenchmark("sample without replacement", N) { + sparkSession.range(N).sample(withReplacement = false, 0.01).groupBy().sum().collect() + } + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + sample without replacement: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + sample without replacement codegen=false 1508 / 1529 86.9 11.5 1.0X + sample without replacement codegen=true 644 / 662 203.5 4.9 2.3X + */ + } + + ignore("collect") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect", N) + benchmark.addCase("collect 1 million") { iter => + sparkSession.range(N).collect() + } + benchmark.addCase("collect 2 millions") { iter => + sparkSession.range(N * 2).collect() + } + benchmark.addCase("collect 4 millions") { iter => + sparkSession.range(N * 4).collect() + } + benchmark.run() + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect 1 million 439 / 654 2.4 418.7 1.0X + collect 2 millions 961 / 1907 1.1 916.4 0.5X + collect 4 millions 3193 / 3895 0.3 3044.7 0.1X + */ + } + + ignore("collect limit") { + val N = 1 << 20 + + val benchmark = new Benchmark("collect limit", N) + benchmark.addCase("collect limit 1 million") { iter => + sparkSession.range(N * 4).limit(N).collect() + } + benchmark.addCase("collect limit 2 millions") { iter => + sparkSession.range(N * 4).limit(N * 2).collect() + } + benchmark.run() + + /** + model name : Westmere E56xx/L56xx/X56xx (Nehalem-C) + collect limit: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + collect limit 1 million 833 / 1284 1.3 794.4 1.0X + collect limit 2 millions 3348 / 4005 0.3 3193.3 0.2X + */ + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala new file mode 100644 index 0000000000000..9964b7373fc20 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.util.{Arrays, Comparator} + +import org.apache.spark.unsafe.array.LongArray +import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.util.Benchmark +import org.apache.spark.util.collection.Sorter +import org.apache.spark.util.collection.unsafe.sort._ +import org.apache.spark.util.random.XORShiftRandom + +/** + * Benchmark to measure performance for aggregate primitives. + * To run this: + * build/sbt "sql/test-only *benchmark.SortBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class SortBenchmark extends BenchmarkBase { + + private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { + val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) + new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( + buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { + override def compare( + r1: RecordPointerAndKeyPrefix, + r2: RecordPointerAndKeyPrefix): Int = { + refCmp.compare(r1.keyPrefix, r2.keyPrefix) + } + }) + } + + private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { + val ref = Array.tabulate[Long](size * 2) { i => rand } + val extended = ref ++ Array.fill[Long](size * 2)(0) + (new LongArray(MemoryBlock.fromLongArray(ref)), + new LongArray(MemoryBlock.fromLongArray(extended))) + } + + ignore("sort") { + val size = 25000000 + val rand = new XORShiftRandom(123) + val benchmark = new Benchmark("radix sort " + size, size) + benchmark.addTimerCase("reference TimSort key prefix array") { timer => + val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } + val buf = new LongArray(MemoryBlock.fromLongArray(array)) + timer.startTiming() + referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) + timer.stopTiming() + } + benchmark.addTimerCase("reference Arrays.sort") { timer => + val ref = Array.tabulate[Long](size) { i => rand.nextLong } + timer.startTiming() + Arrays.sort(ref) + timer.stopTiming() + } + benchmark.addTimerCase("radix sort one byte") { timer => + val array = new Array[Long](size * 2) + var i = 0 + while (i < size) { + array(i) = rand.nextLong & 0xff + i += 1 + } + val buf = new LongArray(MemoryBlock.fromLongArray(array)) + timer.startTiming() + RadixSort.sort(buf, size, 0, 7, false, false) + timer.stopTiming() + } + benchmark.addTimerCase("radix sort two bytes") { timer => + val array = new Array[Long](size * 2) + var i = 0 + while (i < size) { + array(i) = rand.nextLong & 0xffff + i += 1 + } + val buf = new LongArray(MemoryBlock.fromLongArray(array)) + timer.startTiming() + RadixSort.sort(buf, size, 0, 7, false, false) + timer.stopTiming() + } + benchmark.addTimerCase("radix sort eight bytes") { timer => + val array = new Array[Long](size * 2) + var i = 0 + while (i < size) { + array(i) = rand.nextLong + i += 1 + } + val buf = new LongArray(MemoryBlock.fromLongArray(array)) + timer.startTiming() + RadixSort.sort(buf, size, 0, 7, false, false) + timer.stopTiming() + } + benchmark.addTimerCase("radix sort key prefix array") { timer => + val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong) + timer.startTiming() + RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false) + timer.stopTiming() + } + benchmark.run() + + /* + Running benchmark: radix sort 25000000 + Java HotSpot(TM) 64-Bit Server VM 1.8.0_66-b17 on Linux 3.13.0-44-generic + Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz + + radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + reference TimSort key prefix array 15546 / 15859 1.6 621.9 1.0X + reference Arrays.sort 2416 / 2446 10.3 96.6 6.4X + radix sort one byte 133 / 137 188.4 5.3 117.2X + radix sort two bytes 255 / 258 98.2 10.2 61.1X + radix sort eight bytes 991 / 997 25.2 39.6 15.7X + radix sort key prefix array 1540 / 1563 16.2 61.6 10.1X + */ + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala new file mode 100644 index 0000000000000..957a1d6426e87 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Benchmark + +/** + * Benchmark to measure TPCDS query performance. + * To run this: + * spark-submit --class --jars + */ +object TPCDSQueryBenchmark { + val conf = + new SparkConf() + .setMaster("local[1]") + .setAppName("test-sql-context") + .set("spark.sql.parquet.compression.codec", "snappy") + .set("spark.sql.shuffle.partitions", "4") + .set("spark.driver.memory", "3g") + .set("spark.executor.memory", "3g") + .set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString) + + val spark = SparkSession.builder.config(conf).getOrCreate() + + val tables = Seq("catalog_page", "catalog_returns", "customer", "customer_address", + "customer_demographics", "date_dim", "household_demographics", "inventory", "item", + "promotion", "store", "store_returns", "catalog_sales", "web_sales", "store_sales", + "web_returns", "web_site", "reason", "call_center", "warehouse", "ship_mode", "income_band", + "time_dim", "web_page") + + def setupTables(dataLocation: String): Map[String, Long] = { + tables.map { tableName => + spark.read.parquet(s"$dataLocation/$tableName").createOrReplaceTempView(tableName) + tableName -> spark.table(tableName).count() + }.toMap + } + + def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = { + require(dataLocation.nonEmpty, + "please modify the value of dataLocation to point to your local TPCDS data") + val tableSizes = setupTables(dataLocation) + queries.foreach { name => + val queryString = fileToString(new File(Thread.currentThread().getContextClassLoader + .getResource(s"tpcds/$name.sql").getFile)) + + // This is an indirect hack to estimate the size of each query's input by traversing the + // logical plan and adding up the sizes of all tables that appear in the plan. Note that this + // currently doesn't take WITH subqueries into account which might lead to fairly inaccurate + // per-row processing time for those cases. + val queryRelations = scala.collection.mutable.HashSet[String]() + spark.sql(queryString).queryExecution.logical.map { + case ur @ UnresolvedRelation(t: TableIdentifier, _) => + queryRelations.add(t.table) + case lp: LogicalPlan => + lp.expressions.foreach { _ foreach { + case subquery: SubqueryExpression => + subquery.plan.foreach { + case ur @ UnresolvedRelation(t: TableIdentifier, _) => + queryRelations.add(t.table) + case _ => + } + case _ => + } + } + case _ => + } + val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum + val benchmark = new Benchmark(s"TPCDS Snappy", numRows, 5) + benchmark.addCase(name) { i => + spark.sql(queryString).collect() + } + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + + // List of all TPC-DS queries + val tpcdsQueries = Seq( + "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", + "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", "q30", + "q31", "q32", "q33", "q34", "q35", "q36", "q37", "q38", "q39a", "q39b", "q40", + "q41", "q42", "q43", "q44", "q45", "q46", "q47", "q48", "q49", "q50", + "q51", "q52", "q53", "q54", "q55", "q56", "q57", "q58", "q59", "q60", + "q61", "q62", "q63", "q64", "q65", "q66", "q67", "q68", "q69", "q70", + "q71", "q72", "q73", "q74", "q75", "q76", "q77", "q78", "q79", "q80", + "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", + "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + + // In order to run this benchmark, please follow the instructions at + // https://github.com/databricks/spark-sql-perf/blob/master/README.md to generate the TPCDS data + // locally (preferably with a scale factor of 5 for benchmarking). Thereafter, the value of + // dataLocation below needs to be set to the location where the generated data is stored. + val dataLocation = "" + + tpcdsAll(dataLocation, queries = tpcdsQueries) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala new file mode 100644 index 0000000000000..d2704b3d3f371 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/WideSchemaBenchmark.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.{File, FileOutputStream, OutputStream} + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions._ +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark for performance with very wide and nested DataFrames. + * To run this: + * build/sbt "sql/test-only *WideSchemaBenchmark" + * + * Results will be written to "sql/core/benchmarks/WideSchemaBenchmark-results.txt". + */ +class WideSchemaBenchmark extends SparkFunSuite with BeforeAndAfterEach { + private val scaleFactor = 100000 + private val widthsToTest = Seq(1, 100, 2500) + private val depthsToTest = Seq(1, 100, 250) + assert(scaleFactor > widthsToTest.max) + + private lazy val sparkSession = SparkSession.builder + .master("local[1]") + .appName("microbenchmark") + .getOrCreate() + + import sparkSession.implicits._ + + private var tmpFiles: List[File] = Nil + private var out: OutputStream = null + + override def beforeAll() { + super.beforeAll() + out = new FileOutputStream(new File("benchmarks/WideSchemaBenchmark-results.txt")) + } + + override def afterAll() { + super.afterAll() + out.close() + } + + override def afterEach() { + super.afterEach() + for (tmpFile <- tmpFiles) { + Utils.deleteRecursively(tmpFile) + } + } + + /** + * Writes the given DataFrame to parquet at a temporary location, and returns a DataFrame + * backed by the written parquet files. + */ + private def saveAsParquet(df: DataFrame): DataFrame = { + val tmpFile = File.createTempFile("WideSchemaBenchmark", "tmp") + tmpFiles ::= tmpFile + tmpFile.delete() + df.write.parquet(tmpFile.getAbsolutePath) + assert(tmpFile.isDirectory()) + sparkSession.read.parquet(tmpFile.getAbsolutePath) + } + + /** + * Adds standard set of cases to a benchmark given a dataframe and field to select. + */ + private def addCases( + benchmark: Benchmark, + df: DataFrame, + desc: String, + selector: String): Unit = { + benchmark.addCase(desc + " (read in-mem)") { iter => + df.selectExpr(s"sum($selector)").collect() + } + benchmark.addCase(desc + " (exec in-mem)") { iter => + df.selectExpr("*", s"hash($selector) as f").selectExpr(s"sum($selector)", "sum(f)").collect() + } + val parquet = saveAsParquet(df) + benchmark.addCase(desc + " (read parquet)") { iter => + parquet.selectExpr(s"sum($selector) as f").collect() + } + benchmark.addCase(desc + " (write parquet)") { iter => + saveAsParquet(df.selectExpr(s"sum($selector) as f")) + } + } + + ignore("parsing large select expressions") { + val benchmark = new Benchmark("parsing large select", 1, output = Some(out)) + for (width <- widthsToTest) { + val selectExpr = (1 to width).map(i => s"id as a_$i") + benchmark.addCase(s"$width select expressions") { iter => + sparkSession.range(1).toDF.selectExpr(selectExpr: _*) + } + } + benchmark.run() + } + + ignore("many column field read and write") { + val benchmark = new Benchmark("many column field r/w", scaleFactor, output = Some(out)) + for (width <- widthsToTest) { + // normalize by width to keep constant data size + val numRows = scaleFactor / width + val selectExpr = (1 to width).map(i => s"id as a_$i") + val df = sparkSession.range(numRows).toDF.selectExpr(selectExpr: _*).cache() + df.count() // force caching + addCases(benchmark, df, s"$width cols x $numRows rows", "a_1") + } + benchmark.run() + } + + ignore("wide shallowly nested struct field read and write") { + val benchmark = new Benchmark( + "wide shallowly nested struct field r/w", scaleFactor, output = Some(out)) + for (width <- widthsToTest) { + val numRows = scaleFactor / width + var datum: String = "{" + for (i <- 1 to width) { + if (i == 1) { + datum += s""""value_$i": 1""" + } else { + datum += s""", "value_$i": 1""" + } + } + datum += "}" + datum = s"""{"a": {"b": {"c": $datum, "d": $datum}, "e": $datum}}""" + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + df.count() // force caching + addCases(benchmark, df, s"$width wide x $numRows rows", "a.b.c.value_1") + } + benchmark.run() + } + + ignore("deeply nested struct field read and write") { + val benchmark = new Benchmark("deeply nested struct field r/w", scaleFactor, output = Some(out)) + for (depth <- depthsToTest) { + val numRows = scaleFactor / depth + var datum: String = "{\"value\": 1}" + var selector: String = "value" + for (i <- 1 to depth) { + datum = "{\"value\": " + datum + "}" + selector = selector + ".value" + } + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + df.count() // force caching + addCases(benchmark, df, s"$depth deep x $numRows rows", selector) + } + benchmark.run() + } + + ignore("bushy struct field read and write") { + val benchmark = new Benchmark("bushy struct field r/w", scaleFactor, output = Some(out)) + for (width <- Seq(1, 100, 1000)) { + val numRows = scaleFactor / width + var numNodes = 1 + var datum: String = "{\"value\": 1}" + var selector: String = "value" + var depth = 1 + while (numNodes < width) { + numNodes *= 2 + datum = s"""{"left_$depth": $datum, "right_$depth": $datum}""" + selector = s"left_$depth." + selector + depth += 1 + } + // TODO(ekl) seems like the json parsing is actually the majority of the time, perhaps + // we should benchmark that too separately. + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + df.count() // force caching + addCases(benchmark, df, s"$numNodes x $depth deep x $numRows rows", selector) + } + benchmark.run() + } + + ignore("wide array field read and write") { + val benchmark = new Benchmark("wide array field r/w", scaleFactor, output = Some(out)) + for (width <- widthsToTest) { + val numRows = scaleFactor / width + var datum: String = "{\"value\": [" + for (i <- 1 to width) { + if (i == 1) { + datum += "1" + } else { + datum += ", 1" + } + } + datum += "]}" + val df = sparkSession.read.json(sparkSession.range(numRows).map(_ => datum).rdd).cache() + df.count() // force caching + addCases(benchmark, df, s"$width wide x $numRows rows", "value[0]") + } + benchmark.run() + } + + ignore("wide map field read and write") { + val benchmark = new Benchmark("wide map field r/w", scaleFactor, output = Some(out)) + for (width <- widthsToTest) { + val numRows = scaleFactor / width + val datum = Tuple1((1 to width).map(i => ("value_" + i -> 1)).toMap) + val df = sparkSession.range(numRows).map(_ => datum).toDF.cache() + df.count() // force caching + addCases(benchmark, df, s"$width wide x $numRows rows", "_1[\"value_1\"]") + } + benchmark.run() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 052f4cbaebc8e..6e5c22dab52f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -101,14 +101,15 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { - val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE).order(ByteOrder.nativeOrder()) val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) + val totalSize = seq.map(_.getSizeInBytes).sum + val bufferSize = Math.max(DEFAULT_BUFFER_SIZE, totalSize) test(s"$columnType append/extract") { - buffer.rewind() - seq.foreach(columnType.append(_, 0, buffer)) + val buffer = ByteBuffer.allocate(bufferSize).order(ByteOrder.nativeOrder()) + seq.foreach(r => columnType.append(columnType.getField(r, 0), buffer)) buffer.rewind() seq.foreach { row => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 50c8745a288f4..0daa29b666f62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ @@ -32,7 +33,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { setupTestData() test("simple columnar query") { - val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -41,15 +42,15 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) - .toDF().registerTempTable("sizeTst") - sqlContext.cacheTable("sizeTst") + .toDF().createOrReplaceTempView("sizeTst") + spark.catalog.cacheTable("sizeTst") assert( - sqlContext.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - sqlContext.conf.autoBroadcastJoinThreshold) + spark.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } test("projection") { - val plan = sqlContext.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan + val plan = spark.sessionState.executePlan(testData.select('value, 'key).logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -58,7 +59,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = sqlContext.executePlan(testData.logicalPlan).sparkPlan + val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -70,7 +71,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("repeatedData") + spark.catalog.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -82,7 +83,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("nullableRepeatedData") + spark.catalog.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -91,13 +92,13 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-2729 regression: timestamp data type") { val timestamps = (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time") - timestamps.registerTempTable("timestamps") + timestamps.createOrReplaceTempView("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), timestamps.collect().toSeq) - sqlContext.cacheTable("timestamps") + spark.catalog.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -109,7 +110,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - sqlContext.cacheTable("withEmptyParts") + spark.catalog.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -132,7 +133,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(df.schema.head.dataType === DecimalType(15, 10)) - df.cache().registerTempTable("test_fixed_decimal") + df.cache().createOrReplaceTempView("test_fixed_decimal") checkAnswer( sql("SELECT * FROM test_fixed_decimal"), (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) @@ -178,35 +179,35 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { (i to i + 10).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - sqlContext.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + spark.createDataFrame(rdd, schema).createOrReplaceTempView("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - sqlContext.table("InMemoryCache_different_data_types").queryExecution.executedPlan + spark.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - sqlContext.isCached("InMemoryCache_different_data_types"), + spark.catalog.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - sqlContext.table("InMemoryCache_different_data_types").collect()) - sqlContext.dropTempTable("InMemoryCache_different_data_types") + spark.table("InMemoryCache_different_data_types").collect()) + spark.catalog.dropTempView("InMemoryCache_different_data_types") } test("SPARK-10422: String column in InMemoryColumnarCache needs to override clone method") { - val df = sqlContext.range(1, 100).selectExpr("id % 10 as id") + val df = spark.range(1, 100).selectExpr("id % 10 as id") .rdd.map(id => Tuple1(s"str_$id")).toDF("i") val cached = df.cache() // count triggers the caching action. It should not throw. cached.count() // Make sure, the DataFrame is indeed cached. - assert(sqlContext.cacheManager.lookupCachedData(cached).nonEmpty) + assert(spark.sharedState.cacheManager.lookupCachedData(cached).nonEmpty) // Check result. checkAnswer( cached, - sqlContext.range(1, 100).selectExpr("id % 10 as id") + spark.range(1, 100).selectExpr("id % 10 as id") .rdd.map(id => Tuple1(s"str_$id")).toDF("i") ) @@ -215,7 +216,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-10859: Predicates pushed to InMemoryColumnarTableScan are not evaluated correctly") { - val data = sqlContext.range(10).selectExpr("id", "cast(id as string) as s") + val data = spark.range(10).selectExpr("id", "cast(id as string) as s") data.cache() assert(data.count() === 10) assert(data.filter($"s" === "3").count() === 1) @@ -226,8 +227,23 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val columnTypes1 = List.fill(length1)(IntegerType) val columnarIterator1 = GenerateColumnAccessor.generate(columnTypes1) - val length2 = 10000 + // SPARK-16664: the limit of janino is 8117 + val length2 = 8117 val columnTypes2 = List.fill(length2)(IntegerType) val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) } + + test("SPARK-17549: cached table size should be correctly calculated") { + val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() + val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None) + + // Materialize the data. + val expectedAnswer = data.collect() + checkAnswer(cached, expectedAnswer) + + // Check that the right size was calculated. + assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala index 9164074a3e996..7ca8e047f081d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala @@ -32,23 +32,24 @@ class PartitionBatchPruningSuite import testImplicits._ - private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize - private lazy val originalInMemoryPartitionPruning = sqlContext.conf.inMemoryPartitionPruning + private lazy val originalColumnBatchSize = spark.conf.get(SQLConf.COLUMN_BATCH_SIZE) + private lazy val originalInMemoryPartitionPruning = + spark.conf.get(SQLConf.IN_MEMORY_PARTITION_PRUNING) override protected def beforeAll(): Unit = { super.beforeAll() // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) + spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, 10) // Enable in-memory partition pruning - sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, true) // Enable in-memory table scan accumulators - sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + spark.conf.set("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { try { - sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) - sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + spark.conf.set(SQLConf.COLUMN_BATCH_SIZE.key, originalColumnBatchSize) + spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, originalInMemoryPartitionPruning) } finally { super.afterAll() } @@ -62,13 +63,13 @@ class PartitionBatchPruningSuite val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() - pruningData.registerTempTable("pruningData") - sqlContext.cacheTable("pruningData") + pruningData.createOrReplaceTempView("pruningData") + spark.catalog.cacheTable("pruningData") } override protected def afterEach(): Unit = { try { - sqlContext.uncacheTable("pruningData") + spark.catalog.uncacheTable("pruningData") } finally { super.afterEach() } @@ -118,6 +119,21 @@ class PartitionBatchPruningSuite } } + // With disable IN_MEMORY_PARTITION_PRUNING option + test("disable IN_MEMORY_PARTITION_PRUNING") { + spark.conf.set(SQLConf.IN_MEMORY_PARTITION_PRUNING.key, false) + + val df = sql("SELECT key FROM pruningData WHERE key = 1") + val result = df.collect().map(_(0)).toArray + assert(result.length === 1) + + val (readPartitions, readBatches) = df.queryExecution.sparkPlan.collect { + case in: InMemoryTableScanExec => (in.readPartitions.value, in.readBatches.value) + }.head + assert(readPartitions === 5) + assert(readBatches === 10) + } + def checkBatchPruning( query: String, expectedReadPartitions: Int, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index be0f4d78a5237..8d74884df9273 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -17,23 +17,39 @@ package org.apache.spark.sql.execution.command +import scala.reflect.{classTag, ClassTag} + import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, FunctionResource} +import org.apache.spark.sql.catalyst.catalog.FunctionResourceType import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ +import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsing} +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructType} + // TODO: merge this with DDLSuite (SPARK-14441) class DDLCommandSuite extends PlanTest { - private val parser = new SparkSqlParser(new SQLConf) + private lazy val parser = new SparkSqlParser(new SQLConf) - private def assertUnsupported(sql: String): Unit = { + private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { val e = intercept[ParseException] { parser.parsePlan(sql) } assert(e.getMessage.toLowerCase.contains("operation not allowed")) + containsThesePhrases.foreach { p => assert(e.getMessage.toLowerCase.contains(p.toLowerCase)) } + } + + private def parseAs[T: ClassTag](query: String): T = { + parser.parsePlan(query) match { + case t: T => t + case other => + fail(s"Expected to parse ${classTag[T].runtimeClass} from query," + + s"got ${other.getClass.getName}: $query") + } } test("create database") { @@ -44,7 +60,7 @@ class DDLCommandSuite extends PlanTest { |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c') """.stripMargin val parsed = parser.parsePlan(sql) - val expected = CreateDatabase( + val expected = CreateDatabaseCommand( "database_name", ifNotExists = true, Some("/home/user/db"), @@ -53,6 +69,12 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed, expected) } + test("create database - property values must be set") { + assertUnsupported( + sql = "CREATE DATABASE my_db WITH DBPROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_without_value")) + } + test("drop database") { val sql1 = "DROP DATABASE IF EXISTS database_name RESTRICT" val sql2 = "DROP DATABASE IF EXISTS database_name CASCADE" @@ -72,19 +94,19 @@ class DDLCommandSuite extends PlanTest { val parsed6 = parser.parsePlan(sql6) val parsed7 = parser.parsePlan(sql7) - val expected1 = DropDatabase( + val expected1 = DropDatabaseCommand( "database_name", ifExists = true, cascade = false) - val expected2 = DropDatabase( + val expected2 = DropDatabaseCommand( "database_name", ifExists = true, cascade = true) - val expected3 = DropDatabase( + val expected3 = DropDatabaseCommand( "database_name", ifExists = false, cascade = false) - val expected4 = DropDatabase( + val expected4 = DropDatabaseCommand( "database_name", ifExists = false, cascade = true) @@ -106,10 +128,10 @@ class DDLCommandSuite extends PlanTest { val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) - val expected1 = AlterDatabaseProperties( + val expected1 = AlterDatabasePropertiesCommand( "database_name", Map("a" -> "a", "b" -> "b", "c" -> "c")) - val expected2 = AlterDatabaseProperties( + val expected2 = AlterDatabasePropertiesCommand( "database_name", Map("a" -> "a")) @@ -117,6 +139,12 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + test("alter database - property values must be set") { + assertUnsupported( + sql = "ALTER DATABASE my_db SET DBPROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_without_value")) + } + test("describe database") { // DESCRIBE DATABASE [EXTENDED] db_name; val sql1 = "DESCRIBE DATABASE EXTENDED db_name" @@ -125,10 +153,10 @@ class DDLCommandSuite extends PlanTest { val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) - val expected1 = DescribeDatabase( + val expected1 = DescribeDatabaseCommand( "db_name", extended = true) - val expected2 = DescribeDatabase( + val expected2 = DescribeDatabaseCommand( "db_name", extended = false) @@ -151,17 +179,21 @@ class DDLCommandSuite extends PlanTest { """.stripMargin val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) - val expected1 = CreateFunction( + val expected1 = CreateFunctionCommand( None, "helloworld", "com.matthewrathbone.example.SimpleUDFExample", - Seq(("jar", "/path/to/jar1"), ("jar", "/path/to/jar2")), + Seq( + FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar1"), + FunctionResource(FunctionResourceType.fromString("jar"), "/path/to/jar2")), isTemp = true) - val expected2 = CreateFunction( + val expected2 = CreateFunctionCommand( Some("hello"), "world", "com.matthewrathbone.example.SimpleUDFExample", - Seq(("archive", "/path/to/archive"), ("file", "/path/to/file")), + Seq( + FunctionResource(FunctionResourceType.fromString("archive"), "/path/to/archive"), + FunctionResource(FunctionResourceType.fromString("file"), "/path/to/file")), isTemp = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) @@ -178,22 +210,22 @@ class DDLCommandSuite extends PlanTest { val parsed3 = parser.parsePlan(sql3) val parsed4 = parser.parsePlan(sql4) - val expected1 = DropFunction( + val expected1 = DropFunctionCommand( None, "helloworld", ifExists = false, isTemp = true) - val expected2 = DropFunction( + val expected2 = DropFunctionCommand( None, "helloworld", ifExists = true, isTemp = true) - val expected3 = DropFunction( + val expected3 = DropFunctionCommand( Some("hello"), "world", ifExists = false, isTemp = false) - val expected4 = DropFunction( + val expected4 = DropFunctionCommand( Some("hello"), "world", ifExists = true, @@ -205,6 +237,172 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed4, expected4) } + test("create table - table file format") { + val allSources = Seq("parquet", "parquetfile", "orc", "orcfile", "avro", "avrofile", + "sequencefile", "rcfile", "textfile") + + allSources.foreach { s => + val query = s"CREATE TABLE my_tab STORED AS $s" + val ct = parseAs[CreateTableCommand](query) + val hiveSerde = HiveSerDe.sourceToSerDe(s, new SQLConf) + assert(hiveSerde.isDefined) + assert(ct.table.storage.serde == hiveSerde.get.serde) + assert(ct.table.storage.inputFormat == hiveSerde.get.inputFormat) + assert(ct.table.storage.outputFormat == hiveSerde.get.outputFormat) + } + } + + test("create table - row format and table file format") { + val createTableStart = "CREATE TABLE my_tab ROW FORMAT" + val fileFormat = s"STORED AS INPUTFORMAT 'inputfmt' OUTPUTFORMAT 'outputfmt'" + val query1 = s"$createTableStart SERDE 'anything' $fileFormat" + val query2 = s"$createTableStart DELIMITED FIELDS TERMINATED BY ' ' $fileFormat" + + // No conflicting serdes here, OK + val parsed1 = parseAs[CreateTableCommand](query1) + assert(parsed1.table.storage.serde == Some("anything")) + assert(parsed1.table.storage.inputFormat == Some("inputfmt")) + assert(parsed1.table.storage.outputFormat == Some("outputfmt")) + val parsed2 = parseAs[CreateTableCommand](query2) + assert(parsed2.table.storage.serde.isEmpty) + assert(parsed2.table.storage.inputFormat == Some("inputfmt")) + assert(parsed2.table.storage.outputFormat == Some("outputfmt")) + } + + test("create table - row format serde and generic file format") { + val allSources = Seq("parquet", "orc", "avro", "sequencefile", "rcfile", "textfile") + val supportedSources = Set("sequencefile", "rcfile", "textfile") + + allSources.foreach { s => + val query = s"CREATE TABLE my_tab ROW FORMAT SERDE 'anything' STORED AS $s" + if (supportedSources.contains(s)) { + val ct = parseAs[CreateTableCommand](query) + val hiveSerde = HiveSerDe.sourceToSerDe(s, new SQLConf) + assert(hiveSerde.isDefined) + assert(ct.table.storage.serde == Some("anything")) + assert(ct.table.storage.inputFormat == hiveSerde.get.inputFormat) + assert(ct.table.storage.outputFormat == hiveSerde.get.outputFormat) + } else { + assertUnsupported(query, Seq("row format serde", "incompatible", s)) + } + } + } + + test("create table - row format delimited and generic file format") { + val allSources = Seq("parquet", "orc", "avro", "sequencefile", "rcfile", "textfile") + val supportedSources = Set("textfile") + + allSources.foreach { s => + val query = s"CREATE TABLE my_tab ROW FORMAT DELIMITED FIELDS TERMINATED BY ' ' STORED AS $s" + if (supportedSources.contains(s)) { + val ct = parseAs[CreateTableCommand](query) + val hiveSerde = HiveSerDe.sourceToSerDe(s, new SQLConf) + assert(hiveSerde.isDefined) + assert(ct.table.storage.serde == hiveSerde.get.serde) + assert(ct.table.storage.inputFormat == hiveSerde.get.inputFormat) + assert(ct.table.storage.outputFormat == hiveSerde.get.outputFormat) + } else { + assertUnsupported(query, Seq("row format delimited", "only compatible with 'textfile'", s)) + } + } + } + + test("create external table - location must be specified") { + assertUnsupported( + sql = "CREATE EXTERNAL TABLE my_tab", + containsThesePhrases = Seq("create external table", "location")) + val query = "CREATE EXTERNAL TABLE my_tab LOCATION '/something/anything'" + val ct = parseAs[CreateTableCommand](query) + assert(ct.table.tableType == CatalogTableType.EXTERNAL) + assert(ct.table.storage.locationUri == Some("/something/anything")) + } + + test("create table - property values must be set") { + assertUnsupported( + sql = "CREATE TABLE my_tab TBLPROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_without_value")) + assertUnsupported( + sql = "CREATE TABLE my_tab ROW FORMAT SERDE 'serde' " + + "WITH SERDEPROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_without_value")) + } + + test("create table - location implies external") { + val query = "CREATE TABLE my_tab LOCATION '/something/anything'" + val ct = parseAs[CreateTableCommand](query) + assert(ct.table.tableType == CatalogTableType.EXTERNAL) + assert(ct.table.storage.locationUri == Some("/something/anything")) + } + + test("create table - column repeated in partitioning columns") { + val query = "CREATE TABLE tab1 (key INT, value STRING) PARTITIONED BY (key INT, hr STRING)" + val e = intercept[ParseException] { parser.parsePlan(query) } + assert(e.getMessage.contains( + "Operation not allowed: Partition columns may not be specified in the schema: [\"key\"]")) + } + + test("create table - duplicate column names in the table definition") { + val query = "CREATE TABLE default.tab1 (key INT, key STRING)" + val e = intercept[ParseException] { parser.parsePlan(query) } + assert(e.getMessage.contains("Operation not allowed: Duplicated column names found in " + + "table definition of `default`.`tab1`: [\"key\"]")) + } + + test("create table using - with partitioned by") { + val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " + + "USING parquet PARTITIONED BY (a)" + val expected = CreateTableUsing( + TableIdentifier("my_tab"), + Some(new StructType() + .add("a", IntegerType, nullable = true, + new MetadataBuilder().putString("comment", s"test").build()) + .add("b", StringType)), + "parquet", + false, + Map.empty, + null, + None, + false, + true) + + parser.parsePlan(query) match { + case ct: CreateTableUsing => + // We can't compare array in `CreateTableUsing` directly, so here we compare + // `partitionColumns` ahead, and make `partitionColumns` null before plan comparison. + assert(Seq("a") == ct.partitionColumns.toSeq) + comparePlans(ct.copy(partitionColumns = null), expected) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + + test("create table using - with bucket") { + val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" + val expected = CreateTableUsing( + TableIdentifier("my_tab"), + Some(new StructType().add("a", IntegerType).add("b", StringType)), + "parquet", + false, + Map.empty, + null, + Some(BucketSpec(5, Seq("a"), Seq("b"))), + false, + true) + + parser.parsePlan(query) match { + case ct: CreateTableUsing => + // `Array.empty == Array.empty` returns false, here we set `partitionColumns` to null before + // plan comparison. + assert(ct.partitionColumns.isEmpty) + comparePlans(ct.copy(partitionColumns = null), expected) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + // ALTER TABLE table_name RENAME TO new_table_name; // ALTER VIEW view_name RENAME TO new_view_name; test("alter table/view: rename table/view") { @@ -212,11 +410,11 @@ class DDLCommandSuite extends PlanTest { val sql_view = sql_table.replace("TABLE", "VIEW") val parsed_table = parser.parsePlan(sql_table) val parsed_view = parser.parsePlan(sql_view) - val expected_table = AlterTableRename( + val expected_table = AlterTableRenameCommand( TableIdentifier("table_name", None), TableIdentifier("new_table_name", None), isView = false) - val expected_view = AlterTableRename( + val expected_view = AlterTableRenameCommand( TableIdentifier("table_name", None), TableIdentifier("new_table_name", None), isView = true) @@ -245,11 +443,11 @@ class DDLCommandSuite extends PlanTest { val parsed3_view = parser.parsePlan(sql3_view) val tableIdent = TableIdentifier("table_name", None) - val expected1_table = AlterTableSetProperties( + val expected1_table = AlterTableSetPropertiesCommand( tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = false) - val expected2_table = AlterTableUnsetProperties( + val expected2_table = AlterTableUnsetPropertiesCommand( tableIdent, Seq("comment", "test"), ifExists = false, isView = false) - val expected3_table = AlterTableUnsetProperties( + val expected3_table = AlterTableUnsetPropertiesCommand( tableIdent, Seq("comment", "test"), ifExists = true, isView = false) val expected1_view = expected1_table.copy(isView = true) val expected2_view = expected2_table.copy(isView = true) @@ -263,6 +461,18 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed3_view, expected3_view) } + test("alter table - property values must be set") { + assertUnsupported( + sql = "ALTER TABLE my_tab SET TBLPROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_without_value")) + } + + test("alter table unset properties - property values must NOT be set") { + assertUnsupported( + sql = "ALTER TABLE my_tab UNSET TBLPROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_with_value")) + } + test("alter table: SerDe properties") { val sql1 = "ALTER TABLE table_name SET SERDE 'org.apache.class'" val sql2 = @@ -292,21 +502,21 @@ class DDLCommandSuite extends PlanTest { val parsed4 = parser.parsePlan(sql4) val parsed5 = parser.parsePlan(sql5) val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSerDeProperties( + val expected1 = AlterTableSerDePropertiesCommand( tableIdent, Some("org.apache.class"), None, None) - val expected2 = AlterTableSerDeProperties( + val expected2 = AlterTableSerDePropertiesCommand( tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), None) - val expected3 = AlterTableSerDeProperties( + val expected3 = AlterTableSerDePropertiesCommand( tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), None) - val expected4 = AlterTableSerDeProperties( + val expected4 = AlterTableSerDePropertiesCommand( tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) - val expected5 = AlterTableSerDeProperties( + val expected5 = AlterTableSerDePropertiesCommand( tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), @@ -318,6 +528,13 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed5, expected5) } + test("alter table - SerDe property values must be set") { + assertUnsupported( + sql = "ALTER TABLE my_tab SET SERDE 'serde' " + + "WITH SERDEPROPERTIES('key_without_value', 'key_with_value'='x')", + containsThesePhrases = Seq("key_without_value")) + } + // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION partition_spec // [LOCATION 'location1'] partition_spec [LOCATION 'location2'] ...; test("alter table: add partition") { @@ -332,13 +549,13 @@ class DDLCommandSuite extends PlanTest { val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) - val expected1 = AlterTableAddPartition( + val expected1 = AlterTableAddPartitionCommand( TableIdentifier("table_name", None), Seq( (Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")), (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), ifNotExists = true) - val expected2 = AlterTableAddPartition( + val expected2 = AlterTableAddPartitionCommand( TableIdentifier("table_name", None), Seq((Map("dt" -> "2008-08-08"), Some("loc"))), ifNotExists = false) @@ -347,27 +564,21 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } - // ALTER VIEW view_name ADD [IF NOT EXISTS] PARTITION partition_spec PARTITION partition_spec ...; - test("alter view: add partition") { - val sql1 = + test("alter table: recover partitions") { + val sql = "ALTER TABLE table_name RECOVER PARTITIONS" + val parsed = parser.parsePlan(sql) + val expected = AlterTableRecoverPartitionsCommand( + TableIdentifier("table_name", None)) + comparePlans(parsed, expected) + } + + test("alter view: add partition (not supported)") { + assertUnsupported( """ |ALTER VIEW view_name ADD IF NOT EXISTS PARTITION |(dt='2008-08-08', country='us') PARTITION |(dt='2009-09-09', country='uk') - """.stripMargin - // different constant types in partitioning spec - val sql2 = - """ - |ALTER VIEW view_name ADD PARTITION - |(col1=NULL, cOL2='f', col3=5, COL4=true) - """.stripMargin - - intercept[ParseException] { - parser.parsePlan(sql1) - } - intercept[ParseException] { - parser.parsePlan(sql2) - } + """.stripMargin) } test("alter table: rename partition") { @@ -377,7 +588,7 @@ class DDLCommandSuite extends PlanTest { |RENAME TO PARTITION (dt='2008-09-09', country='uk') """.stripMargin val parsed = parser.parsePlan(sql) - val expected = AlterTableRenamePartition( + val expected = AlterTableRenamePartitionCommand( TableIdentifier("table_name", None), Map("dt" -> "2008-08-08", "country" -> "us"), Map("dt" -> "2008-09-09", "country" -> "uk")) @@ -392,7 +603,7 @@ class DDLCommandSuite extends PlanTest { """.stripMargin) } - // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE] + // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] // ALTER VIEW table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] test("alter table/view: drop partitions") { val sql1_table = @@ -403,34 +614,29 @@ class DDLCommandSuite extends PlanTest { val sql2_table = """ |ALTER TABLE table_name DROP PARTITION - |(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') PURGE + |(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') """.stripMargin val sql1_view = sql1_table.replace("TABLE", "VIEW") - // Note: ALTER VIEW DROP PARTITION does not support PURGE - val sql2_view = sql2_table.replace("TABLE", "VIEW").replace("PURGE", "") + val sql2_view = sql2_table.replace("TABLE", "VIEW") val parsed1_table = parser.parsePlan(sql1_table) - val e = intercept[ParseException] { - parser.parsePlan(sql2_table) - } - assert(e.getMessage.contains("Operation not allowed")) - - intercept[ParseException] { - parser.parsePlan(sql1_view) - } - intercept[ParseException] { - parser.parsePlan(sql2_view) - } + val parsed2_table = parser.parsePlan(sql2_table) + assertUnsupported(sql1_table + " PURGE") + assertUnsupported(sql2_table + " PURGE") + assertUnsupported(sql1_view) + assertUnsupported(sql2_view) val tableIdent = TableIdentifier("table_name", None) - val expected1_table = AlterTableDropPartition( + val expected1_table = AlterTableDropPartitionCommand( tableIdent, Seq( Map("dt" -> "2008-08-08", "country" -> "us"), Map("dt" -> "2009-09-09", "country" -> "uk")), ifExists = true) + val expected2_table = expected1_table.copy(ifExists = false) comparePlans(parsed1_table, expected1_table) + comparePlans(parsed2_table, expected2_table) } test("alter table: archive partition (not supported)") { @@ -441,33 +647,9 @@ class DDLCommandSuite extends PlanTest { assertUnsupported("ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')") } - /* - test("alter table: set file format") { - val sql1 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " + - "OUTPUTFORMAT 'test' SERDE 'test'" - val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + - "SET FILEFORMAT PARQUET" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSetFileFormat( - tableIdent, - None, - List("test", "test", "test"), - None)(sql1) - val expected2 = AlterTableSetFileFormat( - tableIdent, - Some(Map("dt" -> "2008-08-08", "country" -> "us")), - Seq(), - Some("PARQUET"))(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - } */ - test("alter table: set file format (not allowed)") { assertUnsupported( - "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " + - "OUTPUTFORMAT 'test' SERDE 'test'") + "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' OUTPUTFORMAT 'test'") assertUnsupported( "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + "SET FILEFORMAT PARQUET") @@ -480,11 +662,11 @@ class DDLCommandSuite extends PlanTest { val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSetLocation( + val expected1 = AlterTableSetLocationCommand( tableIdent, None, "new location") - val expected2 = AlterTableSetLocation( + val expected2 = AlterTableSetLocationCommand( tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")), "new location") @@ -527,58 +709,6 @@ class DDLCommandSuite extends PlanTest { assertUnsupported("ALTER TABLE table_name SKEWED BY (key) ON (1,5,6) STORED AS DIRECTORIES") } - /* - test("alter table: change column name/type/position/comment") { - val sql1 = "ALTER TABLE table_name CHANGE col_old_name col_new_name INT" - val sql2 = - """ - |ALTER TABLE table_name CHANGE COLUMN col_old_name col_new_name INT - |COMMENT 'col_comment' FIRST CASCADE - """.stripMargin - val sql3 = - """ - |ALTER TABLE table_name CHANGE COLUMN col_old_name col_new_name INT - |COMMENT 'col_comment' AFTER column_name RESTRICT - """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableChangeCol( - tableName = tableIdent, - partitionSpec = None, - oldColName = "col_old_name", - newColName = "col_new_name", - dataType = IntegerType, - comment = None, - afterColName = None, - restrict = false, - cascade = false)(sql1) - val expected2 = AlterTableChangeCol( - tableName = tableIdent, - partitionSpec = None, - oldColName = "col_old_name", - newColName = "col_new_name", - dataType = IntegerType, - comment = Some("col_comment"), - afterColName = None, - restrict = false, - cascade = true)(sql2) - val expected3 = AlterTableChangeCol( - tableName = tableIdent, - partitionSpec = None, - oldColName = "col_old_name", - newColName = "col_new_name", - dataType = IntegerType, - comment = Some("col_comment"), - afterColName = Some("column_name"), - restrict = true, - cascade = false)(sql3) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - } */ - test("alter table: change column name/type/position/comment (not allowed)") { assertUnsupported("ALTER TABLE table_name CHANGE col_old_name col_new_name INT") assertUnsupported( @@ -592,44 +722,6 @@ class DDLCommandSuite extends PlanTest { """.stripMargin) } - /* - test("alter table: add/replace columns") { - val sql1 = - """ - |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') - |ADD COLUMNS (new_col1 INT COMMENT 'test_comment', new_col2 LONG - |COMMENT 'test_comment2') CASCADE - """.stripMargin - val sql2 = - """ - |ALTER TABLE table_name REPLACE COLUMNS (new_col1 INT - |COMMENT 'test_comment', new_col2 LONG COMMENT 'test_comment2') RESTRICT - """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val meta1 = new MetadataBuilder().putString("comment", "test_comment").build() - val meta2 = new MetadataBuilder().putString("comment", "test_comment2").build() - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableAddCol( - tableIdent, - Some(Map("dt" -> "2008-08-08", "country" -> "us")), - StructType(Seq( - StructField("new_col1", IntegerType, nullable = true, meta1), - StructField("new_col2", LongType, nullable = true, meta2))), - restrict = false, - cascade = true)(sql1) - val expected2 = AlterTableReplaceCol( - tableIdent, - None, - StructType(Seq( - StructField("new_col1", IntegerType, nullable = true, meta1), - StructField("new_col2", LongType, nullable = true, meta2))), - restrict = true, - cascade = false)(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - } */ - test("alter table: add/replace columns (not allowed)") { assertUnsupported( """ @@ -670,6 +762,21 @@ class DDLCommandSuite extends PlanTest { assert(parsed.isInstanceOf[Project]) } + test("duplicate keys in table properties") { + val e = intercept[ParseException] { + parser.parsePlan("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('key1' = '1', 'key1' = '2')") + }.getMessage + assert(e.contains("Found duplicate keys 'key1'")) + } + + test("duplicate columns in partition specs") { + val e = intercept[ParseException] { + parser.parsePlan( + "ALTER TABLE dbx.tab1 PARTITION (a='1', a='2') RENAME TO PARTITION (a='100', a='200')") + }.getMessage + assert(e.contains("Found duplicate keys 'a'")) + } + test("drop table") { val tableName1 = "db.tab" val tableName2 = "tab" @@ -678,15 +785,16 @@ class DDLCommandSuite extends PlanTest { val parsed2 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName1") val parsed3 = parser.parsePlan(s"DROP TABLE $tableName2") val parsed4 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName2") + assertUnsupported(s"DROP TABLE IF EXISTS $tableName2 PURGE") val expected1 = - DropTable(TableIdentifier("tab", Option("db")), ifExists = false, isView = false) + DropTableCommand(TableIdentifier("tab", Option("db")), ifExists = false, isView = false) val expected2 = - DropTable(TableIdentifier("tab", Option("db")), ifExists = true, isView = false) + DropTableCommand(TableIdentifier("tab", Option("db")), ifExists = true, isView = false) val expected3 = - DropTable(TableIdentifier("tab", None), ifExists = false, isView = false) + DropTableCommand(TableIdentifier("tab", None), ifExists = false, isView = false) val expected4 = - DropTable(TableIdentifier("tab", None), ifExists = true, isView = false) + DropTableCommand(TableIdentifier("tab", None), ifExists = true, isView = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) @@ -704,13 +812,13 @@ class DDLCommandSuite extends PlanTest { val parsed4 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName2") val expected1 = - DropTable(TableIdentifier("view", Option("db")), ifExists = false, isView = true) + DropTableCommand(TableIdentifier("view", Option("db")), ifExists = false, isView = true) val expected2 = - DropTable(TableIdentifier("view", Option("db")), ifExists = true, isView = true) + DropTableCommand(TableIdentifier("view", Option("db")), ifExists = true, isView = true) val expected3 = - DropTable(TableIdentifier("view", None), ifExists = false, isView = true) + DropTableCommand(TableIdentifier("view", None), ifExists = false, isView = true) val expected4 = - DropTable(TableIdentifier("view", None), ifExists = true, isView = true) + DropTableCommand(TableIdentifier("view", None), ifExists = true, isView = true) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) @@ -722,20 +830,20 @@ class DDLCommandSuite extends PlanTest { val sql1 = "SHOW COLUMNS FROM t1" val sql2 = "SHOW COLUMNS IN db1.t1" val sql3 = "SHOW COLUMNS FROM t1 IN db1" - val sql4 = "SHOW COLUMNS FROM db1.t1 IN db2" + val sql4 = "SHOW COLUMNS FROM db1.t1 IN db1" + val sql5 = "SHOW COLUMNS FROM db1.t1 IN db2" val parsed1 = parser.parsePlan(sql1) val expected1 = ShowColumnsCommand(TableIdentifier("t1", None)) val parsed2 = parser.parsePlan(sql2) val expected2 = ShowColumnsCommand(TableIdentifier("t1", Some("db1"))) val parsed3 = parser.parsePlan(sql3) + val parsed4 = parser.parsePlan(sql3) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected2) - val message = intercept[ParseException] { - parser.parsePlan(sql4) - }.getMessage - assert(message.contains("Duplicates the declaration for database")) + comparePlans(parsed4, expected2) + assertUnsupported(sql5) } test("show partitions") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 12acb9f2761df..252064d1a8349 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -22,14 +22,20 @@ import java.io.File import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach +import org.apache.spark.internal.config._ import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ +import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.Utils class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private val escapedIdentifier = "`(.+)`".r @@ -37,7 +43,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { override def afterEach(): Unit = { try { // drop all databases, tables and functions after each test - sqlContext.sessionState.catalog.reset() + spark.sessionState.catalog.reset() } finally { super.afterEach() } @@ -66,76 +72,146 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase( - CatalogDatabase(name, "", sqlContext.conf.warehousePath, Map()), ignoreIfExists = false) + CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()), + ignoreIfExists = false) } - private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { + private def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable = { val storage = CatalogStorageFormat( locationUri = Some(catalog.defaultTablePath(name)), inputFormat = None, outputFormat = None, serde = None, + compressed = false, serdeProperties = Map()) - catalog.createTable(CatalogTable( + CatalogTable( identifier = name, tableType = CatalogTableType.EXTERNAL, storage = storage, - schema = Seq(), - createTime = 0L), ignoreIfExists = false) + schema = Seq( + CatalogColumn("col1", "int"), + CatalogColumn("col2", "string"), + CatalogColumn("a", "int"), + CatalogColumn("b", "int")), + partitionColumnNames = Seq("a", "b"), + createTime = 0L) + } + + private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { + catalog.createTable(generateTable(catalog, name), ignoreIfExists = false) } private def createTablePartition( catalog: SessionCatalog, spec: TablePartitionSpec, tableName: TableIdentifier): Unit = { - val part = CatalogTablePartition(spec, CatalogStorageFormat(None, None, None, None, Map())) + val part = CatalogTablePartition( + spec, CatalogStorageFormat(None, None, None, None, false, Map())) catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false) } test("the qualified path of a database is stored in the catalog") { - val catalog = sqlContext.sessionState.catalog - - val path = System.getProperty("java.io.tmpdir") - // The generated temp path is not qualified. - assert(!path.startsWith("file:/")) - sql(s"CREATE DATABASE db1 LOCATION '$path'") - val pathInCatalog = new Path(catalog.getDatabaseMetadata("db1").locationUri).toUri - assert("file" === pathInCatalog.getScheme) - assert(path === pathInCatalog.getPath) - - withSQLConf( - SQLConf.WAREHOUSE_PATH.key -> (System.getProperty("java.io.tmpdir"))) { - sql(s"CREATE DATABASE db2") - val pathInCatalog = new Path(catalog.getDatabaseMetadata("db2").locationUri).toUri + val catalog = spark.sessionState.catalog + + withTempDir { tmpDir => + val path = tmpDir.getCanonicalPath + // The generated temp path is not qualified. + assert(!path.startsWith("file:/")) + val uri = tmpDir.toURI + sql(s"CREATE DATABASE db1 LOCATION '$uri'") + val pathInCatalog = new Path(catalog.getDatabaseMetadata("db1").locationUri).toUri assert("file" === pathInCatalog.getScheme) - assert(s"${sqlContext.conf.warehousePath}/db2.db" === pathInCatalog.getPath) + val expectedPath = new Path(path).toUri + assert(expectedPath.getPath === pathInCatalog.getPath) + + withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { + sql(s"CREATE DATABASE db2") + val pathInCatalog2 = new Path(catalog.getDatabaseMetadata("db2").locationUri).toUri + assert("file" === pathInCatalog2.getScheme) + val expectedPath2 = new Path(spark.sessionState.conf.warehousePath + "/" + "db2.db").toUri + assert(expectedPath2.getPath === pathInCatalog2.getPath) + } + + sql("DROP DATABASE db1") + sql("DROP DATABASE db2") } + } - sql("DROP DATABASE db1") - sql("DROP DATABASE db2") + private def makeQualifiedPath(path: String): String = { + // copy-paste from SessionCatalog + val hadoopPath = new Path(path) + val fs = hadoopPath.getFileSystem(sparkContext.hadoopConfiguration) + fs.makeQualified(hadoopPath).toString } test("Create/Drop Database") { - withSQLConf( - SQLConf.WAREHOUSE_PATH.key -> (System.getProperty("java.io.tmpdir") + File.separator)) { - val catalog = sqlContext.sessionState.catalog + withTempDir { tmpDir => + val path = tmpDir.getCanonicalPath + withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { + val catalog = spark.sessionState.catalog + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) + val expectedLocation = makeQualifiedPath(s"$path/$dbNameWithoutBackTicks.db") + assert(db1 == CatalogDatabase( + dbNameWithoutBackTicks, + "", + expectedLocation, + Map.empty)) + sql(s"DROP DATABASE $dbName CASCADE") + assert(!catalog.databaseExists(dbNameWithoutBackTicks)) + } finally { + catalog.reset() + } + } + } + } + } - val databaseNames = Seq("db1", "`database`") + test("Create Database using Default Warehouse Path") { + withSQLConf(SQLConf.WAREHOUSE_PATH.key -> "") { + // Will use the default location if and only if we unset the conf + spark.conf.unset(SQLConf.WAREHOUSE_PATH.key) + val catalog = spark.sessionState.catalog + val dbName = "db1" + try { + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabaseMetadata(dbName) + val expectedLocation = makeQualifiedPath(s"spark-warehouse/$dbName.db") + assert(db1 == CatalogDatabase( + dbName, + "", + expectedLocation, + Map.empty)) + sql(s"DROP DATABASE $dbName CASCADE") + assert(!catalog.databaseExists(dbName)) + } finally { + catalog.reset() + } + } + } + test("Create/Drop Database - location") { + val catalog = spark.sessionState.catalog + val databaseNames = Seq("db1", "`database`") + withTempDir { tmpDir => + val path = new Path(tmpDir.getCanonicalPath).toUri databaseNames.foreach { dbName => try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) - - sql(s"CREATE DATABASE $dbName") + sql(s"CREATE DATABASE $dbName Location '$path'") val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = - "file:" + System.getProperty("java.io.tmpdir") + - File.separator + s"$dbNameWithoutBackTicks.db" + val expPath = makeQualifiedPath(tmpDir.toString) assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", - expectedLocation, + expPath, Map.empty)) sql(s"DROP DATABASE $dbName CASCADE") assert(!catalog.databaseExists(dbNameWithoutBackTicks)) @@ -147,77 +223,89 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("Create Database - database already exists") { - withSQLConf( - SQLConf.WAREHOUSE_PATH.key -> (System.getProperty("java.io.tmpdir") + File.separator)) { - val catalog = sqlContext.sessionState.catalog - val databaseNames = Seq("db1", "`database`") - - databaseNames.foreach { dbName => - try { - val dbNameWithoutBackTicks = cleanIdentifier(dbName) - sql(s"CREATE DATABASE $dbName") - val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) - val expectedLocation = - "file:" + System.getProperty("java.io.tmpdir") + - File.separator + s"$dbNameWithoutBackTicks.db" - assert(db1 == CatalogDatabase( - dbNameWithoutBackTicks, - "", - expectedLocation, - Map.empty)) - - val message = intercept[AnalysisException] { + withTempDir { tmpDir => + val path = tmpDir.getCanonicalPath + withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { + val catalog = spark.sessionState.catalog + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") - }.getMessage - assert(message.contains(s"Database '$dbNameWithoutBackTicks' already exists.")) - } finally { - catalog.reset() + val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) + val expectedLocation = makeQualifiedPath(s"$path/$dbNameWithoutBackTicks.db") + assert(db1 == CatalogDatabase( + dbNameWithoutBackTicks, + "", + expectedLocation, + Map.empty)) + + intercept[DatabaseAlreadyExistsException] { + sql(s"CREATE DATABASE $dbName") + } + } finally { + catalog.reset() + } } } } } + test("desc table for parquet data source table using in-memory catalog") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + val tabName = "tab1" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ") + + checkAnswer( + sql(s"DESC $tabName").select("col_name", "data_type", "comment"), + Row("a", "int", "test") + ) + } + } + test("Alter/Describe Database") { - withSQLConf( - SQLConf.WAREHOUSE_PATH.key -> (System.getProperty("java.io.tmpdir") + File.separator)) { - val catalog = sqlContext.sessionState.catalog - val databaseNames = Seq("db1", "`database`") + withTempDir { tmpDir => + val path = tmpDir.getCanonicalPath + withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { + val catalog = spark.sessionState.catalog + val databaseNames = Seq("db1", "`database`") - databaseNames.foreach { dbName => - try { - val dbNameWithoutBackTicks = cleanIdentifier(dbName) - val location = - "file:" + System.getProperty("java.io.tmpdir") + - File.separator + s"$dbNameWithoutBackTicks.db" - - sql(s"CREATE DATABASE $dbName") - - checkAnswer( - sql(s"DESCRIBE DATABASE EXTENDED $dbName"), - Row("Database Name", dbNameWithoutBackTicks) :: - Row("Description", "") :: - Row("Location", location) :: - Row("Properties", "") :: Nil) - - sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") - - checkAnswer( - sql(s"DESCRIBE DATABASE EXTENDED $dbName"), - Row("Database Name", dbNameWithoutBackTicks) :: - Row("Description", "") :: - Row("Location", location) :: - Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) - - sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") - - checkAnswer( - sql(s"DESCRIBE DATABASE EXTENDED $dbName"), - Row("Database Name", dbNameWithoutBackTicks) :: - Row("Description", "") :: - Row("Location", location) :: - Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) - } finally { - catalog.reset() + databaseNames.foreach { dbName => + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + val location = makeQualifiedPath(s"$path/$dbNameWithoutBackTicks.db") + + sql(s"CREATE DATABASE $dbName") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "") :: Nil) + + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) + + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) + } finally { + catalog.reset() + } } } } @@ -228,79 +316,158 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { databaseNames.foreach { dbName => val dbNameWithoutBackTicks = cleanIdentifier(dbName) - assert(!sqlContext.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) + assert(!spark.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) var message = intercept[AnalysisException] { sql(s"DROP DATABASE $dbName") }.getMessage - assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found")) message = intercept[AnalysisException] { sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") }.getMessage - assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found")) message = intercept[AnalysisException] { sql(s"DESCRIBE DATABASE EXTENDED $dbName") }.getMessage - assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + assert(message.contains(s"Database '$dbNameWithoutBackTicks' not found")) sql(s"DROP DATABASE IF EXISTS $dbName") } } - // TODO: test drop database in restrict mode + test("drop non-empty database in restrict mode") { + val catalog = spark.sessionState.catalog + val dbName = "db1" + sql(s"CREATE DATABASE $dbName") + + // create a table in database + val tableIdent1 = TableIdentifier("tab1", Some(dbName)) + createTable(catalog, tableIdent1) + + // drop a non-empty database in Restrict mode + val message = intercept[AnalysisException] { + sql(s"DROP DATABASE $dbName RESTRICT") + }.getMessage + assert(message.contains(s"Database '$dbName' is not empty. One or more tables exist")) + + catalog.dropTable(tableIdent1, ignoreIfNotExists = false) + + assert(catalog.listDatabases().contains(dbName)) + sql(s"DROP DATABASE $dbName RESTRICT") + assert(!catalog.listDatabases().contains(dbName)) + } + + test("drop non-empty database in cascade mode") { + val catalog = spark.sessionState.catalog + val dbName = "db1" + sql(s"CREATE DATABASE $dbName") + + // create a table in database + val tableIdent1 = TableIdentifier("tab1", Some(dbName)) + createTable(catalog, tableIdent1) + + // drop a non-empty database in CASCADE mode + assert(catalog.listTables(dbName).contains(tableIdent1)) + assert(catalog.listDatabases().contains(dbName)) + sql(s"DROP DATABASE $dbName CASCADE") + assert(!catalog.listDatabases().contains(dbName)) + } test("create table in default db") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent1 = TableIdentifier("tab1", None) createTable(catalog, tableIdent1) val expectedTableIdent = tableIdent1.copy(database = Some("default")) - val expectedLocation = - catalog.getDatabaseMetadata("default").locationUri + "/tab1" - val expectedStorage = - CatalogStorageFormat( - locationUri = Some(expectedLocation), - inputFormat = None, - outputFormat = None, - serde = None, - serdeProperties = Map()) - val expectedTable = - CatalogTable( - identifier = expectedTableIdent, - tableType = CatalogTableType.EXTERNAL, - storage = expectedStorage, - schema = Seq(), - createTime = 0L) + val expectedTable = generateTable(catalog, expectedTableIdent) assert(catalog.getTableMetadata(tableIdent1) === expectedTable) } test("create table in a specific db") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog createDatabase(catalog, "dbx") val tableIdent1 = TableIdentifier("tab1", Some("dbx")) createTable(catalog, tableIdent1) - val expectedLocation = - catalog.getDatabaseMetadata("dbx").locationUri + "/tab1" - val expectedStorage = - CatalogStorageFormat( - locationUri = Some(expectedLocation), - inputFormat = None, - outputFormat = None, - serde = None, - serdeProperties = Map()) - val expectedTable = - CatalogTable( - identifier = tableIdent1, - tableType = CatalogTableType.EXTERNAL, - storage = expectedStorage, - schema = Seq(), - createTime = 0L) + val expectedTable = generateTable(catalog, tableIdent1) assert(catalog.getTableMetadata(tableIdent1) === expectedTable) } + test("Analyze in-memory cataloged tables(SimpleCatalogRelation)") { + withTable("tbl") { + sql("CREATE TABLE tbl(a INT, b INT) USING parquet") + val e = intercept[AnalysisException] { + sql("ANALYZE TABLE tbl COMPUTE STATISTICS") + }.getMessage + assert(e.contains("ANALYZE TABLE is only supported for Hive tables, " + + "but 'tbl' is a SimpleCatalogRelation")) + } + } + + test("create table using") { + val catalog = spark.sessionState.catalog + withTable("tbl") { + sql("CREATE TABLE tbl(a INT, b INT) USING parquet") + val table = catalog.getTableMetadata(TableIdentifier("tbl")) + assert(table.tableType == CatalogTableType.MANAGED) + assert(table.schema == Seq(CatalogColumn("a", "int"), CatalogColumn("b", "int"))) + assert(table.properties(DATASOURCE_PROVIDER) == "parquet") + } + } + + test("create table using - with partitioned by") { + val catalog = spark.sessionState.catalog + withTable("tbl") { + sql("CREATE TABLE tbl(a INT, b INT) USING parquet PARTITIONED BY (a)") + val table = catalog.getTableMetadata(TableIdentifier("tbl")) + assert(table.tableType == CatalogTableType.MANAGED) + assert(table.schema.isEmpty) // partitioned datasource table is not hive-compatible + assert(table.properties(DATASOURCE_PROVIDER) == "parquet") + assert(DDLUtils.getSchemaFromTableProperties(table) == + Some(new StructType().add("a", IntegerType).add("b", IntegerType))) + assert(DDLUtils.getPartitionColumnsFromTableProperties(table) == + Seq("a")) + } + } + + test("create table using - with bucket") { + val catalog = spark.sessionState.catalog + withTable("tbl") { + sql("CREATE TABLE tbl(a INT, b INT) USING parquet " + + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS") + val table = catalog.getTableMetadata(TableIdentifier("tbl")) + assert(table.tableType == CatalogTableType.MANAGED) + assert(table.schema.isEmpty) // partitioned datasource table is not hive-compatible + assert(table.properties(DATASOURCE_PROVIDER) == "parquet") + assert(DDLUtils.getSchemaFromTableProperties(table) == + Some(new StructType().add("a", IntegerType).add("b", IntegerType))) + assert(DDLUtils.getBucketSpecFromTableProperties(table) == + Some(BucketSpec(5, Seq("a"), Seq("b")))) + } + } + + test("create temporary view using") { + val csvFile = + Thread.currentThread().getContextClassLoader.getResource("test-data/cars.csv").toString + withView("testview") { + sql(s"CREATE OR REPLACE TEMPORARY VIEW testview (c1: String, c2: String) USING " + + "org.apache.spark.sql.execution.datasources.csv.CSVFileFormat " + + s"OPTIONS (PATH '$csvFile')") + + checkAnswer( + sql("select c1, c2 from testview order by c1 limit 1"), + Row("1997", "Ford") :: Nil) + + // Fails if creating a new view with the same name + intercept[TempTableAlreadyExistsException] { + sql(s"CREATE TEMPORARY VIEW testview USING " + + s"org.apache.spark.sql.execution.datasources.csv.CSVFileFormat OPTIONS (PATH '$csvFile')") + } + } + } + test("alter table: rename") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent1 = TableIdentifier("tab1", Some("dbx")) val tableIdent2 = TableIdentifier("tab2", Some("dbx")) createDatabase(catalog, "dbx") @@ -323,6 +490,84 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + test("alter table: rename cached table") { + import testImplicits._ + sql("CREATE TABLE students (age INT, name STRING) USING parquet") + val df = (1 to 2).map { i => (i, i.toString) }.toDF("age", "name") + df.write.insertInto("students") + spark.catalog.cacheTable("students") + assume(spark.table("students").collect().toSeq == df.collect().toSeq, "bad test: wrong data") + assume(spark.catalog.isCached("students"), "bad test: table was not cached in the first place") + sql("ALTER TABLE students RENAME TO teachers") + sql("CREATE TABLE students (age INT, name STRING) USING parquet") + // Now we have both students and teachers. + // The cached data for the old students table should not be read by the new students table. + assert(!spark.catalog.isCached("students")) + assert(spark.catalog.isCached("teachers")) + assert(spark.table("students").collect().isEmpty) + assert(spark.table("teachers").collect().toSeq == df.collect().toSeq) + } + + test("rename temporary table - destination table with database name") { + withTempView("tab1") { + sql( + """ + |CREATE TEMPORARY TABLE tab1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 RENAME TO default.tab2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY TABLE from '`tab1`' to '`default`.`tab2`': " + + "cannot specify database name 'default' in the destination table")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"))) + } + } + + test("rename temporary table - destination table already exists") { + withTempView("tab1", "tab2") { + sql( + """ + |CREATE TEMPORARY TABLE tab1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + sql( + """ + |CREATE TEMPORARY TABLE tab2 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 RENAME TO tab2") + } + assert(e.getMessage.contains( + "RENAME TEMPORARY TABLE from '`tab1`' to '`tab2`': destination table already exists")) + + val catalog = spark.sessionState.catalog + assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"), TableIdentifier("tab2"))) + } + } + test("alter table: set location") { testSetLocation(isDatasourceTable = false) } @@ -332,63 +577,19 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: set properties") { - val catalog = sqlContext.sessionState.catalog - val tableIdent = TableIdentifier("tab1", Some("dbx")) - createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - assert(catalog.getTableMetadata(tableIdent).properties.isEmpty) - // set table properties - sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('andrew' = 'or14', 'kor' = 'bel')") - assert(catalog.getTableMetadata(tableIdent).properties == - Map("andrew" -> "or14", "kor" -> "bel")) - // set table properties without explicitly specifying database - catalog.setCurrentDatabase("dbx") - sql("ALTER TABLE tab1 SET TBLPROPERTIES ('kor' = 'belle', 'kar' = 'bol')") - assert(catalog.getTableMetadata(tableIdent).properties == - Map("andrew" -> "or14", "kor" -> "belle", "kar" -> "bol")) - // table to alter does not exist - intercept[AnalysisException] { - sql("ALTER TABLE does_not_exist SET TBLPROPERTIES ('winner' = 'loser')") - } - // throw exception for datasource tables - convertToDatasourceTable(catalog, tableIdent) - val e = intercept[AnalysisException] { - sql("ALTER TABLE tab1 SET TBLPROPERTIES ('sora' = 'bol')") - } - assert(e.getMessage.contains("datasource")) + testSetProperties(isDatasourceTable = false) + } + + test("alter table: set properties (datasource table)") { + testSetProperties(isDatasourceTable = true) } test("alter table: unset properties") { - val catalog = sqlContext.sessionState.catalog - val tableIdent = TableIdentifier("tab1", Some("dbx")) - createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - // unset table properties - sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan')") - sql("ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('j')") - assert(catalog.getTableMetadata(tableIdent).properties == Map("p" -> "an", "c" -> "lan")) - // unset table properties without explicitly specifying database - catalog.setCurrentDatabase("dbx") - sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('p')") - assert(catalog.getTableMetadata(tableIdent).properties == Map("c" -> "lan")) - // table to alter does not exist - intercept[AnalysisException] { - sql("ALTER TABLE does_not_exist UNSET TBLPROPERTIES ('c' = 'lan')") - } - // property to unset does not exist - val e = intercept[AnalysisException] { - sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('c', 'xyz')") - } - assert(e.getMessage.contains("xyz")) - // property to unset does not exist, but "IF EXISTS" is specified - sql("ALTER TABLE tab1 UNSET TBLPROPERTIES IF EXISTS ('c', 'xyz')") - assert(catalog.getTableMetadata(tableIdent).properties.isEmpty) - // throw exception for datasource tables - convertToDatasourceTable(catalog, tableIdent) - val e1 = intercept[AnalysisException] { - sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('sora')") - } - assert(e1.getMessage.contains("datasource")) + testUnsetProperties(isDatasourceTable = false) + } + + test("alter table: unset properties (datasource table)") { + testUnsetProperties(isDatasourceTable = true) } test("alter table: set serde") { @@ -399,8 +600,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testSetSerde(isDatasourceTable = true) } + test("alter table: set serde partition") { + testSetSerdePartition(isDatasourceTable = false) + } + + test("alter table: set serde partition (datasource table)") { + testSetSerdePartition(isDatasourceTable = true) + } + test("alter table: bucketing is not supported") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -411,7 +620,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: skew is not supported") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -431,6 +640,64 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { testAddPartitions(isDatasourceTable = true) } + test("alter table: recover partitions (sequential)") { + withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { + testRecoverPartitions() + } + } + + test("alter table: recover partition (parallel)") { + withSQLConf("spark.rdd.parallelListingThreshold" -> "1") { + testRecoverPartitions() + } + } + + private def testRecoverPartitions() { + val catalog = spark.sessionState.catalog + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist RECOVER PARTITIONS") + } + + val tableIdent = TableIdentifier("tab1") + createTable(catalog, tableIdent) + val part1 = Map("a" -> "1", "b" -> "5") + createTablePartition(catalog, part1, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + + val part2 = Map("a" -> "2", "b" -> "6") + val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + // valid + fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) + fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file + fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "_SUCCESS")) // file + fs.mkdirs(new Path(new Path(root, "A=2"), "B=6")) + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "b.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "c.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), ".hiddenFile")) // file + fs.mkdirs(new Path(new Path(root, "A=2/B=6"), "_temporary")) + + // invalid + fs.mkdirs(new Path(new Path(root, "a"), "b")) // bad name + fs.mkdirs(new Path(new Path(root, "b=1"), "a=1")) // wrong order + fs.mkdirs(new Path(root, "a=4")) // not enough columns + fs.createNewFile(new Path(new Path(root, "a=1"), "b=4")) // file + fs.createNewFile(new Path(new Path(root, "a=1"), "_SUCCESS")) // _SUCCESS + fs.mkdirs(new Path(new Path(root, "a=1"), "_temporary")) // _temporary + fs.mkdirs(new Path(new Path(root, "a=1"), ".b=4")) // start with . + + try { + sql("ALTER TABLE tab1 RECOVER PARTITIONS") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2)) + assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } finally { + fs.delete(root, true) + } + } + test("alter table: add partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") } @@ -448,11 +715,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("alter table: rename partition") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) - val part1 = Map("a" -> "1") - val part2 = Map("b" -> "2") - val part3 = Map("c" -> "3") + val part1 = Map("a" -> "1", "b" -> "q") + val part2 = Map("a" -> "2", "b" -> "c") + val part3 = Map("a" -> "3", "b" -> "p") createDatabase(catalog, "dbx") createTable(catalog, tableIdent) createTablePartition(catalog, part1, tableIdent) @@ -460,27 +727,27 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { createTablePartition(catalog, part3, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - sql("ALTER TABLE dbx.tab1 PARTITION (a='1') RENAME TO PARTITION (a='100')") - sql("ALTER TABLE dbx.tab1 PARTITION (b='2') RENAME TO PARTITION (b='200')") + sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") + sql("ALTER TABLE dbx.tab1 PARTITION (a='2', b='c') RENAME TO PARTITION (a='200', b='c')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> "100"), Map("b" -> "200"), part3)) + Set(Map("a" -> "100", "b" -> "p"), Map("a" -> "200", "b" -> "c"), part3)) // rename without explicitly specifying database catalog.setCurrentDatabase("dbx") - sql("ALTER TABLE tab1 PARTITION (a='100') RENAME TO PARTITION (a='10')") + sql("ALTER TABLE tab1 PARTITION (a='100', b='p') RENAME TO PARTITION (a='10', b='p')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> "10"), Map("b" -> "200"), part3)) + Set(Map("a" -> "10", "b" -> "p"), Map("a" -> "200", "b" -> "c"), part3)) // table to alter does not exist - intercept[AnalysisException] { + intercept[NoSuchTableException] { sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") } // partition to rename does not exist - intercept[AnalysisException] { - sql("ALTER TABLE tab1 PARTITION (x='300') RENAME TO PARTITION (x='333')") + intercept[NoSuchPartitionException] { + sql("ALTER TABLE tab1 PARTITION (a='not_found', b='1') RENAME TO PARTITION (a='1', b='2')") } } test("show tables") { - withTempTable("show1a", "show2b") { + withTempView("show1a", "show2b") { sql( """ |CREATE TEMPORARY TABLE show1a @@ -524,24 +791,24 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("show databases") { - sql("CREATE DATABASE showdb1A") sql("CREATE DATABASE showdb2B") + sql("CREATE DATABASE showdb1A") - assert( - sql("SHOW DATABASES").count() >= 2) + // check the result as well as its order + checkDataset(sql("SHOW DATABASES"), Row("default"), Row("showdb1a"), Row("showdb2b")) checkAnswer( sql("SHOW DATABASES LIKE '*db1A'"), - Row("showdb1A") :: Nil) + Row("showdb1a") :: Nil) checkAnswer( sql("SHOW DATABASES LIKE 'showdb1A'"), - Row("showdb1A") :: Nil) + Row("showdb1a") :: Nil) checkAnswer( sql("SHOW DATABASES LIKE '*db1A|*db2B'"), - Row("showdb1A") :: - Row("showdb2B") :: Nil) + Row("showdb1a") :: + Row("showdb2b") :: Nil) checkAnswer( sql("SHOW DATABASES LIKE 'non-existentdb'"), @@ -549,7 +816,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } test("drop table - temporary table") { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog sql( """ |CREATE TEMPORARY TABLE tab1 @@ -574,7 +841,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testDropTable(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -585,15 +852,13 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("DROP TABLE dbx.tab1") assert(catalog.listTables("dbx") == Nil) sql("DROP TABLE IF EXISTS dbx.tab1") - // no exception will be thrown - sql("DROP TABLE dbx.tab1") + intercept[AnalysisException] { + sql("DROP TABLE dbx.tab1") + } } - test("drop view in SQLContext") { - // SQLContext does not support create view. Log an error message, if tab1 does not exists - sql("DROP VIEW tab1") - - val catalog = sqlContext.sessionState.catalog + test("drop view") { + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -610,13 +875,85 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { catalog: SessionCatalog, tableIdent: TableIdentifier): Unit = { catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( - properties = Map("spark.sql.sources.provider" -> "csv"))) + properties = Map(DATASOURCE_PROVIDER -> "csv"))) + } + + private def testSetProperties(isDatasourceTable: Boolean): Unit = { + val catalog = spark.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + def getProps: Map[String, String] = { + catalog.getTableMetadata(tableIdent).properties.filterKeys { k => + !isDatasourceTable || !k.startsWith(DATASOURCE_PREFIX) + } + } + assert(getProps.isEmpty) + // set table properties + sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('andrew' = 'or14', 'kor' = 'bel')") + assert(getProps == Map("andrew" -> "or14", "kor" -> "bel")) + // set table properties without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET TBLPROPERTIES ('kor' = 'belle', 'kar' = 'bol')") + assert(getProps == Map("andrew" -> "or14", "kor" -> "belle", "kar" -> "bol")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist SET TBLPROPERTIES ('winner' = 'loser')") + } + // datasource table property keys are not allowed + val e = intercept[AnalysisException] { + sql(s"ALTER TABLE tab1 SET TBLPROPERTIES ('${DATASOURCE_PREFIX}foo' = 'loser')") + } + assert(e.getMessage.contains(DATASOURCE_PREFIX + "foo")) + } + + private def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + val catalog = spark.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + def getProps: Map[String, String] = { + catalog.getTableMetadata(tableIdent).properties.filterKeys { k => + !isDatasourceTable || !k.startsWith(DATASOURCE_PREFIX) + } + } + // unset table properties + sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan', 'x' = 'y')") + sql("ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('j')") + assert(getProps == Map("p" -> "an", "c" -> "lan", "x" -> "y")) + // unset table properties without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('p')") + assert(getProps == Map("c" -> "lan", "x" -> "y")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist UNSET TBLPROPERTIES ('c' = 'lan')") + } + // property to unset does not exist + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('c', 'xyz')") + } + assert(e.getMessage.contains("xyz")) + // property to unset does not exist, but "IF EXISTS" is specified + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES IF EXISTS ('c', 'xyz')") + assert(getProps == Map("x" -> "y")) + // datasource table property keys are not allowed + val e2 = intercept[AnalysisException] { + sql(s"ALTER TABLE tab1 UNSET TBLPROPERTIES ('${DATASOURCE_PREFIX}foo')") + } + assert(e2.getMessage.contains(DATASOURCE_PREFIX + "foo")) } private def testSetLocation(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) - val partSpec = Map("a" -> "1") + val partSpec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") createTable(catalog, tableIdent) createTablePartition(catalog, partSpec, tableIdent) @@ -649,7 +986,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { verifyLocation("/path/to/your/lovely/heart") // set table partition location maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE dbx.tab1 PARTITION (a='1') SET LOCATION '/path/to/part/ways'") + sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") } verifyLocation("/path/to/part/ways", Some(partSpec)) // set table location without explicitly specifying database @@ -658,7 +995,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { verifyLocation("/swanky/steak/place") // set table partition location without explicitly specifying database maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 PARTITION (a='1') SET LOCATION 'vienna'") + sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'") } verifyLocation("vienna", Some(partSpec)) // table to alter does not exist @@ -672,7 +1009,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testSetSerde(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) @@ -715,15 +1052,76 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { intercept[AnalysisException] { sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')") } + // serde properties must not be a datasource property + val e = intercept[AnalysisException] { + sql(s"ALTER TABLE tab1 SET SERDEPROPERTIES ('${DATASOURCE_PREFIX}foo'='wah')") + } + assert(e.getMessage.contains(DATASOURCE_PREFIX + "foo")) + } + + private def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + val catalog = spark.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val spec = Map("a" -> "1", "b" -> "2") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, spec, tableIdent) + createTablePartition(catalog, Map("a" -> "1", "b" -> "3"), tableIdent) + createTablePartition(catalog, Map("a" -> "2", "b" -> "2"), tableIdent) + createTablePartition(catalog, Map("a" -> "2", "b" -> "3"), tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) + assert(catalog.getPartition(tableIdent, spec).storage.serdeProperties.isEmpty) + // set table serde and/or properties (should fail on datasource tables) + if (isDatasourceTable) { + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'whatever'") + } + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.madoop' " + + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") + } + assert(e1.getMessage.contains("datasource")) + assert(e2.getMessage.contains("datasource")) + } else { + sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.jadoop'") + assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.jadoop")) + assert(catalog.getPartition(tableIdent, spec).storage.serdeProperties.isEmpty) + sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) SET SERDE 'org.apache.madoop' " + + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") + assert(catalog.getPartition(tableIdent, spec).storage.serde == Some("org.apache.madoop")) + assert(catalog.getPartition(tableIdent, spec).storage.serdeProperties == + Map("k" -> "v", "kay" -> "vee")) + } + // set serde properties only + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 PARTITION (a=1, b=2) " + + "SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") + assert(catalog.getPartition(tableIdent, spec).storage.serdeProperties == + Map("k" -> "vvv", "kay" -> "vee")) + } + // set things without explicitly specifying database + catalog.setCurrentDatabase("dbx") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 PARTITION (a=1, b=2) SET SERDEPROPERTIES ('kay' = 'veee')") + assert(catalog.getPartition(tableIdent, spec).storage.serdeProperties == + Map("k" -> "vvv", "kay" -> "veee")) + } + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')") + } } private def testAddPartitions(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) - val part1 = Map("a" -> "1") - val part2 = Map("b" -> "2") - val part3 = Map("c" -> "3") - val part4 = Map("d" -> "4") + val part1 = Map("a" -> "1", "b" -> "5") + val part2 = Map("a" -> "2", "b" -> "6") + val part3 = Map("a" -> "3", "b" -> "7") + val part4 = Map("a" -> "4", "b" -> "8") createDatabase(catalog, "dbx") createTable(catalog, tableIdent) createTablePartition(catalog, part1, tableIdent) @@ -733,18 +1131,18 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) maybeWrapException(isDatasourceTable) { sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + - "PARTITION (b='2') LOCATION 'paris' PARTITION (c='3')") + "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") } if (!isDatasourceTable) { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) - assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Some("paris")) + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) } // add partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") } if (!isDatasourceTable) { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == @@ -752,14 +1150,14 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } // table to alter does not exist intercept[AnalysisException] { - sql("ALTER TABLE does_not_exist ADD IF NOT EXISTS PARTITION (d='4')") + sql("ALTER TABLE does_not_exist ADD IF NOT EXISTS PARTITION (a='4', b='9')") } // partition to add already exists intercept[AnalysisException] { - sql("ALTER TABLE tab1 ADD PARTITION (d='4')") + sql("ALTER TABLE tab1 ADD PARTITION (a='4', b='8')") } maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (a='4', b='8')") } if (!isDatasourceTable) { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == @@ -768,12 +1166,12 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } private def testDropPartitions(isDatasourceTable: Boolean): Unit = { - val catalog = sqlContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) - val part1 = Map("a" -> "1") - val part2 = Map("b" -> "2") - val part3 = Map("c" -> "3") - val part4 = Map("d" -> "4") + val part1 = Map("a" -> "1", "b" -> "5") + val part2 = Map("a" -> "2", "b" -> "6") + val part3 = Map("a" -> "3", "b" -> "7") + val part4 = Map("a" -> "4", "b" -> "8") createDatabase(catalog, "dbx") createTable(catalog, tableIdent) createTablePartition(catalog, part1, tableIdent) @@ -786,7 +1184,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { convertToDatasourceTable(catalog, tableIdent) } maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (d='4'), PARTITION (c='3')") + sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") } if (!isDatasourceTable) { assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) @@ -794,24 +1192,46 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // drop partitions without explicitly specifying database catalog.setCurrentDatabase("dbx") maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (b='2')") + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='2', b ='6')") } if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) } // table to alter does not exist intercept[AnalysisException] { - sql("ALTER TABLE does_not_exist DROP IF EXISTS PARTITION (b='2')") + sql("ALTER TABLE does_not_exist DROP IF EXISTS PARTITION (a='2')") } // partition to drop does not exist intercept[AnalysisException] { - sql("ALTER TABLE tab1 DROP PARTITION (x='300')") + sql("ALTER TABLE tab1 DROP PARTITION (a='300')") } maybeWrapException(isDatasourceTable) { - sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (x='300')") + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (a='300')") } if (!isDatasourceTable) { - assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + } + } + + test("drop build-in function") { + Seq("true", "false").foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + // partition to add already exists + var e = intercept[AnalysisException] { + sql("DROP TEMPORARY FUNCTION year") + } + assert(e.getMessage.contains("Cannot drop native function 'year'")) + + e = intercept[AnalysisException] { + sql("DROP TEMPORARY FUNCTION YeAr") + } + assert(e.getMessage.contains("Cannot drop native function 'YeAr'")) + + e = intercept[AnalysisException] { + sql("DROP TEMPORARY FUNCTION `YeAr`") + } + assert(e.getMessage.contains("Cannot drop native function 'YeAr'")) + } } } @@ -866,4 +1286,214 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Row("Usage: a ^ b - Bitwise exclusive OR.") :: Nil ) } + + test("select/insert into the managed table") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + val tabName = "tbl" + withTable(tabName) { + sql(s"CREATE TABLE $tabName (i INT, j STRING)") + val catalogTable = + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) + assert(catalogTable.tableType == CatalogTableType.MANAGED) + + var message = intercept[AnalysisException] { + sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1, 'a'") + }.getMessage + assert(message.contains("Hive support is required to insert into the following tables")) + message = intercept[AnalysisException] { + sql(s"SELECT * FROM $tabName") + }.getMessage + assert(message.contains("Hive support is required to select over the following tables")) + } + } + + test("select/insert into external table") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") + withTempDir { tempDir => + val tabName = "tbl" + withTable(tabName) { + sql( + s""" + |CREATE EXTERNAL TABLE $tabName (i INT, j STRING) + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |LOCATION '$tempDir' + """.stripMargin) + val catalogTable = + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) + assert(catalogTable.tableType == CatalogTableType.EXTERNAL) + + var message = intercept[AnalysisException] { + sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1, 'a'") + }.getMessage + assert(message.contains("Hive support is required to insert into the following tables")) + message = intercept[AnalysisException] { + sql(s"SELECT * FROM $tabName") + }.getMessage + assert(message.contains("Hive support is required to select over the following tables")) + } + } + } + + test("create table with datasource properties (not allowed)") { + assertUnsupported("CREATE TABLE my_tab TBLPROPERTIES ('spark.sql.sources.me'='anything')") + assertUnsupported("CREATE TABLE my_tab ROW FORMAT SERDE 'serde' " + + "WITH SERDEPROPERTIES ('spark.sql.sources.me'='anything')") + } + + test("drop current database") { + sql("CREATE DATABASE temp") + sql("USE temp") + val m = intercept[AnalysisException] { + sql("DROP DATABASE temp") + }.getMessage + assert(m.contains("Can not drop current database `temp`")) + } + + test("drop default database") { + Seq("true", "false").foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + var message = intercept[AnalysisException] { + sql("DROP DATABASE default") + }.getMessage + assert(message.contains("Can not drop default database")) + + message = intercept[AnalysisException] { + sql("DROP DATABASE DeFault") + }.getMessage + if (caseSensitive == "true") { + assert(message.contains("Database 'DeFault' not found")) + } else { + assert(message.contains("Can not drop default database")) + } + } + } + } + + test("truncate table - datasource table") { + import testImplicits._ + val data = (1 to 10).map { i => (i, i) }.toDF("width", "length") + + // Test both a Hive compatible and incompatible code path. + Seq("json", "parquet").foreach { format => + withTable("rectangles") { + data.write.format(format).saveAsTable("rectangles") + assume(spark.table("rectangles").collect().nonEmpty, + "bad test; table was empty to begin with") + sql("TRUNCATE TABLE rectangles") + assert(spark.table("rectangles").collect().isEmpty) + } + } + + // truncating partitioned data source tables is not supported + withTable("rectangles", "rectangles2") { + data.write.saveAsTable("rectangles") + data.write.partitionBy("length").saveAsTable("rectangles2") + assertUnsupported("TRUNCATE TABLE rectangles PARTITION (width=1)") + assertUnsupported("TRUNCATE TABLE rectangles2 PARTITION (width=1)") + } + } + + test("create temporary view with mismatched schema") { + withTable("tab1") { + spark.range(10).write.saveAsTable("tab1") + withView("view1") { + val e = intercept[AnalysisException] { + sql("CREATE TEMPORARY VIEW view1 (col1, col3) AS SELECT * FROM tab1") + }.getMessage + assert(e.contains("the SELECT clause (num: `1`) does not match") + && e.contains("CREATE VIEW (num: `2`)")) + } + } + } + + test("create temporary view with specified schema") { + withView("view1") { + sql("CREATE TEMPORARY VIEW view1 (col1, col2) AS SELECT 1, 2") + checkAnswer( + sql("SELECT * FROM view1"), + Row(1, 2) :: Nil + ) + } + } + + test("truncate table - external table, temporary table, view (not allowed)") { + import testImplicits._ + val path = Utils.createTempDir().getAbsolutePath + (1 to 10).map { i => (i, i) }.toDF("a", "b").createTempView("my_temp_tab") + sql(s"CREATE EXTERNAL TABLE my_ext_tab LOCATION '$path'") + sql(s"CREATE VIEW my_view AS SELECT 1") + intercept[NoSuchTableException] { + sql("TRUNCATE TABLE my_temp_tab") + } + assertUnsupported("TRUNCATE TABLE my_ext_tab") + assertUnsupported("TRUNCATE TABLE my_view") + } + + test("truncate table - non-partitioned table (not allowed)") { + sql("CREATE TABLE my_tab (age INT, name STRING)") + assertUnsupported("TRUNCATE TABLE my_tab PARTITION (age=10)") + } + + test("SPARK-16034 Partition columns should match when appending to existing data source tables") { + import testImplicits._ + val df = Seq((1, 2, 3)).toDF("a", "b", "c") + withTable("partitionedTable") { + df.write.mode("overwrite").partitionBy("a", "b").saveAsTable("partitionedTable") + // Misses some partition columns + intercept[AnalysisException] { + df.write.mode("append").partitionBy("a").saveAsTable("partitionedTable") + } + // Wrong order + intercept[AnalysisException] { + df.write.mode("append").partitionBy("b", "a").saveAsTable("partitionedTable") + } + // Partition columns not specified + intercept[AnalysisException] { + df.write.mode("append").saveAsTable("partitionedTable") + } + assert(sql("select * from partitionedTable").collect().size == 1) + // Inserts new data successfully when partition columns are correctly specified in + // partitionBy(...). + // TODO: Right now, partition columns are always treated in a case-insensitive way. + // See the write method in DataSource.scala. + Seq((4, 5, 6)).toDF("a", "B", "c") + .write + .mode("append") + .partitionBy("a", "B") + .saveAsTable("partitionedTable") + + Seq((7, 8, 9)).toDF("a", "b", "c") + .write + .mode("append") + .partitionBy("a", "b") + .saveAsTable("partitionedTable") + + checkAnswer( + sql("select a, b, c from partitionedTable"), + Row(1, 2, 3) :: Row(4, 5, 6) :: Row(7, 8, 9) :: Nil + ) + } + } + + test("show functions") { + withUserDefinedFunction("add_one" -> true) { + val numFunctions = FunctionRegistry.functionSet.size.toLong + assert(sql("show functions").count() === numFunctions) + assert(sql("show system functions").count() === numFunctions) + assert(sql("show all functions").count() === numFunctions) + assert(sql("show user functions").count() === 0L) + spark.udf.register("add_one", (x: Long) => x + 1) + assert(sql("show functions").count() === numFunctions + 1L) + assert(sql("show system functions").count() === numFunctions) + assert(sql("show all functions").count() === numFunctions + 1L) + assert(sql("show user functions").count() === 1L) + } + } + + test("SPARK-18009 calling toLocalIterator on commands") { + import scala.collection.JavaConverters._ + val df = sql("show databases") + val rows: Seq[Row] = df.toLocalIterator().asScala.toSeq + assert(rows.length > 0) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala new file mode 100644 index 0000000000000..4f12df9c49856 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileCatalogSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.net.URI + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} + +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test.SharedSQLContext + +class FileCatalogSuite extends SharedSQLContext { + + test("ListingFileCatalog: leaf files are qualified paths") { + withTempDir { dir => + val file = new File(dir, "text.txt") + stringToFile(file, "text") + + val path = new Path(file.getCanonicalPath) + val catalog = new ListingFileCatalog(spark, Seq(path), Map.empty, None) { + def leafFilePaths: Seq[Path] = leafFiles.keys.toSeq + def leafDirPaths: Seq[Path] = leafDirToChildrenFiles.keys.toSeq + } + assert(catalog.leafFilePaths.forall(p => p.toString.startsWith("file:/"))) + assert(catalog.leafDirPaths.forall(p => p.toString.startsWith("file:/"))) + } + } + + test("ListingFileCatalog: input paths are converted to qualified paths") { + withTempDir { dir => + val file = new File(dir, "text.txt") + stringToFile(file, "text") + + val unqualifiedDirPath = new Path(dir.getCanonicalPath) + val unqualifiedFilePath = new Path(file.getCanonicalPath) + require(!unqualifiedDirPath.toString.contains("file:")) + require(!unqualifiedFilePath.toString.contains("file:")) + + val fs = unqualifiedDirPath.getFileSystem(sparkContext.hadoopConfiguration) + val qualifiedFilePath = fs.makeQualified(new Path(file.getCanonicalPath)) + require(qualifiedFilePath.toString.startsWith("file:")) + + val catalog1 = new ListingFileCatalog( + spark, Seq(unqualifiedDirPath), Map.empty, None) + assert(catalog1.allFiles.map(_.getPath) === Seq(qualifiedFilePath)) + + val catalog2 = new ListingFileCatalog( + spark, Seq(unqualifiedFilePath), Map.empty, None) + assert(catalog2.allFiles.map(_.getPath) === Seq(qualifiedFilePath)) + + } + } + + test("ListingFileCatalog: folders that don't exist don't throw exceptions") { + withTempDir { dir => + val deletedFolder = new File(dir, "deleted") + assert(!deletedFolder.exists()) + val catalog1 = new ListingFileCatalog( + spark, Seq(new Path(deletedFolder.getCanonicalPath)), Map.empty, None, + ignoreFileNotFound = true) + // doesn't throw an exception + assert(catalog1.listLeafFiles(catalog1.paths).isEmpty) + } + } + + test("SPARK-17613 - PartitioningAwareFileCatalog: base path w/o '/' at end") { + class MockCatalog( + override val paths: Seq[Path]) extends PartitioningAwareFileCatalog(spark, Map.empty, None) { + + override def refresh(): Unit = {} + + override def leafFiles: mutable.LinkedHashMap[Path, FileStatus] = mutable.LinkedHashMap( + new Path("mockFs://some-bucket/file1.json") -> new FileStatus() + ) + + override def leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = Map( + new Path("mockFs://some-bucket/") -> Array(new FileStatus()) + ) + + override def partitionSpec(): PartitionSpec = { + PartitionSpec.emptySpec + } + } + + withSQLConf( + "fs.mockFs.impl" -> classOf[FakeParentPathFileSystem].getName, + "fs.mockFs.impl.disable.cache" -> "true") { + val pathWithSlash = new Path("mockFs://some-bucket/") + assert(pathWithSlash.getParent === null) + val pathWithoutSlash = new Path("mockFs://some-bucket") + assert(pathWithoutSlash.getParent === null) + val catalog1 = new MockCatalog(Seq(pathWithSlash)) + val catalog2 = new MockCatalog(Seq(pathWithoutSlash)) + assert(catalog1.allFiles().nonEmpty) + assert(catalog2.allFiles().nonEmpty) + } + } +} + +class FakeParentPathFileSystem extends RawLocalFileSystem { + override def getScheme: String = "mockFs" + + override def getUri: URI = { + URI.create("mockFs://some-bucket") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index ac2af77a6eb21..2f1edb0974927 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{BlockLocation, FileStatus, RawLocalFileSystem} +import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem} import org.apache.hadoop.mapreduce.Job import org.apache.spark.SparkConf @@ -29,7 +29,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.execution.DataSourceScanExec +import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ @@ -281,7 +281,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi )) val fakeRDD = new FileScanRDD( - sqlContext.sparkSession, + spark, (file: PartitionedFile) => Iterator.empty, Seq(partition) ) @@ -340,6 +340,106 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("SPARK-15654 do not split non-splittable files") { + // Check if a non-splittable file is not assigned into partitions + Seq("gz", "snappy", "lz4").map { suffix => + val table = createTable( + files = Seq(s"file1.${suffix}" -> 3, s"file2.${suffix}" -> 1, s"file3.${suffix}" -> 1) + ) + withSQLConf( + SQLConf.FILES_MAX_PARTITION_BYTES.key -> "2", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0") { + checkScan(table.select('c1)) { partitions => + assert(partitions.size == 2) + assert(partitions(0).files.size == 1) + assert(partitions(1).files.size == 2) + } + } + } + + // Check if a splittable compressed file is assigned into multiple partitions + Seq("bz2").map { suffix => + val table = createTable( + files = Seq(s"file1.${suffix}" -> 3, s"file2.${suffix}" -> 1, s"file3.${suffix}" -> 1) + ) + withSQLConf( + SQLConf.FILES_MAX_PARTITION_BYTES.key -> "2", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0") { + checkScan(table.select('c1)) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 1) + assert(partitions(1).files.size == 2) + assert(partitions(2).files.size == 1) + } + } + } + } + + test("SPARK-14959: Do not call getFileBlockLocations on directories") { + // Setting PARALLEL_PARTITION_DISCOVERY_THRESHOLD to 2. So we will first + // list file statues at driver side and then for the level of p2, we will list + // file statues in parallel. + withSQLConf( + "fs.file.impl" -> classOf[MockDistributedFileSystem].getName, + SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "2") { + withTempPath { path => + val tempDir = path.getCanonicalPath + + Seq("p1=1/p2=2/p3=3/file1", "p1=1/p2=3/p3=3/file1").foreach { fileName => + val file = new File(tempDir, fileName) + assert(file.getParentFile.exists() || file.getParentFile.mkdirs()) + util.stringToFile(file, fileName) + } + + val fileCatalog = new ListingFileCatalog( + sparkSession = spark, + paths = Seq(new Path(tempDir)), + parameters = Map.empty[String, String], + partitionSchema = None) + // This should not fail. + fileCatalog.listLeafFiles(Seq(new Path(tempDir))) + + // Also have an integration test. + checkAnswer( + spark.read.text(tempDir).select("p1", "p2", "p3", "value"), + Row(1, 2, 3, "p1=1/p2=2/p3=3/file1") :: Row(1, 3, 3, "p1=1/p2=3/p3=3/file1") :: Nil) + } + } + } + + test("[SPARK-16818] partition pruned file scans implement sameResult correctly") { + withTempPath { path => + val tempDir = path.getCanonicalPath + spark.range(100) + .selectExpr("id", "id as b") + .write + .partitionBy("id") + .parquet(tempDir) + val df = spark.read.parquet(tempDir) + def getPlan(df: DataFrame): SparkPlan = { + df.queryExecution.executedPlan + } + assert(getPlan(df.where("id = 2")).sameResult(getPlan(df.where("id = 2")))) + assert(!getPlan(df.where("id = 2")).sameResult(getPlan(df.where("id = 3")))) + } + } + + test("[SPARK-16818] exchange reuse respects differences in partition pruning") { + spark.conf.set("spark.sql.exchange.reuse", true) + withTempPath { path => + val tempDir = path.getCanonicalPath + spark.range(10) + .selectExpr("id % 2 as a", "id % 3 as b", "id as c") + .write + .partitionBy("a") + .parquet(tempDir) + val df = spark.read.parquet(tempDir) + val df1 = df.where("a = 0").groupBy("b").agg("c" -> "sum") + val df2 = df.where("a = 1").groupBy("b").agg("c" -> "sum") + checkAnswer(df1.join(df2, "b"), Row(0, 6, 12) :: Row(1, 4, 8) :: Row(2, 10, 5) :: Nil) + } + } + // Helpers for checking the arguments passed to the FileFormat. protected val checkPartitionSchema = @@ -399,7 +499,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi util.stringToFile(file, "*" * size) } - val df = sqlContext.read + val df = spark.read .format(classOf[TestFileFormat].getName) .load(tempDir.getCanonicalPath) @@ -407,9 +507,10 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val bucketed = df.queryExecution.analyzed transform { case l @ LogicalRelation(r: HadoopFsRelation, _, _) => l.copy(relation = - r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) + r.copy(bucketSpec = + Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))(r.sparkSession)) } - Dataset.ofRows(sqlContext.sparkSession, bucketed) + Dataset.ofRows(spark, bucketed) } else { df } @@ -434,7 +535,7 @@ object LastArguments { } /** A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing. */ -class TestFileFormat extends FileFormat { +class TestFileFormat extends TextBasedFileFormat { override def toString: String = "TestFileFormat" @@ -490,7 +591,19 @@ class LocalityTestFileSystem extends RawLocalFileSystem { override def getFileBlockLocations( file: FileStatus, start: Long, len: Long): Array[BlockLocation] = { + require(!file.isDirectory, "The file path can not be a directory.") val count = invocations.getAndAdd(1) Array(new BlockLocation(Array(s"host$count:50010"), Array(s"host$count"), 0, len)) } } + +// This file system is for SPARK-14959 (DistributedFileSystem will throw an exception +// if we call getFileBlockLocations on a dir). +class MockDistributedFileSystem extends RawLocalFileSystem { + + override def getFileBlockLocations( + file: FileStatus, start: Long, len: Long): Array[BlockLocation] = { + require(!file.isDirectory, "The file path can not be a directory.") + super.getFileBlockLocations(file, start, len) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index 297731c70c151..3c68dc8bb98d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -27,7 +27,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { test("sizeInBytes should be the total size of all files") { withTempDir{ dir => dir.delete() - sqlContext.range(1000).write.parquet(dir.toString) + spark.range(1000).write.parquet(dir.toString) // ignore hidden files val allFiles = dir.listFiles(new FilenameFilter { override def accept(dir: File, name: String): Boolean = { @@ -35,8 +35,19 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { } }) val totalSize = allFiles.map(_.length()).sum - val df = sqlContext.read.parquet(dir.toString) + val df = spark.read.parquet(dir.toString) assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize)) } } + + test("file filtering") { + assert(!HadoopFsRelation.shouldFilterOut("abcd")) + assert(HadoopFsRelation.shouldFilterOut(".ab")) + assert(HadoopFsRelation.shouldFilterOut("_cd")) + + assert(!HadoopFsRelation.shouldFilterOut("_metadata")) + assert(!HadoopFsRelation.shouldFilterOut("_common_metadata")) + assert(HadoopFsRelation.shouldFilterOut("_ab_metadata")) + assert(HadoopFsRelation.shouldFilterOut("_cd_common_metadata")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index daf85be56f3d2..5e00f669b8593 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.datasources.csv -import java.text.SimpleDateFormat - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -35,6 +33,11 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) == TimestampType) assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType) assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == BooleanType) + + val textValueOne = Long.MaxValue.toString + "0" + val decimalValueOne = new java.math.BigDecimal(textValueOne) + val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) + assert(CSVInferSchema.inferField(NullType, textValueOne, options) == expectedTypeOne) } test("String fields types are inferred correctly from other types") { @@ -49,12 +52,17 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType) assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == BooleanType) assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == BooleanType) + + val textValueOne = Long.MaxValue.toString + "0" + val decimalValueOne = new java.math.BigDecimal(textValueOne) + val expectedTypeOne = DecimalType(decimalValueOne.precision, decimalValueOne.scale) + assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == expectedTypeOne) } test("Timestamp field types are inferred correctly via custom data format") { - var options = new CSVOptions(Map("dateFormat" -> "yyyy-mm")) + var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm")) assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == TimestampType) - options = new CSVOptions(Map("dateFormat" -> "yyyy")) + options = new CSVOptions(Map("timestampFormat" -> "yyyy")) assert(CSVInferSchema.inferField(TimestampType, "2015", options) == TimestampType) } @@ -94,6 +102,7 @@ class CSVInferSchemaSuite extends SparkFunSuite { assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType) assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == TimestampType) assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == BooleanType) + assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == DecimalType(1, 1)) } test("Merging Nulltypes should yield Nulltype.") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 07f00a086865c..29aac9def6924 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -22,33 +22,35 @@ import java.nio.charset.UnsupportedCharsetException import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import scala.collection.JavaConverters._ - -import org.apache.hadoop.conf.Configuration +import org.apache.commons.lang3.time.FastDateFormat import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.{DataFrame, QueryTest, Row, UDT} import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { - private val carsFile = "cars.csv" - private val carsMalformedFile = "cars-malformed.csv" - private val carsFile8859 = "cars_iso-8859-1.csv" - private val carsTsvFile = "cars.tsv" - private val carsAltFile = "cars-alternative.csv" - private val carsUnbalancedQuotesFile = "cars-unbalanced-quotes.csv" - private val carsNullFile = "cars-null.csv" - private val emptyFile = "empty.csv" - private val commentsFile = "comments.csv" - private val disableCommentsFile = "disable_comments.csv" - private val boolFile = "bool.csv" - private val simpleSparseFile = "simple_sparse.csv" - private val numbersFile = "numbers.csv" - private val datesFile = "dates.csv" - private val unescapedQuotesFile = "unescaped-quotes.csv" + import testImplicits._ + + private val carsFile = "test-data/cars.csv" + private val carsMalformedFile = "test-data/cars-malformed.csv" + private val carsFile8859 = "test-data/cars_iso-8859-1.csv" + private val carsTsvFile = "test-data/cars.tsv" + private val carsAltFile = "test-data/cars-alternative.csv" + private val carsUnbalancedQuotesFile = "test-data/cars-unbalanced-quotes.csv" + private val carsNullFile = "test-data/cars-null.csv" + private val carsBlankColName = "test-data/cars-blank-column-name.csv" + private val emptyFile = "test-data/empty.csv" + private val commentsFile = "test-data/comments.csv" + private val disableCommentsFile = "test-data/disable_comments.csv" + private val boolFile = "test-data/bool.csv" + private val decimalFile = "test-data/decimal.csv" + private val simpleSparseFile = "test-data/simple_sparse.csv" + private val numbersFile = "test-data/numbers.csv" + private val datesFile = "test-data/dates.csv" + private val unescapedQuotesFile = "test-data/unescaped-quotes.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -74,14 +76,14 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { if (withHeader) { assert(df.schema.fieldNames === Array("year", "make", "model", "comment", "blank")) } else { - assert(df.schema.fieldNames === Array("C0", "C1", "C2", "C3", "C4")) + assert(df.schema.fieldNames === Array("_c0", "_c1", "_c2", "_c3", "_c4")) } } if (checkValues) { val yearValues = List("2012", "1997", "2015") val actualYears = if (!withHeader) "year" :: yearValues else yearValues - val years = if (withHeader) df.select("year").collect() else df.select("C0").collect() + val years = if (withHeader) df.select("year").collect() else df.select("_c0").collect() years.zipWithIndex.foreach { case (year, index) => if (checkTypes) { @@ -94,7 +96,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "false") @@ -104,7 +106,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test with calling another function to load") { - val cars = sqlContext + val cars = spark .read .option("header", "false") .csv(testFile(carsFile)) @@ -113,7 +115,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("simple csv test with type inference") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "true") @@ -124,7 +126,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test inferring booleans") { - val result = sqlContext.read + val result = spark.read .format("csv") .option("header", "true") .option("inferSchema", "true") @@ -135,8 +137,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(result.schema === expectedSchema) } + test("test inferring decimals") { + val result = spark.read + .format("csv") + .option("comment", "~") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(decimalFile)) + val expectedSchema = StructType(List( + StructField("decimal", DecimalType(20, 0), nullable = true), + StructField("long", LongType, nullable = true), + StructField("double", DoubleType, nullable = true))) + assert(result.schema === expectedSchema) + } + test("test with alternative delimiter and quote") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .options(Map("quote" -> "\'", "delimiter" -> "|", "header" -> "true")) .load(testFile(carsAltFile)) @@ -145,7 +161,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("parse unescaped quotes with maxCharsPerColumn") { - val rows = sqlContext.read + val rows = spark.read .format("csv") .option("maxCharsPerColumn", "4") .load(testFile(unescapedQuotesFile)) @@ -157,7 +173,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("bad encoding name") { val exception = intercept[UnsupportedCharsetException] { - sqlContext + spark .read .format("csv") .option("charset", "1-9588-osi") @@ -169,7 +185,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test different encoding") { // scalastyle:off - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable USING csv |OPTIONS (path "${testFile(carsFile8859)}", header "true", @@ -177,12 +193,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin.replaceAll("\n", " ")) // scalastyle:on - verifyCars(sqlContext.table("carsTable"), withHeader = true) + verifyCars(spark.table("carsTable"), withHeader = true) } test("test aliases sep and encoding for delimiter and charset") { // scalastyle:off - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "true") @@ -195,17 +211,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("DDL test with tab separated file") { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable USING csv |OPTIONS (path "${testFile(carsTsvFile)}", header "true", delimiter "\t") """.stripMargin.replaceAll("\n", " ")) - verifyCars(sqlContext.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) + verifyCars(spark.table("carsTable"), numFields = 6, withHeader = true, checkHeader = false) } test("DDL test parsing decimal type") { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, priceTag decimal, @@ -215,11 +231,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { """.stripMargin.replaceAll("\n", " ")) assert( - sqlContext.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) + spark.sql("SELECT makeName FROM carsTable where priceTag > 60000").collect().size === 1) } test("test for DROPMALFORMED parsing mode") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .options(Map("header" -> "true", "mode" -> "dropmalformed")) .load(testFile(carsFile)) @@ -227,9 +243,20 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(cars.select("year").collect().size === 2) } + test("test for blank column names on read and select columns") { + val cars = spark.read + .format("csv") + .options(Map("header" -> "true", "inferSchema" -> "true")) + .load(testFile(carsBlankColName)) + + assert(cars.select("customer").collect().size == 2) + assert(cars.select("_c0").collect().size == 2) + assert(cars.select("_c1").collect().size == 2) + } + test("test for FAILFAST parsing mode") { val exception = intercept[SparkException]{ - sqlContext.read + spark.read .format("csv") .options(Map("header" -> "true", "mode" -> "failfast")) .load(testFile(carsFile)).collect() @@ -239,7 +266,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test for tokens more than the fields in the schema") { - val cars = sqlContext + val cars = spark .read .format("csv") .option("header", "false") @@ -250,7 +277,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with null quote character") { - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .option("quote", "") @@ -261,7 +288,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("test with empty file and known schema") { - val result = sqlContext.read + val result = spark.read .format("csv") .schema(StructType(List(StructField("column", StringType, false)))) .load(testFile(emptyFile)) @@ -271,25 +298,25 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("DDL test with empty file") { - sqlContext.sql(s""" + spark.sql(s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, comments string, grp string) |USING csv |OPTIONS (path "${testFile(emptyFile)}", header "false") """.stripMargin.replaceAll("\n", " ")) - assert(sqlContext.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) + assert(spark.sql("SELECT count(*) FROM carsTable").collect().head(0) === 0) } test("DDL test with schema") { - sqlContext.sql(s""" + spark.sql(s""" |CREATE TEMPORARY TABLE carsTable |(yearMade double, makeName string, modelName string, comments string, blank string) |USING csv |OPTIONS (path "${testFile(carsFile)}", header "true") """.stripMargin.replaceAll("\n", " ")) - val cars = sqlContext.table("carsTable") + val cars = spark.table("carsTable") verifyCars(cars, withHeader = true, checkHeader = false, checkValues = false) assert( cars.schema.fieldNames === Array("yearMade", "makeName", "modelName", "comments", "blank")) @@ -298,7 +325,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -307,7 +334,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("header", "true") .csv(csvDir) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .load(csvDir) @@ -319,7 +346,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("save csv with quote") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -330,7 +357,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .option("quote", "\"") .save(csvDir) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .option("quote", "\"") @@ -340,8 +367,85 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } + test("save csv with quoteAll enabled") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + + val data = Seq(("test \"quote\"", 123, "it \"works\"!", "\"very\" well")) + val df = spark.createDataFrame(data) + + // escapeQuotes should be true by default + df.coalesce(1).write + .format("csv") + .option("quote", "\"") + .option("escape", "\"") + .option("quoteAll", "true") + .save(csvDir) + + val results = spark.read + .format("text") + .load(csvDir) + .collect() + + val expected = "\"test \"\"quote\"\"\",\"123\",\"it \"\"works\"\"!\",\"\"\"very\"\" well\"" + + assert(results.toSeq.map(_.toSeq) === Seq(Seq(expected))) + } + } + + test("save csv with quote escaping enabled") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + + val data = Seq(("test \"quote\"", 123, "it \"works\"!", "\"very\" well")) + val df = spark.createDataFrame(data) + + // escapeQuotes should be true by default + df.coalesce(1).write + .format("csv") + .option("quote", "\"") + .option("escape", "\"") + .save(csvDir) + + val results = spark.read + .format("text") + .load(csvDir) + .collect() + + val expected = "\"test \"\"quote\"\"\",123,\"it \"\"works\"\"!\",\"\"\"very\"\" well\"" + + assert(results.toSeq.map(_.toSeq) === Seq(Seq(expected))) + } + } + + test("save csv with quote escaping disabled") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + + val data = Seq(("test \"quote\"", 123, "it \"works\"!", "\"very\" well")) + val df = spark.createDataFrame(data) + + // escapeQuotes should be true by default + df.coalesce(1).write + .format("csv") + .option("quote", "\"") + .option("escapeQuotes", "false") + .option("escape", "\"") + .save(csvDir) + + val results = spark.read + .format("text") + .load(csvDir) + .collect() + + val expected = "test \"quote\",123,it \"works\"!,\"\"\"very\"\" well\"" + + assert(results.toSeq.map(_.toSeq) === Seq(Seq(expected))) + } + } + test("commented lines in CSV data") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "~", "header" -> "false")) .load(testFile(commentsFile)) @@ -356,7 +460,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("inferring schema with commented lines in CSV data") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "~", "header" -> "false", "inferSchema" -> "true")) .load(testFile(commentsFile)) @@ -374,15 +478,15 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val options = Map( "header" -> "true", "inferSchema" -> "true", - "dateFormat" -> "dd/MM/yyyy hh:mm") - val results = sqlContext.read + "timestampFormat" -> "dd/MM/yyyy HH:mm") + val results = spark.read .format("csv") .options(options) .load(testFile(datesFile)) .select("date") .collect() - val dateFormat = new SimpleDateFormat("dd/MM/yyyy hh:mm") + val dateFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm") val expected = Seq(Seq(new Timestamp(dateFormat.parse("26/08/2015 18:00").getTime)), Seq(new Timestamp(dateFormat.parse("27/10/2014 18:30").getTime)), @@ -396,7 +500,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { "header" -> "true", "inferSchema" -> "false", "dateFormat" -> "dd/MM/yyyy hh:mm") - val results = sqlContext.read + val results = spark.read .format("csv") .options(options) .schema(customSchema) @@ -419,7 +523,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("setting comment to null disables comment support") { - val results = sqlContext.read + val results = spark.read .format("csv") .options(Map("comment" -> "", "header" -> "false")) .load(testFile(disableCommentsFile)) @@ -442,7 +546,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { StructField("model", StringType, nullable = false), StructField("comment", StringType, nullable = true), StructField("blank", StringType, nullable = true))) - val cars = sqlContext.read + val cars = spark.read .format("csv") .schema(dataSchema) .options(Map("header" -> "true", "nullValue" -> "null")) @@ -450,14 +554,14 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true, checkValues = false) val results = cars.collect() - assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null")) + assert(results(0).toSeq === Array(2012, "Tesla", "S", null, null)) assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) } test("save csv with compression codec option") { withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .load(testFile(carsFile)) @@ -471,7 +575,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val compressedFiles = new File(csvDir).listFiles() assert(compressedFiles.exists(_.getName.endsWith(".csv.gz"))) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .load(csvDir) @@ -489,7 +593,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { ) withTempDir { dir => val csvDir = new File(dir, "csv").getCanonicalPath - val cars = sqlContext.read + val cars = spark.read .format("csv") .option("header", "true") .options(extraOptions) @@ -505,7 +609,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val compressedFiles = new File(csvDir).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".csv.gz"))) - val carsCopy = sqlContext.read + val carsCopy = spark.read .format("csv") .option("header", "true") .options(extraOptions) @@ -516,7 +620,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Schema inference correctly identifies the datatype when data is sparse.") { - val df = sqlContext.read + val df = spark.read .format("csv") .option("header", "true") .option("inferSchema", "true") @@ -528,7 +632,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("old csv data source name works") { - val cars = sqlContext + val cars = spark .read .format("com.databricks.spark.csv") .option("header", "false") @@ -538,7 +642,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("nulls, NaNs and Infinity values can be parsed") { - val numbers = sqlContext + val numbers = spark .read .format("csv") .schema(StructType(List( @@ -558,4 +662,198 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(numbers.count() == 8) } + + test("error handling for unsupported data types.") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + var msg = intercept[UnsupportedOperationException] { + Seq((1, "Tesla")).toDF("a", "b").selectExpr("struct(a, b)").write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support struct data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, Map("Tesla" -> 3))).toDF("id", "cars").write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support map data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, Array("Tesla", "Chevy", "Ford"))).toDF("id", "brands").write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[UnsupportedOperationException] { + Seq((1, new UDT.MyDenseVector(Array(0.25, 2.25, 4.25)))).toDF("id", "vectors") + .write.csv(csvDir) + }.getMessage + assert(msg.contains("CSV data source does not support array data type")) + + msg = intercept[SparkException] { + val schema = StructType(StructField("a", new UDT.MyDenseVectorUDT(), true) :: Nil) + spark.range(1).write.csv(csvDir) + spark.read.schema(schema).csv(csvDir).collect() + }.getCause.getMessage + assert(msg.contains("Unsupported type: array")) + } + } + + test("SPARK-15585 turn off quotations") { + val cars = spark.read + .format("csv") + .option("header", "true") + .option("quote", "") + .load(testFile(carsUnbalancedQuotesFile)) + + verifyCars(cars, withHeader = true, checkValues = false) + } + + test("Write timestamps correctly in ISO8601 format by default") { + withTempDir { dir => + val iso8601timestampsPath = s"${dir.getCanonicalPath}/iso8601timestamps.csv" + val timestamps = spark.read + .format("csv") + .option("inferSchema", "true") + .option("header", "true") + .option("timestampFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + timestamps.write + .format("csv") + .option("header", "true") + .save(iso8601timestampsPath) + + // This will load back the timestamps as string. + val iso8601Timestamps = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "false") + .load(iso8601timestampsPath) + + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ") + val expectedTimestamps = timestamps.collect().map { r => + // This should be ISO8601 formatted string. + Row(iso8501.format(r.toSeq.head)) + } + + checkAnswer(iso8601Timestamps, expectedTimestamps) + } + } + + test("Write dates correctly in ISO8601 format by default") { + withTempDir { dir => + val customSchema = new StructType(Array(StructField("date", DateType, true))) + val iso8601datesPath = s"${dir.getCanonicalPath}/iso8601dates.csv" + val dates = spark.read + .format("csv") + .schema(customSchema) + .option("header", "true") + .option("inferSchema", "false") + .option("dateFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + dates.write + .format("csv") + .option("header", "true") + .save(iso8601datesPath) + + // This will load back the dates as string. + val iso8601dates = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "false") + .load(iso8601datesPath) + + val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd") + val expectedDates = dates.collect().map { r => + // This should be ISO8601 formatted string. + Row(iso8501.format(r.toSeq.head)) + } + + checkAnswer(iso8601dates, expectedDates) + } + } + + test("Roundtrip in reading and writing timestamps") { + withTempDir { dir => + val iso8601timestampsPath = s"${dir.getCanonicalPath}/iso8601timestamps.csv" + val timestamps = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(testFile(datesFile)) + + timestamps.write + .format("csv") + .option("header", "true") + .save(iso8601timestampsPath) + + val iso8601timestamps = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .load(iso8601timestampsPath) + + checkAnswer(iso8601timestamps, timestamps) + } + } + + test("Write dates correctly with dateFormat option") { + val customSchema = new StructType(Array(StructField("date", DateType, true))) + withTempDir { dir => + // With dateFormat option. + val datesWithFormatPath = s"${dir.getCanonicalPath}/datesWithFormat.csv" + val datesWithFormat = spark.read + .format("csv") + .schema(customSchema) + .option("header", "true") + .option("dateFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + datesWithFormat.write + .format("csv") + .option("header", "true") + .option("dateFormat", "yyyy/MM/dd") + .save(datesWithFormatPath) + + // This will load back the dates as string. + val stringDatesWithFormat = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "false") + .load(datesWithFormatPath) + val expectedStringDatesWithFormat = Seq( + Row("2015/08/26"), + Row("2014/10/27"), + Row("2016/01/28")) + + checkAnswer(stringDatesWithFormat, expectedStringDatesWithFormat) + } + } + + test("Write timestamps correctly with dateFormat option") { + withTempDir { dir => + // With dateFormat option. + val timestampsWithFormatPath = s"${dir.getCanonicalPath}/timestampsWithFormat.csv" + val timestampsWithFormat = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "true") + .option("timestampFormat", "dd/MM/yyyy HH:mm") + .load(testFile(datesFile)) + timestampsWithFormat.write + .format("csv") + .option("header", "true") + .option("timestampFormat", "yyyy/MM/dd HH:mm") + .save(timestampsWithFormatPath) + + // This will load back the timestamps as string. + val stringTimestampsWithFormat = spark.read + .format("csv") + .option("header", "true") + .option("inferSchema", "false") + .load(timestampsWithFormatPath) + val expectedStringTimestampsWithFormat = Seq( + Row("2015/08/26 18:00"), + Row("2014/10/27 18:30"), + Row("2016/01/28 20:00")) + + checkAnswer(stringTimestampsWithFormat, expectedStringTimestampsWithFormat) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala index 26b33b24efc3d..dae92f626c225 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala @@ -68,16 +68,46 @@ class CSVTypeCastSuite extends SparkFunSuite { } test("Nullable types are handled") { - assert(CSVTypeCast.castTo("", IntegerType, nullable = true, CSVOptions()) == null) + assertNull( + CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", BooleanType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", TimestampType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", DateType, nullable = true, CSVOptions("nullValue", "-"))) + assertNull( + CSVTypeCast.castTo("-", StringType, nullable = true, CSVOptions("nullValue", "-"))) } - test("String type should always return the same as the input") { + test("String type should also respect `nullValue`") { + assertNull( + CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions())) + assert( + CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) == + UTF8String.fromString("")) + assert( - CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()) == + CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions("nullValue", "null")) == UTF8String.fromString("")) assert( - CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) == + CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions("nullValue", "null")) == UTF8String.fromString("")) + + assertNull( + CSVTypeCast.castTo(null, StringType, nullable = true, CSVOptions("nullValue", "null"))) } test("Throws exception for empty string with non null type") { @@ -96,13 +126,18 @@ class CSVTypeCastSuite extends SparkFunSuite { assert(CSVTypeCast.castTo("1.00", DoubleType) == 1.0) assert(CSVTypeCast.castTo("true", BooleanType) == true) - val options = CSVOptions("dateFormat", "dd/MM/yyyy hh:mm") + val timestampsOptions = CSVOptions("timestampFormat", "dd/MM/yyyy hh:mm") val customTimestamp = "31/01/2015 00:00" - val expectedTime = options.dateFormat.parse("31/01/2015 00:00").getTime - assert(CSVTypeCast.castTo(customTimestamp, TimestampType, nullable = true, options) == - expectedTime * 1000L) - assert(CSVTypeCast.castTo(customTimestamp, DateType, nullable = true, options) == - DateTimeUtils.millisToDays(expectedTime)) + val expectedTime = timestampsOptions.timestampFormat.parse(customTimestamp).getTime + val castedTimestamp = + CSVTypeCast.castTo(customTimestamp, TimestampType, nullable = true, timestampsOptions) + assert(castedTimestamp == expectedTime * 1000L) + + val customDate = "31/01/2015" + val dateOptions = CSVOptions("dateFormat", "dd/MM/yyyy") + val expectedDate = dateOptions.dateFormat.parse(customDate).getTime + val castedDate = CSVTypeCast.castTo(customTimestamp, DateType, nullable = true, dateOptions) + assert(castedDate == DateTimeUtils.millisToDays(expectedDate)) val timestamp = "2015-01-01 00:00:00" assert(CSVTypeCast.castTo(timestamp, TimestampType) == @@ -165,20 +200,4 @@ class CSVTypeCastSuite extends SparkFunSuite { assert(doubleVal2 == Double.PositiveInfinity) } - test("Type-specific null values are used for casting") { - assertNull( - CSVTypeCast.castTo("-", ByteType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", ShortType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", IntegerType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", LongType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", FloatType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", DoubleType, nullable = true, CSVOptions("nullValue", "-"))) - assertNull( - CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, CSVOptions("nullValue", "-"))) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 1742df31bba9a..c31dffedbdf67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -27,16 +27,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowComments off") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowComments on") { val str = """{'name': /* hello */ 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowComments", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowComments", "true").json(rdd) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -44,16 +44,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowSingleQuotes off") { val str = """{'name': 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowSingleQuotes", "false").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowSingleQuotes", "false").json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowSingleQuotes on") { val str = """{'name': 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -61,16 +61,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowUnquotedFieldNames off") { val str = """{name: 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowUnquotedFieldNames on") { val str = """{name: 'Reynold Xin'}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowUnquotedFieldNames", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowUnquotedFieldNames", "true").json(rdd) assert(df.schema.head.name == "name") assert(df.first().getString(0) == "Reynold Xin") @@ -78,16 +78,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowNumericLeadingZeros off") { val str = """{"age": 0018}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowNumericLeadingZeros on") { val str = """{"age": 0018}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowNumericLeadingZeros", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowNumericLeadingZeros", "true").json(rdd) assert(df.schema.head.name == "age") assert(df.first().getLong(0) == 18) @@ -97,16 +97,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS. ignore("allowNonNumericNumbers off") { val str = """{"age": NaN}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.json(rdd) assert(df.schema.head.name == "_corrupt_record") } ignore("allowNonNumericNumbers on") { val str = """{"age": NaN}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowNonNumericNumbers", "true").json(rdd) assert(df.schema.head.name == "age") assert(df.first().getDouble(0).isNaN) @@ -114,16 +114,16 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { test("allowBackslashEscapingAnyCharacter off") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) assert(df.schema.head.name == "_corrupt_record") } test("allowBackslashEscapingAnyCharacter on") { val str = """{"name": "Cazen Lee", "price": "\$10"}""" - val rdd = sqlContext.sparkContext.parallelize(Seq(str)) - val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) + val rdd = spark.sparkContext.parallelize(Seq(str)) + val df = spark.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) assert(df.schema.head.name == "name") assert(df.schema.last.name == "price") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b1279abd63830..1ba5b81231178 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -64,9 +64,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { generator.flush() } + val dummyOption = new JSONOptions(Map.empty[String, String]) Utils.tryWithResource(factory.createParser(writer.toString)) { parser => parser.nextToken() - JacksonParser.convertRootField(factory, parser, dataType) + JacksonParser.convertRootField(factory, parser, dataType, dummyOption) } } @@ -99,15 +100,15 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), - enforceCorrectType(ISO8601Time1, TimestampType)) - checkTypePromotion(DateTimeUtils.millisToDays(3601000), - enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" + checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)), + enforceCorrectType(ISO8601Time1, TimestampType)) checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)), - enforceCorrectType(ISO8601Time2, TimestampType)) - checkTypePromotion(DateTimeUtils.millisToDays(10801000), - enforceCorrectType(ISO8601Time2, DateType)) + enforceCorrectType(ISO8601Time2, TimestampType)) + + val ISO8601Date = "1970-01-01" + checkTypePromotion(DateTimeUtils.millisToDays(32400000), + enforceCorrectType(ISO8601Date, DateType)) } test("Get compatible type") { @@ -229,7 +230,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring with null in sampling") { - val jsonDF = sqlContext.read.json(jsonNullStruct) + val jsonDF = spark.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -239,7 +240,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("nullstr", StringType, true):: Nil) assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select nullstr, headers.Host from jsonTable"), @@ -248,7 +249,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Primitive field and type inferring") { - val jsonDF = sqlContext.read.json(primitiveFieldAndType) + val jsonDF = spark.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -261,7 +262,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -276,7 +277,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Complex field and type inferring") { - val jsonDF = sqlContext.read.json(complexFieldAndType1) + val jsonDF = spark.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -302,7 +303,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") // Access elements of a primitive array. checkAnswer( @@ -375,8 +376,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("GetField operation on complex data type") { - val jsonDF = sqlContext.read.json(complexFieldAndType1) - jsonDF.registerTempTable("jsonTable") + val jsonDF = spark.read.json(complexFieldAndType1) + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), @@ -391,7 +392,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in primitive field values") { - val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) + val jsonDF = spark.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -403,7 +404,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -446,13 +447,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - Row(BigDecimal("92233720368547758071.2")) + Row(92233720368547758071.2) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758071.2")) + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) ) // String and Boolean conflict: resolve the type as string. @@ -463,8 +464,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = sqlContext.read.json(primitiveFieldValueTypeConflict) - jsonDF.registerTempTable("jsonTable") + val jsonDF = spark.read.json(primitiveFieldValueTypeConflict) + jsonDF.createOrReplaceTempView("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. // Number and Boolean conflict: resolve the type as boolean in this query. @@ -516,7 +517,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in complex field values") { - val jsonDF = sqlContext.read.json(complexFieldValueTypeConflict) + val jsonDF = spark.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -528,7 +529,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -540,7 +541,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Type conflict in array elements") { - val jsonDF = sqlContext.read.json(arrayElementTypeConflict) + val jsonDF = spark.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -550,7 +551,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -568,7 +569,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Handling missing fields") { - val jsonDF = sqlContext.read.json(missingFields) + val jsonDF = spark.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -580,7 +581,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") } test("Loading a JSON dataset from a text file") { @@ -588,7 +589,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.json(path) + val jsonDF = spark.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -601,7 +602,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -620,7 +621,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(path) + val jsonDF = spark.read.option("primitivesAsString", "true").json(path) val expectedSchema = StructType( StructField("bigInteger", StringType, true) :: @@ -633,7 +634,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -648,7 +649,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Loading a JSON dataset primitivesAsString returns complex fields as strings") { - val jsonDF = sqlContext.read.option("primitivesAsString", "true").json(complexFieldAndType1) + val jsonDF = spark.read.option("primitivesAsString", "true").json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -674,7 +675,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") // Access elements of a primitive array. checkAnswer( @@ -746,7 +747,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Loading a JSON dataset prefersDecimal returns schema with float types as BigDecimal") { - val jsonDF = sqlContext.read.option("prefersDecimal", "true").json(primitiveFieldAndType) + val jsonDF = spark.read.option("prefersDecimal", "true").json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -759,7 +760,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -777,7 +778,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val mixedIntegerAndDoubleRecords = sparkContext.parallelize( """{"a": 3, "b": 1.1}""" :: s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil) - val jsonDF = sqlContext.read + val jsonDF = spark.read .option("prefersDecimal", "true") .json(mixedIntegerAndDoubleRecords) @@ -796,7 +797,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Infer big integers correctly even when it does not fit in decimal") { - val jsonDF = sqlContext.read + val jsonDF = spark.read .json(bigIntegerRecords) // The value in `a` field will be a double as it does not fit in decimal. For `b` field, @@ -810,7 +811,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Infer floating-point values correctly even when it does not fit in decimal") { - val jsonDF = sqlContext.read + val jsonDF = spark.read .option("prefersDecimal", "true") .json(floatingValueRecords) @@ -823,7 +824,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(expectedSchema === jsonDF.schema) checkAnswer(jsonDF, Row(1.0E-39D, BigDecimal(0.01))) - val mergedJsonDF = sqlContext.read + val mergedJsonDF = spark.read .option("prefersDecimal", "true") .json(floatingValueRecords ++ bigIntegerRecords) @@ -847,7 +848,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { sql( s""" - |CREATE TEMPORARY TABLE jsonTableSQL + |CREATE TEMPORARY VIEW jsonTableSQL |USING org.apache.spark.sql.json |OPTIONS ( | path '$path' @@ -881,11 +882,11 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = sqlContext.read.schema(schema).json(path) + val jsonDF1 = spark.read.schema(schema).json(path) assert(schema === jsonDF1.schema) - jsonDF1.registerTempTable("jsonTable1") + jsonDF1.createOrReplaceTempView("jsonTable1") checkAnswer( sql("select * from jsonTable1"), @@ -898,11 +899,11 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val jsonDF2 = sqlContext.read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = spark.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) - jsonDF2.registerTempTable("jsonTable2") + jsonDF2.createOrReplaceTempView("jsonTable2") checkAnswer( sql("select * from jsonTable2"), @@ -919,9 +920,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = spark.read.schema(schemaWithSimpleMap).json(mapType1) - jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") + jsonWithSimpleMap.createOrReplaceTempView("jsonWithSimpleMap") checkAnswer( sql("select `map` from jsonWithSimpleMap"), @@ -947,9 +948,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = sqlContext.read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = spark.read.schema(schemaWithComplexMap).json(mapType2) - jsonWithComplexMap.registerTempTable("jsonWithComplexMap") + jsonWithComplexMap.createOrReplaceTempView("jsonWithComplexMap") checkAnswer( sql("select `map` from jsonWithComplexMap"), @@ -973,8 +974,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = sqlContext.read.json(complexFieldAndType2) - jsonDF.registerTempTable("jsonTable") + val jsonDF = spark.read.json(complexFieldAndType2) + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), @@ -991,8 +992,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3390 Complex arrays") { - val jsonDF = sqlContext.read.json(complexFieldAndType2) - jsonDF.registerTempTable("jsonTable") + val jsonDF = spark.read.json(complexFieldAndType2) + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql( @@ -1014,8 +1015,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = sqlContext.read.json(jsonArray) - jsonDF.registerTempTable("jsonTable") + val jsonDF = spark.read.json(jsonArray) + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql( @@ -1035,7 +1036,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("a", StringType, true) :: Nil) // `FAILFAST` mode should throw an exception for corrupt records. val exceptionOne = intercept[SparkException] { - sqlContext.read + spark.read .option("mode", "FAILFAST") .json(corruptRecords) .collect() @@ -1043,7 +1044,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {")) val exceptionTwo = intercept[SparkException] { - sqlContext.read + spark.read .option("mode", "FAILFAST") .schema(schema) .json(corruptRecords) @@ -1060,7 +1061,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val schemaTwo = StructType( StructField("a", StringType, true) :: Nil) // `DROPMALFORMED` mode should skip corrupt records - val jsonDFOne = sqlContext.read + val jsonDFOne = spark.read .option("mode", "DROPMALFORMED") .json(corruptRecords) checkAnswer( @@ -1069,7 +1070,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) assert(jsonDFOne.schema === schemaOne) - val jsonDFTwo = sqlContext.read + val jsonDFTwo = spark.read .option("mode", "DROPMALFORMED") .schema(schemaTwo) .json(corruptRecords) @@ -1082,9 +1083,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Corrupt records: PERMISSIVE mode") { // Test if we can query corrupt records. withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { - withTempTable("jsonTable") { - val jsonDF = sqlContext.read.json(corruptRecords) - jsonDF.registerTempTable("jsonTable") + withTempView("jsonTable") { + val jsonDF = spark.read.json(corruptRecords) + jsonDF.createOrReplaceTempView("jsonTable") val schema = StructType( StructField("_unparsed", StringType, true) :: StructField("a", StringType, true) :: @@ -1134,7 +1135,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-13953 Rename the corrupt record field via option") { - val jsonDF = sqlContext.read + val jsonDF = spark.read .option("columnNameOfCorruptRecord", "_malformed") .json(corruptRecords) val schema = StructType( @@ -1155,8 +1156,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-4068: nulls in arrays") { - val jsonDF = sqlContext.read.json(nullsInArrays) - jsonDF.registerTempTable("jsonTable") + val jsonDF = spark.read.json(nullsInArrays) + jsonDF.createOrReplaceTempView("jsonTable") val schema = StructType( StructField("field1", @@ -1201,8 +1202,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = sqlContext.createDataFrame(rowRDD1, schema1) - df1.registerTempTable("applySchema1") + val df1 = spark.createDataFrame(rowRDD1, schema1) + df1.createOrReplaceTempView("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() // scalastyle:off @@ -1224,17 +1225,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = sqlContext.createDataFrame(rowRDD2, schema2) - df3.registerTempTable("applySchema2") + val df3 = spark.createDataFrame(rowRDD2, schema2) + df3.createOrReplaceTempView("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = sqlContext.read.json(primitiveFieldAndType) - val primTable = sqlContext.read.json(jsonDF.toJSON.rdd) - primTable.registerTempTable("primitiveTable") + val jsonDF = spark.read.json(primitiveFieldAndType) + val primTable = spark.read.json(jsonDF.toJSON.rdd) + primTable.createOrReplaceTempView("primitiveTable") checkAnswer( sql("select * from primitiveTable"), Row(new java.math.BigDecimal("92233720368547758070"), @@ -1245,9 +1246,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "this is a simple string.") ) - val complexJsonDF = sqlContext.read.json(complexFieldAndType1) - val compTable = sqlContext.read.json(complexJsonDF.toJSON.rdd) - compTable.registerTempTable("complexTable") + val complexJsonDF = spark.read.json(complexFieldAndType1) + val compTable = spark.read.json(complexJsonDF.toJSON.rdd) + compTable.createOrReplaceTempView("complexTable") // Access elements of a primitive array. checkAnswer( sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"), @@ -1316,19 +1317,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) val d1 = DataSource( - sqlContext.sparkSession, + spark, userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, - className = classOf[DefaultSource].getCanonicalName, + className = classOf[JsonFileFormat].getCanonicalName, options = Map("path" -> path)).resolveRelation() val d2 = DataSource( - sqlContext.sparkSession, + spark, userSpecifiedSchema = None, partitionColumns = Array.empty[String], bucketSpec = None, - className = classOf[DefaultSource].getCanonicalName, + className = classOf[JsonFileFormat].getCanonicalName, options = Map("path" -> path)).resolveRelation() assert(d1 === d2) }) @@ -1345,16 +1346,16 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempDir { dir => val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val df = sqlContext.read.schema(schemaWithSimpleMap).json(mapType1) + val df = spark.read.schema(schemaWithSimpleMap).json(mapType1) val path = dir.getAbsolutePath df.write.mode("overwrite").parquet(path) // order of MapType is not defined - assert(sqlContext.read.parquet(path).count() == 5) + assert(spark.read.parquet(path).count() == 5) - val df2 = sqlContext.read.json(corruptRecords) + val df2 = spark.read.json(corruptRecords) df2.write.mode("overwrite").parquet(path) - checkAnswer(sqlContext.read.parquet(path), df2.collect()) + checkAnswer(spark.read.parquet(path), df2.collect()) } } } @@ -1387,7 +1388,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "col1", "abd") - sqlContext.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part") + spark.read.json(root.getAbsolutePath).createOrReplaceTempView("test_myjson_with_part") checkAnswer(sql( "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4)) checkAnswer(sql( @@ -1447,7 +1448,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Row(4.75.toFloat, Seq(false, true)), new UDT.MyDenseVector(Array(0.25, 2.25, 4.25))) val data = - Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil + Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil // Data generated by previous versions. // scalastyle:off @@ -1462,7 +1463,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // scalastyle:on // Generate data for the current version. - val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data, 1), schema) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) withTempPath { path => df.write.format("json").mode("overwrite").save(path.getCanonicalPath) @@ -1486,13 +1487,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { "Spark 1.4.1", "Spark 1.5.0", "Spark 1.5.0", - "Spark " + sqlContext.sparkContext.version, - "Spark " + sqlContext.sparkContext.version) + "Spark " + spark.sparkContext.version, + "Spark " + spark.sparkContext.version) val expectedResult = col0Values.map { v => Row.fromSeq(Seq(v) ++ constantValues) } checkAnswer( - sqlContext.read.format("json").schema(schema).load(path.getCanonicalPath), + spark.read.format("json").schema(schema).load(path.getCanonicalPath), expectedResult ) } @@ -1502,36 +1503,36 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(2) + val df = spark.range(2) df.write.json(path + "/p=1") df.write.json(path + "/p=2") - assert(sqlContext.read.json(path).count() === 4) + assert(spark.read.json(path).count() === 4) val extraOptions = Map( "mapred.input.pathFilter.class" -> classOf[TestFileFilter].getName, "mapreduce.input.pathFilter.class" -> classOf[TestFileFilter].getName ) - assert(sqlContext.read.options(extraOptions).json(path).count() === 2) + assert(spark.read.options(extraOptions).json(path).count() === 2) } } test("SPARK-12057 additional corrupt records do not throw exceptions") { // Test if we can query corrupt records. withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { - withTempTable("jsonTable") { + withTempView("jsonTable") { val schema = StructType( StructField("_unparsed", StringType, true) :: StructField("dummy", StringType, true) :: Nil) { // We need to make sure we can infer the schema. - val jsonDF = sqlContext.read.json(additionalCorruptRecords) + val jsonDF = spark.read.json(additionalCorruptRecords) assert(jsonDF.schema === schema) } { - val jsonDF = sqlContext.read.schema(schema).json(additionalCorruptRecords) - jsonDF.registerTempTable("jsonTable") + val jsonDF = spark.read.schema(schema).json(additionalCorruptRecords) + jsonDF.createOrReplaceTempView("jsonTable") // In HiveContext, backticks should be used to access columns starting with a underscore. checkAnswer( @@ -1563,7 +1564,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { StructField("a", StructType( StructField("b", StringType) :: Nil )) :: Nil) - val jsonDF = sqlContext.read.schema(schema).json(path) + val jsonDF = spark.read.schema(schema).json(path) assert(jsonDF.count() == 2) } } @@ -1575,7 +1576,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.json(path) + val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .format("json") @@ -1585,7 +1586,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val compressedFiles = new File(jsonDir).listFiles() assert(compressedFiles.exists(_.getName.endsWith(".json.gz"))) - val jsonCopy = sqlContext.read + val jsonCopy = spark.read .format("json") .load(jsonDir) @@ -1611,7 +1612,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = sqlContext.read.json(path) + val jsonDF = spark.read.json(path) val jsonDir = new File(dir, "json").getCanonicalPath jsonDF.coalesce(1).write .format("json") @@ -1622,7 +1623,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val compressedFiles = new File(jsonDir).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".json.gz"))) - val jsonCopy = sqlContext.read + val jsonCopy = spark.read .format("json") .options(extraOptions) .load(jsonDir) @@ -1635,11 +1636,11 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("Casting long as timestamp") { - withTempTable("jsonTable") { + withTempView("jsonTable") { val schema = (new StructType).add("ts", TimestampType) - val jsonDF = sqlContext.read.schema(schema).json(timestampAsLong) + val jsonDF = spark.read.schema(schema).json(timestampAsLong) - jsonDF.registerTempTable("jsonTable") + jsonDF.createOrReplaceTempView("jsonTable") checkAnswer( sql("select ts from jsonTable"), @@ -1657,9 +1658,66 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val json = s""" |{"a": [{$nested}], "b": [{$nested}]} """.stripMargin - val rdd = sqlContext.sparkContext.makeRDD(Seq(json)) - val df = sqlContext.read.json(rdd) + val rdd = spark.sparkContext.makeRDD(Seq(json)) + val df = spark.read.json(rdd) assert(df.schema.size === 2) df.collect() } + + test("Write dates correctly with dateFormat option") { + val customSchema = new StructType(Array(StructField("date", DateType, true))) + withTempDir { dir => + // With dateFormat option. + val datesWithFormatPath = s"${dir.getCanonicalPath}/datesWithFormat.json" + val datesWithFormat = spark.read + .schema(customSchema) + .option("dateFormat", "dd/MM/yyyy HH:mm") + .json(datesRecords) + + datesWithFormat.write + .format("json") + .option("dateFormat", "yyyy/MM/dd") + .save(datesWithFormatPath) + + // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) + val stringDatesWithFormat = spark.read + .schema(stringSchema) + .json(datesWithFormatPath) + val expectedStringDatesWithFormat = Seq( + Row("2015/08/26"), + Row("2014/10/27"), + Row("2016/01/28")) + + checkAnswer(stringDatesWithFormat, expectedStringDatesWithFormat) + } + } + + test("Write timestamps correctly with dateFormat option") { + val customSchema = new StructType(Array(StructField("date", TimestampType, true))) + withTempDir { dir => + // With dateFormat option. + val timestampsWithFormatPath = s"${dir.getCanonicalPath}/timestampsWithFormat.json" + val timestampsWithFormat = spark.read + .schema(customSchema) + .option("timestampFormat", "dd/MM/yyyy HH:mm") + .json(datesRecords) + timestampsWithFormat.write + .format("json") + .option("timestampFormat", "yyyy/MM/dd HH:mm") + .save(timestampsWithFormatPath) + + // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) + val stringTimestampsWithFormat = spark.read + .schema(stringSchema) + .json(timestampsWithFormatPath) + val expectedStringDatesWithFormat = Seq( + Row("2015/08/26 18:00"), + Row("2014/10/27 18:30"), + Row("2016/01/28 20:00")) + + checkAnswer(stringTimestampsWithFormat, expectedStringDatesWithFormat) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index 2873c6a881bef..d1d82fd5658b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession private[json] trait TestJsonData { - protected def sqlContext: SQLContext + protected def spark: SparkSession def primitiveFieldAndType: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -35,7 +35,7 @@ private[json] trait TestJsonData { }""" :: Nil) def primitiveFieldValueTypeConflict: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -46,14 +46,14 @@ private[json] trait TestJsonData { "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) def jsonNullStruct: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) def complexFieldValueTypeConflict: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -64,14 +64,14 @@ private[json] trait TestJsonData { "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) def arrayElementTypeConflict: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) def missingFields: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: @@ -79,7 +79,7 @@ private[json] trait TestJsonData { """{"e":"str"}""" :: Nil) def complexFieldAndType1: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -95,7 +95,7 @@ private[json] trait TestJsonData { }""" :: Nil) def complexFieldAndType2: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -149,7 +149,7 @@ private[json] trait TestJsonData { }""" :: Nil) def mapType1: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: @@ -157,7 +157,7 @@ private[json] trait TestJsonData { """{"map": {"e": null}}""" :: Nil) def mapType2: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -166,21 +166,21 @@ private[json] trait TestJsonData { """{"map": {"f": {"field1": null}}}""" :: Nil) def nullsInArrays: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) def jsonArray: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) def corruptRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -189,7 +189,7 @@ private[json] trait TestJsonData { """]""" :: Nil) def additionalCorruptRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"dummy":"test"}""" :: """[1,2,3]""" :: """":"test", "a":1}""" :: @@ -197,7 +197,7 @@ private[json] trait TestJsonData { """ ","ian":"test"}""" :: Nil) def emptyRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{""" :: """""" :: """{"a": {}}""" :: @@ -206,23 +206,29 @@ private[json] trait TestJsonData { """]""" :: Nil) def timestampAsLong: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"ts":1451732645}""" :: Nil) def arrayAndStructRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( """{"a": {"b": 1}}""" :: """{"a": []}""" :: Nil) def floatingValueRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil) def bigIntegerRecords: RDD[String] = - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil) - lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) + def datesRecords: RDD[String] = + spark.sparkContext.parallelize( + """{"date": "26/08/2015 18:00"}""" :: + """{"date": "27/10/2014 18:30"}""" :: + """{"date": "28/01/2016 20:00"}""" :: Nil) - def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) + lazy val singleRow: RDD[String] = spark.sparkContext.parallelize("""{"a":123}""" :: Nil) + + def empty: RDD[String] = spark.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index f98ea8c5aeb80..6509e04e85167 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -67,7 +67,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row( i % 2 == 0, i, @@ -114,7 +114,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => if (i % 3 == 0) { Row.apply(Seq.fill(7)(null): _*) } else { @@ -155,7 +155,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row( Seq.tabulate(3)(i => s"val_$i"), if (i % 3 == 0) null else Seq.tabulate(3)(identity)) @@ -182,7 +182,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row(Seq.tabulate(3, 3)((i, j) => i * 3 + j)) }) } @@ -205,7 +205,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row(Seq.tabulate(3)(i => i.toString -> Seq.tabulate(3)(j => i + j)).toMap) }) } @@ -221,7 +221,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared logParquetSchema(path) - checkAnswer(sqlContext.read.parquet(path), (0 until 10).map { i => + checkAnswer(spark.read.parquet(path), (0 until 10).map { i => Row( Seq.tabulate(3)(n => s"arr_${i + n}"), Seq.tabulate(3)(n => n.toString -> (i + n: Integer)).toMap, @@ -267,7 +267,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared } } - checkAnswer(sqlContext.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + checkAnswer(spark.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 45cc6810d48b6..57cd70e1911c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -38,7 +38,7 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq } protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = { - val hadoopConf = sqlContext.sessionState.newHadoopConf() + val hadoopConf = spark.sessionState.newHadoopConf() val fsPath = new Path(path) val fs = fsPath.getFileSystem(hadoopConf) val parquetFiles = fs.listStatus(fsPath, new PathFilter { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 65635e3c066d9..ab9250045f5b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.util.{AccumulatorContext, LongAccumulator} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -304,7 +305,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(dir.getCanonicalPath).filter("part = 1"), + spark.read.parquet(dir.getCanonicalPath).filter("part = 1"), (1 to 3).map(i => Row(i, i.toString, 1))) } } @@ -321,7 +322,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"), + spark.read.parquet(dir.getCanonicalPath).filter("a > 0 and (part = 0 or a > 1)"), (2 to 3).map(i => Row(i, i.toString, 1))) } } @@ -339,7 +340,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // The filter "a > 1 or b < 2" will not get pushed down, and the projection is empty, // this query will throw an exception since the project from combinedFilter expect // two projection while the - val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + val df1 = spark.read.parquet(dir.getCanonicalPath) assert(df1.filter("a > 1 or b < 2").count() == 2) } @@ -358,7 +359,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // test the generate new projection case // when projects != partitionAndNormalColumnProjs - val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + val df1 = spark.read.parquet(dir.getCanonicalPath) checkAnswer( df1.filter("a > 1 or b > 2").orderBy("a").selectExpr("a", "b", "c", "d"), @@ -370,73 +371,75 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") { import testImplicits._ - - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", - SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true") { - withTempPath { dir => - val pathOne = s"${dir.getCanonicalPath}/table1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathOne) - val pathTwo = s"${dir.getCanonicalPath}/table2" - (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.parquet(pathTwo) - - // If the "c = 1" filter gets pushed down, this query will throw an exception which - // Parquet emits. This is a Parquet issue (PARQUET-389). - val df = sqlContext.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") - checkAnswer( - df, - Row(1, "1", null)) - - // The fields "a" and "c" only exist in one Parquet file. - assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) - assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) - - val pathThree = s"${dir.getCanonicalPath}/table3" - df.write.parquet(pathThree) - - // We will remove the temporary metadata when writing Parquet file. - val schema = sqlContext.read.parquet(pathThree).schema - assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) - - val pathFour = s"${dir.getCanonicalPath}/table4" - val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") - dfStruct.select(struct("a").as("s")).write.parquet(pathFour) - - val pathFive = s"${dir.getCanonicalPath}/table5" - val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") - dfStruct2.select(struct("c").as("s")).write.parquet(pathFive) - - // If the "s.c = 1" filter gets pushed down, this query will throw an exception which - // Parquet emits. - val dfStruct3 = sqlContext.read.parquet(pathFour, pathFive).filter("s.c = 1") - .selectExpr("s") - checkAnswer(dfStruct3, Row(Row(null, 1))) - - // The fields "s.a" and "s.c" only exist in one Parquet file. - val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType] - assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) - assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) - - val pathSix = s"${dir.getCanonicalPath}/table6" - dfStruct3.write.parquet(pathSix) - - // We will remove the temporary metadata when writing Parquet file. - val forPathSix = sqlContext.read.parquet(pathSix).schema - assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) - - // sanity test: make sure optional metadata field is not wrongly set. - val pathSeven = s"${dir.getCanonicalPath}/table7" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven) - val pathEight = s"${dir.getCanonicalPath}/table8" - (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) - - val df2 = sqlContext.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") - checkAnswer( - df2, - Row(1, "1")) - - // The fields "a" and "b" exist in both two Parquet files. No metadata is set. - assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField)) - assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField)) + Seq("true", "false").map { vectorized => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { + withTempPath { dir => + val pathOne = s"${dir.getCanonicalPath}/table1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathOne) + val pathTwo = s"${dir.getCanonicalPath}/table2" + (1 to 3).map(i => (i, i.toString)).toDF("c", "b").write.parquet(pathTwo) + + // If the "c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. This is a Parquet issue (PARQUET-389). + val df = spark.read.parquet(pathOne, pathTwo).filter("c = 1").selectExpr("c", "b", "a") + checkAnswer( + df, + Row(1, "1", null)) + + // The fields "a" and "c" only exist in one Parquet file. + assert(df.schema("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(df.schema("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathThree = s"${dir.getCanonicalPath}/table3" + df.write.parquet(pathThree) + + // We will remove the temporary metadata when writing Parquet file. + val schema = spark.read.parquet(pathThree).schema + assert(schema.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + val pathFour = s"${dir.getCanonicalPath}/table4" + val dfStruct = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") + dfStruct.select(struct("a").as("s")).write.parquet(pathFour) + + val pathFive = s"${dir.getCanonicalPath}/table5" + val dfStruct2 = sparkContext.parallelize(Seq((1, 1))).toDF("c", "b") + dfStruct2.select(struct("c").as("s")).write.parquet(pathFive) + + // If the "s.c = 1" filter gets pushed down, this query will throw an exception which + // Parquet emits. + val dfStruct3 = spark.read.parquet(pathFour, pathFive).filter("s.c = 1") + .selectExpr("s") + checkAnswer(dfStruct3, Row(Row(null, 1))) + + // The fields "s.a" and "s.c" only exist in one Parquet file. + val field = dfStruct3.schema("s").dataType.asInstanceOf[StructType] + assert(field("a").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + assert(field("c").metadata.getBoolean(StructType.metadataKeyForOptionalField)) + + val pathSix = s"${dir.getCanonicalPath}/table6" + dfStruct3.write.parquet(pathSix) + + // We will remove the temporary metadata when writing Parquet file. + val forPathSix = spark.read.parquet(pathSix).schema + assert(forPathSix.forall(!_.metadata.contains(StructType.metadataKeyForOptionalField))) + + // sanity test: make sure optional metadata field is not wrongly set. + val pathSeven = s"${dir.getCanonicalPath}/table7" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathSeven) + val pathEight = s"${dir.getCanonicalPath}/table8" + (4 to 6).map(i => (i, i.toString)).toDF("a", "b").write.parquet(pathEight) + + val df2 = spark.read.parquet(pathSeven, pathEight).filter("a = 1").selectExpr("a", "b") + checkAnswer( + df2, + Row(1, "1")) + + // The fields "a" and "b" exist in both two Parquet files. No metadata is set. + assert(!df2.schema("a").metadata.contains(StructType.metadataKeyForOptionalField)) + assert(!df2.schema("b").metadata.contains(StructType.metadataKeyForOptionalField)) + } } } } @@ -449,7 +452,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - val df = sqlContext.read.parquet(path).filter("a = 2") + val df = spark.read.parquet(path).filter("a = 2") // The result should be single row. // When a filter is pushed to Parquet, Parquet can apply it to every row. @@ -470,11 +473,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.parquet(path) checkAnswer( - sqlContext.read.parquet(path).where("not (a = 2) or not(b in ('1'))"), + spark.read.parquet(path).where("not (a = 2) or not(b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) checkAnswer( - sqlContext.read.parquet(path).where("not (a = 2 and b in ('1'))"), + spark.read.parquet(path).where("not (a = 2 and b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) } } @@ -527,22 +530,64 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex // When a filter is pushed to Parquet, Parquet can apply it to every row. // So, we can check the number of rows returned from the Parquet // to make sure our filter pushdown work. - val df = sqlContext.read.parquet(path).where("b in (0,2)") + val df = spark.read.parquet(path).where("b in (0,2)") assert(stripSparkFilter(df).count == 3) - val df1 = sqlContext.read.parquet(path).where("not (b in (1))") + val df1 = spark.read.parquet(path).where("not (b in (1))") assert(stripSparkFilter(df1).count == 3) - val df2 = sqlContext.read.parquet(path).where("not (b in (1,3) or a <= 2)") + val df2 = spark.read.parquet(path).where("not (b in (1,3) or a <= 2)") assert(stripSparkFilter(df2).count == 2) - val df3 = sqlContext.read.parquet(path).where("not (b in (1,3) and a <= 2)") + val df3 = spark.read.parquet(path).where("not (b in (1,3) and a <= 2)") assert(stripSparkFilter(df3).count == 4) - val df4 = sqlContext.read.parquet(path).where("not (a <= 2)") + val df4 = spark.read.parquet(path).where("not (a <= 2)") assert(stripSparkFilter(df4).count == 3) } } } } + + test("SPARK-16371 Do not push down filters when inner name and outer name are the same") { + withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i)))) { implicit df => + // Here the schema becomes as below: + // + // root + // |-- _1: struct (nullable = true) + // | |-- _1: integer (nullable = true) + // + // The inner column name, `_1` and outer column name `_1` are the same. + // Obviously this should not push down filters because the outer column is struct. + assert(df.filter("_1 IS NOT NULL").count() === 4) + } + } + + test("Fiters should be pushed down for vectorized Parquet reader at row group level") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table" + (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path) + + Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) { + val accu = new LongAccumulator + accu.register(sparkContext, Some("numRowGroups")) + + val df = spark.read.parquet(path).filter("a < 100") + df.foreachPartition(_.foreach(v => accu.add(0))) + df.collect + + val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups") + assert(numRowGroups.isDefined) + assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value)) + AccumulatorContext.remove(accu.id) + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 32fe5ba127ca4..46ccfa53bd79d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -38,11 +38,12 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -113,7 +114,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) readParquetFile(path.toString)(df => { val sparkTypes = df.schema.map(_.dataType) @@ -132,7 +133,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { testStandardAndLegacyModes("fixed-length decimals") { def makeDecimalRDD(decimal: DecimalType): DataFrame = { - sqlContext + spark .range(1000) // Parquet doesn't allow column names with spaces, have to add an alias here. // Minus 500 here so that negative decimals are also tested. @@ -250,10 +251,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) val errorMessage = intercept[Throwable] { - sqlContext.read.parquet(path.toString).printSchema() + spark.read.parquet(path.toString).printSchema() }.toString assert(errorMessage.contains("Parquet type not supported")) } @@ -271,15 +272,15 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { location => val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf) - val sparkTypes = sqlContext.read.parquet(path.toString).schema.map(_.dataType) + val sparkTypes = spark.read.parquet(path.toString).schema.map(_.dataType) assert(sparkTypes === expectedSparkTypes) } } test("compression codec") { - val hadoopConf = sqlContext.sessionState.newHadoopConf() + val hadoopConf = spark.sessionState.newHadoopConf() def compressionCodecFor(path: String, codecName: String): String = { val codecs = for { footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) @@ -296,7 +297,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => - assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { + assertResult(spark.conf.get(SQLConf.PARQUET_COMPRESSION).toUpperCase) { compressionCodecFor(path, codec.name()) } } @@ -304,7 +305,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec)) + checkCompressionCodec( + CompressionCodecName.fromConf(spark.conf.get(SQLConf.PARQUET_COMPRESSION))) checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) checkCompressionCodec(CompressionCodecName.GZIP) @@ -351,7 +353,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } test("write metadata") { - val hadoopConf = sqlContext.sessionState.newHadoopConf() + val hadoopConf = spark.sessionState.newHadoopConf() withTempPath { file => val path = new Path(file.toURI.toString) val fs = FileSystem.getLocal(hadoopConf) @@ -361,7 +363,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE))) assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE))) - val expectedSchema = new CatalystSchemaConverter().convert(schema) + val expectedSchema = new ParquetSchemaConverter().convert(schema) val actualSchema = readFooter(path, hadoopConf).getFileMetaData.getSchema actualSchema.checkContains(expectedSchema) @@ -431,9 +433,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { """.stripMargin) withTempPath { location => - val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) + val extraMetadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) val path = new Path(location.getCanonicalPath) - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() writeMetadata(parquetSchema, path, conf, extraMetadata) readParquetFile(path.toString) { df => @@ -455,7 +457,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { ) withTempPath { dir => val message = intercept[SparkException] { - sqlContext.range(0, 1).write.options(extraOptions).parquet(dir.getCanonicalPath) + spark.range(0, 1).write.options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(message === "Intentional exception for testing purposes") } @@ -465,10 +467,10 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// intercept[Throwable] { - sqlContext.read.parquet("file:///nonexistent") + spark.read.parquet("file:///nonexistent") } val errorMessage = intercept[Throwable] { - sqlContext.read.parquet("hdfs://nonexistent") + spark.read.parquet("hdfs://nonexistent") }.toString assert(errorMessage.contains("UnknownHostException")) } @@ -486,14 +488,14 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withTempPath { dir => val m1 = intercept[SparkException] { - sqlContext.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath) + spark.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(m1.contains("Intentional exception for testing purposes")) } withTempPath { dir => val m2 = intercept[SparkException] { - val df = sqlContext.range(1).select('id as 'a, 'id as 'b).coalesce(1) + val df = spark.range(1).select('id as 'a, 'id as 'b).coalesce(1) df.write.partitionBy("a").options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(m2.contains("Intentional exception for testing purposes")) @@ -512,26 +514,28 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { ParquetOutputFormat.ENABLE_DICTIONARY -> "true" ) - val hadoopConf = sqlContext.sessionState.newHadoopConfWithOptions(extraOptions) + val hadoopConf = spark.sessionState.newHadoopConfWithOptions(extraOptions) - withTempPath { dir => - val path = s"${dir.getCanonicalPath}/part-r-0.parquet" - sqlContext.range(1 << 16).selectExpr("(id % 4) AS i") - .coalesce(1).write.options(extraOptions).mode("overwrite").parquet(path) + withSQLConf(ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part-r-0.parquet" + spark.range(1 << 16).selectExpr("(id % 4) AS i") + .coalesce(1).write.options(extraOptions).mode("overwrite").parquet(path) - val blockMetadata = readFooter(new Path(path), hadoopConf).getBlocks.asScala.head - val columnChunkMetadata = blockMetadata.getColumns.asScala.head + val blockMetadata = readFooter(new Path(path), hadoopConf).getBlocks.asScala.head + val columnChunkMetadata = blockMetadata.getColumns.asScala.head - // If the file is written with version2, this should include - // Encoding.RLE_DICTIONARY type. For version1, it is Encoding.PLAIN_DICTIONARY - assert(columnChunkMetadata.getEncodings.contains(Encoding.RLE_DICTIONARY)) + // If the file is written with version2, this should include + // Encoding.RLE_DICTIONARY type. For version1, it is Encoding.PLAIN_DICTIONARY + assert(columnChunkMetadata.getEncodings.contains(Encoding.RLE_DICTIONARY)) + } } } test("null and non-null strings") { // Create a dataset where the first values are NULL and then some non-null values. The // number of non-nulls needs to be bigger than the ParquetReader batch size. - val data: Dataset[String] = sqlContext.range(200).map (i => + val data: Dataset[String] = spark.range(200).map (i => if (i < 150) null else "a" ) @@ -553,8 +557,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { checkAnswer( // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-i32.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + readResourceParquetFile("test-data/dec-in-i32.parquet"), + spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) } } } @@ -564,8 +568,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { checkAnswer( // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-i64.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + readResourceParquetFile("test-data/dec-in-i64.parquet"), + spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) } } } @@ -575,8 +579,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized) { checkAnswer( // Decimal column in this file is encoded using plain dictionary - readResourceParquetFile("dec-in-fixed-len.parquet"), - sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + readResourceParquetFile("test-data/dec-in-fixed-len.parquet"), + spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) } } } @@ -589,7 +593,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { var hash2: Int = 0 (false :: true :: Nil).foreach { v => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> v.toString) { - val df = sqlContext.read.parquet(dir.getCanonicalPath) + val df = spark.read.parquet(dir.getCanonicalPath) val rows = df.queryExecution.toRdd.map(_.copy()).collect() val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow]) if (!v) { @@ -607,7 +611,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { test("VectorizedParquetRecordReader - direct path read") { val data = (0 to 10).map(i => (i, (i + 'a').toChar.toString)) withTempPath { dir => - sqlContext.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) + spark.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); { val reader = new VectorizedParquetRecordReader @@ -674,6 +678,52 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } } + + test("VectorizedParquetRecordReader - partition column types") { + withTempPath { dir => + Seq(1).toDF().repartition(1).write.parquet(dir.getCanonicalPath) + + val dataTypes = + Seq(StringType, BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType(25, 5), DateType, TimestampType) + + val constantValues = + Seq( + UTF8String.fromString("a string"), + true, + 1.toByte, + 2.toShort, + 3, + Long.MaxValue, + 0.25.toFloat, + 0.75D, + Decimal("1234.23456"), + DateTimeUtils.fromJavaDate(java.sql.Date.valueOf("2015-01-01")), + DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"))) + + dataTypes.zip(constantValues).foreach { case (dt, v) => + val schema = StructType(StructField("pcol", dt) :: Nil) + val vectorizedReader = new VectorizedParquetRecordReader + val partitionValues = new GenericMutableRow(Array(v)) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0) + + try { + vectorizedReader.initialize(file, null) + vectorizedReader.initBatch(schema, partitionValues) + vectorizedReader.nextKeyValue() + val row = vectorizedReader.getCurrentValue.asInstanceOf[InternalRow] + + // Use `GenericMutableRow` by explicitly copying rather than `ColumnarBatch` + // in order to use get(...) method which is not implemented in `ColumnarBatch`. + val actual = row.copy().get(1, dt) + val expected = v + assert(actual == expected) + } finally { + vectorizedReader.close() + } + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala index 83b65fb419ed3..9dc56292c3720 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetInteroperabilitySuite.scala @@ -81,7 +81,7 @@ class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedS logParquetSchema(protobufStylePath) checkAnswer( - sqlContext.read.parquet(dir.getCanonicalPath), + spark.read.parquet(dir.getCanonicalPath), Seq( Row(Seq(0, 1)), Row(Seq(2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 5bffb307ec80e..8d18be9300f7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -25,11 +25,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -191,6 +193,29 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha checkThrows[AssertionError]("file://path/a=", "Empty partition column value") } + test("parse partition with base paths") { + // when the basePaths is the same as the path to a leaf directory + val partitionSpec1: Option[PartitionValues] = parsePartition( + path = new Path("file://path/a=10"), + defaultPartitionName = defaultPartitionName, + typeInference = true, + basePaths = Set(new Path("file://path/a=10")))._1 + + assert(partitionSpec1.isEmpty) + + // when the basePaths is the path to a base directory of leaf directories + val partitionSpec2: Option[PartitionValues] = parsePartition( + path = new Path("file://path/a=10"), + defaultPartitionName = defaultPartitionName, + typeInference = true, + basePaths = Set(new Path("file://path")))._1 + + assert(partitionSpec2 == + Option(PartitionValues( + ArrayBuffer("a"), + ArrayBuffer(Literal.create(10, IntegerType))))) + } + test("parse partitions") { def check( paths: Seq[String], @@ -377,9 +402,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha // Introduce _temporary dir to the base dir the robustness of the schema discovery process. new File(base.getCanonicalPath, "_temporary").mkdir() - sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + spark.read.parquet(base.getCanonicalPath).createOrReplaceTempView("t") - withTempTable("t") { + withTempView("t") { checkAnswer( sql("SELECT * FROM t"), for { @@ -413,6 +438,43 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } + test("read partitioned table using different path options") { + withTempDir { base => + val pi = 1 + val ps = "foo" + val path = makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps) + makeParquetFile( + (1 to 10).map(i => ParquetData(i, i.toString)), path) + + // when the input is the base path containing partitioning directories + val baseDf = spark.read.parquet(base.getCanonicalPath) + assert(baseDf.schema.map(_.name) === Seq("intField", "stringField", "pi", "ps")) + + // when the input is a path to the leaf directory containing a parquet file + val partDf = spark.read.parquet(path.getCanonicalPath) + assert(partDf.schema.map(_.name) === Seq("intField", "stringField")) + + path.listFiles().foreach { f => + if (f.getName.toLowerCase().endsWith(".parquet")) { + // when the input is a path to a parquet file + val df = spark.read.parquet(f.getCanonicalPath) + assert(df.schema.map(_.name) === Seq("intField", "stringField")) + } + } + + path.listFiles().foreach { f => + if (f.getName.toLowerCase().endsWith(".parquet")) { + // when the input is a path to a parquet file but `basePath` is overridden to + // the base path containing partitioning directories + val df = spark + .read.option("basePath", base.getCanonicalPath) + .parquet(f.getCanonicalPath) + assert(df.schema.map(_.name) === Seq("intField", "stringField", "pi", "ps")) + } + } + } + } + test("read partitioned table - partition key included in Parquet file") { withTempDir { base => for { @@ -424,9 +486,9 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") + spark.read.parquet(base.getCanonicalPath).createOrReplaceTempView("t") - withTempTable("t") { + withTempView("t") { checkAnswer( sql("SELECT * FROM t"), for { @@ -472,10 +534,10 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) - parquetRelation.registerTempTable("t") + val parquetRelation = spark.read.format("parquet").load(base.getCanonicalPath) + parquetRelation.createOrReplaceTempView("t") - withTempTable("t") { + withTempView("t") { checkAnswer( sql("SELECT * FROM t"), for { @@ -512,10 +574,10 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) - parquetRelation.registerTempTable("t") + val parquetRelation = spark.read.format("parquet").load(base.getCanonicalPath) + parquetRelation.createOrReplaceTempView("t") - withTempTable("t") { + withTempView("t") { checkAnswer( sql("SELECT * FROM t"), for { @@ -544,14 +606,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - sqlContext + spark .read .option("mergeSchema", "true") .format("parquet") .load(base.getCanonicalPath) - .registerTempTable("t") + .createOrReplaceTempView("t") - withTempTable("t") { + withTempView("t") { checkAnswer( sql("SELECT * FROM t"), (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2))) @@ -562,7 +624,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("SPARK-7749 Non-partitioned table should have empty partition spec") { withTempPath { dir => (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) - val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution + val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation(relation: HadoopFsRelation, _, _) => assert(relation.partitionSpec === PartitionSpec.emptySpec) @@ -576,7 +638,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempPath { dir => val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) - checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect()) + checkAnswer(spark.read.parquet(dir.getCanonicalPath), df.collect()) } } @@ -616,12 +678,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = sqlContext.createDataFrame(sparkContext.parallelize(row :: Nil), schema) + val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name).cast(f.dataType)) - checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row) + checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } } @@ -637,7 +699,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(dir.getCanonicalPath), df) } } @@ -654,7 +716,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df) } withTempPath { dir => @@ -671,7 +733,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${tablePath.getCanonicalPath}/", "_SUCCESS")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df) } } @@ -686,7 +748,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha .save(tablePath.getCanonicalPath) val twoPartitionsDF = - sqlContext + spark .read .option("basePath", tablePath.getCanonicalPath) .parquet( @@ -696,7 +758,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha checkAnswer(twoPartitionsDF, df.filter("b != 3")) intercept[AssertionError] { - sqlContext + spark .read .parquet( s"${tablePath.getCanonicalPath}/b=1", @@ -705,6 +767,53 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } + test("use basePath and file globbing to selectively load partitioned table") { + withTempPath { dir => + + val df = Seq( + (1, "foo", 100), + (1, "bar", 200), + (2, "foo", 300), + (2, "bar", 400) + ).toDF("p1", "p2", "v") + df.write + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .parquet(dir.getCanonicalPath) + + def check(path: String, basePath: String, expectedDf: DataFrame): Unit = { + val testDf = spark.read + .option("basePath", basePath) + .parquet(path) + checkAnswer(testDf, expectedDf) + } + + // Should find all the data with partitioning columns when base path is set to the root + val resultDf = df.select("v", "p1", "p2") + check(path = s"$dir", basePath = s"$dir", resultDf) + check(path = s"$dir/*", basePath = s"$dir", resultDf) + check(path = s"$dir/*/*", basePath = s"$dir", resultDf) + check(path = s"$dir/*/*/*", basePath = s"$dir", resultDf) + + // Should find selective partitions of the data if the base path is not set to root + + check( // read from ../p1=1 with base ../p1=1, should not infer p1 col + path = s"$dir/p1=1/*", + basePath = s"$dir/p1=1/", + resultDf.filter("p1 = 1").drop("p1")) + + check( // red from ../p1=1/p2=foo with base ../p1=1/ should not infer p1 + path = s"$dir/p1=1/p2=foo/*", + basePath = s"$dir/p1=1/", + resultDf.filter("p1 = 1").filter("p2 = 'foo'").drop("p1")) + + check( // red from ../p1=1/p2=foo with base ../p1=1/p2=foo, should not infer p1, p2 + path = s"$dir/p1=1/p2=foo/*", + basePath = s"$dir/p1=1/p2=foo/", + resultDf.filter("p1 = 1").filter("p2 = 'foo'").drop("p1", "p2")) + } + } + test("_SUCCESS should not break partitioning discovery") { Seq(1, 32).foreach { threshold => // We have two paths to list files, one at driver side, another one that we use @@ -722,7 +831,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1", "_SUCCESS")) Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1", "_SUCCESS")) Files.touch(new File(s"${tablePath.getCanonicalPath}/b=1/c=1/d=1", "_SUCCESS")) - checkAnswer(sqlContext.read.format("parquet").load(tablePath.getCanonicalPath), df) + checkAnswer(spark.read.format("parquet").load(tablePath.getCanonicalPath), df) } } } @@ -777,10 +886,52 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha withTempPath { dir => withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { val path = dir.getCanonicalPath - val df = sqlContext.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) + val df = spark.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) df.write.partitionBy("b", "c").parquet(path) - checkAnswer(sqlContext.read.parquet(path), df) + checkAnswer(spark.read.parquet(path), df) } } } + + test("SPARK-15895 summary files in non-leaf partition directories") { + withTempPath { dir => + val path = dir.getCanonicalPath + + withSQLConf(ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true") { + spark.range(3).write.parquet(s"$path/p0=0/p1=0") + } + + val p0 = new File(path, "p0=0") + val p1 = new File(p0, "p1=0") + + // Builds the following directory layout by: + // + // 1. copying Parquet summary files we just wrote into `p0=0`, and + // 2. touching a dot-file `.dummy` under `p0=0`. + // + // + // +- p0=0 + // |- _metadata + // |- _common_metadata + // |- .dummy + // +- p1=0 + // |- _metadata + // |- _common_metadata + // |- part-00000.parquet + // |- part-00001.parquet + // +- ... + // + // The summary files and the dot-file under `p0=0` should not fail partition discovery. + + Files.copy(new File(p1, "_metadata"), new File(p0, "_metadata")) + Files.copy(new File(p1, "_common_metadata"), new File(p0, "_common_metadata")) + Files.touch(new File(p0, ".dummy")) + + checkAnswer(spark.read.parquet(s"$path"), Seq( + Row(0, 0, 0), + Row(1, 0, 0), + Row(2, 0, 0) + )) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala index 98333e58cada8..fa88019298a69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.test.SharedSQLContext class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { test("unannotated array of primitive type") { - checkAnswer(readResourceParquetFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3))) + checkAnswer(readResourceParquetFile("test-data/old-repeated-int.parquet"), Row(Seq(1, 2, 3))) } test("unannotated array of struct") { checkAnswer( - readResourceParquetFile("old-repeated-message.parquet"), + readResourceParquetFile("test-data/old-repeated-message.parquet"), Row( Seq( Row("First inner", null, null), @@ -35,14 +35,14 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh Row(null, null, "Third inner")))) checkAnswer( - readResourceParquetFile("proto-repeated-struct.parquet"), + readResourceParquetFile("test-data/proto-repeated-struct.parquet"), Row( Seq( Row("0 - 1", "0 - 2", "0 - 3"), Row("1 - 1", "1 - 2", "1 - 3")))) checkAnswer( - readResourceParquetFile("proto-struct-with-array-many.parquet"), + readResourceParquetFile("test-data/proto-struct-with-array-many.parquet"), Seq( Row( Seq( @@ -60,13 +60,13 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh test("struct with unannotated array") { checkAnswer( - readResourceParquetFile("proto-struct-with-array.parquet"), + readResourceParquetFile("test-data/proto-struct-with-array.parquet"), Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10)))) } test("unannotated array of struct with unannotated array") { checkAnswer( - readResourceParquetFile("nested-array-struct.parquet"), + readResourceParquetFile("test-data/nested-array-struct.parquet"), Seq( Row(2, Seq(Row(1, Seq(Row(3))))), Row(5, Seq(Row(4, Seq(Row(6))))), @@ -75,7 +75,7 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh test("unannotated array of string") { checkAnswer( - readResourceParquetFile("proto-repeated-string.parquet"), + readResourceParquetFile("test-data/proto-repeated-string.parquet"), Seq( Row(Seq("hello", "world")), Row(Seq("good", "bye")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index f1e9726c38b02..3201f8e0dea1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} @@ -46,27 +47,55 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") // Query appends, don't test with both read modes. withParquetTable(data, "t", false) { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(spark.table("t"), (data ++ data).map(Row.fromTuple)) } - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("tmp"), ignoreIfNotExists = true) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + spark.createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) + checkAnswer(spark.table("t"), data.map(Row.fromTuple)) } - sqlContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("tmp"), ignoreIfNotExists = true) } + test("SPARK-15678: not use cache on overwrite") { + withTempDir { dir => + val path = dir.toString + spark.range(1000).write.mode("overwrite").parquet(path) + val df = spark.read.parquet(path).cache() + assert(df.count() == 1000) + spark.range(10).write.mode("overwrite").parquet(path) + assert(df.count() == 1000) + spark.catalog.refreshByPath(path) + assert(df.count() == 10) + assert(spark.read.parquet(path).count() == 10) + } + } + + test("SPARK-15678: not use cache on append") { + withTempDir { dir => + val path = dir.toString + spark.range(1000).write.mode("append").parquet(path) + val df = spark.read.parquet(path).cache() + assert(df.count() == 1000) + spark.range(10).write.mode("append").parquet(path) + assert(df.count() == 1000) + spark.catalog.refreshByPath(path) + assert(df.count() == 1010) + assert(spark.read.parquet(path).count() == 1010) + } + } + test("self-join") { // 4 rows, cells of column 1 of row 2 and row 4 are null val data = (1 to 4).map { i => @@ -128,9 +157,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) - val df2 = sqlContext.read.parquet(file.getCanonicalPath) + val df2 = spark.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) } } @@ -139,22 +168,27 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) // delete summary files, so if we don't merge part-files, one column will not be included. Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + assert(spark.read.parquet(basePath).columns.length === expectedColumnNumber) } } - withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", - SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") { + withSQLConf( + SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true", + ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true" + ) { testSchemaMerging(2) } - withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", - SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") { + withSQLConf( + SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false" + ) { testSchemaMerging(3) } } @@ -163,9 +197,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) - assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + assert(spark.read.parquet(basePath).columns.length === expectedColumnNumber) } } @@ -181,19 +215,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath - sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) - sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) + spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + spark.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString) // Disables the global SQL option for schema merging withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") { assertResult(2) { // Disables schema merging via data source option - sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length + spark.read.option("mergeSchema", "false").parquet(basePath).columns.length } assertResult(3) { // Enables schema merging via data source option - sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length + spark.read.option("mergeSchema", "true").parquet(basePath).columns.length } } } @@ -204,10 +238,10 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val basePath = dir.getCanonicalPath val schema = StructType(Array(StructField("name", DecimalType(10, 5), false))) val rowRDD = sparkContext.parallelize(Array(Row(Decimal("67123.45")))) - val df = sqlContext.createDataFrame(rowRDD, schema) + val df = spark.createDataFrame(rowRDD, schema) df.write.parquet(basePath) - val decimal = sqlContext.read.parquet(basePath).first().getDecimal(0) + val decimal = spark.read.parquet(basePath).first().getDecimal(0) assert(Decimal("67123.45") === Decimal(decimal)) } } @@ -227,7 +261,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { checkAnswer( - sqlContext.read.option("mergeSchema", "true").parquet(path), + spark.read.option("mergeSchema", "true").parquet(path), Seq( Row(Row(1, 1, null)), Row(Row(2, 2, null)), @@ -240,7 +274,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - same schema") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -253,7 +287,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 1L))) } } @@ -261,12 +295,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-11997 parquet with null partition values") { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(1, 3) + spark.range(1, 3) .selectExpr("if(id % 2 = 0, null, id) AS n", "id") .write.partitionBy("n").parquet(path) checkAnswer( - sqlContext.read.parquet(path).filter("n is null"), + spark.read.parquet(path).filter("n is null"), Row(2, null)) } } @@ -275,7 +309,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext ignore("SPARK-10301 requested schema clipping - schemas with disjoint sets of fields") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -288,7 +322,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(null, null))) } } @@ -296,7 +330,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - requested schema contains physical schema") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'b', id + 1) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -311,13 +345,13 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 1L, null, null))) } withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) + val df = spark.range(1).selectExpr("NAMED_STRUCT('a', id, 'd', id + 3) AS s").coalesce(1) df.write.parquet(path) val userDefinedSchema = @@ -332,7 +366,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, null, null, 3L))) } } @@ -340,7 +374,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - physical schema contains requested schema") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) @@ -357,13 +391,13 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 1L))) } withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2, 'd', id + 3) AS s") .coalesce(1) @@ -380,7 +414,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(0L, 3L))) } } @@ -388,7 +422,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("SPARK-10301 requested schema clipping - schemas overlap but don't contain each other") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) @@ -406,7 +440,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(1L, 2L, null))) } } @@ -415,7 +449,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr("NAMED_STRUCT('a', ARRAY(NAMED_STRUCT('b', id, 'c', id))) AS s") .coalesce(1) @@ -436,7 +470,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(Seq(Row(0, null))))) } } @@ -445,12 +479,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df1 = sqlContext + val df1 = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) - val df2 = sqlContext + val df2 = spark .range(1, 2) .selectExpr("NAMED_STRUCT('c', id + 2, 'b', id + 1, 'd', id + 3) AS s") .coalesce(1) @@ -467,7 +501,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Seq( Row(Row(0, 1, null)), Row(Row(null, 2, 4)))) @@ -478,12 +512,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df1 = sqlContext + val df1 = spark .range(1) .selectExpr("NAMED_STRUCT('a', id, 'c', id + 2) AS s") .coalesce(1) - val df2 = sqlContext + val df2 = spark .range(1, 2) .selectExpr("NAMED_STRUCT('a', id, 'b', id + 1, 'c', id + 2) AS s") .coalesce(1) @@ -492,7 +526,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext df2.write.mode(SaveMode.Append).parquet(path) checkAnswer( - sqlContext + spark .read .option("mergeSchema", "true") .parquet(path) @@ -507,7 +541,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext + val df = spark .range(1) .selectExpr( """NAMED_STRUCT( @@ -532,7 +566,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext nullable = true) checkAnswer( - sqlContext.read.schema(userDefinedSchema).parquet(path), + spark.read.schema(userDefinedSchema).parquet(path), Row(Row(NestedStruct(1, 2L, 3.5D)))) } } @@ -540,7 +574,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext test("expand UDT in StructType") { val schema = new StructType().add("n", new NestedStructUDT, nullable = true) val expected = new StructType().add("n", new NestedStructUDT().sqlType, nullable = true) - assert(CatalystReadSupport.expandUDT(schema) === expected) + assert(ParquetReadSupport.expandUDT(schema) === expected) } test("expand UDT in ArrayType") { @@ -558,7 +592,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext containsNull = false), nullable = true) - assert(CatalystReadSupport.expandUDT(schema) === expected) + assert(ParquetReadSupport.expandUDT(schema) === expected) } test("expand UDT in MapType") { @@ -578,34 +612,24 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext valueContainsNull = false), nullable = true) - assert(CatalystReadSupport.expandUDT(schema) === expected) - } - - test("read/write wide table") { - withTempPath { dir => - val path = dir.getCanonicalPath - - val df = sqlContext.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) - df.write.mode(SaveMode.Overwrite).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df) - } + assert(ParquetReadSupport.expandUDT(schema) === expected) } test("returning batch for wide table") { - withSQLConf("spark.sql.codegen.maxFields" -> "100") { + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "10") { withTempPath { dir => val path = dir.getCanonicalPath - val df = sqlContext.range(100).select(Seq.tabulate(110) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(10).select(Seq.tabulate(11) {i => ('id + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) // donot return batch, because whole stage codegen is disabled for wide table (>200 columns) - val df2 = sqlContext.read.parquet(path) + val df2 = spark.read.parquet(path) assert(df2.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isEmpty, "Should not return batch") checkAnswer(df2, df) // return batch - val columns = Seq.tabulate(90) {i => s"c$i"} + val columns = Seq.tabulate(9) {i => s"c$i"} val df3 = df2.selectExpr(columns : _*) assert( df3.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isDefined, @@ -614,6 +638,60 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("SPARK-15719: disable writing summary files by default") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(3).write.parquet(path) + + val fs = FileSystem.get(sparkContext.hadoopConfiguration) + val files = fs.listFiles(new Path(path), true) + + while (files.hasNext) { + val file = files.next + assert(!file.getPath.getName.contains("_metadata")) + } + } + } + + test("SPARK-15804: write out the metadata to parquet file") { + val df = Seq((1, "abc"), (2, "hello")).toDF("a", "b") + val md = new MetadataBuilder().putString("key", "value").build() + val dfWithmeta = df.select('a, 'b.as("b", md)) + + withTempPath { dir => + val path = dir.getCanonicalPath + dfWithmeta.write.parquet(path) + + readParquetFile(path) { df => + assert(df.schema.last.metadata.getString("key") == "value") + } + } + } + + test("SPARK-16632: read Parquet int32 as ByteType and ShortType") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + withTempPath { dir => + val path = dir.getCanonicalPath + + // When being written to Parquet, `TINYINT` and `SMALLINT` should be converted into + // `int32 (INT_8)` and `int32 (INT_16)` respectively. However, Hive doesn't add the `INT_8` + // and `INT_16` annotation properly (HIVE-14294). Thus, when reading files written by Hive + // using Spark with the vectorized Parquet reader enabled, we may hit error due to type + // mismatch. + // + // Here we are simulating Hive's behavior by writing a single `INT` field and then read it + // back as `TINYINT` and `SMALLINT` in Spark to verify this issue. + Seq(1).toDF("f").write.parquet(path) + + val withByteField = new StructType().add("f", ByteType) + checkAnswer(spark.read.schema(withByteField).parquet(path), Row(1: Byte)) + + val withShortField = new StructType().add("f", ShortType) + checkAnswer(spark.read.schema(withShortField).parquet(path), Row(1: Short)) + } + } + } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index cef541f0444be..487d7a7e5ac88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -21,9 +21,9 @@ import java.io.File import scala.collection.JavaConverters._ import scala.util.Try -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.util.{Benchmark, Utils} /** @@ -34,12 +34,16 @@ import org.apache.spark.util.{Benchmark, Utils} object ParquetReadBenchmark { val conf = new SparkConf() conf.set("spark.sql.parquet.compression.codec", "snappy") - val sc = new SparkContext("local[1]", "test-sql-context", conf) - val sqlContext = new SQLContext(sc) + + val spark = SparkSession.builder + .master("local[1]") + .appName("test-sql-context") + .config(conf) + .getOrCreate() // Set default configs. Individual cases will change them if necessary. - sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") def withTempPath(f: File => Unit): Unit = { val path = Utils.createTempDir() @@ -48,17 +52,17 @@ object ParquetReadBenchmark { } def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(sqlContext.dropTempTable) + try f finally tableNames.foreach(spark.catalog.dropTempView) } def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) } } } @@ -71,18 +75,18 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast(id as INT) as id from t1") + spark.range(values).createOrReplaceTempView("t1") + spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(id) from tempTable").collect() + spark.sql("select sum(id) from tempTable").collect() } sqlBenchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(id) from tempTable").collect() + spark.sql("select sum(id) from tempTable").collect() } } @@ -155,20 +159,20 @@ object ParquetReadBenchmark { def intStringScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") + spark.range(values).createOrReplaceTempView("t1") + spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") val benchmark = new Benchmark("Int and String Scan", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect } benchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect } } @@ -189,20 +193,20 @@ object ParquetReadBenchmark { def stringDictionaryScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") + spark.range(values).createOrReplaceTempView("t1") + spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") val benchmark = new Benchmark("String Dictionary", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(length(c1)) from tempTable").collect + spark.sql("select sum(length(c1)) from tempTable").collect } benchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(length(c1)) from tempTable").collect + spark.sql("select sum(length(c1)) from tempTable").collect } } @@ -221,23 +225,23 @@ object ParquetReadBenchmark { def partitionTableScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select id % 2 as p, cast(id as INT) as id from t1") + spark.range(values).createOrReplaceTempView("t1") + spark.sql("select id % 2 as p, cast(id as INT) as id from t1") .write.partitionBy("p").parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") val benchmark = new Benchmark("Partitioned Table", values) benchmark.addCase("Read data column") { iter => - sqlContext.sql("select sum(id) from tempTable").collect + spark.sql("select sum(id) from tempTable").collect } benchmark.addCase("Read partition column") { iter => - sqlContext.sql("select sum(p) from tempTable").collect + spark.sql("select sum(p) from tempTable").collect } benchmark.addCase("Read both columns") { iter => - sqlContext.sql("select sum(p), sum(id) from tempTable").collect + spark.sql("select sum(p), sum(id) from tempTable").collect } /* @@ -256,16 +260,16 @@ object ParquetReadBenchmark { def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + + spark.range(values).createOrReplaceTempView("t1") + spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tempTable") val benchmark = new Benchmark("String with Nulls Scan", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(length(c2)) from tempTable where c1 is " + + spark.sql("select sum(length(c2)) from tempTable where c1 is " + "not NULL and c2 is not NULL").collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 90e3d50714ef3..51bb236fe8441 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -54,7 +54,7 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { binaryAsString: Boolean, int96AsTimestamp: Boolean, writeLegacyParquetFormat: Boolean): Unit = { - val converter = new CatalystSchemaConverter( + val converter = new ParquetSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, writeLegacyParquetFormat = writeLegacyParquetFormat) @@ -78,7 +78,7 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { binaryAsString: Boolean, int96AsTimestamp: Boolean, writeLegacyParquetFormat: Boolean): Unit = { - val converter = new CatalystSchemaConverter( + val converter = new ParquetSchemaConverter( assumeBinaryIsString = binaryAsString, assumeInt96IsTimestamp = int96AsTimestamp, writeLegacyParquetFormat = writeLegacyParquetFormat) @@ -375,7 +375,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructField("lowerCase", StringType), StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("lowercase", StringType), StructField("uppercase", DoubleType, nullable = false))), @@ -390,7 +390,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructType(Seq( StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -401,7 +401,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Metastore schema contains additional non-nullable fields. assert(intercept[Throwable] { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false), StructField("lowerCase", BinaryType, nullable = false))), @@ -412,7 +412,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Conflicting non-nullable field names intercept[Throwable] { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq(StructField("lower", StringType, nullable = false))), StructType(Seq(StructField("lowerCase", BinaryType)))) } @@ -426,7 +426,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructField("firstField", StringType, nullable = true), StructField("secondField", StringType, nullable = true), StructField("thirdfield", StringType, nullable = true)))) { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), @@ -439,7 +439,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Merge should fail if the Metastore contains any additional fields that are not // nullable. assert(intercept[Throwable] { - ParquetRelation.mergeMetastoreParquetSchema( + ParquetFileFormat.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), @@ -451,31 +451,18 @@ class ParquetSchemaSuite extends ParquetSchemaTest { } test("schema merging failure error message") { - withTempPath { dir => - val path = dir.getCanonicalPath - sqlContext.range(3).write.parquet(s"$path/p=1") - sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") - - val message = intercept[SparkException] { - sqlContext.read.option("mergeSchema", "true").parquet(path).schema - }.getMessage - - assert(message.contains("Failed merging schema of file")) - } + import testImplicits._ - // test for second merging (after read Parquet schema in parallel done) withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(3).write.parquet(s"$path/p=1") - sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") - - sqlContext.sparkContext.conf.set("spark.default.parallelism", "20") + spark.range(3).write.parquet(s"$path/p=1") + spark.range(3).select('id cast IntegerType as 'id).write.parquet(s"$path/p=2") val message = intercept[SparkException] { - sqlContext.read.option("mergeSchema", "true").parquet(path).schema + spark.read.option("mergeSchema", "true").parquet(path).schema }.getMessage - assert(message.contains("Failed merging schema:")) + assert(message.contains("Failed merging schema")) } } @@ -1067,7 +1054,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { expectedSchema: String): Unit = { test(s"Clipping - $testName") { val expected = MessageTypeParser.parseMessageType(expectedSchema) - val actual = CatalystReadSupport.clipParquetSchema( + val actual = ParquetReadSupport.clipParquetSchema( MessageTypeParser.parseMessageType(parquetSchema), catalystSchema) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index e8c524e9e550d..85efca3c4b24d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -52,7 +52,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (true :: false :: Nil).foreach { vectorized => if (!vectorized || testVectorized) { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - f(sqlContext.read.parquet(path.toString)) + f(spark.read.parquet(path.toString)) } } } @@ -66,7 +66,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + spark.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -90,14 +90,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T], tableName: String, testVectorized: Boolean = true) (f: => Unit): Unit = { withParquetDataFrame(data, testVectorized) { df => - sqlContext.registerDataFrameAsTable(df, tableName) - withTempTable(tableName)(f) + df.createOrReplaceTempView(tableName) + withTempView(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + spark.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( @@ -124,8 +124,8 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def writeMetadata( schema: StructType, path: Path, configuration: Configuration): Unit = { - val parquetSchema = new CatalystSchemaConverter().convert(schema) - val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> schema.json).asJava + val parquetSchema = new ParquetSchemaConverter().convert(schema) + val extraMetadata = Map(ParquetReadSupport.SPARK_METADATA_KEY -> schema.json).asJava val createdBy = s"Apache Spark ${org.apache.spark.SPARK_VERSION}" val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, createdBy) val parquetMetadata = new ParquetMetadata(fileMetadata, Seq.empty[BlockMetaData].asJava) @@ -173,6 +173,6 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def readResourceParquetFile(name: String): DataFrame = { val url = Thread.currentThread().getContextClassLoader.getResource(name) - sqlContext.read.parquet(url.toString) + spark.read.parquet(url.toString) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 88a3d878f97fe..4157a5b46dc42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.test.SharedSQLContext class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext { import ParquetCompatibilityTest._ - private val parquetFilePath = - Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") + private val parquetFilePath = Thread.currentThread().getContextClassLoader.getResource( + "test-data/parquet-thrift-compat.snappy.parquet") test("Read Parquet file generated by parquet-thrift") { logInfo( @@ -32,7 +32,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar |${readParquetSchema(parquetFilePath.toString)} """.stripMargin) - checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => + checkAnswer(spark.read.parquet(parquetFilePath.toString), (0 until 10).map { i => val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") val nonNullablePrimitiveValues = Seq( @@ -139,7 +139,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar logParquetSchema(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(Seq(Seq(0, 1), Seq(2, 3))), Row(Seq(Seq(4, 5), Seq(6, 7))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala deleted file mode 100644 index fd562652974ec..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala +++ /dev/null @@ -1,1225 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.parquet - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.util.Benchmark - -/** - * Benchmark to measure TPCDS query performance. - * To run this: - * spark-submit --class --jars - */ -object TPCDSBenchmark { - val conf = new SparkConf() - conf.set("spark.sql.parquet.compression.codec", "snappy") - conf.set("spark.sql.shuffle.partitions", "4") - conf.set("spark.driver.memory", "3g") - conf.set("spark.executor.memory", "3g") - conf.set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString) - - val sc = new SparkContext("local[1]", "test-sql-context", conf) - val sqlContext = new SQLContext(sc) - - // These queries a subset of the TPCDS benchmark queries and are taken from - // https://github.com/databricks/spark-sql-perf/blob/master/src/main/scala/com/databricks/spark/ - // sql/perf/tpcds/ImpalaKitQueries.scala - val tpcds = Seq( - ("q19", """ - |select - | i_brand_id, - | i_brand, - | i_manufact_id, - | i_manufact, - | sum(ss_ext_sales_price) ext_price - |from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | join customer on (store_sales.ss_customer_sk = customer.c_customer_sk) - | join customer_address on - | (customer.c_current_addr_sk = customer_address.ca_address_sk) - |where - | ss_sold_date_sk between 2451484 and 2451513 - | and d_moy = 11 - | and d_year = 1999 - | and i_manager_id = 7 - | and substr(ca_zip, 1, 5) <> substr(s_zip, 1, 5) - |group by - | i_brand, - | i_brand_id, - | i_manufact_id, - | i_manufact - |order by - | ext_price desc, - | i_brand, - | i_brand_id, - | i_manufact_id, - | i_manufact - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q19 1710 / 1858 8.7 114.5 1.0X - */ - - ("q27", """ - |select - | i_item_id, - | s_state, - | avg(ss_quantity) agg1, - | avg(ss_list_price) agg2, - | avg(ss_coupon_amt) agg3, - | avg(ss_sales_price) agg4 - |from - | store_sales - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | join customer_demographics on - | (store_sales.ss_cdemo_sk = customer_demographics.cd_demo_sk) - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - |where - | ss_sold_date_sk between 2450815 and 2451179 -- partition key filter - | and d_year = 1998 - | and cd_gender = 'F' - | and cd_marital_status = 'W' - | and cd_education_status = 'Primary' - | and s_state in ('WI', 'CA', 'TX', 'FL', 'WA', 'TN') - |group by - | i_item_id, - | s_state - |order by - | i_item_id, - | s_state - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q27 2016 / 2180 8.2 122.6 1.0X - */ - - ("q3", """ - |select - | dt.d_year, - | item.i_brand_id brand_id, - | item.i_brand brand, - | sum(ss_ext_sales_price) sum_agg - |from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join date_dim dt on (dt.d_date_sk = store_sales.ss_sold_date_sk) - |where - | item.i_manufact_id = 436 - | and dt.d_moy = 12 - | and (ss_sold_date_sk between 2451149 and 2451179 - | or ss_sold_date_sk between 2451514 and 2451544 - | or ss_sold_date_sk between 2451880 and 2451910 - | or ss_sold_date_sk between 2452245 and 2452275 - | or ss_sold_date_sk between 2452610 and 2452640) - |group by - | d_year, - | item.i_brand, - | item.i_brand_id - |order by - | d_year, - | sum_agg desc, - | brand_id - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q3 1073 / 1140 13.5 73.9 1.0X - */ - - ("q34", """ - |select - | c_last_name, - | c_first_name, - | c_salutation, - | c_preferred_cust_flag, - | ss_ticket_number, - | cnt - |from - | (select - | ss_ticket_number, - | ss_customer_sk, - | count(*) cnt - | from - | store_sales - | join household_demographics on - | (store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk) - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | where - | date_dim.d_year in (1998, 1998 + 1, 1998 + 2) - | and (date_dim.d_dom between 1 and 3 - | or date_dim.d_dom between 25 and 28) - | and (household_demographics.hd_buy_potential = '>10000' - | or household_demographics.hd_buy_potential = 'unknown') - | and household_demographics.hd_vehicle_count > 0 - | and (case when household_demographics.hd_vehicle_count > 0 then - | household_demographics.hd_dep_count / household_demographics.hd_vehicle_count - | else null end) > 1.2 - | and ss_sold_date_sk between 2450816 and 2451910 -- partition key filter - | group by - | ss_ticket_number, - | ss_customer_sk - | ) dn - |join customer on (dn.ss_customer_sk = customer.c_customer_sk) - |where - | cnt between 15 and 20 - |order by - | c_last_name, - | c_first_name, - | c_salutation, - | c_preferred_cust_flag desc, - | ss_ticket_number, - | cnt - |limit 1000 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q34 1482 / 1734 10.0 100.4 1.0X - */ - - ("q42", """ - |select - | d_year, - | i_category_id, - | i_category, - | sum(ss_ext_sales_price) as total_price - |from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join date_dim dt on (dt.d_date_sk = store_sales.ss_sold_date_sk) - |where - | item.i_manager_id = 1 - | and dt.d_moy = 12 - | and dt.d_year = 1998 - | and ss_sold_date_sk between 2451149 and 2451179 -- partition key filter - |group by - | d_year, - | i_category_id, - | i_category - |order by - | total_price desc, - | d_year, - | i_category_id, - | i_category - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q42 1125 / 1357 12.9 77.4 1.0X - */ - - ("q43", """ - |select - | s_store_name, - | s_store_id, - | sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, - | sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales, - | sum(case when (d_day_name = 'Tuesday') then - | ss_sales_price else null end) tue_sales, - | sum(case when (d_day_name = 'Wednesday') then - | ss_sales_price else null end) wed_sales, - | sum(case when (d_day_name = 'Thursday') then - | ss_sales_price else null end) thu_sales, - | sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales, - | sum(case when (d_day_name = 'Saturday') then - | ss_sales_price else null end) sat_sales - |from - | store_sales - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - |where - | s_gmt_offset = -5 - | and d_year = 1998 - | and ss_sold_date_sk between 2450816 and 2451179 -- partition key filter - |group by - | s_store_name, - | s_store_id - |order by - | s_store_name, - | s_store_id, - | sun_sales, - | mon_sales, - | tue_sales, - | wed_sales, - | thu_sales, - | fri_sales, - | sat_sales - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q43 1681 / 1985 8.6 116.1 1.0X - */ - - ("q46", """ - |select - | c_last_name, - | c_first_name, - | ca_city, - | bought_city, - | ss_ticket_number, - | amt, - | profit - |from - | (select - | ss_ticket_number, - | ss_customer_sk, - | ca_city bought_city, - | sum(ss_coupon_amt) amt, - | sum(ss_net_profit) profit - | from - | store_sales - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | join household_demographics on - | (store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | join customer_address on - | (store_sales.ss_addr_sk = customer_address.ca_address_sk) - | where - | store.s_city in ('Midway', 'Concord', 'Spring Hill', 'Brownsville', 'Greenville') - | and (household_demographics.hd_dep_count = 5 - | or household_demographics.hd_vehicle_count = 3) - | and date_dim.d_dow in (6, 0) - | and date_dim.d_year in (1999, 1999 + 1, 1999 + 2) - | group by - | ss_ticket_number, - | ss_customer_sk, - | ss_addr_sk, - | ca_city - | ) dn - | join customer on (dn.ss_customer_sk = customer.c_customer_sk) - | join customer_address current_addr on - | (customer.c_current_addr_sk = current_addr.ca_address_sk) - |where - | current_addr.ca_city <> bought_city - |order by - | c_last_name, - | c_first_name, - | ca_city, - | bought_city, - | ss_ticket_number - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q46 2948 / 3218 5.1 196.1 1.0X - */ - - ("q52", """ - |select - | d_year, - | i_brand_id, - | i_brand, - | sum(ss_ext_sales_price) ext_price - |from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join date_dim dt on (store_sales.ss_sold_date_sk = dt.d_date_sk) - |where - | i_manager_id = 1 - | and d_moy = 12 - | and d_year = 1998 - | and ss_sold_date_sk between 2451149 and 2451179 -- partition key filter - |group by - | d_year, - | i_brand, - | i_brand_id - |order by - | d_year, - | ext_price desc, - | i_brand_id - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q52 1099 / 1228 13.2 75.7 1.0X - */ - - ("q53", """ - |select - | * - |from - | (select - | i_manufact_id, - | sum(ss_sales_price) sum_sales - | from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | where - | ss_sold_date_sk between 2451911 and 2452275 -- partition key filter - | and d_month_seq in(1212, 1212 + 1, 1212 + 2, 1212 + 3, 1212 + 4, 1212 + 5, - | 1212 + 6, 1212 + 7, 1212 + 8, 1212 + 9, 1212 + 10, 1212 + 11) - | and ( - | (i_category in('Books', 'Children', 'Electronics') - | and i_class in('personal', 'portable', 'reference', 'self-help') - | and i_brand in('scholaramalgamalg #14', 'scholaramalgamalg #7', - | 'exportiunivamalg #9', 'scholaramalgamalg #9') - | ) - | or - | (i_category in('Women', 'Music', 'Men') - | and i_class in('accessories', 'classical', 'fragrances', 'pants') - | and i_brand in('amalgimporto #1', 'edu packscholar #1', - | 'exportiimporto #1', 'importoamalg #1') - | ) - | ) - | group by - | i_manufact_id, - | d_qoy - | ) tmp1 - |order by - | sum_sales, - | i_manufact_id - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q53 968 / 1020 15.0 66.6 1.0X - */ - - ("q55", """ - |select - | i_brand_id, - | i_brand, - | sum(ss_ext_sales_price) ext_price - |from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - |where - | i_manager_id = 36 - | and d_moy = 12 - | and d_year = 2001 - | and ss_sold_date_sk between 2452245 and 2452275 -- partition key filter - |group by - | i_brand, - | i_brand_id - |order by - | ext_price desc, - | i_brand_id - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q55 1002 / 1020 14.5 69.0 1.0X - */ - - ("q59", """ - |select - | s_store_name1, - | s_store_id1, - | d_week_seq1, - | sun_sales1 / sun_sales2, - | mon_sales1 / mon_sales2, - | tue_sales1 / tue_sales2, - | wed_sales1 / wed_sales2, - | thu_sales1 / thu_sales2, - | fri_sales1 / fri_sales2, - | sat_sales1 / sat_sales2 - |from - | (select - | s_store_name s_store_name1, - | wss.d_week_seq d_week_seq1, - | s_store_id s_store_id1, - | sun_sales sun_sales1, - | mon_sales mon_sales1, - | tue_sales tue_sales1, - | wed_sales wed_sales1, - | thu_sales thu_sales1, - | fri_sales fri_sales1, - | sat_sales sat_sales1 - | from - | (select - | d_week_seq, - | ss_store_sk, - | sum(case when(d_day_name = 'Sunday') then - | ss_sales_price else null end) sun_sales, - | sum(case when(d_day_name = 'Monday') then - | ss_sales_price else null end) mon_sales, - | sum(case when(d_day_name = 'Tuesday') then - | ss_sales_price else null end) tue_sales, - | sum(case when(d_day_name = 'Wednesday') then - | ss_sales_price else null end) wed_sales, - | sum(case when(d_day_name = 'Thursday') then - | ss_sales_price else null end) thu_sales, - | sum(case when(d_day_name = 'Friday') then - | ss_sales_price else null end) fri_sales, - | sum(case when(d_day_name = 'Saturday') then - | ss_sales_price else null end) sat_sales - | from - | store_sales - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | where - | ss_sold_date_sk between 2451088 and 2451452 - | group by - | d_week_seq, - | ss_store_sk - | ) wss - | join store on (wss.ss_store_sk = store.s_store_sk) - | join date_dim d on (wss.d_week_seq = d.d_week_seq) - | where - | d_month_seq between 1185 and 1185 + 11 - | ) y - | join - | (select - | s_store_name s_store_name2, - | wss.d_week_seq d_week_seq2, - | s_store_id s_store_id2, - | sun_sales sun_sales2, - | mon_sales mon_sales2, - | tue_sales tue_sales2, - | wed_sales wed_sales2, - | thu_sales thu_sales2, - | fri_sales fri_sales2, - | sat_sales sat_sales2 - | from - | (select - | d_week_seq, - | ss_store_sk, - | sum(case when(d_day_name = 'Sunday') then - | ss_sales_price else null end) sun_sales, - | sum(case when(d_day_name = 'Monday') then - | ss_sales_price else null end) mon_sales, - | sum(case when(d_day_name = 'Tuesday') then - | ss_sales_price else null end) tue_sales, - | sum(case when(d_day_name = 'Wednesday') then - | ss_sales_price else null end) wed_sales, - | sum(case when(d_day_name = 'Thursday') then - | ss_sales_price else null end) thu_sales, - | sum(case when(d_day_name = 'Friday') then - | ss_sales_price else null end) fri_sales, - | sum(case when(d_day_name = 'Saturday') then - | ss_sales_price else null end) sat_sales - | from - | store_sales - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | where - | ss_sold_date_sk between 2451088 and 2451452 - | group by - | d_week_seq, - | ss_store_sk - | ) wss - | join store on (wss.ss_store_sk = store.s_store_sk) - | join date_dim d on (wss.d_week_seq = d.d_week_seq) - | where - | d_month_seq between 1185 + 12 and 1185 + 23 - | ) x - | on (y.s_store_id1 = x.s_store_id2) - |where - | d_week_seq1 = d_week_seq2 - 52 - |order by - | s_store_name1, - | s_store_id1, - | d_week_seq1 - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q59 1624 / 1663 17.9 55.8 1.0X - */ - - ("q63", """ - |select - | * - |from - | (select - | i_manager_id, - | sum(ss_sales_price) sum_sales - | from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | where - | ss_sold_date_sk between 2451911 and 2452275 -- partition key filter - | and d_month_seq in (1212, 1212 + 1, 1212 + 2, 1212 + 3, 1212 + 4, 1212 + 5, - | 1212 + 6, 1212 + 7, 1212 + 8, 1212 + 9, 1212 + 10, 1212 + 11) - | and ( - | (i_category in('Books', 'Children', 'Electronics') - | and i_class in('personal', 'portable', 'refernece', 'self-help') - | and i_brand in('scholaramalgamalg #14', 'scholaramalgamalg #7', - | 'exportiunivamalg #9', 'scholaramalgamalg #9') - | ) - | or - | (i_category in('Women', 'Music', 'Men') - | and i_class in('accessories', 'classical', 'fragrances', 'pants') - | and i_brand in('amalgimporto #1', 'edu packscholar #1', - | 'exportiimporto #1', 'importoamalg #1') - | ) - | ) - | group by - | i_manager_id, - | d_moy - | ) tmp1 - |order by - | i_manager_id, - | sum_sales - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q63 979 / 1006 14.8 67.4 1.0X - */ - - ("q65", """ - |select - | s_store_name, - | i_item_desc, - | sc.revenue, - | i_current_price, - | i_wholesale_cost, - | i_brand - |from - | (select - | ss_store_sk, - | ss_item_sk, - | sum(ss_sales_price) as revenue - | from - | store_sales - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | where - | ss_sold_date_sk between 2451911 and 2452275 -- partition key filter - | and d_month_seq between 1212 and 1212 + 11 - | group by - | ss_store_sk, - | ss_item_sk - | ) sc - | join item on (sc.ss_item_sk = item.i_item_sk) - | join store on (sc.ss_store_sk = store.s_store_sk) - | join - | (select - | ss_store_sk, - | avg(revenue) as ave - | from - | (select - | ss_store_sk, - | ss_item_sk, - | sum(ss_sales_price) as revenue - | from - | store_sales - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | where - | ss_sold_date_sk between 2451911 and 2452275 -- partition key filter - | and d_month_seq between 1212 and 1212 + 11 - | group by - | ss_store_sk, - | ss_item_sk - | ) sa - | group by - | ss_store_sk - | ) sb on (sc.ss_store_sk = sb.ss_store_sk) -- 676 rows - |where - | sc.revenue <= 0.1 * sb.ave - |order by - | s_store_name, - | i_item_desc - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q65 7770 / 8097 3.7 267.9 1.0X - */ - - ("q68", """ - |select - | c_last_name, - | c_first_name, - | ca_city, - | bought_city, - | ss_ticket_number, - | extended_price, - | extended_tax, - | list_price - |from - | (select - | ss_ticket_number, - | ss_customer_sk, - | ca_city bought_city, - | sum(ss_ext_sales_price) extended_price, - | sum(ss_ext_list_price) list_price, - | sum(ss_ext_tax) extended_tax - | from - | store_sales - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | join household_demographics on - | (store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | join customer_address on - | (store_sales.ss_addr_sk = customer_address.ca_address_sk) - | where - | store.s_city in('Midway', 'Fairview') - | --and date_dim.d_dom between 1 and 2 - | --and date_dim.d_year in(1999, 1999 + 1, 1999 + 2) - | -- and ss_date between '1999-01-01' and '2001-12-31' - | -- and dayofmonth(ss_date) in (1,2) - | and (household_demographics.hd_dep_count = 5 - | or household_demographics.hd_vehicle_count = 3) - | and d_date between '1999-01-01' and '1999-03-31' - | and ss_sold_date_sk between 2451180 and 2451269 - | -- partition key filter (3 months) - | group by - | ss_ticket_number, - | ss_customer_sk, - | ss_addr_sk, - | ca_city - | ) dn - | join customer on (dn.ss_customer_sk = customer.c_customer_sk) - | join customer_address current_addr on - | (customer.c_current_addr_sk = current_addr.ca_address_sk) - |where - | current_addr.ca_city <> bought_city - |order by - | c_last_name, - | ss_ticket_number - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q68 3105 / 3405 4.8 206.5 1.0X - */ - - ("q7", """ - |select - | i_item_id, - | avg(ss_quantity) agg1, - | avg(ss_list_price) agg2, - | avg(ss_coupon_amt) agg3, - | avg(ss_sales_price) agg4 - |from - | store_sales - | join customer_demographics on - | (store_sales.ss_cdemo_sk = customer_demographics.cd_demo_sk) - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join promotion on (store_sales.ss_promo_sk = promotion.p_promo_sk) - | join date_dim on (ss_sold_date_sk = d_date_sk) - |where - | cd_gender = 'F' - | and cd_marital_status = 'W' - | and cd_education_status = 'Primary' - | and (p_channel_email = 'N' - | or p_channel_event = 'N') - | and d_year = 1998 - | and ss_sold_date_sk between 2450815 and 2451179 -- partition key filter - |group by - | i_item_id - |order by - | i_item_id - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q7 2042 / 2333 8.1 124.2 1.0X - */ - - ("q73", """ - |select - | c_last_name, - | c_first_name, - | c_salutation, - | c_preferred_cust_flag, - | ss_ticket_number, - | cnt - |from - | (select - | ss_ticket_number, - | ss_customer_sk, - | count(*) cnt - | from - | store_sales - | join household_demographics on - | (store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk) - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | -- join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | where - | store.s_county in - | ('Williamson County','Franklin Parish','Bronx County','Orange County') - | -- and date_dim.d_dom between 1 and 2 - | -- and date_dim.d_year in(1998, 1998 + 1, 1998 + 2) - | -- and ss_date between '1999-01-01' and '2001-12-02' - | -- and dayofmonth(ss_date) in (1,2) - | -- partition key filter - | -- and ss_sold_date_sk in (2450816, 2450846, 2450847, 2450874, 2450875, 2450905, - | -- 2450906, 2450935, 2450936, 2450966, 2450967, - | -- 2450996, 2450997, 2451027, 2451028, 2451058, 2451059, - | -- 2451088, 2451089, 2451119, 2451120, 2451149, - | -- 2451150, 2451180, 2451181, 2451211, 2451212, 2451239, - | -- 2451240, 2451270, 2451271, 2451300, 2451301, - | -- 2451331, 2451332, 2451361, 2451362, 2451392, 2451393, - | -- 2451423, 2451424, 2451453, 2451454, 2451484, - | -- 2451485, 2451514, 2451515, 2451545, 2451546, 2451576, - | -- 2451577, 2451605, 2451606, 2451636, 2451637, - | -- 2451666, 2451667, 2451697, 2451698, 2451727, 2451728, - | -- 2451758, 2451759, 2451789, 2451790, 2451819, - | -- 2451820, 2451850, 2451851, 2451880, 2451881) - | and (household_demographics.hd_buy_potential = '>10000' - | or household_demographics.hd_buy_potential = 'unknown') - | and household_demographics.hd_vehicle_count > 0 - | and case when household_demographics.hd_vehicle_count > 0 then - | household_demographics.hd_dep_count / household_demographics.hd_vehicle_count - | else null end > 1 - | and ss_sold_date_sk between 2451180 and 2451269 - | -- partition key filter (3 months) - | group by - | ss_ticket_number, - | ss_customer_sk - | ) dj - | join customer on (dj.ss_customer_sk = customer.c_customer_sk) - |where - | cnt between 1 and 5 - |order by - | cnt desc - |limit 1000 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q73 1124 / 1221 13.1 76.5 1.0X - */ - - ("q79", """ - |select - | c_last_name, - | c_first_name, - | substr(s_city, 1, 30) as city, - | ss_ticket_number, - | amt, - | profit - |from - | (select - | ss_ticket_number, - | ss_customer_sk, - | s_city, - | sum(ss_coupon_amt) amt, - | sum(ss_net_profit) profit - | from - | store_sales - | join household_demographics on - | (store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | where - | store.s_number_employees between 200 and 295 - | and (household_demographics.hd_dep_count = 8 - | or household_demographics.hd_vehicle_count > 0) - | and date_dim.d_dow = 1 - | and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) - | -- and ss_date between '1998-01-01' and '2000-12-25' - | -- 156 days - | and d_date between '1999-01-01' and '1999-03-31' - | and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter - | group by - | ss_ticket_number, - | ss_customer_sk, - | ss_addr_sk, - | s_city - | ) ms - | join customer on (ms.ss_customer_sk = customer.c_customer_sk) - |order by - | c_last_name, - | c_first_name, - | city, - | profit - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q79 2029 / 2488 7.3 137.5 1.0X - */ - - ("q8", - """ - |select s_store_name - | ,sum(ss_net_profit) - | from store_sales - | ,date_dim - | ,store, - | (select distinct a01.ca_zip - | from - | (SELECT substr(ca_zip,1,5) ca_zip - | FROM customer_address - | WHERE substr(ca_zip,1,5) IN ('89436', '30868', '65085', '22977', '83927', '77557', - | '58429', '40697', '80614', '10502', '32779', - | '91137', '61265', '98294', '17921', '18427', '21203', '59362', '87291', '84093', - | '21505', '17184', '10866', '67898', '25797', - | '28055', '18377', '80332', '74535', '21757', '29742', '90885', '29898', '17819', - | '40811', '25990', '47513', '89531', '91068', - | '10391', '18846', '99223', '82637', '41368', '83658', '86199', '81625', '26696', - | '89338', '88425', '32200', '81427', '19053', - | '77471', '36610', '99823', '43276', '41249', '48584', '83550', '82276', '18842', - | '78890', '14090', '38123', '40936', '34425', - | '19850', '43286', '80072', '79188', '54191', '11395', '50497', '84861', '90733', - | '21068', '57666', '37119', '25004', '57835', - | '70067', '62878', '95806', '19303', '18840', '19124', '29785', '16737', '16022', - | '49613', '89977', '68310', '60069', '98360', - | '48649', '39050', '41793', '25002', '27413', '39736', '47208', '16515', '94808', - | '57648', '15009', '80015', '42961', '63982', - | '21744', '71853', '81087', '67468', '34175', '64008', '20261', '11201', '51799', - | '48043', '45645', '61163', '48375', '36447', - | '57042', '21218', '41100', '89951', '22745', '35851', '83326', '61125', '78298', - | '80752', '49858', '52940', '96976', '63792', - | '11376', '53582', '18717', '90226', '50530', '94203', '99447', '27670', '96577', - | '57856', '56372', '16165', '23427', '54561', - | '28806', '44439', '22926', '30123', '61451', '92397', '56979', '92309', '70873', - | '13355', '21801', '46346', '37562', '56458', - | '28286', '47306', '99555', '69399', '26234', '47546', '49661', '88601', '35943', - | '39936', '25632', '24611', '44166', '56648', - | '30379', '59785', '11110', '14329', '93815', '52226', '71381', '13842', '25612', - | '63294', '14664', '21077', '82626', '18799', - | '60915', '81020', '56447', '76619', '11433', '13414', '42548', '92713', '70467', - | '30884', '47484', '16072', '38936', '13036', - | '88376', '45539', '35901', '19506', '65690', '73957', '71850', '49231', '14276', - | '20005', '18384', '76615', '11635', '38177', - | '55607', '41369', '95447', '58581', '58149', '91946', '33790', '76232', '75692', - | '95464', '22246', '51061', '56692', '53121', - | '77209', '15482', '10688', '14868', '45907', '73520', '72666', '25734', '17959', - | '24677', '66446', '94627', '53535', '15560', - | '41967', '69297', '11929', '59403', '33283', '52232', '57350', '43933', '40921', - | '36635', '10827', '71286', '19736', '80619', - | '25251', '95042', '15526', '36496', '55854', '49124', '81980', '35375', '49157', - | '63512', '28944', '14946', '36503', '54010', - | '18767', '23969', '43905', '66979', '33113', '21286', '58471', '59080', '13395', - | '79144', '70373', '67031', '38360', '26705', - | '50906', '52406', '26066', '73146', '15884', '31897', '30045', '61068', '45550', - | '92454', '13376', '14354', '19770', '22928', - | '97790', '50723', '46081', '30202', '14410', '20223', '88500', '67298', '13261', - | '14172', '81410', '93578', '83583', '46047', - | '94167', '82564', '21156', '15799', '86709', '37931', '74703', '83103', '23054', - | '70470', '72008', '35709', '91911', '69998', - | '20961', '70070', '63197', '54853', '88191', '91830', '49521', '19454', '81450', - | '89091', '62378', '31904', '61869', '51744', - | '36580', '85778', '36871', '48121', '28810', '83712', '45486', '67393', '26935', - | '42393', '20132', '55349', '86057', '21309', - | '80218', '10094', '11357', '48819', '39734', '40758', '30432', '21204', '29467', - | '30214', '61024', '55307', '74621', '11622', - | '68908', '33032', '52868', '99194', '99900', '84936', '69036', '99149', '45013', - | '32895', '59004', '32322', '14933', '32936', - | '33562', '72550', '27385', '58049', '58200', '16808', '21360', '32961', '18586', - | '79307', '15492')) a01 - | inner join - | (select ca_zip - | from (SELECT substr(ca_zip,1,5) ca_zip,count(*) cnt - | FROM customer_address, customer - | WHERE ca_address_sk = c_current_addr_sk and - | c_preferred_cust_flag='Y' - | group by ca_zip - | having count(*) > 10)A1 - | ) b11 - | on (a01.ca_zip = b11.ca_zip )) A2 - | where ss_store_sk = s_store_sk - | and ss_sold_date_sk = d_date_sk - | and ss_sold_date_sk between 2451271 and 2451361 - | and d_qoy = 2 and d_year = 1999 - | and (substr(s_zip,1,2) = substr(a2.ca_zip,1,2)) - | group by s_store_name - | order by s_store_name - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q8 1737 / 2197 8.7 115.6 1.0X - */ - - ("q82", """ - |select - | i_item_id, - | i_item_desc, - | i_current_price - |from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join inventory on (item.i_item_sk = inventory.inv_item_sk) - | join date_dim on (inventory.inv_date_sk = date_dim.d_date_sk) - |where - | i_current_price between 30 and 30 + 30 - | and i_manufact_id in (437, 129, 727, 663) - | and inv_quantity_on_hand between 100 and 500 - |group by - | i_item_id, - | i_item_desc, - | i_current_price - |order by - | i_item_id - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q82 9399 / 10245 6.8 147.2 1.0X - */ - - ("q89", """ - |select - | * - |from - | (select - | i_category, - | i_class, - | i_brand, - | s_store_name, - | s_company_name, - | d_moy, - | sum(ss_sales_price) sum_sales - | from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join store on (store_sales.ss_store_sk = store.s_store_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - | where - | ss_sold_date_sk between 2451545 and 2451910 -- partition key filter - | and d_year in (2000) - | and ((i_category in('Home', 'Books', 'Electronics') - | and i_class in('wallpaper', 'parenting', 'musical')) - | or (i_category in('Shoes', 'Jewelry', 'Men') - | and i_class in('womens', 'birdal', 'pants')) - | ) - | group by - | i_category, - | i_class, - | i_brand, - | s_store_name, - | s_company_name, - | d_moy - | ) tmp1 - |order by - | sum_sales, - | s_store_name - |limit 100 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q89 1122 / 1274 12.9 77.2 1.0X - */ - - ("q98", """ - |select - | i_item_desc, - | i_category, - | i_class, - | i_current_price, - | sum(ss_ext_sales_price) as itemrevenue - |from - | store_sales - | join item on (store_sales.ss_item_sk = item.i_item_sk) - | join date_dim on (store_sales.ss_sold_date_sk = date_dim.d_date_sk) - |where - | ss_sold_date_sk between 2451911 and 2451941 - | -- partition key filter (1 calendar month) - | and d_date between '2001-01-01' and '2001-01-31' - | and i_category in('Jewelry', 'Sports', 'Books') - |group by - | i_item_id, - | i_item_desc, - | i_category, - | i_class, - | i_current_price - |order by - | i_category, - | i_class, - | i_item_id, - | i_item_desc - | -- revenueratio - |limit 1000 - """.stripMargin), - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - q98 1235 / 1542 11.8 85.0 1.0X - */ - - ("ss_max", """ - |select - | count(*) as total, - | max(ss_sold_date_sk) as max_ss_sold_date_sk, - | max(ss_sold_time_sk) as max_ss_sold_time_sk, - | max(ss_item_sk) as max_ss_item_sk, - | max(ss_customer_sk) as max_ss_customer_sk, - | max(ss_cdemo_sk) as max_ss_cdemo_sk, - | max(ss_hdemo_sk) as max_ss_hdemo_sk, - | max(ss_addr_sk) as max_ss_addr_sk, - | max(ss_store_sk) as max_ss_store_sk, - | max(ss_promo_sk) as max_ss_promo_sk - |from store_sales - """.stripMargin) - - /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 - Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz - - TPCDS Snappy (scale = 5): Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - ss_max 2305 / 2731 6.2 160.0 1.0X - */ - - ).toArray - - val tables = Seq("customer", "customer_address", "customer_demographics", "date_dim", - "household_demographics", "inventory", "item", "promotion", "store", "catalog_sales", - "web_sales", "store_sales") - - def setupTables(dataLocation: String): Map[String, Long] = { - tables.map { tableName => - sqlContext.read.parquet(s"$dataLocation/$tableName").registerTempTable(tableName) - tableName -> sqlContext.table(tableName).count() - }.toMap - } - - def tpcdsAll(dataLocation: String): Unit = { - require(dataLocation.nonEmpty, - "please modify the value of dataLocation to point to your local TPCDS data") - val tableSizes = setupTables(dataLocation) - sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") - tpcds.filter(q => q._1 != "").foreach { - case (name: String, query: String) => - val numRows = sqlContext.sql(query).queryExecution.logical.map { - case ur@UnresolvedRelation(t: TableIdentifier, _) => - tableSizes.getOrElse(t.table, throw new RuntimeException(s"${t.table} not found.")) - case _ => 0L - }.sum - val benchmark = new Benchmark("TPCDS Snappy (scale = 5)", numRows, 5) - benchmark.addCase(name) { i => - sqlContext.sql(query).collect() - } - benchmark.run() - } - } - - def main(args: Array[String]): Unit = { - - // In order to run this benchmark, please follow the instructions at - // https://github.com/databricks/spark-sql-perf/blob/master/README.md to generate the TPCDS data - // locally (preferably with a scale factor of 5 for benchmarking). Thereafter, the value of - // dataLocation below needs to be set to the location where the generated data is stored. - val dataLocation = "" - - tpcdsAll(dataLocation) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 923c0b350ea79..d11c2acb815d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -19,34 +19,33 @@ package org.apache.spark.sql.execution.datasources.text import java.io.File -import scala.collection.JavaConverters._ - -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveMode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.Utils class TextSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("reading text file") { - verifyFrame(sqlContext.read.format("text").load(testFile)) + verifyFrame(spark.read.format("text").load(testFile)) } test("SQLContext.read.text() API") { - verifyFrame(sqlContext.read.text(testFile).toDF()) + verifyFrame(spark.read.text(testFile)) } test("SPARK-12562 verify write.text() can handle column name beyond `value`") { - val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") + val df = spark.read.text(testFile).withColumnRenamed("value", "adwrasdf") val tempFile = Utils.createTempDir() tempFile.delete() df.write.text(tempFile.getCanonicalPath) - verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath).toDF()) + verifyFrame(spark.read.text(tempFile.getCanonicalPath)) Utils.deleteRecursively(tempFile) } @@ -55,18 +54,38 @@ class TextSuite extends QueryTest with SharedSQLContext { val tempFile = Utils.createTempDir() tempFile.delete() - val df = sqlContext.range(2) + val df = spark.range(2) intercept[AnalysisException] { df.write.text(tempFile.getCanonicalPath) } intercept[AnalysisException] { - sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) + spark.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) } } + test("reading partitioned data using read.textFile()") { + val partitionedData = Thread.currentThread().getContextClassLoader + .getResource("test-data/text-partitioned").toString + val ds = spark.read.textFile(partitionedData) + val data = ds.collect() + + assert(ds.schema == new StructType().add("value", StringType)) + assert(data.length == 2) + } + + test("support for partitioned reading using read.text()") { + val partitionedData = Thread.currentThread().getContextClassLoader + .getResource("test-data/text-partitioned").toString + val df = spark.read.text(partitionedData) + val data = df.filter("year = '2015'").select("value").collect() + + assert(data(0) == Row("2015-test")) + assert(data.length == 1) + } + test("SPARK-13503 Support to specify the option for compression codec for TEXT") { - val testDf = sqlContext.read.text(testFile) + val testDf = spark.read.text(testFile) val extensionNameMap = Map("bzip2" -> ".bz2", "deflate" -> ".deflate", "gzip" -> ".gz") extensionNameMap.foreach { case (codecName, extension) => @@ -75,7 +94,7 @@ class TextSuite extends QueryTest with SharedSQLContext { testDf.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(_.getName.endsWith(s".txt$extension"))) - verifyFrame(sqlContext.read.text(tempDirPath).toDF()) + verifyFrame(spark.read.text(tempDirPath)) } val errMsg = intercept[IllegalArgumentException] { @@ -95,19 +114,48 @@ class TextSuite extends QueryTest with SharedSQLContext { "mapreduce.map.output.compress.codec" -> classOf[GzipCodec].getName ) withTempDir { dir => - val testDf = sqlContext.read.text(testFile) + val testDf = spark.read.text(testFile) val tempDir = Utils.createTempDir() val tempDirPath = tempDir.getAbsolutePath testDf.write.option("compression", "none") .options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz"))) - verifyFrame(sqlContext.read.options(extraOptions).text(tempDirPath).toDF()) + verifyFrame(spark.read.options(extraOptions).text(tempDirPath)) + } + } + + test("SPARK-14343: select partitioning column") { + withTempPath { dir => + val path = dir.getCanonicalPath + val ds1 = spark.range(1).selectExpr("CONCAT('val_', id)") + ds1.write.text(s"$path/part=a") + ds1.write.text(s"$path/part=b") + + checkAnswer( + spark.read.format("text").load(path).select($"part"), + Row("a") :: Row("b") :: Nil) + } + } + + test("SPARK-15654: should not split gz files") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df1 = spark.range(0, 1000).selectExpr("CAST(id AS STRING) AS s") + df1.write.option("compression", "gzip").mode("overwrite").text(path) + + val expected = df1.collect() + Seq(10, 100, 1000).foreach { bytes => + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> bytes.toString) { + val df2 = spark.read.format("text").load(path) + checkAnswer(df2, expected) + } + } } } private def testFile: String = { - Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString + Thread.currentThread().getContextClassLoader.getResource("test-data/text-suite.txt").toString } /** Verifies data and schema. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8aa0114d98d74..4fc52c99fbeeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -33,7 +33,7 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index b9df43d04984e..c22c106fdfece 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.execution.joins @@ -22,9 +22,14 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} +import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{LongType, ShortType} /** * Test various broadcast join operators. @@ -33,53 +38,169 @@ import org.apache.spark.sql.functions._ * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered * without serializing the hashed relation, which does not happen in local mode. */ -class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { - protected var sqlContext: SQLContext = null +class BroadcastJoinSuite extends QueryTest with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null /** - * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. + * Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled. */ override def beforeAll(): Unit = { super.beforeAll() - val conf = new SparkConf() - .setMaster("local-cluster[2,1,1024]") - .setAppName("testing") - val sc = new SparkContext(conf) - sqlContext = new SQLContext(sc) + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() } override def afterAll(): Unit = { - sqlContext.sparkContext.stop() - sqlContext = null + spark.stop() + spark = null } /** * Test whether the specified broadcast join updates the peak execution memory accumulator. */ - private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { - AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) { - val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") - // Comparison at the end is for broadcast left semi join - val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") - val df3 = df1.join(broadcast(df2), joinExpression, joinType) - val plan = - EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan) - assert(plan.collect { case p: T => p }.size === 1) + private def testBroadcastJoinPeak[T: ClassTag](name: String, joinType: String): Unit = { + AccumulatorSuite.verifyPeakExecutionMemorySet(spark.sparkContext, name) { + val plan = testBroadcastJoin[T](joinType) plan.executeCollect() } } + private def testBroadcastJoin[T: ClassTag](joinType: String, + forceBroadcast: Boolean = false): SparkPlan = { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + var df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + + // Comparison at the end is for broadcast left semi join + val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") + val df3 = if (forceBroadcast) { + df1.join(broadcast(df2), joinExpression, joinType) + } else { + df1.join(df2, joinExpression, joinType) + } + val plan = + EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan) + assert(plan.collect { case p: T => p }.size === 1) + + return plan + } + test("unsafe broadcast hash join updates peak execution memory") { - testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash join", "inner") + testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash join", "inner") } test("unsafe broadcast hash outer join updates peak execution memory") { - testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast hash outer join", "left_outer") + testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast hash outer join", "left_outer") } test("unsafe broadcast left semi join updates peak execution memory") { - testBroadcastJoin[BroadcastHashJoinExec]("unsafe broadcast left semi join", "leftsemi") + testBroadcastJoinPeak[BroadcastHashJoinExec]("unsafe broadcast left semi join", "leftsemi") } + test("broadcast hint isn't bothered by authBroadcastJoinThreshold set to low values") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + testBroadcastJoin[BroadcastHashJoinExec]("inner", true) + } + } + + test("broadcast hint isn't bothered by a disabled authBroadcastJoinThreshold") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + testBroadcastJoin[BroadcastHashJoinExec]("inner", true) + } + } + + test("broadcast hint isn't propagated after a join") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + val df3 = df1.join(broadcast(df2), Seq("key"), "inner").drop(df2("key")) + + val df4 = spark.createDataFrame(Seq((1, "5"), (2, "5"))).toDF("key", "value") + val df5 = df4.join(df3, Seq("key"), "inner") + + val plan = + EnsureRequirements(spark.sessionState.conf).apply(df5.queryExecution.sparkPlan) + + assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) + assert(plan.collect { case p: SortMergeJoinExec => p }.size === 1) + } + } + + private def assertBroadcastJoin(df : Dataset[Row]) : Unit = { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val joined = df1.join(df, Seq("key"), "inner") + + val plan = + EnsureRequirements(spark.sessionState.conf).apply(joined.queryExecution.sparkPlan) + + assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) + } + + test("broadcast hint is propagated correctly") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") + val broadcasted = broadcast(df2) + val df3 = spark.createDataFrame(Seq((2, "2"), (3, "3"))).toDF("key", "value") + + val cases = Seq(broadcasted.limit(2), + broadcasted.filter("value < 10"), + broadcasted.sample(true, 0.5), + broadcasted.distinct(), + broadcasted.groupBy("value").agg(min($"key").as("key")), + // except and intersect are semi/anti-joins which won't return more data then + // their left argument, so the broadcast hint should be propagated here + broadcasted.except(df3), + broadcasted.intersect(df3)) + + cases.foreach(assertBroadcastJoin) + } + } + + test("join key rewritten") { + val l = Literal(1L) + val i = Literal(2) + val s = Literal.create(3, ShortType) + val ss = Literal("hello") + + assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) === + BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)), + BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil) + + assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) === + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) === + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft( + BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + Literal(16)), + BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) === + s :: s :: s :: s :: s :: Nil) + + assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil) + assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index b32b6444b6d9a..38377164c10e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -18,19 +18,19 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan, SparkPlanTest} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} +import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, StructType} class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = sqlContext.createDataFrame( + private lazy val left = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, 2.0), Row(1, 2.0), @@ -42,7 +42,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = sqlContext.createDataFrame( + private lazy val right = spark.createDataFrame( sparkContext.parallelize(Seq( Row(2, 3.0), Row(2, 3.0), @@ -53,7 +53,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { Row(6, null) )), new StructType().add("c", IntegerType).add("d", DoubleType)) - private lazy val rightUniqueKey = sqlContext.createDataFrame( + private lazy val rightUniqueKey = spark.createDataFrame( sparkContext.parallelize(Seq( Row(2, 3.0), Row(3, 2.0), @@ -89,6 +89,18 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { ExtractEquiJoinKeys.unapply(join) } + val existsAttr = AttributeReference("exists", BooleanType, false)() + val leftSemiPlus = ExistenceJoin(existsAttr) + def createLeftSemiPlusJoin(join: SparkPlan): SparkPlan = { + val output = join.output.dropRight(1) + val condition = if (joinType == LeftSemi) { + existsAttr + } else { + Not(existsAttr) + } + ProjectExec(output, FilterExec(condition, join)) + } + test(s"$testName using ShuffledHashJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -98,6 +110,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(ShuffledHashJoinExec( + leftKeys, rightKeys, leftSemiPlus, BuildRight, boundCondition, left, right))), + expectedAnswer, + sortAnswers = true) } } } @@ -111,6 +129,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(BroadcastHashJoinExec( + leftKeys, rightKeys, leftSemiPlus, BuildRight, boundCondition, left, right))), + expectedAnswer, + sortAnswers = true) } } } @@ -123,6 +147,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(SortMergeJoinExec( + leftKeys, rightKeys, leftSemiPlus, boundCondition, left, right))), + expectedAnswer, + sortAnswers = true) } } } @@ -134,6 +164,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { BroadcastNestedLoopJoinExec(left, right, BuildLeft, joinType, Some(condition))), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec( + left, right, BuildLeft, leftSemiPlus, Some(condition)))), + expectedAnswer, + sortAnswers = true) } } @@ -144,6 +180,12 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { BroadcastNestedLoopJoinExec(left, right, BuildRight, joinType, Some(condition))), expectedAnswer, sortAnswers = true) + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + createLeftSemiPlusJoin(BroadcastNestedLoopJoinExec( + left, right, BuildRight, leftSemiPlus, Some(condition)))), + expectedAnswer, + sortAnswers = true) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 3ee25c0996035..ede63fea9606f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.execution.joins import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import scala.util.Random + import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.CompactBuffer @@ -111,14 +114,14 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { test("LongToUnsafeRowMap") { val unsafeProj = UnsafeProjection.create( - Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) - val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) - val key = Seq(BoundReference(0, IntegerType, false)) + Seq(BoundReference(0, LongType, false), BoundReference(1, IntegerType, true))) + val rows = (0 until 100).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy()) + val key = Seq(BoundReference(0, LongType, false)) val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) assert(longRelation.keyIsUnique) (0 until 100).foreach { i => val row = longRelation.getValue(i) - assert(row.getInt(0) === i) + assert(row.getLong(0) === i) assert(row.getInt(1) === i + 1) } @@ -127,9 +130,9 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { (0 until 100).foreach { i => val rows = longRelation2.get(i).toArray assert(rows.length === 2) - assert(rows(0).getInt(0) === i) + assert(rows(0).getLong(0) === i) assert(rows(0).getInt(1) === i + 1) - assert(rows(1).getInt(0) === i) + assert(rows(1).getLong(0) === i) assert(rows(1).getInt(1) === i + 1) } @@ -144,13 +147,146 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { (0 until 100).foreach { i => val rows = relation.get(i).toArray assert(rows.length === 2) - assert(rows(0).getInt(0) === i) + assert(rows(0).getLong(0) === i) assert(rows(0).getInt(1) === i + 1) - assert(rows(1).getInt(0) === i) + assert(rows(1).getLong(0) === i) assert(rows(1).getInt(1) === i + 1) } } + test("LongToUnsafeRowMap with very wide range") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false))) + + { + // SPARK-16740 + val keys = Seq(0L, Long.MaxValue, Long.MaxValue) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + map.free() + } + + + { + // SPARK-16802 + val keys = Seq(Long.MaxValue, Long.MaxValue - 10) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + assert(map.getValue(Long.MinValue, row) eq null) + map.free() + } + } + + test("LongToUnsafeRowMap with random keys") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false))) + + val N = 1000000 + val rand = new Random + val keys = (0 to N).map(x => rand.nextLong()).toArray + + val map = new LongToUnsafeRowMap(taskMemoryManager, 10) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + map.writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1) + map2.readExternal(in) + + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + val r = map2.get(k, row) + assert(r.hasNext) + var c = 0 + while (r.hasNext) { + val rr = r.next() + assert(rr.getLong(0) === k) + c += 1 + } + } + var i = 0 + while (i < N * 10) { + val k = rand.nextLong() + val r = map2.get(k, row) + if (r != null) { + assert(r.hasNext) + while (r.hasNext) { + assert(r.next().getLong(0) === k) + } + } + i += 1 + } + map.free() + } + + test("Spark-14521") { + val ser = new KryoSerializer( + (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() + val key = Seq(BoundReference(0, LongType, false)) + + // Testing Kryo serialization of HashedRelation + val unsafeProj = UnsafeProjection.create( + Seq(BoundReference(0, LongType, false), BoundReference(1, IntegerType, true))) + val rows = (0 until 100).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy()) + val longRelation = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm) + val longRelation2 = ser.deserialize[LongHashedRelation](ser.serialize(longRelation)) + (0 until 100).foreach { i => + val rows = longRelation2.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getLong(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getLong(0) === i) + assert(rows(1).getInt(1) === i + 1) + } + + // Testing Kryo serialization of UnsafeHashedRelation + val unsafeHashed = UnsafeHashedRelation(rows.iterator, key, 1, mm) + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out) + out.flush() + val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed)) + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + unsafeHashed2.writeExternal(out2) + out2.flush() + assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray)) + } + // This test require 4G heap to run, should run it manually ignore("build HashedRelation that is larger than 1G") { val unsafeProj = UnsafeProjection.create( @@ -177,4 +313,20 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(longRelation.estimatedSize > (2L << 30)) longRelation.close() } + + // This test require 4G heap to run, should run it manually + ignore("build HashedRelation with more than 100 millions rows") { + val unsafeProj = UnsafeProjection.create( + Seq(BoundReference(0, IntegerType, false), + BoundReference(1, StringType, true))) + val unsafeRow = unsafeProj(InternalRow(0, UTF8String.fromString(" " * 100))) + val key = Seq(BoundReference(0, IntegerType, false)) + val rows = (0 until (1 << 10)).iterator.map { i => + unsafeRow.setInt(0, i % 1000000) + unsafeRow.setInt(1, i) + unsafeRow + } + val m = LongHashedRelation(rows, key, 100 << 20, mm) + m.close() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 2a4a3690f2078..35dab63672c05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -32,7 +32,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.newProductEncoder import testImplicits.localSeqToDatasetHolder - private lazy val myUpperCaseData = sqlContext.createDataFrame( + private lazy val myUpperCaseData = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, "A"), Row(2, "B"), @@ -43,7 +43,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, "G") )), new StructType().add("N", IntegerType).add("L", StringType)) - private lazy val myLowerCaseData = sqlContext.createDataFrame( + private lazy val myLowerCaseData = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, "a"), Row(2, "b"), @@ -99,7 +99,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) + EnsureRequirements(spark.sessionState.conf).apply(broadcastJoin) } def makeShuffledHashJoin( @@ -113,7 +113,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) val filteredJoin = boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin) + EnsureRequirements(spark.sessionState.conf).apply(filteredJoin) } def makeSortMergeJoin( @@ -124,7 +124,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) + EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { @@ -187,7 +187,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using CartesianProduct") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CROSS_JOINS_ENABLED.key -> "true") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => CartesianProductExec(left, right, Some(condition())), expectedAnswer.map(Row.fromTuple), @@ -270,4 +271,19 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ) ) } + + { + def df: DataFrame = spark.range(3).selectExpr("struct(id, id) as key", "id as value") + lazy val left = df.selectExpr("key", "concat('L', value) as value").alias("left") + lazy val right = df.selectExpr("key", "concat('R', value) as value").alias("right") + testInnerJoin( + "SPARK-15822 - test structs as keys", + left, + right, + () => (left.col("key") === right.col("key")).expr, + Seq( + (Row(0, 0), "L0", Row(0, 0), "R0"), + (Row(1, 1), "L1", Row(1, 1), "R1"), + (Row(2, 2), "L2", Row(2, 2), "R2"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index c26cb8483eb17..001feb0f2b399 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = sqlContext.createDataFrame( + private lazy val left = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, 2.0), Row(2, 100.0), @@ -42,7 +42,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = sqlContext.createDataFrame( + private lazy val right = spark.createDataFrame( sparkContext.parallelize(Seq( Row(0, 0.0), Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches @@ -82,7 +82,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext.sessionState.conf).apply( + EnsureRequirements(spark.sessionState.conf).apply( ShuffledHashJoinExec( leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), @@ -115,7 +115,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext.sessionState.conf).apply( + EnsureRequirements(spark.sessionState.conf).apply( SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 8de4d8bbd4e07..229d8814e0143 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -1,64 +1,35 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.execution.metric -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - -import scala.collection.mutable - -import org.apache.xbean.asm5._ -import org.apache.xbean.asm5.Opcodes._ - import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.{JsonProtocol, Utils} - +import org.apache.spark.util.{AccumulatorContext, JsonProtocol} class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ - test("SQLMetric should not box Long") { - val l = SQLMetrics.createMetric(sparkContext, "long") - val f = () => { - l += 1L - l.add(1L) - } - val cl = BoxingFinder.getClassReader(f.getClass) - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") - } - - test("Normal accumulator should do boxing") { - // We need this test to make sure BoxingFinder works. - val l = sparkContext.accumulator(0L) - val f = () => { l += 1L } - val cl = BoxingFinder.getClassReader(f.getClass) - val boxingFinder = new BoxingFinder() - cl.accept(boxingFinder, 0) - assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") - } - /** * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics". * @@ -71,21 +42,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet withSQLConf("spark.sql.codegen.wholeStage" -> "false") { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = + spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs + val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( df.queryExecution.executedPlan)).allNodes.filter { node => expectedMetrics.contains(node.id) @@ -114,6 +86,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } } + test("LocalTableScanExec computes metrics in collect and take") { + val df1 = spark.createDataset(Seq(1, 2, 3)) + val logical = df1.queryExecution.logical + require(logical.isInstanceOf[LocalRelation]) + df1.collect() + val metrics1 = df1.queryExecution.executedPlan.collectLeaves().head.metrics + assert(metrics1.contains("numOutputRows")) + assert(metrics1("numOutputRows").value === 3) + + val df2 = spark.createDataset(Seq(1, 2, 3)).limit(2) + df2.collect() + val metrics2 = df2.queryExecution.executedPlan.collectLeaves().head.metrics + assert(metrics2.contains("numOutputRows")) + assert(metrics2("numOutputRows").value === 2) + } + test("Filter metrics") { // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) @@ -128,36 +116,32 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Filter(nodeId = 1)) // TODO: update metrics in generated operators - val ds = sqlContext.range(10).filter('id < 5) + val ds = spark.range(10).filter('id < 5) testSparkPlanMetrics(ds.toDF(), 1, Map.empty) } - test("TungstenAggregate metrics") { + test("Aggregate metrics") { // Assume the execution plan is - // ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1) - // -> TungstenAggregate(nodeId = 0) + // ... -> HashAggregate(nodeId = 2) -> Exchange(nodeId = 1) + // -> HashAggregate(nodeId = 0) val df = testData2.groupBy().count() // 2 partitions testSparkPlanMetrics(df, 1, Map( - 2L -> ("TungstenAggregate", Map( - "number of output rows" -> 2L)), - 0L -> ("TungstenAggregate", Map( - "number of output rows" -> 1L))) + 2L -> ("HashAggregate", Map("number of output rows" -> 2L)), + 0L -> ("HashAggregate", Map("number of output rows" -> 1L))) ) // 2 partitions and each partition contains 2 keys val df2 = testData2.groupBy('a).count() testSparkPlanMetrics(df2, 1, Map( - 2L -> ("TungstenAggregate", Map( - "number of output rows" -> 4L)), - 0L -> ("TungstenAggregate", Map( - "number of output rows" -> 3L))) + 2L -> ("HashAggregate", Map("number of output rows" -> 4L)), + 0L -> ("HashAggregate", Map("number of output rows" -> 3L))) ) } test("Sort metrics") { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) - val ds = sqlContext.range(10).sort('id) + val ds = spark.range(10).sort('id) testSparkPlanMetrics(ds.toDF(), 2, Map.empty) } @@ -165,11 +149,11 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { + testDataForJoin.createOrReplaceTempView("testDataForJoin") + withTempView("testDataForJoin") { // Assume the execution plan is // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -183,11 +167,11 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // Because SortMergeJoin may skip different rows if the number of partitions is different, // this test should use the deterministic number of partitions. val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { + testDataForJoin.createOrReplaceTempView("testDataForJoin") + withTempView("testDataForJoin") { // Assume the execution plan is // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -195,7 +179,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { "number of output rows" -> 8L))) ) - val df2 = sqlContext.sql( + val df2 = spark.sql( "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df2, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -237,17 +221,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("BroadcastNestedLoopJoin metrics") { val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 left JOIN testDataForJoin ON " + - "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") - testSparkPlanMetrics(df, 3, Map( - 1L -> ("BroadcastNestedLoopJoin", Map( - "number of output rows" -> 12L))) - ) + testDataForJoin.createOrReplaceTempView("testDataForJoin") + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + withTempView("testDataForJoin") { + // Assume the execution plan is + // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = spark.sql( + "SELECT * FROM testData2 left JOIN testDataForJoin ON " + + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") + testSparkPlanMetrics(df, 3, Map( + 1L -> ("BroadcastNestedLoopJoin", Map( + "number of output rows" -> 12L))) + ) + } } } @@ -264,35 +250,37 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { } test("CartesianProduct metrics") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) - testDataForJoin.registerTempTable("testDataForJoin") - withTempTable("testDataForJoin") { - // Assume the execution plan is - // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( - "SELECT * FROM testData2 JOIN testDataForJoin") - testSparkPlanMetrics(df, 1, Map( - 0L -> ("CartesianProduct", Map( - "number of output rows" -> 12L))) - ) + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + testDataForJoin.createOrReplaceTempView("testDataForJoin") + withTempView("testDataForJoin") { + // Assume the execution plan is + // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) + val df = spark.sql( + "SELECT * FROM testData2 JOIN testDataForJoin") + testSparkPlanMetrics(df, 1, Map( + 0L -> ("CartesianProduct", Map("number of output rows" -> 12L))) + ) + } } } test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) person.select('name).write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = + spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs + val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq.exists(_ === "2")) @@ -302,13 +290,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("metrics can be loaded by history server") { val metric = SQLMetrics.createMetric(sparkContext, "zanzibar") metric += 10L - val metricInfo = metric.toInfo(Some(metric.localValue), None) + val metricInfo = metric.toInfo(Some(metric.value), None) metricInfo.update match { case Some(v: Long) => assert(v === 10L) case Some(v) => fail(s"metric value was not a Long: ${v.getClass.getName}") case _ => fail("metric update is missing") } - assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) + assert(metricInfo.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) // After serializing to JSON, the original value type is lost, but we can still // identify that it's a SQL metric from the metadata val metricInfoJson = JsonProtocol.accumulableInfoToJson(metricInfo) @@ -318,80 +306,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { case Some(v) => fail(s"deserialized metric value was not a string: ${v.getClass.getName}") case _ => fail("deserialized metric update is missing") } - assert(metricInfoDeser.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER)) - } - -} - -private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) - -/** - * If `method` is null, search all methods of this class recursively to find if they do some boxing. - * If `method` is specified, only search this method of the class to speed up the searching. - * - * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. - */ -private class BoxingFinder( - method: MethodIdentifier[_] = null, - val boxingInvokes: mutable.Set[String] = mutable.Set.empty, - visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) - extends ClassVisitor(ASM5) { - - private val primitiveBoxingClassName = - Set("java/lang/Long", - "java/lang/Double", - "java/lang/Integer", - "java/lang/Float", - "java/lang/Short", - "java/lang/Character", - "java/lang/Byte", - "java/lang/Boolean") - - override def visitMethod( - access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): - MethodVisitor = { - if (method != null && (method.name != name || method.desc != desc)) { - // If method is specified, skip other methods. - return new MethodVisitor(ASM5) {} - } - - new MethodVisitor(ASM5) { - override def visitMethodInsn( - op: Int, owner: String, name: String, desc: String, itf: Boolean) { - if (op == INVOKESPECIAL && name == "" || op == INVOKESTATIC && name == "valueOf") { - if (primitiveBoxingClassName.contains(owner)) { - // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) - boxingInvokes.add(s"$owner.$name") - } - } else { - // scalastyle:off classforname - val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, - Thread.currentThread.getContextClassLoader) - // scalastyle:on classforname - val m = MethodIdentifier(classOfMethodOwner, name, desc) - if (!visitedMethods.contains(m)) { - // Keep track of visited methods to avoid potential infinite cycles - visitedMethods += m - val cl = BoxingFinder.getClassReader(classOfMethodOwner) - visitedMethods += m - cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) - } - } - } - } - } -} - -private object BoxingFinder { - - def getClassReader(cls: Class[_]): ClassReader = { - val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" - val resourceStream = cls.getResourceAsStream(className) - val baos = new ByteArrayOutputStream(128) - // Copy data over, before delegating to ClassReader - - // else we can run out of open file handles. - Utils.copyStream(resourceStream, baos, true) - new ClassReader(new ByteArrayInputStream(baos.toByteArray)) + assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala index 0a989d026ce1c..8bd6b3c5cdc8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala @@ -42,6 +42,20 @@ class ApproxQuantileSuite extends SparkFunSuite { summary.compress() } + /** + * Interleaves compression and insertions. + */ + private def buildCompressSummary( + data: Seq[Double], + epsi: Double, + threshold: Int): QuantileSummaries = { + var summary = new QuantileSummaries(threshold, epsi) + data.foreach { x => + summary = summary.insert(x).compress() + } + summary + } + private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { val approx = summary.query(quant) // The rank of the approximation. @@ -56,8 +70,8 @@ class ApproxQuantileSuite extends SparkFunSuite { for { (seq_name, data) <- Seq(increasing, decreasing, random) - epsi <- Seq(0.1, 0.0001) - compression <- Seq(1000, 10) + epsi <- Seq(0.1, 0.0001) // With a significant value and with full precision + compression <- Seq(1000, 10) // This interleaves n so that we test without and with compression } { test(s"Extremas with epsi=$epsi and seq=$seq_name, compression=$compression") { @@ -77,6 +91,17 @@ class ApproxQuantileSuite extends SparkFunSuite { checkQuantile(0.1, data, s) checkQuantile(0.001, data, s) } + + test(s"Some quantile values with epsi=$epsi and seq=$seq_name, compression=$compression " + + s"(interleaved)") { + val s = buildCompressSummary(data, epsi, compression) + assert(s.count == data.size, s"Found count=${s.count} but data size=${data.size}") + checkQuantile(0.9999, data, s) + checkQuantile(0.9, data, s) + checkQuantile(0.5, data, s) + checkQuantile(0.1, data, s) + checkQuantile(0.001, data, s) + } } // Tests for merging procedure diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 7b413dda1edd9..e1bc674a28071 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.charset.StandardCharsets.UTF_8 import org.apache.spark.SparkFunSuite @@ -25,13 +26,14 @@ import org.apache.spark.sql.test.SharedSQLContext class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { + import CompactibleFileStreamLog._ import FileStreamSinkLog._ test("getBatchIdFromFileName") { assert(1234L === getBatchIdFromFileName("1234")) assert(1234L === getBatchIdFromFileName("1234.compact")) intercept[NumberFormatException] { - FileStreamSinkLog.getBatchIdFromFileName("1234a") + getBatchIdFromFileName("1234a") } } @@ -83,22 +85,24 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { } test("compactLogs") { - val logs = Seq( - newFakeSinkFileStatus("/a/b/x", FileStreamSinkLog.ADD_ACTION), - newFakeSinkFileStatus("/a/b/y", FileStreamSinkLog.ADD_ACTION), - newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.ADD_ACTION)) - assert(logs === compactLogs(logs)) + withFileStreamSinkLog { sinkLog => + val logs = Seq( + newFakeSinkFileStatus("/a/b/x", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/y", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.ADD_ACTION)) + assert(logs === sinkLog.compactLogs(logs)) - val logs2 = Seq( - newFakeSinkFileStatus("/a/b/m", FileStreamSinkLog.ADD_ACTION), - newFakeSinkFileStatus("/a/b/n", FileStreamSinkLog.ADD_ACTION), - newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.DELETE_ACTION)) - assert(logs.dropRight(1) ++ logs2.dropRight(1) === compactLogs(logs ++ logs2)) + val logs2 = Seq( + newFakeSinkFileStatus("/a/b/m", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/n", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.DELETE_ACTION)) + assert(logs.dropRight(1) ++ logs2.dropRight(1) === sinkLog.compactLogs(logs ++ logs2)) + } } test("serialize") { withFileStreamSinkLog { sinkLog => - val logs = Seq( + val logs = Array( SinkFileStatus( path = "/a/b/x", size = 100L, @@ -125,21 +129,24 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { action = FileStreamSinkLog.ADD_ACTION)) // scalastyle:off - val expected = s"""${FileStreamSinkLog.VERSION} + val expected = s"""$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin // scalastyle:on - assert(expected === new String(sinkLog.serialize(logs), UTF_8)) - - assert(FileStreamSinkLog.VERSION === new String(sinkLog.serialize(Nil), UTF_8)) + val baos = new ByteArrayOutputStream() + sinkLog.serialize(logs, baos) + assert(expected === baos.toString(UTF_8.name())) + baos.reset() + sinkLog.serialize(Array(), baos) + assert(VERSION === baos.toString(UTF_8.name())) } } test("deserialize") { withFileStreamSinkLog { sinkLog => // scalastyle:off - val logs = s"""${FileStreamSinkLog.VERSION} + val logs = s"""$VERSION |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin @@ -171,9 +178,9 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { blockSize = 30000L, action = FileStreamSinkLog.ADD_ACTION)) - assert(expected === sinkLog.deserialize(logs.getBytes(UTF_8))) + assert(expected === sinkLog.deserialize(new ByteArrayInputStream(logs.getBytes(UTF_8)))) - assert(Nil === sinkLog.deserialize(FileStreamSinkLog.VERSION.getBytes(UTF_8))) + assert(Nil === sinkLog.deserialize(new ByteArrayInputStream(VERSION.getBytes(UTF_8)))) } } @@ -190,13 +197,13 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("compact") { + testWithUninterruptibleThread("compact") { withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { withFileStreamSinkLog { sinkLog => for (batchId <- 0 to 10) { sinkLog.add( batchId, - Seq(newFakeSinkFileStatus("/a/b/" + batchId, FileStreamSinkLog.ADD_ACTION))) + Array(newFakeSinkFileStatus("/a/b/" + batchId, FileStreamSinkLog.ADD_ACTION))) val expectedFiles = (0 to batchId).map { id => newFakeSinkFileStatus("/a/b/" + id, FileStreamSinkLog.ADD_ACTION) } @@ -210,14 +217,14 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("delete expired file") { + testWithUninterruptibleThread("delete expired file") { // Set FILE_SINK_LOG_CLEANUP_DELAY to 0 so that we can detect the deleting behaviour // deterministically withSQLConf( SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3", SQLConf.FILE_SINK_LOG_CLEANUP_DELAY.key -> "0") { withFileStreamSinkLog { sinkLog => - val fs = sinkLog.metadataPath.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = sinkLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) def listBatchFiles(): Set[String] = { fs.listStatus(sinkLog.metadataPath).map(_.getPath.getName).filter { fileName => @@ -230,17 +237,17 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { }.toSet } - sinkLog.add(0, Seq(newFakeSinkFileStatus("/a/b/0", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(0, Array(newFakeSinkFileStatus("/a/b/0", FileStreamSinkLog.ADD_ACTION))) assert(Set("0") === listBatchFiles()) - sinkLog.add(1, Seq(newFakeSinkFileStatus("/a/b/1", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(1, Array(newFakeSinkFileStatus("/a/b/1", FileStreamSinkLog.ADD_ACTION))) assert(Set("0", "1") === listBatchFiles()) - sinkLog.add(2, Seq(newFakeSinkFileStatus("/a/b/2", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(2, Array(newFakeSinkFileStatus("/a/b/2", FileStreamSinkLog.ADD_ACTION))) assert(Set("2.compact") === listBatchFiles()) - sinkLog.add(3, Seq(newFakeSinkFileStatus("/a/b/3", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(3, Array(newFakeSinkFileStatus("/a/b/3", FileStreamSinkLog.ADD_ACTION))) assert(Set("2.compact", "3") === listBatchFiles()) - sinkLog.add(4, Seq(newFakeSinkFileStatus("/a/b/4", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(4, Array(newFakeSinkFileStatus("/a/b/4", FileStreamSinkLog.ADD_ACTION))) assert(Set("2.compact", "3", "4") === listBatchFiles()) - sinkLog.add(5, Seq(newFakeSinkFileStatus("/a/b/5", FileStreamSinkLog.ADD_ACTION))) + sinkLog.add(5, Array(newFakeSinkFileStatus("/a/b/5", FileStreamSinkLog.ADD_ACTION))) assert(Set("5.compact") === listBatchFiles()) } } @@ -263,7 +270,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { private def withFileStreamSinkLog(f: FileStreamSinkLog => Unit): Unit = { withTempDir { file => - val sinkLog = new FileStreamSinkLog(sqlContext.sparkSession, file.getCanonicalPath) + val sinkLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, file.getCanonicalPath) f(sinkLog) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala new file mode 100644 index 0000000000000..3bad5bb093b3b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{File, FileNotFoundException} +import java.net.URI + +import scala.util.Random + +import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.streaming.ExistsThrowsExceptionFileSystem._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext { + + import FileStreamSource._ + + test("SeenFilesMap") { + val map = new SeenFilesMap(maxAgeMs = 10) + + map.add("a", 5) + assert(map.size == 1) + map.purge() + assert(map.size == 1) + + // Add a new entry and purge should be no-op, since the gap is exactly 10 ms. + map.add("b", 15) + assert(map.size == 2) + map.purge() + assert(map.size == 2) + + // Add a new entry that's more than 10 ms than the first entry. We should be able to purge now. + map.add("c", 16) + assert(map.size == 3) + map.purge() + assert(map.size == 2) + + // Override existing entry shouldn't change the size + map.add("c", 25) + assert(map.size == 2) + + // Not a new file because we have seen c before + assert(!map.isNewFile("c", 20)) + + // Not a new file because timestamp is too old + assert(!map.isNewFile("d", 5)) + + // Finally a new file: never seen and not too old + assert(map.isNewFile("e", 20)) + } + + test("SeenFilesMap should only consider a file old if it is earlier than last purge time") { + val map = new SeenFilesMap(maxAgeMs = 10) + + map.add("a", 20) + assert(map.size == 1) + + // Timestamp 5 should still considered a new file because purge time should be 0 + assert(map.isNewFile("b", 9)) + assert(map.isNewFile("b", 10)) + + // Once purge, purge time should be 10 and then b would be a old file if it is less than 10. + map.purge() + assert(!map.isNewFile("b", 9)) + assert(map.isNewFile("b", 10)) + } + + testWithUninterruptibleThread("do not recheck that files exist during getBatch") { + withTempDir { temp => + spark.conf.set( + s"fs.$scheme.impl", + classOf[ExistsThrowsExceptionFileSystem].getName) + // add the metadata entries as a pre-req + val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir + val metadataLog = + new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath) + assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0)))) + + val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil, + dir.getAbsolutePath, Map.empty) + // this method should throw an exception if `fs.exists` is called during resolveRelation + newSource.getBatch(None, LongOffset(1)) + } + } +} + +/** Fake FileSystem to test whether the method `fs.exists` is called during + * `DataSource.resolveRelation`. + */ +class ExistsThrowsExceptionFileSystem extends RawLocalFileSystem { + override def getUri: URI = { + URI.create(s"$scheme:///") + } + + override def exists(f: Path): Boolean = { + throw new IllegalArgumentException("Exists shouldn't have been called!") + } + + /** Simply return an empty file for now. */ + override def listStatus(file: Path): Array[FileStatus] = { + throw new FileNotFoundException("Folder was suddenly deleted but this should not make it fail!") + } +} + +object ExistsThrowsExceptionFileSystem { + val scheme = s"FileStreamSourceSuite${math.abs(Random.nextInt)}fs" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala new file mode 100644 index 0000000000000..9e059216110f2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.collection.mutable + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkException +import org.apache.spark.sql.ForeachWriter +import org.apache.spark.sql.streaming.{OutputMode, StreamingQueryException, StreamTest} +import org.apache.spark.sql.test.SharedSQLContext + +class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { + + import testImplicits._ + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("foreach() with `append` output mode") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + val query = input.toDS().repartition(2).writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Append) + .foreach(new TestForeachWriter()) + .start() + + // -- batch 0 --------------------------------------- + input.addData(1, 2, 3, 4) + query.processAllAvailable() + + var expectedEventsForPartition0 = Seq( + ForeachSinkSuite.Open(partition = 0, version = 0), + ForeachSinkSuite.Process(value = 1), + ForeachSinkSuite.Process(value = 3), + ForeachSinkSuite.Close(None) + ) + var expectedEventsForPartition1 = Seq( + ForeachSinkSuite.Open(partition = 1, version = 0), + ForeachSinkSuite.Process(value = 2), + ForeachSinkSuite.Process(value = 4), + ForeachSinkSuite.Close(None) + ) + + var allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 2) + assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + + ForeachSinkSuite.clear() + + // -- batch 1 --------------------------------------- + input.addData(5, 6, 7, 8) + query.processAllAvailable() + + expectedEventsForPartition0 = Seq( + ForeachSinkSuite.Open(partition = 0, version = 1), + ForeachSinkSuite.Process(value = 5), + ForeachSinkSuite.Process(value = 7), + ForeachSinkSuite.Close(None) + ) + expectedEventsForPartition1 = Seq( + ForeachSinkSuite.Open(partition = 1, version = 1), + ForeachSinkSuite.Process(value = 6), + ForeachSinkSuite.Process(value = 8), + ForeachSinkSuite.Close(None) + ) + + allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 2) + assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1)) + + query.stop() + } + } + + test("foreach() with `complete` output mode") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + + val query = input.toDS() + .groupBy().count().as[Long].map(_.toInt) + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .outputMode(OutputMode.Complete) + .foreach(new TestForeachWriter()) + .start() + + // -- batch 0 --------------------------------------- + input.addData(1, 2, 3, 4) + query.processAllAvailable() + + var allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 1) + var expectedEvents = Seq( + ForeachSinkSuite.Open(partition = 0, version = 0), + ForeachSinkSuite.Process(value = 4), + ForeachSinkSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + + ForeachSinkSuite.clear() + + // -- batch 1 --------------------------------------- + input.addData(5, 6, 7, 8) + query.processAllAvailable() + + allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 1) + expectedEvents = Seq( + ForeachSinkSuite.Open(partition = 0, version = 1), + ForeachSinkSuite.Process(value = 8), + ForeachSinkSuite.Close(None) + ) + assert(allEvents === Seq(expectedEvents)) + + query.stop() + } + } + + testQuietly("foreach with error") { + withTempDir { checkpointDir => + val input = MemoryStream[Int] + val query = input.toDS().repartition(1).writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .foreach(new TestForeachWriter() { + override def process(value: Int): Unit = { + super.process(value) + throw new RuntimeException("error") + } + }).start() + input.addData(1, 2, 3, 4) + + // Error in `process` should fail the Spark job + val e = intercept[StreamingQueryException] { + query.processAllAvailable() + } + assert(e.getCause.isInstanceOf[SparkException]) + assert(e.getCause.getCause.getMessage === "error") + assert(query.isActive === false) + + val allEvents = ForeachSinkSuite.allEvents() + assert(allEvents.size === 1) + assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + + // `close` should be called with the error + val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] + assert(errorEvent.error.get.isInstanceOf[RuntimeException]) + assert(errorEvent.error.get.getMessage === "error") + } + } +} + +/** A global object to collect events in the executor */ +object ForeachSinkSuite { + + trait Event + + case class Open(partition: Long, version: Long) extends Event + + case class Process[T](value: T) extends Event + + case class Close(error: Option[Throwable]) extends Event + + private val _allEvents = new ConcurrentLinkedQueue[Seq[Event]]() + + def addEvents(events: Seq[Event]): Unit = { + _allEvents.add(events) + } + + def allEvents(): Seq[Seq[Event]] = { + _allEvents.toArray(new Array[Seq[Event]](_allEvents.size())) + } + + def clear(): Unit = { + _allEvents.clear() + } +} + +/** A [[ForeachWriter]] that writes collected events to ForeachSinkSuite */ +class TestForeachWriter extends ForeachWriter[Int] { + ForeachSinkSuite.clear() + + private val events = mutable.ArrayBuffer[ForeachSinkSuite.Event]() + + override def open(partitionId: Long, version: Long): Boolean = { + events += ForeachSinkSuite.Open(partition = partitionId, version = version) + true + } + + override def process(value: Int): Unit = { + events += ForeachSinkSuite.Process(value) + } + + override def close(errorOrNull: Throwable): Unit = { + events += ForeachSinkSuite.Close(error = Option(errorOrNull)) + ForeachSinkSuite.addEvents(events) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 5f92c5bb9b246..4259384f0bc61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { @@ -45,77 +46,97 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { test("FileManager: FileContextManager") { withTempDir { temp => val path = new Path(temp.getAbsolutePath) - testManager(path, new FileContextManager(path, new Configuration)) + testFileManager(path, new FileContextManager(path, new Configuration)) } } test("FileManager: FileSystemManager") { withTempDir { temp => val path = new Path(temp.getAbsolutePath) - testManager(path, new FileSystemManager(path, new Configuration)) + testFileManager(path, new FileSystemManager(path, new Configuration)) } } - test("HDFSMetadataLog: basic") { + testWithUninterruptibleThread("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, dir.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) assert(metadataLog.add(1, "batch1")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) // Adding the same batch does nothing metadataLog.add(1, "batch1-duplicated") assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) } } - testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - sqlContext.conf.setConfString( + testWithUninterruptibleThread( + "HDFSMetadataLog: fallback from FileContext to FileSystem", quietly = true) { + spark.conf.set( s"fs.$scheme.impl", classOf[FakeFileSystem].getName) withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://$temp") assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://$temp") assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) } } - test("HDFSMetadataLog: restart") { + testWithUninterruptibleThread("HDFSMetadataLog: purge") { withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.add(1, "batch1")) + assert(metadataLog.add(2, "batch2")) + assert(metadataLog.get(0).isDefined) + assert(metadataLog.get(1).isDefined) + assert(metadataLog.get(2).isDefined) + assert(metadataLog.getLatest().get._1 == 2) + + metadataLog.purge(2) + assert(metadataLog.get(0).isEmpty) + assert(metadataLog.get(1).isEmpty) + assert(metadataLog.get(2).isDefined) + assert(metadataLog.getLatest().get._1 == 2) + } + } + + testWithUninterruptibleThread("HDFSMetadataLog: restart") { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.add(1, "batch1")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog2 = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.get(1) === Some("batch1")) assert(metadataLog2.getLatest() === Some(1 -> "batch1")) - assert(metadataLog2.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog2.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) } } @@ -124,10 +145,10 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { val waiter = new Waiter val maxBatchId = 100 for (id <- 0 until 10) { - new Thread() { + new UninterruptibleThread(s"HDFSMetadataLog: metadata directory collision - thread $id") { override def run(): Unit = waiter { val metadataLog = - new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + new HDFSMetadataLog[String](spark, temp.getAbsolutePath) try { var nextBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) nextBatchId += 1 @@ -146,14 +167,15 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } waiter.await(timeout(10.seconds), dismissals(10)) - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog.getLatest() === Some(maxBatchId -> maxBatchId.toString)) - assert(metadataLog.get(None, maxBatchId) === (0 to maxBatchId).map(i => (i, i.toString))) + assert( + metadataLog.get(None, Some(maxBatchId)) === (0 to maxBatchId).map(i => (i, i.toString))) } } - - def testManager(basePath: Path, fm: FileManager): Unit = { + /** Basic test case for [[FileManager]] implementation. */ + private def testFileManager(basePath: Path, fm: FileManager): Unit = { // Mkdirs val dir = new Path(s"$basePath/dir/subdir/subsubdir") assert(!fm.exists(dir)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala index dd5f92248bf5c..00d5e051de357 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -20,20 +20,35 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.{CountDownLatch, TimeUnit} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.ProcessingTime -import org.apache.spark.util.ManualClock +import org.apache.spark.sql.streaming.ProcessingTime +import org.apache.spark.util.{Clock, ManualClock, SystemClock} class ProcessingTimeExecutorSuite extends SparkFunSuite { test("nextBatchTime") { val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100)) + assert(processingTimeExecutor.nextBatchTime(0) === 100) assert(processingTimeExecutor.nextBatchTime(1) === 100) assert(processingTimeExecutor.nextBatchTime(99) === 100) - assert(processingTimeExecutor.nextBatchTime(100) === 100) + assert(processingTimeExecutor.nextBatchTime(100) === 200) assert(processingTimeExecutor.nextBatchTime(101) === 200) assert(processingTimeExecutor.nextBatchTime(150) === 200) } + test("calling nextBatchTime with the result of a previous call should return the next interval") { + val intervalMS = 100 + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMS)) + + val ITERATION = 10 + var nextBatchTime: Long = 0 + for (it <- 1 to ITERATION) { + nextBatchTime = processingTimeExecutor.nextBatchTime(nextBatchTime) + } + + // nextBatchTime should be 1000 + assert(nextBatchTime === intervalMS * ITERATION) + } + private def testBatchTermination(intervalMs: Long): Unit = { var batchCounts = 0 val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala new file mode 100644 index 0000000000000..938423db64745 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StreamMetricsSuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.scalactic.TolerantNumerics + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.ManualClock + +class StreamMetricsSuite extends SparkFunSuite { + import StreamMetrics._ + + // To make === between double tolerate inexact values + implicit val doubleEquality = TolerantNumerics.tolerantDoubleEquality(0.01) + + test("rates, latencies, trigger details - basic life cycle") { + val sm = newStreamMetrics(source) + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails().isEmpty) + + // When trigger started, the rates should not change, but should return + // reported trigger details + sm.reportTriggerStarted(1) + sm.reportTriggerDetail("key", "value") + sm.reportSourceTriggerDetail(source, "key2", "value2") + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails() === + Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "true", + START_TIMESTAMP -> "0", "key" -> "value")) + assert(sm.currentSourceTriggerDetails(source) === + Map(TRIGGER_ID -> "1", "key2" -> "value2")) + + // Finishing the trigger should calculate the rates, except input rate which needs + // to have another trigger interval + sm.reportNumInputRows(Map(source -> 100L)) // 100 input rows, 10 output rows + clock.advance(1000) + sm.reportTriggerFinished() + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 100.0) // 100 input rows processed in 1 sec + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 100.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails() === + Map(TRIGGER_ID -> "1", IS_TRIGGER_ACTIVE -> "false", + START_TIMESTAMP -> "0", FINISH_TIMESTAMP -> "1000", + NUM_INPUT_ROWS -> "100", "key" -> "value")) + assert(sm.currentSourceTriggerDetails(source) === + Map(TRIGGER_ID -> "1", NUM_SOURCE_INPUT_ROWS -> "100", "key2" -> "value2")) + + // After another trigger starts, the rates and latencies should not change until + // new rows are reported + clock.advance(1000) + sm.reportTriggerStarted(2) + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 100.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 100.0) + assert(sm.currentLatency() === None) + + // Reporting new rows should update the rates and latencies + sm.reportNumInputRows(Map(source -> 200L)) // 200 input rows + clock.advance(500) + sm.reportTriggerFinished() + assert(sm.currentInputRate() === 100.0) // 200 input rows generated in 2 seconds b/w starts + assert(sm.currentProcessingRate() === 400.0) // 200 output rows processed in 0.5 sec + assert(sm.currentSourceInputRate(source) === 100.0) + assert(sm.currentSourceProcessingRate(source) === 400.0) + assert(sm.currentLatency().get === 1500.0) // 2000 ms / 2 + 500 ms + + // Rates should be set to 0 after stop + sm.stop() + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + assert(sm.currentTriggerDetails().isEmpty) + } + + test("rates and latencies - after trigger with no data") { + val sm = newStreamMetrics(source) + // Trigger 1 with data + sm.reportTriggerStarted(1) + sm.reportNumInputRows(Map(source -> 100L)) // 100 input rows + clock.advance(1000) + sm.reportTriggerFinished() + + // Trigger 2 with data + clock.advance(1000) + sm.reportTriggerStarted(2) + sm.reportNumInputRows(Map(source -> 200L)) // 200 input rows + clock.advance(500) + sm.reportTriggerFinished() + + // Make sure that all rates are set + require(sm.currentInputRate() === 100.0) // 200 input rows generated in 2 seconds b/w starts + require(sm.currentProcessingRate() === 400.0) // 200 output rows processed in 0.5 sec + require(sm.currentSourceInputRate(source) === 100.0) + require(sm.currentSourceProcessingRate(source) === 400.0) + require(sm.currentLatency().get === 1500.0) // 2000 ms / 2 + 500 ms + + // Trigger 3 with data + clock.advance(500) + sm.reportTriggerStarted(3) + clock.advance(500) + sm.reportTriggerFinished() + + // Rates are set to zero and latency is set to None + assert(sm.currentInputRate() === 0.0) + assert(sm.currentProcessingRate() === 0.0) + assert(sm.currentSourceInputRate(source) === 0.0) + assert(sm.currentSourceProcessingRate(source) === 0.0) + assert(sm.currentLatency() === None) + sm.stop() + } + + test("rates - after trigger with multiple sources, and one source having no info") { + val source1 = TestSource(1) + val source2 = TestSource(2) + val sm = newStreamMetrics(source1, source2) + // Trigger 1 with data + sm.reportTriggerStarted(1) + sm.reportNumInputRows(Map(source1 -> 100L, source2 -> 100L)) + clock.advance(1000) + sm.reportTriggerFinished() + + // Trigger 2 with data + clock.advance(1000) + sm.reportTriggerStarted(2) + sm.reportNumInputRows(Map(source1 -> 200L, source2 -> 200L)) + clock.advance(500) + sm.reportTriggerFinished() + + // Make sure that all rates are set + assert(sm.currentInputRate() === 200.0) // 200*2 input rows generated in 2 seconds b/w starts + assert(sm.currentProcessingRate() === 800.0) // 200*2 output rows processed in 0.5 sec + assert(sm.currentSourceInputRate(source1) === 100.0) + assert(sm.currentSourceInputRate(source2) === 100.0) + assert(sm.currentSourceProcessingRate(source1) === 400.0) + assert(sm.currentSourceProcessingRate(source2) === 400.0) + + // Trigger 3 with data + clock.advance(500) + sm.reportTriggerStarted(3) + clock.advance(500) + sm.reportNumInputRows(Map(source1 -> 200L)) + sm.reportTriggerFinished() + + // Rates are set to zero and latency is set to None + assert(sm.currentInputRate() === 200.0) + assert(sm.currentProcessingRate() === 400.0) + assert(sm.currentSourceInputRate(source1) === 200.0) + assert(sm.currentSourceInputRate(source2) === 0.0) + assert(sm.currentSourceProcessingRate(source1) === 400.0) + assert(sm.currentSourceProcessingRate(source2) === 0.0) + sm.stop() + } + + test("registered Codahale metrics") { + import scala.collection.JavaConverters._ + val sm = newStreamMetrics(source) + val gaugeNames = sm.metricRegistry.getGauges().keySet().asScala + + // so that all metrics are considered as a single metric group in Ganglia + assert(!gaugeNames.exists(_.contains("."))) + assert(gaugeNames === Set( + "inputRate-total", + "inputRate-source0", + "processingRate-total", + "processingRate-source0", + "latency")) + } + + private def newStreamMetrics(sources: Source*): StreamMetrics = { + new StreamMetrics(sources.toSet, clock, "test") + } + + private val clock = new ManualClock() + private val source = TestSource(0) + + case class TestSource(id: Int) extends Source { + override def schema: StructType = StructType(Array.empty[StructField]) + override def getOffset: Option[Offset] = Some(new LongOffset(0)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { null } + override def stop() {} + override def toString(): String = s"source$id" + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala new file mode 100644 index 0000000000000..5174a0415304c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.{IOException, OutputStreamWriter} +import java.net.ServerSocket +import java.sql.Timestamp +import java.util.concurrent.LinkedBlockingQueue + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + +class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { + import testImplicits._ + + override def afterEach() { + sqlContext.streams.active.foreach(_.stop()) + if (serverThread != null) { + serverThread.interrupt() + serverThread.join() + serverThread = null + } + if (source != null) { + source.stop() + source = null + } + } + + private var serverThread: ServerThread = null + private var source: Source = null + + test("basic usage") { + serverThread = new ServerThread() + serverThread.start() + + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) + val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 + assert(schema === StructType(StructField("value", StringType) :: Nil)) + + source = provider.createSource(sqlContext, "", None, "", parameters) + + failAfter(streamingTimeout) { + serverThread.enqueue("hello") + while (source.getOffset.isEmpty) { + Thread.sleep(10) + } + val offset1 = source.getOffset.get + val batch1 = source.getBatch(None, offset1) + assert(batch1.as[String].collect().toSeq === Seq("hello")) + + serverThread.enqueue("world") + while (source.getOffset.get === offset1) { + Thread.sleep(10) + } + val offset2 = source.getOffset.get + val batch2 = source.getBatch(Some(offset1), offset2) + assert(batch2.as[String].collect().toSeq === Seq("world")) + + val both = source.getBatch(None, offset2) + assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world")) + + // Try stopping the source to make sure this does not block forever. + source.stop() + source = null + } + } + + test("timestamped usage") { + serverThread = new ServerThread() + serverThread.start() + + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString, + "includeTimestamp" -> "true") + val schema = provider.sourceSchema(sqlContext, None, "", parameters)._2 + assert(schema === StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil)) + + source = provider.createSource(sqlContext, "", None, "", parameters) + + failAfter(streamingTimeout) { + serverThread.enqueue("hello") + while (source.getOffset.isEmpty) { + Thread.sleep(10) + } + val offset1 = source.getOffset.get + val batch1 = source.getBatch(None, offset1) + val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq + assert(batch1Seq.map(_._1) === Seq("hello")) + val batch1Stamp = batch1Seq(0)._2 + + serverThread.enqueue("world") + while (source.getOffset.get === offset1) { + Thread.sleep(10) + } + val offset2 = source.getOffset.get + val batch2 = source.getBatch(Some(offset1), offset2) + val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq + assert(batch2Seq.map(_._1) === Seq("world")) + val batch2Stamp = batch2Seq(0)._2 + assert(!batch2Stamp.before(batch1Stamp)) + + // Try stopping the source to make sure this does not block forever. + source.stop() + source = null + } + } + + test("params not given") { + val provider = new TextSocketSourceProvider + intercept[AnalysisException] { + provider.sourceSchema(sqlContext, None, "", Map()) + } + intercept[AnalysisException] { + provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost")) + } + intercept[AnalysisException] { + provider.sourceSchema(sqlContext, None, "", Map("port" -> "1234")) + } + } + + test("non-boolean includeTimestamp") { + val provider = new TextSocketSourceProvider + intercept[AnalysisException] { + provider.sourceSchema(sqlContext, None, "", Map("host" -> "localhost", + "port" -> "1234", "includeTimestamp" -> "fasle")) + } + } + + test("no server up") { + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> "0") + intercept[IOException] { + source = provider.createSource(sqlContext, "", None, "", parameters) + } + } + + test("input row metrics") { + serverThread = new ServerThread() + serverThread.start() + + val provider = new TextSocketSourceProvider + val parameters = Map("host" -> "localhost", "port" -> serverThread.port.toString) + source = provider.createSource(sqlContext, "", None, "", parameters) + + failAfter(streamingTimeout) { + serverThread.enqueue("hello") + while (source.getOffset.isEmpty) { + Thread.sleep(10) + } + val batch = source.getBatch(None, source.getOffset.get).as[String] + batch.collect() + val numRowsMetric = + batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows") + assert(numRowsMetric.nonEmpty) + assert(numRowsMetric.get.value === 1) + source.stop() + source = null + } + } + + private class ServerThread extends Thread with Logging { + private val serverSocket = new ServerSocket(0) + private val messageQueue = new LinkedBlockingQueue[String]() + + val port = serverSocket.getLocalPort + + override def run(): Unit = { + try { + val clientSocket = serverSocket.accept() + clientSocket.setTcpNoDelay(true) + val out = new OutputStreamWriter(clientSocket.getOutputStream) + while (true) { + val line = messageQueue.take() + out.write(line + "\n") + out.flush() + } + } catch { + case e: InterruptedException => + } finally { + serverSocket.close() + } + } + + def enqueue(line: String): Unit = { + messageQueue.put(line) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 6be94eb24fcf1..bd197be655d58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -27,10 +27,11 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} @@ -54,19 +55,18 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("versioning and immutability") { - withSpark(new SparkContext(sparkConf)) { sc => - val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = - makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( + spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -79,30 +79,30 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( - sc: SparkContext, + spark: SparkSession, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { - implicit val sqlContext = new SQLContext(sc) - makeRDD(sc, Seq("a")).mapPartitionsWithStateStore( + implicit val sqlContext = spark.sqlContext + makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) } // Generate RDDs and state store data - withSpark(new SparkContext(sparkConf)) { sc => + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => for (i <- 1 to 20) { - require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + require(makeStoreRDD(spark, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) } } // With a new context, try using the earlier state store data - withSpark(new SparkContext(sparkConf)) { sc => - assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + assert(makeStoreRDD(spark, Seq("a"), 20).collect().toSet === Set("a" -> 21)) } } test("usage with iterators - only gets and only puts") { - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + implicit val sqlContext = spark.sqlContext val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 @@ -130,15 +130,15 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } } - val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( + spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) - val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) - val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } @@ -149,8 +149,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") @@ -159,7 +159,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) require(rdd.partitions.length === 2) @@ -178,16 +178,20 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("distributed test") { quietly { - withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => - implicit val sqlContext = new SQLContext(sc) + + withSparkSession( + SparkSession.builder + .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) + .getOrCreate()) { spark => + implicit val sqlContext = spark.sqlContext val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index dd23925716b06..fcf300b3c81bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -47,8 +47,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth private val keySchema = StructType(Seq(StructField("key", StringType, true))) private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + before { + StateStore.stop() + require(!StateStore.isMaintenanceRunning) + } + after { StateStore.stop() + require(!StateStore.isMaintenanceRunning) } test("get, put, remove, commit, and all data iterator") { @@ -68,6 +74,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Verify state after updating put(store, "a", 1) + assert(store.numKeys() === 1) intercept[IllegalStateException] { store.iterator() } @@ -79,7 +86,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Make updates, commit and then verify state put(store, "b", 2) put(store, "aa", 3) + assert(store.numKeys() === 3) remove(store, _.startsWith("a")) + assert(store.numKeys() === 1) assert(store.commit() === 1) assert(store.hasCommitted) @@ -101,7 +110,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val reloadedProvider = new HDFSBackedStateStoreProvider( store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) val reloadedStore = reloadedProvider.getStore(1) + assert(reloadedStore.numKeys() === 1) put(reloadedStore, "c", 4) + assert(reloadedStore.numKeys() === 2) assert(reloadedStore.commit() === 2) assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) @@ -182,7 +193,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth provider.getStore(-1) } - // Prepare some data in the stoer + // Prepare some data in the store val store = provider.getStore(0) put(store, "a", 1) assert(store.commit() === 1) @@ -352,11 +363,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } } - ignore("maintenance") { + test("maintenance") { val conf = new SparkConf() .setMaster("local") .setAppName("test") + // Make maintenance thread do snapshots and cleanups very fast .set(StateStore.MAINTENANCE_INTERVAL_CONFIG, "10ms") + // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly' + // fails to talk to the StateStoreCoordinator and unloads all the StateStores .set("spark.rpc.numRetries", "1") val opId = 0 val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString @@ -366,31 +380,49 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val provider = new HDFSBackedStateStoreProvider( storeId, keySchema, valueSchema, storeConf, hadoopConf) + var latestStoreVersion = 0 + + def generateStoreVersions() { + for (i <- 1 to 20) { + val store = StateStore.get( + storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) + put(store, "a", i) + store.commit() + latestStoreVersion += 1 + } + } + quietly { withSpark(new SparkContext(conf)) { sc => withCoordinatorRef(sc) { coordinatorRef => - for (i <- 1 to 20) { - val store = StateStore.get( - storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) - put(store, "a", i) - store.commit() - } + require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running") + + // Generate sufficient versions of store for snapshots + generateStoreVersions() + eventually(timeout(10 seconds)) { + // Store should have been reported to the coordinator assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported") - } - // Background maintenance should clean up and generate snapshots - eventually(timeout(10 seconds)) { - // Earliest delta file should get cleaned up - assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + // Background maintenance should clean up and generate snapshots + assert(StateStore.isMaintenanceRunning, "Maintenance task is not running") // Some snapshots should have been generated - val snapshotVersions = (0 to 20).filter { version => + val snapshotVersions = (1 to latestStoreVersion).filter { version => fileExists(provider, version, isSnapshot = true) } assert(snapshotVersions.nonEmpty, "no snapshot file found") } + // Generate more versions such that there is another snapshot and + // the earliest delta file will be cleaned up + generateStoreVersions() + + // Earliest delta file should get cleaned up + eventually(timeout(10 seconds)) { + assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted") + } + // If driver decides to deactivate all instances of the store, then this instance // should be unloaded coordinatorRef.deactivateInstances(dir) @@ -399,7 +431,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) // If some other executor loads the store, then this instance should be unloaded @@ -409,15 +441,16 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Reload the store and verify - StateStore.get(storeId, keySchema, valueSchema, 20, storeConf, hadoopConf) + StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) } } // Verify if instance is unloaded if SparkContext is stopped - require(SparkEnv.get === null) eventually(timeout(10 seconds)) { + require(SparkEnv.get === null) assert(!StateStore.isLoaded(storeId)) + assert(!StateStore.isMaintenanceRunning) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 8572ed16aa261..6e60b0e4fad15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,17 +19,23 @@ package org.apache.spark.sql.execution.ui import java.util.Properties -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.mock import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rdd.RDD import org.apache.spark.scheduler._ -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.ui.SparkUI +import org.apache.spark.util.{AccumulatorMetadata, LongAccumulator} + class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { import testImplicits._ @@ -72,13 +78,13 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { ) private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { - val metrics = mock(classOf[TaskMetrics]) - when(metrics.accumulators()).thenReturn(accumulatorUpdates.map { case (id, update) => + val metrics = TaskMetrics.empty + accumulatorUpdates.foreach { case (id, update) => val acc = new LongAccumulator acc.metadata = AccumulatorMetadata(id, Some(""), true) - acc.setValue(update) - acc - }.toSeq) + acc.add(update) + metrics.registerAccumulator(acc) + } metrics } @@ -96,7 +102,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } } - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -237,7 +243,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -267,7 +273,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -308,7 +314,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -338,16 +344,16 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("SPARK-11126: no memory leak when running non SQL jobs") { - val previousStageNumber = sqlContext.listener.stageIdToStageMetrics.size - sqlContext.sparkContext.parallelize(1 to 10).foreach(i => ()) - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val previousStageNumber = spark.sharedState.listener.stageIdToStageMetrics.size + spark.sparkContext.parallelize(1 to 10).foreach(i => ()) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should ignore the non SQL stage - assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) + assert(spark.sharedState.listener.stageIdToStageMetrics.size == previousStageNumber) - sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + spark.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage - assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) + assert(spark.sharedState.listener.stageIdToStageMetrics.size == previousStageNumber + 1) } test("SPARK-13055: history listener only tracks SQL metrics") { @@ -363,9 +369,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { // This task has both accumulators that are SQL metrics and accumulators that are not. // The listener should only track the ones that are actually SQL metrics. val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella") - val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball") - val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None) - val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None) + val nonSqlMetric = sparkContext.longAccumulator("baseball") + val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.value), None) + val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.value), None) val taskInfo = createTaskInfo(0, 0) taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo) val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null) @@ -381,9 +387,55 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } // Listener tracks only SQL metrics, not other accumulators assert(trackedAccums.size === 1) - assert(trackedAccums.head === sqlMetricInfo) + assert(trackedAccums.head === (sqlMetricInfo.id, sqlMetricInfo.update.get)) } + test("driver side SQL metrics") { + val listener = new SQLListener(spark.sparkContext.conf) + val expectedAccumValue = 12345 + val physicalPlan = MyPlan(sqlContext.sparkContext, expectedAccumValue) + sqlContext.sparkContext.addSparkListener(listener) + val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { + override lazy val sparkPlan = physicalPlan + override lazy val executedPlan = physicalPlan + } + SQLExecution.withNewExecutionId(spark, dummyQueryExecution) { + physicalPlan.execute().collect() + } + + def waitTillExecutionFinished(): Unit = { + while (listener.getCompletedExecutions.isEmpty) { + Thread.sleep(100) + } + } + waitTillExecutionFinished() + + val driverUpdates = listener.getCompletedExecutions.head.driverAccumUpdates + assert(driverUpdates.size == 1) + assert(driverUpdates(physicalPlan.longMetric("dummy").id) == expectedAccumValue) + } + +} + + +/** + * A dummy [[org.apache.spark.sql.execution.SparkPlan]] that updates a [[SQLMetrics]] + * on the driver. + */ +private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExecNode { + override def sparkContext: SparkContext = sc + override def output: Seq[Attribute] = Seq() + + override val metrics: Map[String, SQLMetric] = Map( + "dummy" -> SQLMetrics.createMetric(sc, "dummy")) + + override def doExecute(): RDD[InternalRow] = { + longMetric("dummy") += expectedValue + sc.listenerBus.post(SparkListenerDriverAccumUpdates( + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY).toLong, + metrics.values.map(m => m.id -> m.value).toSeq)) + sc.emptyRDD + } } @@ -398,9 +450,9 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly val sc = new SparkContext(conf) try { - SQLContext.clearSqlListener() - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + SparkSession.sqlListener.set(null) + val spark = new SparkSession(sc) + import spark.implicits._ // Run 100 successful executions and 100 failed executions. // Each execution only has one job and one stage. for (i <- 0 until 100) { @@ -416,12 +468,12 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { } } sc.listenerBus.waitUntilEmpty(10000) - assert(sqlContext.listener.getCompletedExecutions.size <= 50) - assert(sqlContext.listener.getFailedExecutions.size <= 50) + assert(spark.sharedState.listener.getCompletedExecutions.size <= 50) + assert(spark.sharedState.listener.getFailedExecutions.size <= 50) // 50 for successful executions and 50 for failed executions - assert(sqlContext.listener.executionIdToData.size <= 100) - assert(sqlContext.listener.jobIdToExecutionId.size <= 100) - assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + assert(spark.sharedState.listener.executionIdToData.size <= 100) + assert(spark.sharedState.listener.jobIdToExecutionId.size <= 100) + assert(spark.sharedState.listener.stageIdToStageMetrics.size <= 100) } finally { sc.stop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index a63007fc3bf25..e3943f31a48ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution.vectorized import java.nio.charset.StandardCharsets +import java.nio.ByteBuffer +import java.nio.ByteOrder import scala.collection.JavaConverters._ import scala.collection.mutable @@ -280,6 +282,13 @@ class ColumnarBatchSuite extends SparkFunSuite { Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET, 2.234) Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, 1.123) + if (ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)) { + // Ensure array contains Liitle Endian doubles + var bb = ByteBuffer.wrap(buffer).order(ByteOrder.LITTLE_ENDIAN) + Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET, bb.getDouble(0)) + Platform.putDouble(buffer, Platform.BYTE_ARRAY_OFFSET + 8, bb.getDouble(8)) + } + column.putDoubles(idx, 1, buffer, 8) column.putDoubles(idx + 1, 1, buffer, 0) reference += 1.123 @@ -778,4 +787,23 @@ class ColumnarBatchSuite extends SparkFunSuite { } } } + + test("exceeding maximum capacity should throw an error") { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val column = ColumnVector.allocate(1, ByteType, memMode) + column.MAX_CAPACITY = 15 + column.appendBytes(5, 0.toByte) + // Successfully allocate twice the requested capacity + assert(column.capacity == 10) + column.appendBytes(10, 0.toByte) + // Allocated capacity doesn't exceed MAX_CAPACITY + assert(column.capacity == 15) + val ex = intercept[RuntimeException] { + // Over-allocating beyond MAX_CAPACITY throws an exception + column.appendBytes(10, 0.toByte) + } + assert(ex.getMessage.contains(s"Cannot reserve additional contiguous bytes in the " + + s"vectorized reader")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala new file mode 100644 index 0000000000000..d826d3f54d922 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder + +class ReduceAggregatorSuite extends SparkFunSuite { + + test("zero value") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt) + assert(aggregator.zero == (false, null)) + } + + test("reduce, merge and finish") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt) + + val firstReduce = aggregator.reduce(aggregator.zero, 1) + assert(firstReduce == (true, 1)) + + val secondReduce = aggregator.reduce(firstReduce, 2) + assert(secondReduce == (true, 3)) + + val thirdReduce = aggregator.reduce(secondReduce, 3) + assert(thirdReduce == (true, 6)) + + val mergeWithZero1 = aggregator.merge(aggregator.zero, firstReduce) + assert(mergeWithZero1 == (true, 1)) + + val mergeWithZero2 = aggregator.merge(secondReduce, aggregator.zero) + assert(mergeWithZero2 == (true, 3)) + + val mergeTwoReduced = aggregator.merge(firstReduce, secondReduce) + assert(mergeTwoReduced == (true, 4)) + + assert(aggregator.finish(firstReduce)== 1) + assert(aggregator.finish(secondReduce) == 3) + assert(aggregator.finish(thirdReduce) == 6) + assert(aggregator.finish(mergeWithZero1) == 1) + assert(aggregator.finish(mergeWithZero2) == 3) + assert(aggregator.finish(mergeTwoReduced) == 4) + } + + test("requires at least one input row") { + val encoder: ExpressionEncoder[Int] = ExpressionEncoder() + val func = (v1: Int, v2: Int) => v1 + v2 + val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt) + + intercept[IllegalStateException] { + aggregator.finish(aggregator.zero) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 986d8f514f2fb..e62ae38cd35a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.internal import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalog.{Column, Database, Function, Table} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.Range @@ -37,13 +37,12 @@ class CatalogSuite with BeforeAndAfterEach with SharedSQLContext { - private def sparkSession: SparkSession = sqlContext.sparkSession - private def sessionCatalog: SessionCatalog = sparkSession.sessionState.catalog + private def sessionCatalog: SessionCatalog = spark.sessionState.catalog private val utils = new CatalogTestUtils { override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" - override def newEmptyCatalog(): ExternalCatalog = sparkSession.sharedState.externalCatalog + override def newEmptyCatalog(): ExternalCatalog = spark.sharedState.externalCatalog } private def createDatabase(name: String): Unit = { @@ -59,7 +58,7 @@ class CatalogSuite } private def createTempTable(name: String): Unit = { - sessionCatalog.createTempTable(name, Range(1, 2, 3, 4, Seq()), overrideIfExists = true) + sessionCatalog.createTempView(name, Range(1, 2, 3, 4), overrideIfExists = true) } private def dropTable(name: String, db: Option[String] = None): Unit = { @@ -87,8 +86,8 @@ class CatalogSuite private def testListColumns(tableName: String, dbName: Option[String]): Unit = { val tableMetadata = sessionCatalog.getTableMetadata(TableIdentifier(tableName, dbName)) val columns = dbName - .map { db => sparkSession.catalog.listColumns(db, tableName) } - .getOrElse { sparkSession.catalog.listColumns(tableName) } + .map { db => spark.catalog.listColumns(db, tableName) } + .getOrElse { spark.catalog.listColumns(tableName) } assume(tableMetadata.schema.nonEmpty, "bad test") assume(tableMetadata.partitionColumnNames.nonEmpty, "bad test") assume(tableMetadata.bucketColumnNames.nonEmpty, "bad test") @@ -108,85 +107,85 @@ class CatalogSuite } test("current database") { - assert(sparkSession.catalog.currentDatabase == "default") + assert(spark.catalog.currentDatabase == "default") assert(sessionCatalog.getCurrentDatabase == "default") createDatabase("my_db") - sparkSession.catalog.setCurrentDatabase("my_db") - assert(sparkSession.catalog.currentDatabase == "my_db") + spark.catalog.setCurrentDatabase("my_db") + assert(spark.catalog.currentDatabase == "my_db") assert(sessionCatalog.getCurrentDatabase == "my_db") val e = intercept[AnalysisException] { - sparkSession.catalog.setCurrentDatabase("unknown_db") + spark.catalog.setCurrentDatabase("unknown_db") } assert(e.getMessage.contains("unknown_db")) } test("list databases") { - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == Set("default")) + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default")) createDatabase("my_db1") createDatabase("my_db2") - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default", "my_db1", "my_db2")) dropDatabase("my_db1") - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default", "my_db2")) } test("list tables") { - assert(sparkSession.catalog.listTables().collect().isEmpty) + assert(spark.catalog.listTables().collect().isEmpty) createTable("my_table1") createTable("my_table2") createTempTable("my_temp_table") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table1", "my_table2", "my_temp_table")) dropTable("my_table1") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_temp_table") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == Set("my_table2")) + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2")) } test("list tables with database") { - assert(sparkSession.catalog.listTables("default").collect().isEmpty) + assert(spark.catalog.listTables("default").collect().isEmpty) createDatabase("my_db1") createDatabase("my_db2") createTable("my_table1", Some("my_db1")) createTable("my_table2", Some("my_db2")) createTempTable("my_temp_table") - assert(sparkSession.catalog.listTables("default").collect().map(_.name).toSet == + assert(spark.catalog.listTables("default").collect().map(_.name).toSet == Set("my_temp_table")) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet == Set("my_table1", "my_temp_table")) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_table1", Some("my_db1")) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet == Set("my_temp_table")) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_temp_table") - assert(sparkSession.catalog.listTables("default").collect().map(_.name).isEmpty) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).isEmpty) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("default").collect().map(_.name).isEmpty) + assert(spark.catalog.listTables("my_db1").collect().map(_.name).isEmpty) + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2")) val e = intercept[AnalysisException] { - sparkSession.catalog.listTables("unknown_db") + spark.catalog.listTables("unknown_db") } assert(e.getMessage.contains("unknown_db")) } test("list functions") { assert(Set("+", "current_database", "window").subsetOf( - sparkSession.catalog.listFunctions().collect().map(_.name).toSet)) + spark.catalog.listFunctions().collect().map(_.name).toSet)) createFunction("my_func1") createFunction("my_func2") createTempFunction("my_temp_func") - val funcNames1 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet + val funcNames1 = spark.catalog.listFunctions().collect().map(_.name).toSet assert(funcNames1.contains("my_func1")) assert(funcNames1.contains("my_func2")) assert(funcNames1.contains("my_temp_func")) dropFunction("my_func1") dropTempFunction("my_temp_func") - val funcNames2 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet + val funcNames2 = spark.catalog.listFunctions().collect().map(_.name).toSet assert(!funcNames2.contains("my_func1")) assert(funcNames2.contains("my_func2")) assert(!funcNames2.contains("my_temp_func")) @@ -194,30 +193,38 @@ class CatalogSuite test("list functions with database") { assert(Set("+", "current_database", "window").subsetOf( - sparkSession.catalog.listFunctions("default").collect().map(_.name).toSet)) + spark.catalog.listFunctions().collect().map(_.name).toSet)) createDatabase("my_db1") createDatabase("my_db2") createFunction("my_func1", Some("my_db1")) createFunction("my_func2", Some("my_db2")) createTempFunction("my_temp_func") - val funcNames1 = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet - val funcNames2 = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet + val funcNames1 = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet + val funcNames2 = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet assert(funcNames1.contains("my_func1")) assert(!funcNames1.contains("my_func2")) assert(funcNames1.contains("my_temp_func")) assert(!funcNames2.contains("my_func1")) assert(funcNames2.contains("my_func2")) assert(funcNames2.contains("my_temp_func")) + + // Make sure database is set properly. + assert( + spark.catalog.listFunctions("my_db1").collect().map(_.database).toSet == Set("my_db1", null)) + assert( + spark.catalog.listFunctions("my_db2").collect().map(_.database).toSet == Set("my_db2", null)) + + // Remove the function and make sure they no longer appear. dropFunction("my_func1", Some("my_db1")) dropTempFunction("my_temp_func") - val funcNames1b = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet - val funcNames2b = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet + val funcNames1b = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet + val funcNames2b = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet assert(!funcNames1b.contains("my_func1")) assert(!funcNames1b.contains("my_temp_func")) assert(funcNames2b.contains("my_func2")) assert(!funcNames2b.contains("my_temp_func")) val e = intercept[AnalysisException] { - sparkSession.catalog.listFunctions("unknown_db") + spark.catalog.listFunctions("unknown_db") } assert(e.getMessage.contains("unknown_db")) } @@ -227,6 +234,11 @@ class CatalogSuite testListColumns("tab1", dbName = None) } + test("list columns in temporary table") { + createTempTable("temp1") + spark.catalog.listColumns("temp1") + } + test("list columns in database") { createDatabase("db1") createTable("tab1", Some("db1")) @@ -249,9 +261,11 @@ class CatalogSuite } test("Function.toString") { - assert(new Function("nama", "commenta", "classNameAh", isTemporary = true).toString == - "Function[name='nama', description='commenta', className='classNameAh', isTemporary='true']") - assert(new Function("nama", null, "classNameAh", isTemporary = false).toString == + assert( + new Function("nama", "databasa", "commenta", "classNameAh", isTemporary = true).toString == + "Function[name='nama', database='databasa', description='commenta', " + + "className='classNameAh', isTemporary='true']") + assert(new Function("nama", null, null, "classNameAh", isTemporary = false).toString == "Function[name='nama', className='classNameAh', isTemporary='false']") } @@ -266,6 +280,41 @@ class CatalogSuite "nullable='false', isPartition='true', isBucket='true']") } + test("catalog classes format in Dataset.show") { + val db = new Database("nama", "descripta", "locata") + val table = new Table("nama", "databasa", "descripta", "typa", isTemporary = false) + val function = new Function("nama", "databasa", "descripta", "classa", isTemporary = false) + val column = new Column( + "nama", "descripta", "typa", nullable = false, isPartition = true, isBucket = true) + val dbFields = ScalaReflection.getConstructorParameterValues(db) + val tableFields = ScalaReflection.getConstructorParameterValues(table) + val functionFields = ScalaReflection.getConstructorParameterValues(function) + val columnFields = ScalaReflection.getConstructorParameterValues(column) + assert(dbFields == Seq("nama", "descripta", "locata")) + assert(tableFields == Seq("nama", "databasa", "descripta", "typa", false)) + assert(functionFields == Seq("nama", "databasa", "descripta", "classa", false)) + assert(columnFields == Seq("nama", "descripta", "typa", false, true, true)) + val dbString = CatalogImpl.makeDataset(Seq(db), spark).showString(10) + val tableString = CatalogImpl.makeDataset(Seq(table), spark).showString(10) + val functionString = CatalogImpl.makeDataset(Seq(function), spark).showString(10) + val columnString = CatalogImpl.makeDataset(Seq(column), spark).showString(10) + dbFields.foreach { f => assert(dbString.contains(f.toString)) } + tableFields.foreach { f => assert(tableString.contains(f.toString)) } + functionFields.foreach { f => assert(functionString.contains(f.toString)) } + columnFields.foreach { f => assert(columnString.contains(f.toString)) } + } + + test("dropTempView should not un-cache and drop metastore table if a same-name table exists") { + withTable("same_name") { + spark.range(10).write.saveAsTable("same_name") + sql("CACHE TABLE same_name") + assert(spark.catalog.isCached("default.same_name")) + spark.catalog.dropTempView("same_name") + assert(spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + assert(spark.catalog.isCached("default.same_name")) + } + } + // TODO: add tests for the rest of them } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index cc6919913948d..95bfd05c1f260 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -153,6 +153,17 @@ class SQLConfEntrySuite extends SparkFunSuite { assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d", "e")) } + test("optionalConf") { + val key = "spark.sql.SQLConfEntrySuite.optional" + val confEntry = SQLConfigBuilder(key) + .stringConf + .createOptional + + assert(conf.getConf(confEntry) === None) + conf.setConfString(key, "a") + assert(conf.getConf(confEntry) === Some("a")) + } + test("duplicate entry") { val key = "spark.sql.SQLConfEntrySuite.duplicate" SQLConfigBuilder(key).stringConf.createOptional diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index b87f482941419..41e011b8448d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.internal -import org.apache.spark.sql.{QueryTest, SparkSession, SQLContext} +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext} +import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} +import org.apache.spark.util.Utils class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" @@ -27,67 +31,135 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(sparkContext) + val newContext = new SQLContext(SparkSession.builder().sparkContext(sparkContext).getOrCreate()) assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { // Set a conf first. - sqlContext.setConf(testKey, testVal) + spark.conf.set(testKey, testVal) // Clear the conf. - sqlContext.conf.clear() + spark.sessionState.conf.clear() // After clear, only overrideConfs used by unit test should be in the SQLConf. - assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) + assert(spark.conf.getAll === TestSQLContext.overrideConfs) - sqlContext.setConf(testKey, testVal) - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) + spark.conf.set(testKey, testVal) + assert(spark.conf.get(testKey) === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.getAll.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) + assert(spark.conf.get(testKey) === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.getAll.contains(testKey)) - sqlContext.conf.clear() + spark.sessionState.conf.clear() } test("parse SQL set commands") { - sqlContext.conf.clear() + spark.sessionState.conf.clear() sql(s"set $testKey=$testVal") - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) sql("set some.property=20") - assert(sqlContext.getConf("some.property", "0") === "20") + assert(spark.conf.get("some.property", "0") === "20") sql("set some.property = 40") - assert(sqlContext.getConf("some.property", "0") === "40") + assert(spark.conf.get("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(sqlContext.getConf(key, "0") === vs) + assert(spark.conf.get(key, "0") === vs) sql(s"set $key=") - assert(sqlContext.getConf(key, "0") === "") + assert(spark.conf.get(key, "0") === "") - sqlContext.conf.clear() + spark.sessionState.conf.clear() + } + + test("set command for display") { + spark.sessionState.conf.clear() + checkAnswer( + sql("SET").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"), + Nil) + + checkAnswer( + sql("SET -v").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"), + Row("spark.sql.groupByOrdinal", "true")) + + sql("SET spark.sql.groupByOrdinal=false") + + checkAnswer( + sql("SET").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"), + Row("spark.sql.groupByOrdinal", "false")) + + checkAnswer( + sql("SET -v").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"), + Row("spark.sql.groupByOrdinal", "false")) } test("deprecated property") { - sqlContext.conf.clear() - val original = sqlContext.conf.numShufflePartitions - try{ + spark.sessionState.conf.clear() + val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) + try { sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(sqlContext.conf.numShufflePartitions === 10) + assert(spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) === 10) } finally { sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") } } + test("reset - public conf") { + spark.sessionState.conf.clear() + val original = spark.conf.get(SQLConf.GROUP_BY_ORDINAL) + try { + assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === true) + sql(s"set ${SQLConf.GROUP_BY_ORDINAL.key}=false") + assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === false) + assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 1) + sql(s"reset") + assert(spark.conf.get(SQLConf.GROUP_BY_ORDINAL) === true) + assert(sql(s"set").where(s"key = '${SQLConf.GROUP_BY_ORDINAL.key}'").count() == 0) + } finally { + sql(s"set ${SQLConf.GROUP_BY_ORDINAL}=$original") + } + } + + test("reset - internal conf") { + spark.sessionState.conf.clear() + val original = spark.conf.get(SQLConf.NATIVE_VIEW) + try { + assert(spark.conf.get(SQLConf.NATIVE_VIEW) === true) + sql(s"set ${SQLConf.NATIVE_VIEW.key}=false") + assert(spark.conf.get(SQLConf.NATIVE_VIEW) === false) + assert(sql(s"set").where(s"key = '${SQLConf.NATIVE_VIEW.key}'").count() == 1) + sql(s"reset") + assert(spark.conf.get(SQLConf.NATIVE_VIEW) === true) + assert(sql(s"set").where(s"key = '${SQLConf.NATIVE_VIEW.key}'").count() == 0) + } finally { + sql(s"set ${SQLConf.NATIVE_VIEW}=$original") + } + } + + test("reset - user-defined conf") { + spark.sessionState.conf.clear() + val userDefinedConf = "x.y.z.reset" + try { + assert(spark.conf.getOption(userDefinedConf).isEmpty) + sql(s"set $userDefinedConf=false") + assert(spark.conf.get(userDefinedConf) === "false") + assert(sql(s"set").where(s"key = '$userDefinedConf'").count() == 1) + sql(s"reset") + assert(spark.conf.getOption(userDefinedConf).isEmpty) + } finally { + spark.conf.unset(userDefinedConf) + } + } + test("invalid conf value") { - sqlContext.conf.clear() + spark.sessionState.conf.clear() val e = intercept[IllegalArgumentException] { sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } @@ -95,35 +167,35 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") { - sqlContext.conf.clear() + spark.sessionState.conf.clear() - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") - assert(sqlContext.conf.targetPostShuffleInputSize === 100) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") - assert(sqlContext.conf.targetPostShuffleInputSize === 1024) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1024) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") - assert(sqlContext.conf.targetPostShuffleInputSize === 1048576) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1048576) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") - assert(sqlContext.conf.targetPostShuffleInputSize === 1073741824) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1073741824) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") - assert(sqlContext.conf.targetPostShuffleInputSize === -1) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === -1) // Test overflow exception intercept[IllegalArgumentException] { // This value exceeds Long.MaxValue - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") } intercept[IllegalArgumentException] { // This value less than Long.MinValue - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") } - sqlContext.conf.clear() + spark.sessionState.conf.clear() } test("SparkSession can access configs set in SparkConf") { @@ -139,4 +211,47 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } } + test("default value of WAREHOUSE_PATH") { + + val original = spark.conf.get(SQLConf.WAREHOUSE_PATH) + try { + // to get the default value, always unset it + spark.conf.unset(SQLConf.WAREHOUSE_PATH.key) + // JVM adds a trailing slash if the directory exists and leaves it as-is, if it doesn't + // In our comparison, strip trailing slash off of both sides, to account for such cases + assert(new Path(Utils.resolveURI("spark-warehouse")).toString.stripSuffix("/") === spark + .sessionState.conf.warehousePath.stripSuffix("/")) + } finally { + sql(s"set ${SQLConf.WAREHOUSE_PATH}=$original") + } + } + + test("MAX_CASES_BRANCHES") { + withTable("tab1") { + spark.range(10).write.saveAsTable("tab1") + val sql_one_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 END FROM tab1" + val sql_two_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 ELSE 0 END FROM tab1" + + withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "0") { + assert(!sql(sql_one_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + assert(!sql(sql_two_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + } + + withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "1") { + assert(sql(sql_one_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + assert(!sql(sql_two_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + } + + withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "2") { + assert(sql(sql_one_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + assert(sql(sql_two_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 783511b781096..1a6dba82b0e27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -24,12 +24,13 @@ import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -83,7 +84,7 @@ class JDBCSuite extends SparkFunSuite |CREATE TEMPORARY TABLE fetchtwo |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', - | fetchSize '2') + | ${JdbcUtils.JDBC_BATCH_FETCH_SIZE} '2') """.stripMargin.replaceAll("\n", " ")) sql( @@ -184,6 +185,16 @@ class JDBCSuite extends SparkFunSuite "insert into test.emp values ('kathy', null, null)").executeUpdate() conn.commit() + conn.prepareStatement( + "create table test.seq(id INTEGER)").executeUpdate() + (0 to 6).foreach { value => + conn.prepareStatement( + s"insert into test.seq values ($value)").executeUpdate() + } + conn.prepareStatement( + "insert into test.seq values (null)").executeUpdate() + conn.commit() + sql( s""" |CREATE TEMPORARY TABLE nullparts @@ -278,7 +289,7 @@ class JDBCSuite extends SparkFunSuite assert(names(2).equals("mary")) } - test("SELECT first field when fetchSize is two") { + test("SELECT first field when fetchsize is two") { val names = sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) assert(names.size === 3) assert(names(0).equals("fred")) @@ -294,7 +305,7 @@ class JDBCSuite extends SparkFunSuite assert(ids(2) === 3) } - test("SELECT second field when fetchSize is two") { + test("SELECT second field when fetchsize is two") { val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) assert(ids(0) === 1) @@ -337,42 +348,108 @@ class JDBCSuite extends SparkFunSuite } test("Basic API") { - assert(sqlContext.read.jdbc( - urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) + assert(spark.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", new Properties()).collect().length === 3) + } + + test("Basic API with illegal fetchsize") { + val properties = new Properties() + properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "-1") + val e = intercept[SparkException] { + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect() + }.getMessage + assert(e.contains("Invalid value `-1` for parameter `fetchsize`")) } test("Basic API with FetchSize") { - val properties = new Properties - properties.setProperty("fetchSize", "2") - assert(sqlContext.read.jdbc( - urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) + (0 to 4).foreach { size => + val properties = new Properties() + properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, size.toString) + assert(spark.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) + } } test("Partitioning via JDBCPartitioningInfo API") { assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties()) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties()) .collect().length === 3) } test("Partitioning on column that might have null values.") { assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties()) .collect().length === 4) assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties()) .collect().length === 4) // partitioning on a nullable quoted column assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties()) .collect().length === 4) } + test("Partitioning on column where numPartitions is zero") { + val res = spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 0, + upperBound = 4, + numPartitions = 0, + connectionProperties = new Properties() + ) + assert(res.count() === 8) + } + + test("Partitioning on column where numPartitions are more than the number of total rows") { + val res = spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 1, + upperBound = 5, + numPartitions = 10, + connectionProperties = new Properties() + ) + assert(res.count() === 8) + } + + test("Partitioning on column where lowerBound is equal to upperBound") { + val res = spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 5, + upperBound = 5, + numPartitions = 4, + connectionProperties = new Properties() + ) + assert(res.count() === 8) + } + + test("Partitioning on column where lowerBound is larger than upperBound") { + val e = intercept[IllegalArgumentException] { + spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 5, + upperBound = 1, + numPartitions = 3, + connectionProperties = new Properties() + ) + }.getMessage + assert(e.contains("Operation not allowed: the lower bound of partitioning column " + + "is larger than the upper bound. Lower bound: 5; Upper bound: 1")) + } + test("SELECT * on partitioned table with a nullable partition column") { assert(sql("SELECT * FROM nullparts").collect().size == 4) } @@ -429,9 +506,9 @@ class JDBCSuite extends SparkFunSuite } test("test DATE types") { - val rows = sqlContext.read.jdbc( - urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = spark.read.jdbc( + urlWithUserAndPass, "TEST.TIMETYPES", new Properties()).collect() + val cachedRows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties()) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -439,17 +516,17 @@ class JDBCSuite extends SparkFunSuite } test("test DATE types in cache") { - val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) - .cache().registerTempTable("mycached_date") + val rows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties()).collect() + spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties()) + .cache().createOrReplaceTempView("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } test("test types for null value") { - val rows = sqlContext.read.jdbc( - urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() + val rows = spark.read.jdbc( + urlWithUserAndPass, "TEST.NULLTYPES", new Properties()).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -495,7 +572,7 @@ class JDBCSuite extends SparkFunSuite test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties()) assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) @@ -600,6 +677,15 @@ class JDBCSuite extends SparkFunSuite assert(derbyDialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "BOOLEAN") } + test("OracleDialect jdbc type mapping") { + val oracleDialect = JdbcDialects.get("jdbc:oracle") + val metadata = new MetadataBuilder().putString("name", "test_column").putLong("scale", -127) + assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "float", 1, metadata) == + Some(DecimalType(DecimalType.MAX_PRECISION, 10))) + assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "numeric", 0, null) == + Some(DecimalType(DecimalType.MAX_PRECISION, 10))) + } + test("table exists query by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") @@ -620,7 +706,7 @@ class JDBCSuite extends SparkFunSuite // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); val date = java.sql.Date.valueOf("1995-01-01") - val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val jdbcDf = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties()) val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(0).getAs[java.sql.Timestamp](2) @@ -630,7 +716,7 @@ class JDBCSuite extends SparkFunSuite test("test credentials in the properties are not in plan output") { val df = sql("SELECT * FROM parts") val explain = ExplainCommand(df.queryExecution.logical, extended = true) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + spark.sessionState.executePlan(explain).executedPlan.executeCollect().foreach { r => assert(!List("testPass", "testUser").exists(r.toString.contains)) } // test the JdbcRelation toString output @@ -640,9 +726,9 @@ class JDBCSuite extends SparkFunSuite } test("test credentials in the connection url are not in the plan output") { - val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties()) val explain = ExplainCommand(df.queryExecution.logical, extended = true) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + spark.sessionState.executePlan(explain).executedPlan.executeCollect().foreach { r => assert(!List("testPass", "testUser").exists(r.toString.contains)) } } @@ -652,4 +738,44 @@ class JDBCSuite extends SparkFunSuite assert(oracleDialect.getJDBCType(StringType). map(_.databaseTypeDefinition).get == "VARCHAR2(255)") } + + private def assertEmptyQuery(sqlString: String): Unit = { + assert(sql(sqlString).collect().isEmpty) + } + + test("SPARK-15916: JDBC filter operator push down should respect operator precedence") { + val TRUE = "NAME != 'non_exists'" + val FALSE1 = "THEID > 1000000000" + val FALSE2 = "THEID < -1000000000" + + assertEmptyQuery(s"SELECT * FROM foobar WHERE ($TRUE OR $FALSE1) AND $FALSE2") + assertEmptyQuery(s"SELECT * FROM foobar WHERE $FALSE1 AND ($FALSE2 OR $TRUE)") + + // Tests JDBCPartition whereClause clause push down. + withTempView("tempFrame") { + val jdbcPartitionWhereClause = s"$FALSE1 OR $TRUE" + val df = spark.read.jdbc( + urlWithUserAndPass, + "TEST.PEOPLE", + predicates = Array[String](jdbcPartitionWhereClause), + new Properties()) + + df.createOrReplaceTempView("tempFrame") + assertEmptyQuery(s"SELECT * FROM tempFrame where $FALSE2") + } + } + + test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") { + val df = spark.createDataset(Seq("a", "b", "c")).toDF("order") + val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") + assert(schema.contains("`order` TEXT")) + } + + test("SPARK-17673: Exchange reuse respects differences in output schema") { + val df = sql("SELECT * FROM inttypes WHERE a IS NOT NULL") + val df1 = df.groupBy("a").agg("c" -> "min") + val df2 = df.groupBy("a").agg("d" -> "min") + val res = df1.union(df2) + assert(res.distinct().count() == 2) // would be 1 if the exchange was incorrectly reused + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133b..2c6449fa6870b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -22,7 +22,9 @@ import java.util.Properties import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkException import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -57,14 +59,14 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { sql( s""" - |CREATE TEMPORARY TABLE PEOPLE + |CREATE OR REPLACE TEMPORARY VIEW PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) sql( s""" - |CREATE TEMPORARY TABLE PEOPLE1 + |CREATE OR REPLACE TEMPORARY VIEW PEOPLE1 |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) @@ -88,67 +90,91 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties()) + assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count()) assert( - 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + 2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).collect()(0).length) + } + + test("Basic CREATE with illegal batchsize") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + (-1 to 0).foreach { size => + val properties = new Properties() + properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString) + val e = intercept[SparkException] { + df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties) + }.getMessage + assert(e.contains(s"Invalid value `$size` for parameter `batchsize`")) + } + } + + test("Basic CREATE with batchsize") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + (1 to 3).foreach { size => + val properties = new Properties() + properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString) + df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties) + assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count()) + } } test("CREATE with overwrite") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) + assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) - df.write.jdbc(url, "TEST.APPENDTEST", new Properties) - df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + df.write.jdbc(url, "TEST.APPENDTEST", new Properties()) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties()) + assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count()) + assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) - df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) + df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties()) intercept[org.apache.spark.SparkException] { - df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties()) } } test("INSERT to JDBC Datasource") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index cb88a1c83c999..ea1f7a5a3c344 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -17,198 +17,219 @@ package org.apache.spark.sql.sources -import java.io.{File, IOException} +import java.io.File -import org.scalatest.BeforeAndAfter +import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.AnalysisException +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { - protected override lazy val sql = caseInsensitiveContext.sql _ +class CreateTableAsSelectSuite + extends DataSourceTest + with SharedSQLContext + with BeforeAndAfterEach { + + protected override lazy val sql = spark.sql _ private var path: File = null override def beforeAll(): Unit = { super.beforeAll() - path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - caseInsensitiveContext.read.json(rdd).registerTempTable("jt") + spark.read.json(rdd).createOrReplaceTempView("jt") } override def afterAll(): Unit = { try { - caseInsensitiveContext.dropTempTable("jt") + spark.catalog.dropTempView("jt") + Utils.deleteRecursively(path) } finally { super.afterAll() } } - after { + override def beforeEach(): Unit = { + super.beforeEach() + path = Utils.createTempDir() + path.delete() + } + + override def afterEach(): Unit = { Utils.deleteRecursively(path) + super.afterEach() } - test("CREATE TEMPORARY TABLE AS SELECT") { - sql( - s""" - |CREATE TEMPORARY TABLE jsonTable - |USING json - |OPTIONS ( - | path '${path.toString}' - |) AS - |SELECT a, b FROM jt - """.stripMargin) - - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - sql("SELECT a, b FROM jt").collect()) - - caseInsensitiveContext.dropTempTable("jsonTable") + test("CREATE TABLE USING AS SELECT") { + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING json + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt")) + } } - test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") { + test("CREATE TABLE USING AS SELECT based on the file without write permission") { val childPath = new File(path.toString, "child") path.mkdir() - childPath.createNewFile() path.setWritable(false) - val e = intercept[IOException] { + val e = intercept[SparkException] { sql( s""" - |CREATE TEMPORARY TABLE jsonTable + |CREATE TABLE jsonTable |USING json |OPTIONS ( - | path '${path.toString}' + | path '${childPath.toString}' |) AS |SELECT a, b FROM jt - """.stripMargin) + """.stripMargin) sql("SELECT a, b FROM jsonTable").collect() } - assert(e.getMessage().contains("Unable to clear output directory")) + assert(e.getMessage().contains("Job aborted")) path.setWritable(true) } test("create a table, drop it and create another one with the same name") { - sql( - s""" - |CREATE TEMPORARY TABLE jsonTable - |USING json - |OPTIONS ( - | path '${path.toString}' - |) AS - |SELECT a, b FROM jt - """.stripMargin) - - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - sql("SELECT a, b FROM jt").collect()) - - val message = intercept[AnalysisException]{ + withTable("jsonTable") { sql( s""" - |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING json - |OPTIONS ( - | path '${path.toString}' - |) AS - |SELECT a * 4 FROM jt - """.stripMargin) - }.getMessage - assert( - message.contains(s"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."), - "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") - - // Overwrite the temporary table. - sql( - s""" - |CREATE TEMPORARY TABLE jsonTable - |USING json - |OPTIONS ( - | path '${path.toString}' - |) AS - |SELECT a * 4 FROM jt - """.stripMargin) - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT a * 4 FROM jt").collect()) - - caseInsensitiveContext.dropTempTable("jsonTable") - // Explicitly delete the data. - if (path.exists()) Utils.deleteRecursively(path) - - sql( - s""" - |CREATE TEMPORARY TABLE jsonTable - |USING json - |OPTIONS ( - | path '${path.toString}' - |) AS - |SELECT b FROM jt - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT b FROM jt").collect()) - - caseInsensitiveContext.dropTempTable("jsonTable") - } + |CREATE TABLE jsonTable + |USING json + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a, b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + sql("SELECT a, b FROM jt")) + + // Creates a table of the same name with flag "if not exists", nothing happens + sql( + s""" + |CREATE TABLE IF NOT EXISTS jsonTable + |USING json + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT a * 4 FROM jt + """.stripMargin) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT a, b FROM jt")) - test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { - val message = intercept[AnalysisException]{ + // Explicitly drops the table and deletes the underlying data. + sql("DROP TABLE jsonTable") + if (path.exists()) Utils.deleteRecursively(path) + + // Creates a table of the same name again, this time we succeed. sql( s""" - |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable - |USING json - |OPTIONS ( - | path '${path.toString}' - |) AS - |SELECT b FROM jt - """.stripMargin) - }.getMessage - assert( - message.contains("a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."), - "CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.") + |CREATE TABLE jsonTable + |USING json + |OPTIONS ( + | path '${path.toString}' + |) AS + |SELECT b FROM jt + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT b FROM jt")) + } + } + + test("disallows CREATE TEMPORARY TABLE ... USING ... AS query") { + withTable("t") { + val error = intercept[ParseException] { + sql( + s""" + |CREATE TEMPORARY TABLE t USING PARQUET + |OPTIONS (PATH '${path.toString}') + |PARTITIONED BY (a) + |AS SELECT 1 AS a, 2 AS b + """.stripMargin + ) + }.getMessage + assert(error.contains("Operation not allowed") && + error.contains("CREATE TEMPORARY TABLE ... USING ... AS query")) + } } - test("a CTAS statement with column definitions is not allowed") { - intercept[AnalysisException]{ + test("disallows CREATE EXTERNAL TABLE ... USING ... AS query") { + withTable("t") { + val error = intercept[ParseException] { + sql( + s""" + |CREATE EXTERNAL TABLE t USING PARQUET + |OPTIONS (PATH '${path.toString}') + |AS SELECT 1 AS a, 2 AS b + """.stripMargin + ) + }.getMessage + + assert(error.contains("Operation not allowed") && + error.contains("CREATE EXTERNAL TABLE ... USING")) + } + } + + test("create table using as select - with partitioned by") { + val catalog = spark.sessionState.catalog + withTable("t") { sql( s""" - |CREATE TEMPORARY TABLE jsonTable (a int, b string) - |USING json - |OPTIONS ( - | path '${path.toString}' - |) AS - |SELECT a, b FROM jt - """.stripMargin) + |CREATE TABLE t USING PARQUET + |OPTIONS (PATH '${path.toString}') + |PARTITIONED BY (a) + |AS SELECT 1 AS a, 2 AS b + """.stripMargin + ) + val table = catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.getPartitionColumnsFromTableProperties(table) == Seq("a")) } } - test("it is not allowed to write to a table while querying it.") { - sql( - s""" - |CREATE TEMPORARY TABLE jsonTable - |USING json - |OPTIONS ( - | path '${path.toString}' - |) AS - |SELECT a, b FROM jt - """.stripMargin) - - val message = intercept[AnalysisException] { + test("create table using as select - with bucket") { + val catalog = spark.sessionState.catalog + withTable("t") { sql( s""" - |CREATE TEMPORARY TABLE jsonTable - |USING json - |OPTIONS ( - | path '${path.toString}' - |) AS - |SELECT a, b FROM jsonTable - """.stripMargin) - }.getMessage - assert( - message.contains("Cannot overwrite table "), - "Writing to a table while querying it should not be allowed.") + |CREATE TABLE t USING PARQUET + |OPTIONS (PATH '${path.toString}') + |CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS + |AS SELECT 1 AS a, 2 AS b + """.stripMargin + ) + val table = catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.getBucketSpecFromTableProperties(table) == + Some(BucketSpec(5, Seq("a"), Seq("b")))) + } + } + + test("SPARK-17409: CTAS of decimal calculation") { + withTable("tab2") { + withTempView("tab1") { + spark.range(99, 101).createOrReplaceTempView("tab1") + val sqlStmt = + "SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1" + sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt") + checkAnswer(spark.table("tab2"), sql(sqlStmt)) + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala index 853707c036c9a..85ba33e58a787 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -27,24 +27,25 @@ class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext { test("data sources with the same name") { intercept[RuntimeException] { - caseInsensitiveContext.read.format("Fluet da Bomb").load() + spark.read.format("Fluet da Bomb").load() } } test("load data source from format alias") { - caseInsensitiveContext.read.format("gathering quorum").load().schema == + spark.read.format("gathering quorum").load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) } test("specify full classname with duplicate formats") { - caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne") + spark.read.format("org.apache.spark.sql.sources.FakeSourceOne") .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))) } - test("should fail to load ORC without HiveContext") { - intercept[ClassNotFoundException] { - caseInsensitiveContext.read.format("orc").load() + test("should fail to load ORC without Hive Support") { + val e = intercept[AnalysisException] { + spark.read.format("orc").load() } + assert(e.message.contains("The ORC data source must be used with Hive support enabled")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 612cfc7ec7bd6..c2aedfff348a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -41,7 +41,7 @@ case class SimpleDDLScan( table: String)(@transient val sparkSession: SparkSession) extends BaseRelation with TableScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType(Seq( @@ -78,7 +78,7 @@ case class SimpleDDLScan( } class DDLTestSuite extends DataSourceTest with SharedSQLContext { - protected override lazy val sql = caseInsensitiveContext.sql _ + protected override lazy val sql = spark.sql _ override def beforeAll(): Unit = { super.beforeAll() @@ -98,21 +98,21 @@ class DDLTestSuite extends DataSourceTest with SharedSQLContext { "describe ddlPeople", Seq( Row("intType", "int", "test comment test1"), - Row("stringType", "string", ""), - Row("dateType", "date", ""), - Row("timestampType", "timestamp", ""), - Row("doubleType", "double", ""), - Row("bigintType", "bigint", ""), - Row("tinyintType", "tinyint", ""), - Row("decimalType", "decimal(10,0)", ""), - Row("fixedDecimalType", "decimal(5,1)", ""), - Row("binaryType", "binary", ""), - Row("booleanType", "boolean", ""), - Row("smallIntType", "smallint", ""), - Row("floatType", "float", ""), - Row("mapType", "map", ""), - Row("arrayType", "array", ""), - Row("structType", "struct", "") + Row("stringType", "string", null), + Row("dateType", "date", null), + Row("timestampType", "timestamp", null), + Row("doubleType", "double", null), + Row("bigintType", "bigint", null), + Row("tinyintType", "tinyint", null), + Row("decimalType", "decimal(10,0)", null), + Row("fixedDecimalType", "decimal(5,1)", null), + Row("binaryType", "binary", null), + Row("booleanType", "boolean", null), + Row("smallIntType", "smallint", null), + Row("floatType", "float", null), + Row("mapType", "map", null), + Row("arrayType", "array", null), + Row("structType", "struct", null) )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala new file mode 100644 index 0000000000000..448adcf11d656 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal} +import org.apache.spark.sql.execution.datasources.DataSourceAnalysis +import org.apache.spark.sql.types.{IntegerType, StructType} + +class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var targetAttributes: Seq[Attribute] = _ + private var targetPartitionSchema: StructType = _ + + override def beforeAll(): Unit = { + targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int) + targetPartitionSchema = new StructType() + .add("b", IntegerType) + .add("c", IntegerType) + } + + private def checkProjectList(actual: Seq[Expression], expected: Seq[Expression]): Unit = { + // Remove aliases since we have no control on their exprId. + val withoutAliases = actual.map { + case alias: Alias => alias.child + case other => other + } + assert(withoutAliases === expected) + } + + Seq(true, false).foreach { caseSensitive => + val rule = DataSourceAnalysis(SimpleCatalystConf(caseSensitive)) + test( + s"convertStaticPartitions only handle INSERT having at least static partitions " + + s"(caseSensitive: $caseSensitive)") { + intercept[AssertionError] { + rule.convertStaticPartitions( + sourceAttributes = Seq('e.int, 'f.int), + providedPartitions = Map("b" -> None, "c" -> None), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + } + } + + test(s"Missing columns (caseSensitive: $caseSensitive)") { + // Missing columns. + intercept[AnalysisException] { + rule.convertStaticPartitions( + sourceAttributes = Seq('e.int), + providedPartitions = Map("b" -> Some("1"), "c" -> None), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + } + } + + test(s"Missing partitioning columns (caseSensitive: $caseSensitive)") { + // Missing partitioning columns. + intercept[AnalysisException] { + rule.convertStaticPartitions( + sourceAttributes = Seq('e.int, 'f.int), + providedPartitions = Map("b" -> Some("1")), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + } + + // Missing partitioning columns. + intercept[AnalysisException] { + rule.convertStaticPartitions( + sourceAttributes = Seq('e.int, 'f.int, 'g.int), + providedPartitions = Map("b" -> Some("1")), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + } + + // Wrong partitioning columns. + intercept[AnalysisException] { + rule.convertStaticPartitions( + sourceAttributes = Seq('e.int, 'f.int), + providedPartitions = Map("b" -> Some("1"), "d" -> None), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + } + } + + test(s"Wrong partitioning columns (caseSensitive: $caseSensitive)") { + // Wrong partitioning columns. + intercept[AnalysisException] { + rule.convertStaticPartitions( + sourceAttributes = Seq('e.int, 'f.int), + providedPartitions = Map("b" -> Some("1"), "d" -> Some("2")), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + } + + // Wrong partitioning columns. + intercept[AnalysisException] { + rule.convertStaticPartitions( + sourceAttributes = Seq('e.int), + providedPartitions = Map("b" -> Some("1"), "c" -> Some("3"), "d" -> Some("2")), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + } + + if (caseSensitive) { + // Wrong partitioning columns. + intercept[AnalysisException] { + rule.convertStaticPartitions( + sourceAttributes = Seq('e.int, 'f.int), + providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + } + } + } + + test( + s"Static partitions need to appear before dynamic partitions" + + s" (caseSensitive: $caseSensitive)") { + // Static partitions need to appear before dynamic partitions. + intercept[AnalysisException] { + rule.convertStaticPartitions( + sourceAttributes = Seq('e.int, 'f.int), + providedPartitions = Map("b" -> None, "c" -> Some("3")), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + } + } + + test(s"All static partitions (caseSensitive: $caseSensitive)") { + if (!caseSensitive) { + val nonPartitionedAttributes = Seq('e.int, 'f.int) + val expected = nonPartitionedAttributes ++ + Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + val actual = rule.convertStaticPartitions( + sourceAttributes = nonPartitionedAttributes, + providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + checkProjectList(actual, expected) + } + + { + val nonPartitionedAttributes = Seq('e.int, 'f.int) + val expected = nonPartitionedAttributes ++ + Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + val actual = rule.convertStaticPartitions( + sourceAttributes = nonPartitionedAttributes, + providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + checkProjectList(actual, expected) + } + + // Test the case having a single static partition column. + { + val nonPartitionedAttributes = Seq('e.int, 'f.int) + val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType)) + val actual = rule.convertStaticPartitions( + sourceAttributes = nonPartitionedAttributes, + providedPartitions = Map("b" -> Some("1")), + targetAttributes = Seq('a.int, 'd.int, 'b.int), + targetPartitionSchema = new StructType().add("b", IntegerType)) + checkProjectList(actual, expected) + } + } + + test(s"Static partition and dynamic partition (caseSensitive: $caseSensitive)") { + val nonPartitionedAttributes = Seq('e.int, 'f.int) + val dynamicPartitionAttributes = Seq('g.int) + val expected = + nonPartitionedAttributes ++ + Seq(Cast(Literal("1"), IntegerType)) ++ + dynamicPartitionAttributes + val actual = rule.convertStaticPartitions( + sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes, + providedPartitions = Map("b" -> Some("1"), "c" -> None), + targetAttributes = targetAttributes, + targetPartitionSchema = targetPartitionSchema) + checkProjectList(actual, expected) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 92061133cd49b..cc77d3c4b91ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -18,20 +18,12 @@ package org.apache.spark.sql.sources import org.apache.spark.sql._ -import org.apache.spark.sql.internal.SQLConf private[sql] abstract class DataSourceTest extends QueryTest { - // We want to test some edge cases. - protected lazy val caseInsensitiveContext: SQLContext = { - val ctx = new SQLContext(sqlContext.sparkContext) - ctx.setConf(SQLConf.CASE_SENSITIVE, false) - ctx - } - protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) { test(sqlString) { - checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer) + checkAnswer(spark.sql(sqlString), expectedAnswer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 51d04f2f4efc6..be56c964a18f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -40,7 +40,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S extends BaseRelation with PrunedFilteredScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType( @@ -133,13 +133,13 @@ object ColumnsRequired { } class FilteredScanSuite extends DataSourceTest with SharedSQLContext with PredicateHelper { - protected override lazy val sql = caseInsensitiveContext.sql _ + protected override lazy val sql = spark.sql _ override def beforeAll(): Unit = { super.beforeAll() sql( """ - |CREATE TEMPORARY TABLE oneToTenFiltered + |CREATE TEMPORARY VIEW oneToTenFiltered |USING org.apache.spark.sql.sources.FilteredScanSource |OPTIONS ( | from '1', @@ -310,7 +310,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic test(s"PushDown Returns $expectedCount: $sqlString") { // These tests check a particular plan, disable whole stage codegen. - caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, false) + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, false) try { val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { @@ -322,7 +322,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic val rawCount = rawPlan.execute().count() assert(ColumnsRequired.set === requiredColumnNames) - val table = caseInsensitiveContext.table("oneToTenFiltered") + val table = spark.table("oneToTenFiltered") val relation = table.queryExecution.logical.collectFirst { case LogicalRelation(r, _, _) => r }.get @@ -337,7 +337,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic queryExecution) } } finally { - caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.defaultValue.get) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 5ac39f54b91ce..6454d716ec0db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -20,18 +20,19 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with SharedSQLContext { - protected override lazy val sql = caseInsensitiveContext.sql _ + protected override lazy val sql = spark.sql _ private var path: File = null override def beforeAll(): Unit = { super.beforeAll() path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - caseInsensitiveContext.read.json(rdd).registerTempTable("jt") + spark.read.json(rdd).createOrReplaceTempView("jt") sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) @@ -44,8 +45,8 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { override def afterAll(): Unit = { try { - caseInsensitiveContext.dropTempTable("jsonTable") - caseInsensitiveContext.dropTempTable("jt") + spark.catalog.dropTempView("jsonTable") + spark.catalog.dropTempView("jt") Utils.deleteRecursively(path) } finally { super.afterAll() @@ -87,15 +88,13 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } test("SELECT clause generating a different number of columns is not allowed.") { - val message = intercept[RuntimeException] { + val message = intercept[AnalysisException] { sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt """.stripMargin) }.getMessage - assert( - message.contains("generates the same number of columns as its schema"), - "SELECT clause generating a different number of columns should not be not allowed." + assert(message.contains("the number of columns are different") ) } @@ -111,7 +110,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 5) - caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") + spark.read.json(rdd1).createOrReplaceTempView("jt1") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 @@ -123,7 +122,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""), 10) - caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") + spark.read.json(rdd2).createOrReplaceTempView("jt2") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 @@ -142,8 +141,8 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { (1 to 10).map(i => Row(i * 10, s"str$i")) ) - caseInsensitiveContext.dropTempTable("jt1") - caseInsensitiveContext.dropTempTable("jt2") + spark.catalog.dropTempView("jt1") + spark.catalog.dropTempView("jt2") } test("INSERT INTO JSONRelation for now") { @@ -185,7 +184,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt """.stripMargin) // Cached Query Execution - caseInsensitiveContext.cacheTable("jsonTable") + spark.catalog.cacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable")) checkAnswer( sql("SELECT * FROM jsonTable"), @@ -226,7 +225,7 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { // sql("SELECT a * 2, b FROM jt").collect()) // // // Verify uncaching -// caseInsensitiveContext.uncacheTable("jsonTable") +// spark.catalog.uncacheTable("jsonTable") // assertCached(sql("SELECT * FROM jsonTable"), 0) } @@ -257,6 +256,30 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { "It is not allowed to insert into a table that is not an InsertableRelation." ) - caseInsensitiveContext.dropTempTable("oneToTen") + spark.catalog.dropTempView("oneToTen") + } + + test("SPARK-15824 - Execute an INSERT wrapped in a WITH statement immediately") { + withTable("target", "target2") { + sql(s"CREATE TABLE target(a INT, b STRING) USING JSON") + sql("WITH tbl AS (SELECT * FROM jt) INSERT OVERWRITE TABLE target SELECT a, b FROM tbl") + checkAnswer( + sql("SELECT a, b FROM target"), + sql("SELECT a, b FROM jt") + ) + + sql(s"CREATE TABLE target2(a INT, b STRING) USING JSON") + val e = sql( + """ + |WITH tbl AS (SELECT * FROM jt) + |FROM tbl + |INSERT INTO target2 SELECT a, b WHERE a <= 5 + |INSERT INTO target2 SELECT a, b WHERE a > 5 + """.stripMargin) + checkAnswer( + sql("SELECT a, b FROM target2"), + sql("SELECT a, b FROM jt") + ) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index a9b1970a7c393..a2decadbe0444 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -29,11 +29,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val df = sqlContext.range(100).select($"id", lit(1).as("data")) + val df = spark.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - sqlContext.read.load(path.getCanonicalPath), + spark.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val base = sqlContext.range(100) + val base = spark.range(100) val df = base.union(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - sqlContext.read.load(path.getCanonicalPath), + spark.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -58,7 +58,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => val path = f.getAbsolutePath Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path) - assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i")) + assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index cd0256db43aad..fb6123d1cc4b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -37,7 +37,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sparkSession: Spa extends BaseRelation with PrunedScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType( @@ -56,13 +56,13 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sparkSession: Spa } class PrunedScanSuite extends DataSourceTest with SharedSQLContext { - protected override lazy val sql = caseInsensitiveContext.sql _ + protected override lazy val sql = spark.sql _ override def beforeAll(): Unit = { super.beforeAll() sql( """ - |CREATE TEMPORARY TABLE oneToTenPruned + |CREATE TEMPORARY VIEW oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource |OPTIONS ( | from '1', @@ -122,7 +122,7 @@ class PrunedScanSuite extends DataSourceTest with SharedSQLContext { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { // These tests check a particular plan, disable whole stage codegen. - caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, false) + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, false) try { val queryExecution = sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { @@ -145,7 +145,7 @@ class PrunedScanSuite extends DataSourceTest with SharedSQLContext { fail(s"Wrong output row. Got $rawOutput\n$queryExecution") } } finally { - caseInsensitiveContext.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.defaultValue.get) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 4f6df5441736e..76ffb949f1293 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.DataSource class ResolvedDataSourceSuite extends SparkFunSuite { @@ -27,53 +28,62 @@ class ResolvedDataSourceSuite extends SparkFunSuite { test("jdbc") { assert( getProvidingClass("jdbc") === - classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider]) assert( getProvidingClass("org.apache.spark.sql.execution.datasources.jdbc") === - classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider]) assert( getProvidingClass("org.apache.spark.sql.jdbc") === - classOf[org.apache.spark.sql.execution.datasources.jdbc.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider]) } test("json") { assert( getProvidingClass("json") === - classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat]) assert( getProvidingClass("org.apache.spark.sql.execution.datasources.json") === - classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat]) assert( getProvidingClass("org.apache.spark.sql.json") === - classOf[org.apache.spark.sql.execution.datasources.json.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat]) } test("parquet") { assert( getProvidingClass("parquet") === - classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat]) assert( getProvidingClass("org.apache.spark.sql.execution.datasources.parquet") === - classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat]) assert( getProvidingClass("org.apache.spark.sql.parquet") === - classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) + classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat]) + } + + test("csv") { + assert( + getProvidingClass("csv") === + classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat]) + assert( + getProvidingClass("com.databricks.spark.csv") === + classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat]) } test("error message for unknown data sources") { - val error1 = intercept[ClassNotFoundException] { + val error1 = intercept[AnalysisException] { getProvidingClass("avro") } - assert(error1.getMessage.contains("spark-packages")) + assert(error1.getMessage.contains("Failed to find data source: avro.")) - val error2 = intercept[ClassNotFoundException] { + val error2 = intercept[AnalysisException] { getProvidingClass("com.databricks.spark.avro") } - assert(error2.getMessage.contains("spark-packages")) + assert(error2.getMessage.contains("Failed to find data source: com.databricks.spark.avro.")) val error3 = intercept[ClassNotFoundException] { getProvidingClass("asfdwefasdfasdf") } - assert(error3.getMessage.contains("spark-packages")) + assert(error3.getMessage.contains("Failed to find data source: asfdwefasdfasdf.")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index bb2c54aa64977..b1756c27fae0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -28,26 +28,26 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter { - protected override lazy val sql = caseInsensitiveContext.sql _ + protected override lazy val sql = spark.sql _ private var originalDefaultSource: String = null private var path: File = null private var df: DataFrame = null override def beforeAll(): Unit = { super.beforeAll() - originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName + originalDefaultSource = spark.sessionState.conf.defaultDataSourceName path = Utils.createTempDir() path.delete() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = caseInsensitiveContext.read.json(rdd) - df.registerTempTable("jsonTable") + df = spark.read.json(rdd) + df.createOrReplaceTempView("jsonTable") } override def afterAll(): Unit = { try { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, originalDefaultSource) } finally { super.afterAll() } @@ -58,45 +58,42 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA } def checkLoad(expectedDF: DataFrame = df, tbl: String = "jsonTable"): Unit = { - caseInsensitiveContext.conf.setConf( - SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - checkAnswer(caseInsensitiveContext.read.load(path.toString), expectedDF.collect()) + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "org.apache.spark.sql.json") + checkAnswer(spark.read.load(path.toString), expectedDF.collect()) // Test if we can pick up the data source name passed in load. - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "not a source name") + checkAnswer(spark.read.format("json").load(path.toString), expectedDF.collect()) - checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), + checkAnswer(spark.read.format("json").load(path.toString), expectedDF.collect()) val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( - caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), + spark.read.format("json").schema(schema).load(path.toString), sql(s"SELECT b FROM $tbl").collect()) } test("save with path and load") { - caseInsensitiveContext.conf.setConf( - SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "org.apache.spark.sql.json") df.write.save(path.toString) checkLoad() } test("save with string mode and path, and load") { - caseInsensitiveContext.conf.setConf( - SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "org.apache.spark.sql.json") path.createNewFile() df.write.mode("overwrite").save(path.toString) checkLoad() } test("save with path and datasource, and load") { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "not a source name") df.write.json(path.toString) checkLoad() } test("save with data source and options, and load") { - caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + spark.conf.set(SQLConf.DEFAULT_DATA_SOURCE_NAME.key, "not a source name") df.write.mode(SaveMode.ErrorIfExists).json(path.toString) checkLoad() } @@ -123,7 +120,7 @@ class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndA // verify the append mode df.write.mode(SaveMode.Append).json(path.toString) val df2 = df.union(df) - df2.registerTempTable("jsonTable2") + df2.createOrReplaceTempView("jsonTable2") checkLoad(df2, "jsonTable2") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 34b8726a922fb..0fa0706a10b1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -38,7 +38,7 @@ class SimpleScanSource extends RelationProvider { case class SimpleScan(from: Int, to: Int)(@transient val sparkSession: SparkSession) extends BaseRelation with TableScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType(StructField("i", IntegerType, nullable = false) :: Nil) @@ -70,7 +70,7 @@ case class AllDataTypesScan( extends BaseRelation with TableScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = userSpecifiedSchema @@ -106,7 +106,7 @@ case class AllDataTypesScan( } class TableScanSuite extends DataSourceTest with SharedSQLContext { - protected override lazy val sql = caseInsensitiveContext.sql _ + protected override lazy val sql = spark.sql _ private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( @@ -137,7 +137,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { super.beforeAll() sql( """ - |CREATE TEMPORARY TABLE oneToTen + |CREATE TEMPORARY VIEW oneToTen |USING org.apache.spark.sql.sources.SimpleScanSource |OPTIONS ( | From '1', @@ -149,7 +149,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { sql( """ - |CREATE TEMPORARY TABLE tableWithSchema ( + |CREATE TEMPORARY VIEW tableWithSchema ( |`string$%Field` stRIng, |binaryField binary, |`booleanField` boolean, @@ -241,7 +241,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { Nil ) - assert(expectedSchema == caseInsensitiveContext.table("tableWithSchema").schema) + assert(expectedSchema == spark.table("tableWithSchema").schema) checkAnswer( sql( @@ -297,7 +297,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { test("Caching") { // Cached Query Execution - caseInsensitiveContext.cacheTable("oneToTen") + spark.catalog.cacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen")) checkAnswer( sql("SELECT * FROM oneToTen"), @@ -325,14 +325,14 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching - caseInsensitiveContext.uncacheTable("oneToTen") + spark.catalog.uncacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen"), 0) } test("defaultSource") { sql( """ - |CREATE TEMPORARY TABLE oneToTenDef + |CREATE TEMPORARY VIEW oneToTenDef |USING org.apache.spark.sql.sources |OPTIONS ( | from '1', @@ -351,7 +351,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { val schemaNotAllowed = intercept[Exception] { sql( """ - |CREATE TEMPORARY TABLE relationProvierWithSchema (i int) + |CREATE TEMPORARY VIEW relationProvierWithSchema (i int) |USING org.apache.spark.sql.sources.SimpleScanSource |OPTIONS ( | From '1', @@ -364,7 +364,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { val schemaNeeded = intercept[Exception] { sql( """ - |CREATE TEMPORARY TABLE schemaRelationProvierWithoutSchema + |CREATE TEMPORARY VIEW schemaRelationProvierWithoutSchema |USING org.apache.spark.sql.sources.AllDataTypesScanSource |OPTIONS ( | From '1', @@ -378,7 +378,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext { test("SPARK-5196 schema field with comment") { sql( """ - |CREATE TEMPORARY TABLE student(name string comment "SN", age int comment "SA", grade int) + |CREATE TEMPORARY VIEW student(name string comment "SN", age int comment "SA", grade int) |USING org.apache.spark.sql.sources.AllDataTypesScanSource |OPTIONS ( | from '1', diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala deleted file mode 100644 index 3be0ea481dc53..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQuerySuite.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.apache.spark.SparkException -import org.apache.spark.sql.StreamTest -import org.apache.spark.sql.execution.streaming.{CompositeOffset, LongOffset, MemoryStream, StreamExecution} -import org.apache.spark.sql.test.SharedSQLContext - -class ContinuousQuerySuite extends StreamTest with SharedSQLContext { - - import AwaitTerminationTester._ - import testImplicits._ - - testQuietly("lifecycle states and awaitTermination") { - val inputData = MemoryStream[Int] - val mapped = inputData.toDS().map { 6 / _} - - testStream(mapped)( - AssertOnQuery(_.isActive === true), - AssertOnQuery(_.exception.isEmpty), - AddData(inputData, 1, 2), - CheckAnswer(6, 3), - TestAwaitTermination(ExpectBlocked), - TestAwaitTermination(ExpectBlocked, timeoutMs = 2000), - TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = false), - StopStream, - AssertOnQuery(_.isActive === false), - AssertOnQuery(_.exception.isEmpty), - TestAwaitTermination(ExpectNotBlocked), - TestAwaitTermination(ExpectNotBlocked, timeoutMs = 2000, expectedReturnValue = true), - TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = true), - StartStream, - AssertOnQuery(_.isActive === true), - AddData(inputData, 0), - ExpectFailure[SparkException], - AssertOnQuery(_.isActive === false), - TestAwaitTermination(ExpectException[SparkException]), - TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000), - TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10), - AssertOnQuery( - q => - q.exception.get.startOffset.get === q.committedOffsets.toCompositeOffset(Seq(inputData)), - "incorrect start offset on exception") - ) - } - - testQuietly("source and sink statuses") { - val inputData = MemoryStream[Int] - val mapped = inputData.toDS().map(6 / _) - - testStream(mapped)( - AssertOnQuery(_.sourceStatuses.length === 1), - AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")), - AssertOnQuery(_.sourceStatuses(0).offset === None), - AssertOnQuery(_.sinkStatus.description.contains("Memory")), - AssertOnQuery(_.sinkStatus.offset === new CompositeOffset(None :: Nil)), - AddData(inputData, 1, 2), - CheckAnswer(6, 3), - AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(0))), - AssertOnQuery(_.sinkStatus.offset === CompositeOffset.fill(LongOffset(0))), - AddData(inputData, 1, 2), - CheckAnswer(6, 3, 6, 3), - AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(1))), - AssertOnQuery(_.sinkStatus.offset === CompositeOffset.fill(LongOffset(1))), - AddData(inputData, 0), - ExpectFailure[SparkException], - AssertOnQuery(_.sourceStatuses(0).offset === Some(LongOffset(2))), - AssertOnQuery(_.sinkStatus.offset === CompositeOffset.fill(LongOffset(1))) - ) - } - - /** - * A [[StreamAction]] to test the behavior of `ContinuousQuery.awaitTermination()`. - * - * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) - * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested - * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used - */ - case class TestAwaitTermination( - expectedBehavior: ExpectedBehavior, - timeoutMs: Int = -1, - expectedReturnValue: Boolean = false - ) extends AssertOnQuery( - TestAwaitTermination.assertOnQueryCondition(expectedBehavior, timeoutMs, expectedReturnValue), - "Error testing awaitTermination behavior" - ) { - override def toString(): String = { - s"TestAwaitTermination($expectedBehavior, timeoutMs = $timeoutMs, " + - s"expectedReturnValue = $expectedReturnValue)" - } - } - - object TestAwaitTermination { - - /** - * Tests the behavior of `ContinuousQuery.awaitTermination`. - * - * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) - * @param timeoutMs Timeout in milliseconds - * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) - * When timeoutMs > 0, awaitTermination(timeoutMs) is tested - * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used - */ - def assertOnQueryCondition( - expectedBehavior: ExpectedBehavior, - timeoutMs: Int, - expectedReturnValue: Boolean - )(q: StreamExecution): Boolean = { - - def awaitTermFunc(): Unit = { - if (timeoutMs <= 0) { - q.awaitTermination() - } else { - val returnedValue = q.awaitTermination(timeoutMs) - assert(returnedValue === expectedReturnValue, "Returned value does not match expected") - } - } - AwaitTerminationTester.test(expectedBehavior, awaitTermFunc) - true // If the control reached here, then everything worked as expected - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 8cf5dedabcee1..9d0a2b3d5b462 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,33 +17,271 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.StreamTest -import org.apache.spark.sql.execution.streaming.MemoryStream +import java.io.File + +import org.apache.commons.io.FileUtils +import org.apache.commons.io.filefilter.{DirectoryFileFilter, RegexFileFilter} + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.DataSourceScanExec +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.streaming.{FileStreamSinkWriter, MemoryStream, MetadataLogFileCatalog} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils -class FileStreamSinkSuite extends StreamTest with SharedSQLContext { +class FileStreamSinkSuite extends StreamTest { import testImplicits._ - test("unpartitioned writing") { + + test("FileStreamSinkWriter - unpartitioned data") { + val path = Utils.createTempDir() + path.delete() + + val hadoopConf = spark.sparkContext.hadoopConfiguration + val fileFormat = new parquet.ParquetFileFormat() + + def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { + val df = spark + .range(start, end, 1, numPartitions) + .select($"id", lit(100).as("data")) + val writer = new FileStreamSinkWriter( + df, fileFormat, path.toString, partitionColumnNames = Nil, hadoopConf, Map.empty) + writer.write().map(_.path.stripPrefix("file://")) + } + + // Write and check whether new files are written correctly + val files1 = writeRange(0, 10, 2) + assert(files1.size === 2, s"unexpected number of files: $files1") + checkFilesExist(path, files1, "file not written") + checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100))) + + // Append and check whether new files are written correctly and old files still exist + val files2 = writeRange(10, 20, 3) + assert(files2.size === 3, s"unexpected number of files: $files2") + assert(files2.intersect(files1).isEmpty, "old files returned") + checkFilesExist(path, files2, s"New file not written") + checkFilesExist(path, files1, s"Old file not found") + checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100))) + } + + test("FileStreamSinkWriter - partitioned data") { + implicit val e = ExpressionEncoder[java.lang.Long] + val path = Utils.createTempDir() + path.delete() + + val hadoopConf = spark.sparkContext.hadoopConfiguration + val fileFormat = new parquet.ParquetFileFormat() + + def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { + val df = spark + .range(start, end, 1, numPartitions) + .flatMap(x => Iterator(x, x, x)).toDF("id") + .select($"id", lit(100).as("data1"), lit(1000).as("data2")) + + require(df.rdd.partitions.size === numPartitions) + val writer = new FileStreamSinkWriter( + df, fileFormat, path.toString, partitionColumnNames = Seq("id"), hadoopConf, Map.empty) + writer.write().map(_.path.stripPrefix("file://")) + } + + def checkOneFileWrittenPerKey(keys: Seq[Int], filesWritten: Seq[String]): Unit = { + keys.foreach { id => + assert( + filesWritten.count(_.contains(s"/id=$id/")) == 1, + s"no file for id=$id. all files: \n\t${filesWritten.mkString("\n\t")}" + ) + } + } + + // Write and check whether new files are written correctly + val files1 = writeRange(0, 10, 2) + assert(files1.size === 10, s"unexpected number of files:\n${files1.mkString("\n")}") + checkFilesExist(path, files1, "file not written") + checkOneFileWrittenPerKey(0 until 10, files1) + + val answer1 = (0 until 10).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) + checkAnswer(spark.read.load(path.getCanonicalPath), answer1) + + // Append and check whether new files are written correctly and old files still exist + val files2 = writeRange(0, 20, 3) + assert(files2.size === 20, s"unexpected number of files:\n${files2.mkString("\n")}") + assert(files2.intersect(files1).isEmpty, "old files returned") + checkFilesExist(path, files2, s"New file not written") + checkFilesExist(path, files1, s"Old file not found") + checkOneFileWrittenPerKey(0 until 20, files2) + + val answer2 = (0 until 20).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) + checkAnswer(spark.read.load(path.getCanonicalPath), answer1 ++ answer2) + } + + test("FileStreamSink - unpartitioned writing and batch reading") { val inputData = MemoryStream[Int] val df = inputData.toDF() val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath - val query = - df.write - .format("parquet") - .option("checkpointLocation", checkpointDir) - .startStream(outputDir) + var query: StreamingQuery = null + + try { + query = + df.writeStream + .option("checkpointLocation", checkpointDir) + .format("parquet") + .start(outputDir) + + inputData.addData(1, 2, 3) + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + + val outputDf = spark.read.parquet(outputDir).as[Int] + checkDatasetUnorderly(outputDf, 1, 2, 3) + + } finally { + if (query != null) { + query.stop() + } + } + } + + test("FileStreamSink - partitioned writing and batch reading") { + val inputData = MemoryStream[Int] + val ds = inputData.toDS() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + var query: StreamingQuery = null + + try { + query = + ds.map(i => (i, i * 1000)) + .toDF("id", "value") + .writeStream + .partitionBy("id") + .option("checkpointLocation", checkpointDir) + .format("parquet") + .start(outputDir) + + inputData.addData(1, 2, 3) + failAfter(streamingTimeout) { + query.processAllAvailable() + } + + val outputDf = spark.read.parquet(outputDir) + val expectedSchema = new StructType() + .add(StructField("value", IntegerType)) + .add(StructField("id", IntegerType)) + assert(outputDf.schema === expectedSchema) + + // Verify that MetadataLogFileCatalog is being used and the correct partitioning schema has + // been inferred + val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect { + case LogicalRelation(baseRelation, _, _) if baseRelation.isInstanceOf[HadoopFsRelation] => + baseRelation.asInstanceOf[HadoopFsRelation] + } + assert(hadoopdFsRelations.size === 1) + assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileCatalog]) + assert(hadoopdFsRelations.head.partitionSchema.exists(_.name == "id")) + assert(hadoopdFsRelations.head.dataSchema.exists(_.name == "value")) + + // Verify the data is correctly read + checkDatasetUnorderly( + outputDf.as[(Int, Int)], + (1000, 1), (2000, 2), (3000, 3)) - inputData.addData(1, 2, 3) - failAfter(streamingTimeout) { query.processAllAvailable() } + /** Check some condition on the partitions of the FileScanRDD generated by a DF */ + def checkFileScanPartitions(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = { + val getFileScanRDD = df.queryExecution.executedPlan.collect { + case scan: DataSourceScanExec if scan.rdd.isInstanceOf[FileScanRDD] => + scan.rdd.asInstanceOf[FileScanRDD] + }.headOption.getOrElse { + fail(s"No FileScan in query\n${df.queryExecution}") + } + func(getFileScanRDD.filePartitions) + } - val outputDf = sqlContext.read.parquet(outputDir).as[Int] - checkDataset( - outputDf, - 1, 2, 3) + // Read without pruning + checkFileScanPartitions(outputDf) { partitions => + // There should be as many distinct partition values as there are distinct ids + assert(partitions.flatMap(_.files.map(_.partitionValues)).distinct.size === 3) + } + + // Read with pruning, should read only files in partition dir id=1 + checkFileScanPartitions(outputDf.filter("id = 1")) { partitions => + val filesToBeRead = partitions.flatMap(_.files) + assert(filesToBeRead.map(_.filePath).forall(_.contains("/id=1/"))) + assert(filesToBeRead.map(_.partitionValues).distinct.size === 1) + } + + // Read with pruning, should read only files in partition dir id=1 and id=2 + checkFileScanPartitions(outputDf.filter("id in (1,2)")) { partitions => + val filesToBeRead = partitions.flatMap(_.files) + assert(!filesToBeRead.map(_.filePath).exists(_.contains("/id=3/"))) + assert(filesToBeRead.map(_.partitionValues).distinct.size === 2) + } + } finally { + if (query != null) { + query.stop() + } + } + } + + test("FileStreamSink - supported formats") { + def testFormat(format: Option[String]): Unit = { + val inputData = MemoryStream[Int] + val ds = inputData.toDS() + + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath + + var query: StreamingQuery = null + + try { + val writer = + ds.map(i => (i, i * 1000)) + .toDF("id", "value") + .writeStream + if (format.nonEmpty) { + writer.format(format.get) + } + query = writer + .option("checkpointLocation", checkpointDir) + .start(outputDir) + } finally { + if (query != null) { + query.stop() + } + } + } + + testFormat(None) // should not throw error as default format parquet when not specified + testFormat(Some("parquet")) + val e = intercept[UnsupportedOperationException] { + testFormat(Some("text")) + } + Seq("text", "not support", "stream").foreach { s => + assert(e.getMessage.contains(s)) + } + } + + private def checkFilesExist(dir: File, expectedFiles: Seq[String], msg: String): Unit = { + import scala.collection.JavaConverters._ + val files = + FileUtils.listFiles(dir, new RegexFileFilter("[^.]+"), DirectoryFileFilter.DIRECTORY) + .asScala + .map(_.getCanonicalPath) + .toSet + + expectedFiles.foreach { f => + assert(files.contains(f), + s"\n$msg\nexpected file:\n\t$f\nfound files:\n${files.mkString("\n\t")}") + } } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 6b1ecd08c13c3..ca433d5b0701b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -19,14 +19,18 @@ package org.apache.spark.sql.streaming import java.io.File +import org.scalatest.PrivateMethodTester +import org.scalatest.time.SpanSugar._ + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class FileStreamSourceTest extends StreamTest with SharedSQLContext { +class FileStreamSourceTest extends StreamTest with SharedSQLContext with PrivateMethodTester { import testImplicits._ @@ -57,7 +61,7 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { addData(source) source.currentOffset + 1 } - logInfo(s"Added data to $source at offset $newOffset") + logInfo(s"Added file to $source at offset $newOffset") (source, newOffset) } @@ -68,8 +72,11 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { extends AddFileData { override def addData(source: FileStreamSource): Unit = { - val file = Utils.tempFileWith(new File(tmp, "text")) - stringToFile(file, content).renameTo(new File(src, file.getName)) + val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val finalFile = new File(src, tempFile.getName) + src.mkdirs() + require(stringToFile(tempFile, content).renameTo(finalFile)) + logInfo(s"Written text '$content' to file $finalFile") } } @@ -84,10 +91,20 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { AddParquetFileData(seq.toDS().toDF(), src, tmp) } + /** Write parquet files in a temp dir, and move the individual files to the 'src' dir */ def writeToFile(df: DataFrame, src: File, tmp: File): Unit = { - val file = Utils.tempFileWith(new File(tmp, "parquet")) - df.write.parquet(file.getCanonicalPath) - file.renameTo(new File(src, file.getName)) + val tmpDir = Utils.tempFileWith(new File(tmp, "parquet")) + df.write.parquet(tmpDir.getCanonicalPath) + src.mkdirs() + tmpDir.listFiles().foreach { f => + f.renameTo(new File(src, s"${f.getName}")) + } + } + } + + case class DeleteFile(file: File) extends ExternalAction { + def runAction(): Unit = { + Utils.deleteRecursively(file) } } @@ -95,15 +112,15 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { def createFileStream( format: String, path: String, - schema: Option[StructType] = None): DataFrame = { - + schema: Option[StructType] = None, + options: Map[String, String] = Map.empty): DataFrame = { val reader = if (schema.isDefined) { - sqlContext.read.format(format).schema(schema.get) + spark.readStream.format(format).schema(schema.get).options(options) } else { - sqlContext.read.format(format) + spark.readStream.format(format).options(options) } - reader.stream(path) + reader.load(path) } protected def getSourceFromFileStream(df: DataFrame): FileStreamSource = { @@ -129,10 +146,12 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { val valueSchema = new StructType().add("value", StringType) } -class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { +class FileStreamSourceSuite extends FileStreamSourceTest { import testImplicits._ + override val streamingTimeout = 20.seconds + /** Use `format` and `path` to create FileStreamSource via DataFrameReader */ private def createFileStreamSource( format: String, @@ -145,32 +164,55 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { format: Option[String], path: Option[String], schema: Option[StructType] = None): StructType = { - val reader = sqlContext.read + val reader = spark.readStream format.foreach(reader.format) schema.foreach(reader.schema) val df = if (path.isDefined) { - reader.stream(path.get) + reader.load(path.get) } else { - reader.stream() + reader.load() } df.queryExecution.analyzed .collect { case s @ StreamingRelation(dataSource, _, _) => s.schema }.head } + // ============= Basic parameter exists tests ================ + test("FileStreamSource schema: no path") { - val e = intercept[IllegalArgumentException] { - createFileStreamSourceAndGetSchema(format = None, path = None, schema = None) + def testError(): Unit = { + val e = intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema(format = None, path = None, schema = None) + } + assert(e.getMessage.contains("path")) // reason is path, not schema } - assert("'path' is not specified" === e.getMessage) + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "false") { testError() } + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { testError() } } - test("FileStreamSource schema: path doesn't exist") { - intercept[AnalysisException] { - createFileStreamSourceAndGetSchema(format = None, path = Some("/a/b/c"), schema = None) + test("FileStreamSource schema: path doesn't exist (without schema) should throw exception") { + withTempDir { dir => + intercept[AnalysisException] { + val userSchema = new StructType().add(new StructField("value", IntegerType)) + val schema = createFileStreamSourceAndGetSchema( + format = None, path = Some(new File(dir, "1").getAbsolutePath), schema = None) + } } } + test("FileStreamSource schema: path doesn't exist (with schema) should throw exception") { + withTempDir { dir => + intercept[AnalysisException] { + val userSchema = new StructType().add(new StructField("value", IntegerType)) + val schema = createFileStreamSourceAndGetSchema( + format = None, path = Some(new File(dir, "1").getAbsolutePath), schema = Some(userSchema)) + } + } + } + + + // =============== Text file stream schema tests ================ + test("FileStreamSource schema: text, no existing files, no schema") { withTempDir { src => val schema = createFileStreamSourceAndGetSchema( @@ -198,23 +240,28 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } } - test("FileStreamSource schema: parquet, no existing files, no schema") { - withTempDir { src => - val e = intercept[AnalysisException] { - createFileStreamSourceAndGetSchema( - format = Some("parquet"), path = Some(new File(src, "1").getCanonicalPath), schema = None) - } - assert("Unable to infer schema. It must be specified manually.;" === e.getMessage) - } - } + // =============== Parquet file stream schema tests ================ test("FileStreamSource schema: parquet, existing files, no schema") { withTempDir { src => - Seq("a", "b", "c").toDS().as("userColumn").toDF() - .write.parquet(new File(src, "1").getCanonicalPath) - val schema = createFileStreamSourceAndGetSchema( - format = Some("parquet"), path = Some(src.getCanonicalPath), schema = None) - assert(schema === new StructType().add("value", StringType)) + Seq("a", "b", "c").toDS().as("userColumn").toDF().write + .mode(org.apache.spark.sql.SaveMode.Overwrite) + .parquet(src.getCanonicalPath) + + // Without schema inference, should throw error + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "false") { + intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(src.getCanonicalPath), schema = None) + } + } + + // With schema inference, should infer correct schema + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + val schema = createFileStreamSourceAndGetSchema( + format = Some("parquet"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("value", StringType)) + } } } @@ -229,22 +276,39 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } } + // =============== JSON file stream schema tests ================ + test("FileStreamSource schema: json, no existing files, no schema") { withTempDir { src => - val e = intercept[AnalysisException] { - createFileStreamSourceAndGetSchema( - format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + + val e = intercept[AnalysisException] { + createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + } + assert("Unable to infer schema. It must be specified manually.;" === e.getMessage) } - assert("Unable to infer schema. It must be specified manually.;" === e.getMessage) } } test("FileStreamSource schema: json, existing files, no schema") { withTempDir { src => - stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c': '3'}") - val schema = createFileStreamSourceAndGetSchema( - format = Some("json"), path = Some(src.getCanonicalPath), schema = None) - assert(schema === new StructType().add("c", StringType)) + + // Without schema inference, should throw error + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "false") { + intercept[IllegalArgumentException] { + createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + } + } + + // With schema inference, should infer correct schema + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c': '3'}") + val schema = createFileStreamSourceAndGetSchema( + format = Some("json"), path = Some(src.getCanonicalPath), schema = None) + assert(schema === new StructType().add("c", StringType)) + } } } @@ -258,6 +322,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } } + // =============== Text file stream tests ================ + test("read from text files") { withTempDirs { case (src, tmp) => val textStream = createFileStream("text", src.getCanonicalPath) @@ -268,7 +334,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { CheckAnswer("keep2", "keep3"), StopStream, AddTextFileData("drop4\nkeep5\nkeep6", src, tmp), - StartStream, + StartStream(), CheckAnswer("keep2", "keep3", "keep5", "keep6"), AddTextFileData("drop7\nkeep8\nkeep9", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") @@ -276,6 +342,44 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } } + test("SPARK-17165 should not track the list of seen files indefinitely") { + // This test works by: + // 1. Create a file + // 2. Get it processed + // 3. Sleeps for a very short amount of time (larger than maxFileAge + // 4. Add another file (at this point the original file should have been purged + // 5. Test the size of the seenFiles internal data structure + + // Note that if we change maxFileAge to a very large number, the last step should fail. + withTempDirs { case (src, tmp) => + val textStream: DataFrame = + createFileStream("text", src.getCanonicalPath, options = Map("maxFileAge" -> "5ms")) + + testStream(textStream)( + AddTextFileData("a\nb", src, tmp), + CheckAnswer("a", "b"), + + // SLeeps longer than 5ms (maxFileAge) + // Unfortunately since a lot of file system does not have modification time granularity + // finer grained than 1 sec, we need to use 1 sec here. + AssertOnQuery { _ => Thread.sleep(1000); true }, + + AddTextFileData("c\nd", src, tmp), + CheckAnswer("a", "b", "c", "d"), + + AssertOnQuery("seen files should contain only one entry") { streamExecution => + val source = streamExecution.logicalPlan.collect { case e: StreamingExecutionRelation => + e.source.asInstanceOf[FileStreamSource] + }.head + assert(source.seenFiles.size == 1) + true + } + ) + } + } + + // =============== JSON file stream tests ================ + test("read from json files") { withTempDirs { case (src, tmp) => val fileStream = createFileStream("json", src.getCanonicalPath, Some(valueSchema)) @@ -292,7 +396,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { "{'value': 'drop4'}\n{'value': 'keep5'}\n{'value': 'keep6'}", src, tmp), - StartStream, + StartStream(), CheckAnswer("keep2", "keep3", "keep5", "keep6"), AddTextFileData( "{'value': 'drop7'}\n{'value': 'keep8'}\n{'value': 'keep9'}", @@ -305,76 +409,82 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { test("read from json files with inferring schema") { withTempDirs { case (src, tmp) => + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { - // Add a file so that we can infer its schema - stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}") + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}") - val fileStream = createFileStream("json", src.getCanonicalPath) - assert(fileStream.schema === StructType(Seq(StructField("c", StringType)))) + val fileStream = createFileStream("json", src.getCanonicalPath) + assert(fileStream.schema === StructType(Seq(StructField("c", StringType)))) - // FileStreamSource should infer the column "c" - val filtered = fileStream.filter($"c" contains "keep") + // FileStreamSource should infer the column "c" + val filtered = fileStream.filter($"c" contains "keep") - testStream(filtered)( - AddTextFileData("{'c': 'drop4'}\n{'c': 'keep5'}\n{'c': 'keep6'}", src, tmp), - CheckAnswer("keep2", "keep3", "keep5", "keep6") - ) + testStream(filtered)( + AddTextFileData("{'c': 'drop4'}\n{'c': 'keep5'}\n{'c': 'keep6'}", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6") + ) + } } } - test("reading from json files inside partitioned directory") { withTempDirs { case (baseSrc, tmp) => - val src = new File(baseSrc, "type=X") - src.mkdirs() + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + val src = new File(baseSrc, "type=X") + src.mkdirs() - // Add a file so that we can infer its schema - stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}") + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}") - val fileStream = createFileStream("json", src.getCanonicalPath) + val fileStream = createFileStream("json", src.getCanonicalPath) - // FileStreamSource should infer the column "c" - val filtered = fileStream.filter($"c" contains "keep") + // FileStreamSource should infer the column "c" + val filtered = fileStream.filter($"c" contains "keep") - testStream(filtered)( - AddTextFileData("{'c': 'drop4'}\n{'c': 'keep5'}\n{'c': 'keep6'}", src, tmp), - CheckAnswer("keep2", "keep3", "keep5", "keep6") - ) + testStream(filtered)( + AddTextFileData("{'c': 'drop4'}\n{'c': 'keep5'}\n{'c': 'keep6'}", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6") + ) + } } } - test("reading from json files with changing schema") { withTempDirs { case (src, tmp) => + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { - // Add a file so that we can infer its schema - stringToFile(new File(src, "existing"), "{'k': 'value0'}") + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), "{'k': 'value0'}") - val fileStream = createFileStream("json", src.getCanonicalPath) + val fileStream = createFileStream("json", src.getCanonicalPath) - // FileStreamSource should infer the column "k" - assert(fileStream.schema === StructType(Seq(StructField("k", StringType)))) + // FileStreamSource should infer the column "k" + assert(fileStream.schema === StructType(Seq(StructField("k", StringType)))) - // After creating DF and before starting stream, add data with different schema - // Should not affect the inferred schema any more - stringToFile(new File(src, "existing2"), "{'k': 'value1', 'v': 'new'}") + // After creating DF and before starting stream, add data with different schema + // Should not affect the inferred schema any more + stringToFile(new File(src, "existing2"), "{'k': 'value1', 'v': 'new'}") - testStream(fileStream)( + testStream(fileStream)( - // Should not pick up column v in the file added before start - AddTextFileData("{'k': 'value2'}", src, tmp), - CheckAnswer("value0", "value1", "value2"), + // Should not pick up column v in the file added before start + AddTextFileData("{'k': 'value2'}", src, tmp), + CheckAnswer("value0", "value1", "value2"), - // Should read data in column k, and ignore v - AddTextFileData("{'k': 'value3', 'v': 'new'}", src, tmp), - CheckAnswer("value0", "value1", "value2", "value3"), + // Should read data in column k, and ignore v + AddTextFileData("{'k': 'value3', 'v': 'new'}", src, tmp), + CheckAnswer("value0", "value1", "value2", "value3"), - // Should ignore rows that do not have the necessary k column - AddTextFileData("{'v': 'value4'}", src, tmp), - CheckAnswer("value0", "value1", "value2", "value3", null)) + // Should ignore rows that do not have the necessary k column + AddTextFileData("{'v': 'value4'}", src, tmp), + CheckAnswer("value0", "value1", "value2", "value3", null)) + } } } + // =============== Parquet file stream tests ================ + test("read from parquet files") { withTempDirs { case (src, tmp) => val fileStream = createFileStream("parquet", src.getCanonicalPath, Some(valueSchema)) @@ -385,7 +495,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { CheckAnswer("keep2", "keep3"), StopStream, AddParquetFileData(Seq("drop4", "keep5", "keep6"), src, tmp), - StartStream, + StartStream(), CheckAnswer("keep2", "keep3", "keep5", "keep6"), AddParquetFileData(Seq("drop7", "keep8", "keep9"), src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") @@ -396,45 +506,185 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { test("read from parquet files with changing schema") { withTempDirs { case (src, tmp) => - // Add a file so that we can infer its schema - AddParquetFileData.writeToFile(Seq("value0").toDF("k"), src, tmp) + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + + // Add a file so that we can infer its schema + AddParquetFileData.writeToFile(Seq("value0").toDF("k"), src, tmp) + + val fileStream = createFileStream("parquet", src.getCanonicalPath) + + // FileStreamSource should infer the column "k" + assert(fileStream.schema === StructType(Seq(StructField("k", StringType)))) + + // After creating DF and before starting stream, add data with different schema + // Should not affect the inferred schema any more + AddParquetFileData.writeToFile(Seq(("value1", 0)).toDF("k", "v"), src, tmp) + + testStream(fileStream)( + // Should not pick up column v in the file added before start + AddParquetFileData(Seq("value2").toDF("k"), src, tmp), + CheckAnswer("value0", "value1", "value2"), + + // Should read data in column k, and ignore v + AddParquetFileData(Seq(("value3", 1)).toDF("k", "v"), src, tmp), + CheckAnswer("value0", "value1", "value2", "value3"), - val fileStream = createFileStream("parquet", src.getCanonicalPath) + // Should ignore rows that do not have the necessary k column + AddParquetFileData(Seq("value5").toDF("v"), src, tmp), + CheckAnswer("value0", "value1", "value2", "value3", null) + ) + } + } + } - // FileStreamSource should infer the column "k" - assert(fileStream.schema === StructType(Seq(StructField("k", StringType)))) + // =============== file stream globbing tests ================ - // After creating DF and before starting stream, add data with different schema - // Should not affect the inferred schema any more - AddParquetFileData.writeToFile(Seq(("value1", 0)).toDF("k", "v"), src, tmp) + test("read new files in nested directories with globbing") { + withTempDirs { case (dir, tmp) => - testStream(fileStream)( - // Should not pick up column v in the file added before start - AddParquetFileData(Seq("value2").toDF("k"), src, tmp), - CheckAnswer("value0", "value1", "value2"), + // src/*/* should consider all the files and directories that matches that glob. + // So any files that matches the glob as well as any files in directories that matches + // this glob should be read. + val fileStream = createFileStream("text", s"${dir.getCanonicalPath}/*/*") + val filtered = fileStream.filter($"value" contains "keep") + val subDir = new File(dir, "subdir") + val subSubDir = new File(subDir, "subsubdir") + val subSubSubDir = new File(subSubDir, "subsubsubdir") + + require(!subDir.exists()) + require(!subSubDir.exists()) - // Should read data in column k, and ignore v - AddParquetFileData(Seq(("value3", 1)).toDF("k", "v"), src, tmp), - CheckAnswer("value0", "value1", "value2", "value3"), + testStream(filtered)( + // Create new dir/subdir and write to it, should read + AddTextFileData("drop1\nkeep2", subDir, tmp), + CheckAnswer("keep2"), - // Should ignore rows that do not have the necessary k column - AddParquetFileData(Seq("value5").toDF("v"), src, tmp), - CheckAnswer("value0", "value1", "value2", "value3", null) + // Add files to dir/subdir, should read + AddTextFileData("keep3", subDir, tmp), + CheckAnswer("keep2", "keep3"), + + // Create new dir/subdir/subsubdir and write to it, should read + AddTextFileData("keep4", subSubDir, tmp), + CheckAnswer("keep2", "keep3", "keep4"), + + // Add files to dir/subdir/subsubdir, should read + AddTextFileData("keep5", subSubDir, tmp), + CheckAnswer("keep2", "keep3", "keep4", "keep5"), + + // 1. Add file to src dir, should not read as globbing src/*/* does not capture files in + // dir, only captures files in dir/subdir/ + // 2. Add files to dir/subDir/subsubdir/subsubsubdir, should not read as src/*/* should + // not capture those files + AddTextFileData("keep6", dir, tmp), + AddTextFileData("keep7", subSubSubDir, tmp), + AddTextFileData("keep8", subDir, tmp), // needed to make query detect new data + CheckAnswer("keep2", "keep3", "keep4", "keep5", "keep8") ) } } - test("file stream source without schema") { - withTempDir { src => - // Only "text" doesn't need a schema - createFileStream("text", src.getCanonicalPath) + test("read new files in partitioned table with globbing, should not read partition data") { + withTempDirs { case (dir, tmp) => + val partitionFooSubDir = new File(dir, "partition=foo") + val partitionBarSubDir = new File(dir, "partition=bar") - // Both "json" and "parquet" require a schema if no existing file to infer - intercept[AnalysisException] { - createFileStream("json", src.getCanonicalPath) - } - intercept[AnalysisException] { - createFileStream("parquet", src.getCanonicalPath) + val schema = new StructType().add("value", StringType).add("partition", StringType) + val fileStream = createFileStream("json", s"${dir.getCanonicalPath}/*/*", Some(schema)) + val filtered = fileStream.filter($"value" contains "keep") + val nullStr = null.asInstanceOf[String] + testStream(filtered)( + // Create new partition=foo sub dir and write to it, should read only value, not partition + AddTextFileData("{'value': 'drop1'}\n{'value': 'keep2'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", nullStr)), + + // Append to same partition=1 sub dir, should read only value, not partition + AddTextFileData("{'value': 'keep3'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", nullStr), ("keep3", nullStr)), + + // Create new partition sub dir and write to it, should read only value, not partition + AddTextFileData("{'value': 'keep4'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", nullStr), ("keep3", nullStr), ("keep4", nullStr)), + + // Append to same partition=2 sub dir, should read only value, not partition + AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", nullStr), ("keep3", nullStr), ("keep4", nullStr), ("keep5", nullStr)) + ) + } + } + + // =============== other tests ================ + + test("read new files in partitioned table without globbing, should read partition data") { + withTempDirs { case (dir, tmp) => + val partitionFooSubDir = new File(dir, "partition=foo") + val partitionBarSubDir = new File(dir, "partition=bar") + + val schema = new StructType().add("value", StringType).add("partition", StringType) + val fileStream = createFileStream("json", s"${dir.getCanonicalPath}", Some(schema)) + val filtered = fileStream.filter($"value" contains "keep") + testStream(filtered)( + // Create new partition=foo sub dir and write to it + AddTextFileData("{'value': 'drop1'}\n{'value': 'keep2'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo")), + + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'keep3'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo")), + + // Create new partition sub dir and write to it + AddTextFileData("{'value': 'keep4'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar")), + + // Append to same partition=bar sub dir + AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar")) + ) + } + } + + test("when schema inference is turned on, should read partition data") { + def createFile(content: String, src: File, tmp: File): Unit = { + val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val finalFile = new File(src, tempFile.getName) + src.mkdirs() + require(stringToFile(tempFile, content).renameTo(finalFile)) + } + + withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { + withTempDirs { case (dir, tmp) => + val partitionFooSubDir = new File(dir, "partition=foo") + val partitionBarSubDir = new File(dir, "partition=bar") + + // Create file in partition, so we can infer the schema. + createFile("{'value': 'drop0'}", partitionFooSubDir, tmp) + + val fileStream = createFileStream("json", s"${dir.getCanonicalPath}") + val filtered = fileStream.filter($"value" contains "keep") + testStream(filtered)( + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'drop1'}\n{'value': 'keep2'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo")), + + // Append to same partition=foo sub dir + AddTextFileData("{'value': 'keep3'}", partitionFooSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo")), + + // Create new partition sub dir and write to it + AddTextFileData("{'value': 'keep4'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar")), + + // Append to same partition=bar sub dir + AddTextFileData("{'value': 'keep5'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar")), + + // Delete the two partition dirs + DeleteFile(partitionFooSubDir), + DeleteFile(partitionBarSubDir), + + AddTextFileData("{'value': 'keep6'}", partitionBarSubDir, tmp), + CheckAnswer(("keep2", "foo"), ("keep3", "foo"), ("keep4", "bar"), ("keep5", "bar"), + ("keep6", "bar")) + ) } } } @@ -449,16 +699,304 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { CheckAnswer("keep2", "keep3"), StopStream, AddTextFileData("drop4\nkeep5\nkeep6", src, tmp), - StartStream, + StartStream(), CheckAnswer("keep2", "keep3", "keep5", "keep6"), AddTextFileData("drop7\nkeep8\nkeep9", src, tmp), CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9") ) } } + + test("max files per trigger") { + withTempDir { case src => + var lastFileModTime: Option[Long] = None + + /** Create a text file with a single data item */ + def createFile(data: Int): File = { + val file = stringToFile(new File(src, s"$data.txt"), data.toString) + if (lastFileModTime.nonEmpty) file.setLastModified(lastFileModTime.get + 1000) + lastFileModTime = Some(file.lastModified) + file + } + + createFile(1) + createFile(2) + createFile(3) + + // Set up a query to read text files 2 at a time + val df = spark + .readStream + .option("maxFilesPerTrigger", 2) + .text(src.getCanonicalPath) + val q = df + .writeStream + .format("memory") + .queryName("file_data") + .start() + .asInstanceOf[StreamExecution] + q.processAllAvailable() + val memorySink = q.sink.asInstanceOf[MemorySink] + val fileSource = q.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => + source.asInstanceOf[FileStreamSource] + }.head + + /** Check the data read in the last batch */ + def checkLastBatchData(data: Int*): Unit = { + val schema = StructType(Seq(StructField("value", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.makeRDD(memorySink.latestBatchData), schema) + checkAnswer(df, data.map(_.toString).toDF("value")) + } + + def checkAllData(data: Seq[Int]): Unit = { + val schema = StructType(Seq(StructField("value", StringType))) + val df = spark.createDataFrame( + spark.sparkContext.makeRDD(memorySink.allData), schema) + checkAnswer(df, data.map(_.toString).toDF("value")) + } + + /** Check how many batches have executed since the last time this check was made */ + var lastBatchId = -1L + def checkNumBatchesSinceLastCheck(numBatches: Int): Unit = { + require(lastBatchId >= 0) + assert(memorySink.latestBatchId.get === lastBatchId + numBatches) + lastBatchId = memorySink.latestBatchId.get + } + + checkLastBatchData(3) // (1 and 2) should be in batch 1, (3) should be in batch 2 (last) + checkAllData(1 to 3) + lastBatchId = memorySink.latestBatchId.get + + fileSource.withBatchingLocked { + createFile(4) + createFile(5) // 4 and 5 should be in a batch + createFile(6) + createFile(7) // 6 and 7 should be in the last batch + } + q.processAllAvailable() + checkNumBatchesSinceLastCheck(2) + checkLastBatchData(6, 7) + checkAllData(1 to 7) + + fileSource.withBatchingLocked { + createFile(8) + createFile(9) // 8 and 9 should be in a batch + createFile(10) + createFile(11) // 10 and 11 should be in a batch + createFile(12) // 12 should be in the last batch + } + q.processAllAvailable() + checkNumBatchesSinceLastCheck(3) + checkLastBatchData(12) + checkAllData(1 to 12) + + q.stop() + } + } + + test("max files per trigger - incorrect values") { + withTempDir { case src => + def testMaxFilePerTriggerValue(value: String): Unit = { + val df = spark.readStream.option("maxFilesPerTrigger", value).text(src.getCanonicalPath) + val e = intercept[IllegalArgumentException] { + testStream(df)() + } + Seq("maxFilesPerTrigger", value, "positive integer").foreach { s => + assert(e.getMessage.contains(s)) + } + } + + testMaxFilePerTriggerValue("not-a-integer") + testMaxFilePerTriggerValue("-1") + testMaxFilePerTriggerValue("0") + testMaxFilePerTriggerValue("10.1") + } + } + + test("explain") { + withTempDirs { case (src, tmp) => + src.mkdirs() + + val df = spark.readStream.format("text").load(src.getCanonicalPath).map(_ + "-x") + // Test `explain` not throwing errors + df.explain() + + val q = df.writeStream.queryName("file_explain").format("memory").start() + .asInstanceOf[StreamExecution] + try { + assert("No physical plan. Waiting for data." === q.explainInternal(false)) + assert("No physical plan. Waiting for data." === q.explainInternal(true)) + + val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val finalFile = new File(src, tempFile.getName) + require(stringToFile(tempFile, "foo").renameTo(finalFile)) + + q.processAllAvailable() + + val explainWithoutExtended = q.explainInternal(false) + // `extended = false` only displays the physical plan. + assert("Relation.*text".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("TextFileFormat".r.findAllMatchIn(explainWithoutExtended).size === 1) + + val explainWithExtended = q.explainInternal(true) + // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical + // plan. + assert("Relation.*text".r.findAllMatchIn(explainWithExtended).size === 3) + assert("TextFileFormat".r.findAllMatchIn(explainWithExtended).size === 1) + } finally { + q.stop() + } + } + } + + test("SPARK-17372 - write file names to WAL as Array[String]") { + // Note: If this test takes longer than the timeout, then its likely that this is actually + // running a Spark job with 10000 tasks. This test tries to avoid that by + // 1. Setting the threshold for parallel file listing to very high + // 2. Using a query that should use constant folding to eliminate reading of the files + + val numFiles = 10000 + + // This is to avoid running a spark job to list of files in parallel + // by the ListingFileCatalog. + spark.sessionState.conf.setConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD, numFiles * 2) + + withTempDirs { case (root, tmp) => + val src = new File(root, "a=1") + src.mkdirs() + + (1 to numFiles).map { _.toString }.foreach { i => + val tempFile = Utils.tempFileWith(new File(tmp, "text")) + val finalFile = new File(src, tempFile.getName) + stringToFile(finalFile, i) + } + assert(src.listFiles().size === numFiles) + + val files = spark.readStream.text(root.getCanonicalPath).as[(String, Int)] + + // Note this query will use constant folding to eliminate the file scan. + // This is to avoid actually running a Spark job with 10000 tasks + val df = files.filter("1 == 0").groupBy().count() + + testStream(df, InternalOutputModes.Complete)( + AddTextFileData("0", src, tmp), + CheckAnswer(0) + ) + } + } + + test("compacat metadata log") { + val _sources = PrivateMethod[Seq[Source]]('sources) + val _metadataLog = PrivateMethod[FileStreamSourceLog]('metadataLog) + + def verify(execution: StreamExecution) + (batchId: Long, expectedBatches: Int): Boolean = { + import CompactibleFileStreamLog._ + + val fileSource = (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] + val metadataLog = fileSource invokePrivate _metadataLog() + + if (isCompactionBatch(batchId, 2)) { + val path = metadataLog.batchIdToPath(batchId) + + // Assert path name should be ended with compact suffix. + assert(path.getName.endsWith(COMPACT_FILE_SUFFIX)) + + // Compacted batch should include all entries from start. + val entries = metadataLog.get(batchId) + assert(entries.isDefined) + assert(entries.get.length === metadataLog.allFiles().length) + assert(metadataLog.get(None, Some(batchId)).flatMap(_._2).length === entries.get.length) + } + + assert(metadataLog.allFiles().sortBy(_.batchId) === + metadataLog.get(None, Some(batchId)).flatMap(_._2).sortBy(_.batchId)) + + metadataLog.get(None, Some(batchId)).flatMap(_._2).length === expectedBatches + } + + withTempDirs { case (src, tmp) => + withSQLConf( + SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key -> "2" + ) { + val fileStream = createFileStream("text", src.getCanonicalPath) + val filtered = fileStream.filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData("drop1\nkeep2\nkeep3", src, tmp), + CheckAnswer("keep2", "keep3"), + AssertOnQuery(verify(_)(0L, 1)), + AddTextFileData("drop4\nkeep5\nkeep6", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6"), + AssertOnQuery(verify(_)(1L, 2)), + AddTextFileData("drop7\nkeep8\nkeep9", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9"), + AssertOnQuery(verify(_)(2L, 3)), + StopStream, + StartStream(), + AssertOnQuery(verify(_)(2L, 3)), + AddTextFileData("drop10\nkeep11", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11"), + AssertOnQuery(verify(_)(3L, 4)), + AddTextFileData("drop12\nkeep13", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9", "keep11", "keep13"), + AssertOnQuery(verify(_)(4L, 5)) + ) + } + } + } + + test("get arbitrary batch from FileStreamSource") { + withTempDirs { case (src, tmp) => + withSQLConf( + SQLConf.FILE_SOURCE_LOG_COMPACT_INTERVAL.key -> "2", + // Force deleting the old logs + SQLConf.FILE_SOURCE_LOG_CLEANUP_DELAY.key -> "1" + ) { + val fileStream = createFileStream("text", src.getCanonicalPath) + val filtered = fileStream.filter($"value" contains "keep") + + testStream(filtered)( + AddTextFileData("keep1", src, tmp), + CheckAnswer("keep1"), + AddTextFileData("keep2", src, tmp), + CheckAnswer("keep1", "keep2"), + AddTextFileData("keep3", src, tmp), + CheckAnswer("keep1", "keep2", "keep3"), + AssertOnQuery("check getBatch") { execution: StreamExecution => + val _sources = PrivateMethod[Seq[Source]]('sources) + val fileSource = + (execution invokePrivate _sources()).head.asInstanceOf[FileStreamSource] + assert(fileSource.getBatch(None, LongOffset(2)).as[String].collect() === + List("keep1", "keep2", "keep3")) + assert(fileSource.getBatch(Some(LongOffset(0)), LongOffset(2)).as[String].collect() === + List("keep2", "keep3")) + assert(fileSource.getBatch(Some(LongOffset(1)), LongOffset(2)).as[String].collect() === + List("keep3")) + true + } + ) + } + } + } + + test("input row metrics") { + withTempDirs { case (src, tmp) => + val input = spark.readStream.format("text").load(src.getCanonicalPath) + testStream(input)( + AddTextFileData("100", src, tmp), + CheckAnswer("100"), + AssertOnLastQueryStatus { status => + assert(status.triggerDetails.get("numRows.input.total") === "1") + assert(status.sourceStatuses(0).processingRate > 0.0) + } + ) + } + } } -class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext { +class FileStreamSourceStressTestSuite extends FileStreamSourceTest { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala index 5b49a0a86a04f..f9e236c449634 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -23,9 +23,7 @@ import java.util.UUID import scala.util.Random import scala.util.control.NonFatal -import org.apache.spark.sql.{ContinuousQuery, ContinuousQueryException, StreamTest} import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils /** @@ -38,10 +36,18 @@ import org.apache.spark.util.Utils * * At the end, the resulting files are loaded and the answer is checked. */ -class FileStressSuite extends StreamTest with SharedSQLContext { +class FileStressSuite extends StreamTest { import testImplicits._ - test("fault tolerance stress test") { + testQuietly("fault tolerance stress test - unpartitioned output") { + stressTest(partitionWrites = false) + } + + testQuietly("fault tolerance stress test - partitioned output") { + stressTest(partitionWrites = true) + } + + def stressTest(partitionWrites: Boolean): Unit = { val numRecords = 10000 val inputDir = Utils.createTempDir(namePrefix = "stream.input").getCanonicalPath val stagingDir = Utils.createTempDir(namePrefix = "stream.staging").getCanonicalPath @@ -51,7 +57,7 @@ class FileStressSuite extends StreamTest with SharedSQLContext { @volatile var continue = true @volatile - var stream: ContinuousQuery = null + var stream: StreamingQuery = null val writer = new Thread("stream writer") { override def run(): Unit = { @@ -92,19 +98,37 @@ class FileStressSuite extends StreamTest with SharedSQLContext { } writer.start() - val input = sqlContext.read.format("text").stream(inputDir) - def startStream(): ContinuousQuery = input + val input = spark.readStream.format("text").load(inputDir) + + def startStream(): StreamingQuery = { + val output = input .repartition(5) .as[String] .mapPartitions { iter => val rand = Random.nextInt(100) - if (rand < 5) { sys.error("failure") } + if (rand < 10) { + sys.error("failure") + } iter.map(_.toLong) } - .write - .format("parquet") - .option("checkpointLocation", checkpoint) - .startStream(outputDir) + .map(x => (x % 400, x.toString)) + .toDF("id", "data") + + if (partitionWrites) { + output + .writeStream + .partitionBy("id") + .format("parquet") + .option("checkpointLocation", checkpoint) + .start(outputDir) + } else { + output + .writeStream + .format("parquet") + .option("checkpointLocation", checkpoint) + .start(outputDir) + } + } var failures = 0 val streamThread = new Thread("stream runner") { @@ -115,7 +139,7 @@ class FileStressSuite extends StreamTest with SharedSQLContext { try { stream.awaitTermination() } catch { - case ce: ContinuousQueryException => + case ce: StreamingQueryException => failures += 1 } } @@ -124,6 +148,6 @@ class FileStressSuite extends StreamTest with SharedSQLContext { streamThread.join() logError(s"Stream restarted $failures times.") - assert(sqlContext.read.parquet(outputDir).distinct().count() == numRecords) + assert(spark.read.parquet(outputDir).distinct().count() == numRecords) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index 1f2834054519b..310d75630272b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -17,42 +17,207 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.{AnalysisException, Row, StreamTest} +import scala.language.implicitConversions + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils -class MemorySinkSuite extends StreamTest with SharedSQLContext { +class MemorySinkSuite extends StreamTest with BeforeAndAfter { + import testImplicits._ - test("registering as a table") { + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("directly add data in Append output mode") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + val sink = new MemorySink(schema, InternalOutputModes.Append) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data + + // Re-add batch 1 with different data, should not be added and outputs should not be changed + sink.addBatch(1, 7 to 9) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) + + // Add batch 2 and check outputs + sink.addBatch(2, 7 to 9) + assert(sink.latestBatchId === Some(2)) + checkAnswer(sink.latestBatchData, 7 to 9) + checkAnswer(sink.allData, 1 to 9) + } + + test("directly add data in Update output mode") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + val sink = new MemorySink(schema, InternalOutputModes.Update) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data + + // Re-add batch 1 with different data, should not be added and outputs should not be changed + sink.addBatch(1, 7 to 9) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 1 to 6) + + // Add batch 2 and check outputs + sink.addBatch(2, 7 to 9) + assert(sink.latestBatchId === Some(2)) + checkAnswer(sink.latestBatchData, 7 to 9) + checkAnswer(sink.allData, 1 to 9) + } + + test("directly add data in Complete output mode") { + implicit val schema = new StructType().add(new StructField("value", IntegerType)) + val sink = new MemorySink(schema, InternalOutputModes.Complete) + + // Before adding data, check output + assert(sink.latestBatchId === None) + checkAnswer(sink.latestBatchData, Seq.empty) + checkAnswer(sink.allData, Seq.empty) + + // Add batch 0 and check outputs + sink.addBatch(0, 1 to 3) + assert(sink.latestBatchId === Some(0)) + checkAnswer(sink.latestBatchData, 1 to 3) + checkAnswer(sink.allData, 1 to 3) + + // Add batch 1 and check outputs + sink.addBatch(1, 4 to 6) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 4 to 6) // new data should replace old data + + // Re-add batch 1 with different data, should not be added and outputs should not be changed + sink.addBatch(1, 7 to 9) + assert(sink.latestBatchId === Some(1)) + checkAnswer(sink.latestBatchData, 4 to 6) + checkAnswer(sink.allData, 4 to 6) + + // Add batch 2 and check outputs + sink.addBatch(2, 7 to 9) + assert(sink.latestBatchId === Some(2)) + checkAnswer(sink.latestBatchData, 7 to 9) + checkAnswer(sink.allData, 7 to 9) + } + + + test("registering as a table in Append output mode") { val input = MemoryStream[Int] - val query = input.toDF().write + val query = input.toDF().writeStream .format("memory") + .outputMode("append") .queryName("memStream") - .startStream() + .start() input.addData(1, 2, 3) query.processAllAvailable() checkDataset( - sqlContext.table("memStream").as[Int], + spark.table("memStream").as[Int], 1, 2, 3) input.addData(4, 5, 6) query.processAllAvailable() checkDataset( - sqlContext.table("memStream").as[Int], + spark.table("memStream").as[Int], 1, 2, 3, 4, 5, 6) query.stop() } + test("registering as a table in Complete output mode") { + val input = MemoryStream[Int] + val query = input.toDF() + .groupBy("value") + .count() + .writeStream + .format("memory") + .outputMode("complete") + .queryName("memStream") + .start() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDatasetUnorderly( + spark.table("memStream").as[(Int, Long)], + (1, 1L), (2, 1L), (3, 1L)) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDatasetUnorderly( + spark.table("memStream").as[(Int, Long)], + (1, 1L), (2, 1L), (3, 1L), (4, 1L), (5, 1L), (6, 1L)) + + query.stop() + } + + ignore("stress test") { + // Ignore the stress test as it takes several minutes to run + (0 until 1000).foreach { _ => + val input = MemoryStream[Int] + val query = input.toDF().writeStream + .format("memory") + .queryName("memStream") + .start() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDataset( + spark.table("memStream").as[Int], + 1, 2, 3) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDataset( + spark.table("memStream").as[Int], + 1, 2, 3, 4, 5, 6) + + query.stop() + } + } + test("error when no name is specified") { val error = intercept[AnalysisException] { val input = MemoryStream[Int] - val query = input.toDF().write + val query = input.toDF().writeStream .format("memory") - .startStream() + .start() } assert(error.message contains "queryName must be specified") @@ -62,21 +227,32 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { val location = Utils.createTempDir(namePrefix = "steaming.checkpoint").getCanonicalPath val input = MemoryStream[Int] - val query = input.toDF().write + val query = input.toDF().writeStream .format("memory") .queryName("memStream") .option("checkpointLocation", location) - .startStream() + .start() input.addData(1, 2, 3) query.processAllAvailable() query.stop() intercept[AnalysisException] { - input.toDF().write + input.toDF().writeStream .format("memory") .queryName("memStream") .option("checkpointLocation", location) - .startStream() + .start() } } + + private def checkAnswer(rows: Seq[Row], expected: Seq[Int])(implicit schema: StructType): Unit = { + checkAnswer( + sqlContext.createDataFrame(sparkContext.makeRDD(rows), schema), + intsToDF(expected)(schema)) + } + + private implicit def intsToDF(seq: Seq[Int])(implicit schema: StructType): DataFrame = { + require(schema.fields.size === 1) + sqlContext.createDataset(seq).toDF(schema.fieldNames.head) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala index 81760d2aa8205..7f2972edea727 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySourceStressSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.StreamTest import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.test.SharedSQLContext -class MemorySourceStressSuite extends StreamTest with SharedSQLContext { +class MemorySourceStressSuite extends StreamTest { import testImplicits._ test("memory stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala index 9590af4e7737d..b65a987770304 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/OffsetSuite.scala @@ -24,44 +24,12 @@ trait OffsetSuite extends SparkFunSuite { /** Creates test to check all the comparisons of offsets given a `one` that is less than `two`. */ def compare(one: Offset, two: Offset): Unit = { test(s"comparison $one <=> $two") { - assert(one < two) - assert(one <= two) - assert(one <= one) - assert(two > one) - assert(two >= one) - assert(one >= one) assert(one == one) assert(two == two) assert(one != two) assert(two != one) } } - - /** Creates test to check that non-equality comparisons throw exception. */ - def compareInvalid(one: Offset, two: Offset): Unit = { - test(s"invalid comparison $one <=> $two") { - intercept[IllegalArgumentException] { - assert(one < two) - } - - intercept[IllegalArgumentException] { - assert(one <= two) - } - - intercept[IllegalArgumentException] { - assert(one > two) - } - - intercept[IllegalArgumentException] { - assert(one >= two) - } - - assert(!(one == two)) - assert(!(two == one)) - assert(one != two) - assert(two != one) - } - } } class LongOffsetSuite extends OffsetSuite { @@ -79,10 +47,6 @@ class CompositeOffsetSuite extends OffsetSuite { one = CompositeOffset(None :: Nil), two = CompositeOffset(Some(LongOffset(2)) :: Nil)) - compareInvalid( // sizes must be same - one = CompositeOffset(Nil), - two = CompositeOffset(Some(LongOffset(2)) :: Nil)) - compare( one = CompositeOffset.fill(LongOffset(0), LongOffset(1)), two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) @@ -91,8 +55,5 @@ class CompositeOffsetSuite extends OffsetSuite { one = CompositeOffset.fill(LongOffset(1), LongOffset(1)), two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) - compareInvalid( - one = CompositeOffset.fill(LongOffset(2), LongOffset(1)), // vector time inconsistent - two = CompositeOffset.fill(LongOffset(1), LongOffset(2))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 6f3149dbc5033..6bdf47901ae68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.sql.streaming +import scala.reflect.ClassTag +import scala.util.control.ControlThrowable + import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.util.ManualClock -class StreamSuite extends StreamTest with SharedSQLContext { +class StreamSuite extends StreamTest { import testImplicits._ @@ -34,11 +36,11 @@ class StreamSuite extends StreamTest with SharedSQLContext { testStream(mapped)( AddData(inputData, 1, 2, 3), - StartStream, + StartStream(), CheckAnswer(2, 3, 4), StopStream, AddData(inputData, 4, 5, 6), - StartStream, + StartStream(), CheckAnswer(2, 3, 4, 5, 6, 7)) } @@ -70,14 +72,14 @@ class StreamSuite extends StreamTest with SharedSQLContext { CheckAnswer(1, 2, 3, 4, 5, 6), StopStream, AddData(inputData1, 7), - StartStream, + StartStream(), AddData(inputData2, 8), CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8)) } test("sql queries") { val inputData = MemoryStream[Int] - inputData.toDF().registerTempTable("stream") + inputData.toDF().createOrReplaceTempView("stream") val evens = sql("SELECT * FROM stream WHERE value % 2 = 0") testStream(evens)( @@ -89,12 +91,12 @@ class StreamSuite extends StreamTest with SharedSQLContext { def assertDF(df: DataFrame) { withTempDir { outputDir => withTempDir { checkpointDir => - val query = df.write.format("parquet") + val query = df.writeStream.format("parquet") .option("checkpointLocation", checkpointDir.getAbsolutePath) - .startStream(outputDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) try { query.processAllAvailable() - val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long] + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] checkDataset[Long](outputDf, (0L to 10L).toArray: _*) } finally { query.stop() @@ -103,7 +105,7 @@ class StreamSuite extends StreamTest with SharedSQLContext { } } - val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream() + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load() assertDF(df) assertDF(df) } @@ -120,12 +122,12 @@ class StreamSuite extends StreamTest with SharedSQLContext { } // Running streaming plan as a batch query - assertError("startStream" :: Nil) { + assertError("start" :: Nil) { streamInput.toDS.map { i => i }.count() } // Running non-streaming plan with as a streaming query - assertError("without streaming sources" :: "startStream" :: Nil) { + assertError("without streaming sources" :: "start" :: Nil) { val ds = batchInput.map { i => i } testStream(ds)() } @@ -136,6 +138,168 @@ class StreamSuite extends StreamTest with SharedSQLContext { testStream(ds)() } } + + test("minimize delay between batch construction and execution") { + + // For each batch, we would retrieve new data's offsets and log them before we run the execution + // This checks whether the key of the offset log is the expected batch id + def CheckOffsetLogLatestBatchId(expectedId: Int): AssertOnQuery = + AssertOnQuery(_.offsetLog.getLatest().get._1 == expectedId, + s"offsetLog's latest should be $expectedId") + + // For each batch, we would log the state change during the execution + // This checks whether the key of the state change log is the expected batch id + def CheckIncrementalExecutionCurrentBatchId(expectedId: Int): AssertOnQuery = + AssertOnQuery(_.lastExecution.asInstanceOf[IncrementalExecution].currentBatchId == expectedId, + s"lastExecution's currentBatchId should be $expectedId") + + // For each batch, we would log the sink change after the execution + // This checks whether the key of the sink change log is the expected batch id + def CheckSinkLatestBatchId(expectedId: Int): AssertOnQuery = + AssertOnQuery(_.sink.asInstanceOf[MemorySink].latestBatchId.get == expectedId, + s"sink's lastBatchId should be $expectedId") + + val inputData = MemoryStream[Int] + testStream(inputData.toDS())( + StartStream(ProcessingTime("10 seconds"), new StreamManualClock), + + /* -- batch 0 ----------------------- */ + // Add some data in batch 0 + AddData(inputData, 1, 2, 3), + AdvanceManualClock(10 * 1000), // 10 seconds + + /* -- batch 1 ----------------------- */ + // Check the results of batch 0 + CheckAnswer(1, 2, 3), + CheckIncrementalExecutionCurrentBatchId(0), + CheckOffsetLogLatestBatchId(0), + CheckSinkLatestBatchId(0), + // Add some data in batch 1 + AddData(inputData, 4, 5, 6), + AdvanceManualClock(10 * 1000), + + /* -- batch _ ----------------------- */ + // Check the results of batch 1 + CheckAnswer(1, 2, 3, 4, 5, 6), + CheckIncrementalExecutionCurrentBatchId(1), + CheckOffsetLogLatestBatchId(1), + CheckSinkLatestBatchId(1), + + AdvanceManualClock(10 * 1000), + AdvanceManualClock(10 * 1000), + AdvanceManualClock(10 * 1000), + + /* -- batch __ ---------------------- */ + // Check the results of batch 1 again; this is to make sure that, when there's no new data, + // the currentId does not get logged (e.g. as 2) even if the clock has advanced many times + CheckAnswer(1, 2, 3, 4, 5, 6), + CheckIncrementalExecutionCurrentBatchId(1), + CheckOffsetLogLatestBatchId(1), + CheckSinkLatestBatchId(1), + + /* Stop then restart the Stream */ + StopStream, + StartStream(ProcessingTime("10 seconds"), new StreamManualClock(60 * 1000)), + + /* -- batch 1 rerun ----------------- */ + // this batch 1 would re-run because the latest batch id logged in offset log is 1 + AdvanceManualClock(10 * 1000), + + /* -- batch 2 ----------------------- */ + // Check the results of batch 1 + CheckAnswer(1, 2, 3, 4, 5, 6), + CheckIncrementalExecutionCurrentBatchId(1), + CheckOffsetLogLatestBatchId(1), + CheckSinkLatestBatchId(1), + // Add some data in batch 2 + AddData(inputData, 7, 8, 9), + AdvanceManualClock(10 * 1000), + + /* -- batch 3 ----------------------- */ + // Check the results of batch 2 + CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8, 9), + CheckIncrementalExecutionCurrentBatchId(2), + CheckOffsetLogLatestBatchId(2), + CheckSinkLatestBatchId(2)) + } + + test("insert an extraStrategy") { + try { + spark.experimental.extraStrategies = TestStrategy :: Nil + + val inputData = MemoryStream[(String, Int)] + val df = inputData.toDS().map(_._1).toDF("a") + + testStream(df)( + AddData(inputData, ("so slow", 1)), + CheckAnswer("so fast")) + } finally { + spark.experimental.extraStrategies = Nil + } + } + + testQuietly("fatal errors from a source should be sent to the user") { + for (e <- Seq( + new VirtualMachineError {}, + new ThreadDeath, + new LinkageError, + new ControlThrowable {} + )) { + val source = new Source { + override def getOffset: Option[Offset] = { + throw e + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + throw e + } + + override def schema: StructType = StructType(Array(StructField("value", IntegerType))) + + override def stop(): Unit = {} + } + val df = Dataset[Int](sqlContext.sparkSession, StreamingExecutionRelation(source)) + testStream(df)( + ExpectFailure()(ClassTag(e.getClass)) + ) + } + } + + test("output mode API in Scala") { + val o1 = OutputMode.Append + assert(o1 === InternalOutputModes.Append) + val o2 = OutputMode.Complete + assert(o2 === InternalOutputModes.Complete) + } + + test("explain") { + val inputData = MemoryStream[String] + val df = inputData.toDS().map(_ + "foo") + // Test `explain` not throwing errors + df.explain() + val q = df.writeStream.queryName("memory_explain").format("memory").start() + .asInstanceOf[StreamExecution] + try { + assert("No physical plan. Waiting for data." === q.explainInternal(false)) + assert("No physical plan. Waiting for data." === q.explainInternal(true)) + + inputData.addData("abc") + q.processAllAvailable() + + val explainWithoutExtended = q.explainInternal(false) + // `extended = false` only displays the physical plan. + assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0) + assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1) + + val explainWithExtended = q.explainInternal(true) + // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical + // plan. + assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3) + assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1) + } finally { + q.stop() + } + } } /** @@ -146,13 +310,13 @@ class FakeDefaultSource extends StreamSourceProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( - sqlContext: SQLContext, + spark: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) override def createSource( - sqlContext: SQLContext, + spark: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, @@ -174,8 +338,10 @@ class FakeDefaultSource extends StreamSourceProvider { override def getBatch(start: Option[Offset], end: Offset): DataFrame = { val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 - sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") + spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") } + + override def stop() {} } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala similarity index 69% rename from sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index dff6acc94b3f0..742833065144d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.streaming import java.lang.Thread.UncaughtExceptionHandler @@ -28,17 +28,21 @@ import scala.util.control.NonFatal import org.scalatest.Assertions import org.scalatest.concurrent.{Eventually, Timeouts} +import org.scalatest.concurrent.AsyncAssertions.Waiter +import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ -import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode} +import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.util.Utils +import org.apache.spark.sql.streaming.StreamingQueryListener._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} /** * A framework for implementing tests for streaming queries and sources. @@ -49,11 +53,11 @@ import org.apache.spark.util.Utils * * {{{ * val inputData = MemoryStream[Int] - val mapped = inputData.toDS().map(_ + 1) - - testStream(mapped)( - AddData(inputData, 1, 2, 3), - CheckAnswer(2, 3, 4)) + * val mapped = inputData.toDS().map(_ + 1) + * + * testStream(mapped)( + * AddData(inputData, 1, 2, 3), + * CheckAnswer(2, 3, 4)) * }}} * * Note that while we do sleep to allow the other thread to progress without spinning, @@ -64,13 +68,11 @@ import org.apache.spark.util.Utils * avoid hanging forever in the case of failures. However, individual suites can change this * by overriding `streamingTimeout`. */ -trait StreamTest extends QueryTest with Timeouts { +trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds - val outputMode: OutputMode = Append - /** A trait for actions that can be performed while testing a streaming DataFrame. */ trait StreamAction @@ -96,6 +98,11 @@ trait StreamTest extends QueryTest with Timeouts { def addData(query: Option[StreamExecution]): (Source, Offset) } + /** A trait that can be extended when testing a source. */ + trait ExternalAction extends StreamAction { + def runAction(): Unit + } + case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" @@ -111,11 +118,14 @@ trait StreamTest extends QueryTest with Timeouts { object CheckAnswer { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] - val toExternalRow = RowEncoder(encoder.schema) - CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false) + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + CheckAnswerRows( + data.map(d => toExternalRow.fromRow(encoder.toRow(d))), + lastOnly = false, + isSorted = false) } - def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false) + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false) } /** @@ -124,30 +134,43 @@ trait StreamTest extends QueryTest with Timeouts { */ object CheckLastBatch { def apply[A : Encoder](data: A*): CheckAnswerRows = { + apply(isSorted = false, data: _*) + } + + def apply[A: Encoder](isSorted: Boolean, data: A*): CheckAnswerRows = { val encoder = encoderFor[A] - val toExternalRow = RowEncoder(encoder.schema) - CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true) + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() + CheckAnswerRows( + data.map(d => toExternalRow.fromRow(encoder.toRow(d))), + lastOnly = true, + isSorted = isSorted) } - def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true) + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false) } - case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean) + case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean) extends StreamAction with StreamMustBeRunning { override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } - /** Stops the stream. It must currently be running. */ + /** Stops the stream. It must currently be running. */ case object StopStream extends StreamAction with StreamMustBeRunning - /** Starts the stream, resuming if data has already been processed. It must not be running. */ - case object StartStream extends StreamAction + /** Starts the stream, resuming if data has already been processed. It must not be running. */ + case class StartStream( + trigger: Trigger = ProcessingTime(0), + triggerClock: Clock = new SystemClock) + extends StreamAction + + /** Advance the trigger clock's time manually. */ + case class AdvanceManualClock(timeToAdd: Long) extends StreamAction /** Signals that a failure is expected and should not kill the test. */ case class ExpectFailure[T <: Throwable : ClassTag]() extends StreamAction { val causeClass: Class[T] = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]] - override def toString(): String = s"ExpectFailure[${causeClass.getCanonicalName}]" + override def toString(): String = s"ExpectFailure[${causeClass.getName}]" } /** Assert that a body is true */ @@ -178,6 +201,27 @@ trait StreamTest extends QueryTest with Timeouts { } } + case class AssertOnLastQueryStatus(condition: StreamingQueryStatus => Unit) + extends StreamAction + + class StreamManualClock(time: Long = 0L) extends ManualClock(time) { + private var waitStartTime: Option[Long] = None + + override def waitTillTime(targetTime: Long): Long = synchronized { + try { + waitStartTime = Some(getTimeMillis()) + super.waitTillTime(targetTime) + } finally { + waitStartTime = None + } + } + + def isStreamWaitingAt(time: Long): Boolean = synchronized { + waitStartTime == Some(time) + } + } + + /** * Executes the specified actions on the given streaming DataFrame and provides helpful * error messages in the case of failures or incorrect answers. @@ -185,22 +229,25 @@ trait StreamTest extends QueryTest with Timeouts { * Note that if the stream is not explicitly started before an action that requires it to be * running then it will be automatically started before performing any other actions. */ - def testStream(_stream: Dataset[_])(actions: StreamAction*): Unit = { + def testStream( + _stream: Dataset[_], + outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = { + val stream = _stream.toDF() var pos = 0 var currentPlan: LogicalPlan = stream.logicalPlan var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = new MemorySink(stream.schema) + val sink = new MemorySink(stream.schema, outputMode) @volatile var streamDeathCause: Throwable = null // If the test doesn't manually start the stream, we do it automatically at the beginning. val startedManually = - actions.takeWhile(!_.isInstanceOf[StreamMustBeRunning]).contains(StartStream) - val startedTest = if (startedManually) actions else StartStream +: actions + actions.takeWhile(!_.isInstanceOf[StreamMustBeRunning]).exists(_.isInstanceOf[StartStream]) + val startedTest = if (startedManually) actions else StartStream() +: actions def testActions = actions.zipWithIndex.map { case (a, i) => @@ -276,31 +323,58 @@ trait StreamTest extends QueryTest with Timeouts { val testThread = Thread.currentThread() val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - + val statusCollector = new QueryStatusCollector + var manualClockExpectedTime = -1L try { + spark.streams.addListener(statusCollector) startedTest.foreach { action => + logInfo(s"Processing test stream action: $action") action match { - case StartStream => + case StartStream(trigger, triggerClock) => verify(currentStream == null, "stream already running") + verify(triggerClock.isInstanceOf[SystemClock] + || triggerClock.isInstanceOf[StreamManualClock], + "Use either SystemClock or StreamManualClock to start the stream") + if (triggerClock.isInstanceOf[StreamManualClock]) { + manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + } lastStream = currentStream currentStream = - sqlContext + spark .streams .startQuery( - StreamExecution.nextName, - metadataRoot, + None, + Some(metadataRoot), stream, sink, - outputMode = outputMode) + outputMode, + trigger = trigger, + triggerClock = triggerClock) .asInstanceOf[StreamExecution] currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { override def uncaughtException(t: Thread, e: Throwable): Unit = { streamDeathCause = e - testThread.interrupt() } }) + case AdvanceManualClock(timeToAdd) => + verify(currentStream != null, + "can not advance manual clock when a stream is not running") + verify(currentStream.triggerClock.isInstanceOf[StreamManualClock], + s"can not advance clock of type ${currentStream.triggerClock.getClass}") + val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock] + assert(manualClockExpectedTime >= 0) + // Make sure we don't advance ManualClock too early. See SPARK-16002. + eventually("StreamManualClock has not yet entered the waiting state") { + assert(clock.isStreamWaitingAt(manualClockExpectedTime)) + } + clock.advance(timeToAdd) + manualClockExpectedTime += timeToAdd + verify(clock.getTimeMillis() === manualClockExpectedTime, + s"Unexpected clock time after updating: " + + s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}") + case StopStream => verify(currentStream != null, "can not stop a stream that is not running") try failAfter(streamingTimeout) { @@ -326,7 +400,7 @@ trait StreamTest extends QueryTest with Timeouts { case ef: ExpectFailure[_] => verify(currentStream != null, "can not expect failure when stream is not running") try failAfter(streamingTimeout) { - val thrownException = intercept[ContinuousQueryException] { + val thrownException = intercept[StreamingQueryException] { currentStream.awaitTermination() } eventually("microbatch thread not stopped after termination with failure") { @@ -363,6 +437,13 @@ trait StreamTest extends QueryTest with Timeouts { val streamToAssert = Option(currentStream).getOrElse(lastStream) verify({ a.run(); true }, s"Assert failed: ${a.message}") + case a: AssertOnLastQueryStatus => + Eventually.eventually(timeout(streamingTimeout)) { + require(statusCollector.lastTriggerStatus.nonEmpty) + } + val status = statusCollector.lastTriggerStatus.get + verify({ a.condition(status); true }, "Assert on last query status failed") + case a: AddData => try { // Add data and get the source where it was added, and the expected offset of the @@ -397,7 +478,10 @@ trait StreamTest extends QueryTest with Timeouts { failTest("Error adding data", e) } - case CheckAnswerRows(expectedAnswer, lastOnly) => + case e: ExternalAction => + e.runAction() + + case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) => verify(currentStream != null, "stream not running") // Get the map of source index to the current source objects val indexToSource = currentStream @@ -414,12 +498,12 @@ trait StreamTest extends QueryTest with Timeouts { } } - val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch { + val sparkAnswer = try if (lastOnly) sink.latestBatchData else sink.allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } - QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { + QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach { error => failTest(error) } } @@ -434,24 +518,45 @@ trait StreamTest extends QueryTest with Timeouts { if (currentStream != null && currentStream.microBatchThread.isAlive) { currentStream.stop() } + spark.streams.removeListener(statusCollector) } } + /** * Creates a stress test that randomly starts/stops/adds data/checks the result. * - * @param ds a dataframe that executes + 1 on a stream of integers, returning the result. - * @param addData and add data action that adds the given numbers to the stream, encoding them + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result + * @param addData an add data action that adds the given numbers to the stream, encoding them * as needed + * @param iterations the iteration number + */ + def runStressTest( + ds: Dataset[Int], + addData: Seq[Int] => StreamAction, + iterations: Int = 100): Unit = { + runStressTest(ds, Seq.empty, (data, running) => addData(data), iterations) + } + + /** + * Creates a stress test that randomly starts/stops/adds data/checks the result. + * + * @param ds a dataframe that executes + 1 on a stream of integers, returning the result + * @param prepareActions actions need to run before starting the stress test. + * @param addData an add data action that adds the given numbers to the stream, encoding them + * as needed + * @param iterations the iteration number */ def runStressTest( ds: Dataset[Int], - addData: Seq[Int] => StreamAction, - iterations: Int = 100): Unit = { + prepareActions: Seq[StreamAction], + addData: (Seq[Int], Boolean) => StreamAction, + iterations: Int): Unit = { implicit val intEncoder = ExpressionEncoder[Int]() var dataPos = 0 var running = true val actions = new ArrayBuffer[StreamAction]() + actions ++= prepareActions def addCheck() = { actions += CheckAnswer(1 to dataPos: _*) } @@ -459,7 +564,7 @@ trait StreamTest extends QueryTest with Timeouts { val numItems = Random.nextInt(10) val data = dataPos until (dataPos + numItems) dataPos += numItems - actions += addData(data) + actions += addData(data, running) } (1 to iterations).foreach { i => @@ -470,7 +575,7 @@ trait StreamTest extends QueryTest with Timeouts { addRandomData() case _ => // StartStream - actions += StartStream + actions += StartStream() running = true } } else { @@ -488,7 +593,7 @@ trait StreamTest extends QueryTest with Timeouts { } } } - if(!running) { actions += StartStream } + if(!running) { actions += StartStream() } addCheck() testStream(ds)(actions: _*) } @@ -508,7 +613,7 @@ trait StreamTest extends QueryTest with Timeouts { case class ExpectException[E <: Exception]()(implicit val t: ClassTag[E]) extends ExpectedBehavior - private val DEFAULT_TEST_TIMEOUT = 1 second + private val DEFAULT_TEST_TIMEOUT = 1.second def test( expectedBehavior: ExpectedBehavior, @@ -536,7 +641,7 @@ trait StreamTest extends QueryTest with Timeouts { case e: ExpectException[_] => val thrownException = withClue(s"Did not throw ${e.t.runtimeClass.getSimpleName} when expected.") { - intercept[ContinuousQueryException] { + intercept[StreamingQueryException] { failAfter(testTimeout) { awaitTermFunc() } @@ -547,4 +652,58 @@ trait StreamTest extends QueryTest with Timeouts { } } } + + + class QueryStatusCollector extends StreamingQueryListener { + // to catch errors in the async listener events + @volatile private var asyncTestWaiter = new Waiter + + @volatile var startStatus: StreamingQueryStatus = null + @volatile var terminationStatus: StreamingQueryStatus = null + @volatile var terminationException: Option[String] = null + + private val progressStatuses = new mutable.ArrayBuffer[StreamingQueryStatus] + + /** Get the info of the last trigger that processed data */ + def lastTriggerStatus: Option[StreamingQueryStatus] = synchronized { + progressStatuses.filter { i => + i.triggerDetails.get("isTriggerActive").toBoolean == false && + i.triggerDetails.get("isDataPresentInTrigger").toBoolean == true + }.lastOption + } + + def reset(): Unit = { + startStatus = null + terminationStatus = null + progressStatuses.clear() + asyncTestWaiter = new Waiter + } + + def checkAsyncErrors(): Unit = { + asyncTestWaiter.await(timeout(10 seconds)) + } + + + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { + asyncTestWaiter { + startStatus = queryStarted.queryStatus + } + } + + override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryProgress called before onQueryStarted") + synchronized { progressStatuses += queryProgress.queryStatus } + } + } + + override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = { + asyncTestWaiter { + assert(startStatus != null, "onQueryTerminated called before onQueryStarted") + terminationStatus = queryTerminated.queryStatus + terminationException = queryTerminated.exception + } + asyncTestWaiter.dismiss() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index fa3b122f6d2da..e59b5491f90b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -17,25 +17,31 @@ package org.apache.spark.sql.streaming +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.SparkException -import org.apache.spark.sql.StreamTest -import org.apache.spark.sql.catalyst.analysis.Update +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.InternalOutputModes._ +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext object FailureSinglton { var firstTime = true } -class StreamingAggregationSuite extends StreamTest with SharedSQLContext { +class StreamingAggregationSuite extends StreamTest with BeforeAndAfterAll { - import testImplicits._ + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } - override val outputMode = Update + import testImplicits._ - test("simple count") { + test("simple count, update mode") { val inputData = MemoryStream[Int] val aggregated = @@ -44,13 +50,13 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext { .agg(count("*")) .as[(Int, Long)] - testStream(aggregated)( + testStream(aggregated, Update)( AddData(inputData, 3), CheckLastBatch((3, 1)), AddData(inputData, 3, 2), CheckLastBatch((3, 2), (2, 1)), StopStream, - StartStream, + StartStream(), AddData(inputData, 3, 2, 1), CheckLastBatch((3, 3), (2, 2), (1, 1)), // By default we run in new tuple mode. @@ -59,39 +65,138 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext { ) } - test("multiple keys") { + test("simple count, complete mode") { val inputData = MemoryStream[Int] val aggregated = inputData.toDF() - .groupBy($"value", $"value" + 1) + .groupBy($"value") .agg(count("*")) - .as[(Int, Int, Long)] + .as[(Int, Long)] - testStream(aggregated)( - AddData(inputData, 1, 2), - CheckLastBatch((1, 2, 1), (2, 3, 1)), - AddData(inputData, 1, 2), - CheckLastBatch((1, 2, 2), (2, 3, 2)) + testStream(aggregated, Complete)( + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 2), + CheckLastBatch((3, 1), (2, 1)), + StopStream, + StartStream(), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 2), (2, 2), (1, 1)), + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4), (3, 2), (2, 2), (1, 1)) ) } - test("multiple aggregations") { + test("simple count, append mode") { val inputData = MemoryStream[Int] val aggregated = inputData.toDF() .groupBy($"value") - .agg(count("*") as 'count) - .groupBy($"value" % 2) - .agg(sum($"count")) + .agg(count("*")) .as[(Int, Long)] - testStream(aggregated)( - AddData(inputData, 1, 2, 3, 4), - CheckLastBatch((0, 2), (1, 2)), - AddData(inputData, 1, 3, 5), - CheckLastBatch((1, 5)) + val e = intercept[AnalysisException] { + testStream(aggregated, Append)() + } + Seq("append", "not supported").foreach { m => + assert(e.getMessage.toLowerCase.contains(m.toLowerCase)) + } + } + + test("sort after aggregate in complete mode") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .toDF("value", "count") + .orderBy($"count".desc) + .as[(Int, Long)] + + testStream(aggregated, Complete)( + AddData(inputData, 3), + CheckLastBatch(isSorted = true, (3, 1)), + AddData(inputData, 2, 3), + CheckLastBatch(isSorted = true, (3, 2), (2, 1)), + StopStream, + StartStream(), + AddData(inputData, 3, 2, 1), + CheckLastBatch(isSorted = true, (3, 3), (2, 2), (1, 1)), + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch(isSorted = true, (4, 4), (3, 3), (2, 2), (1, 1)) + ) + } + + test("state metrics") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDS() + .flatMap(x => Seq(x, x + 1)) + .toDF("value") + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + implicit class RichStreamExecution(query: StreamExecution) { + def stateNodes: Seq[SparkPlan] = { + query.lastExecution.executedPlan.collect { + case p if p.isInstanceOf[StateStoreSaveExec] => p + } + } + } + + // Test with Update mode + testStream(aggregated, Update)( + AddData(inputData, 1), + CheckLastBatch((1, 1), (2, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 }, + AddData(inputData, 2, 3), + CheckLastBatch((2, 2), (3, 2), (4, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 3 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 } + ) + + // Test with Complete mode + inputData.reset() + testStream(aggregated, Complete)( + AddData(inputData, 1), + CheckLastBatch((1, 1), (2, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 2 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 2 }, + AddData(inputData, 2, 3), + CheckLastBatch((1, 1), (2, 2), (3, 2), (4, 1)), + AssertOnQuery { _.stateNodes.size === 1 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numOutputRows").get.value === 4 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numUpdatedStateRows").get.value === 3 }, + AssertOnQuery { _.stateNodes.head.metrics.get("numTotalStateRows").get.value === 4 } + ) + } + + test("multiple keys") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value", $"value" + 1) + .agg(count("*")) + .as[(Int, Int, Long)] + + testStream(aggregated, Update)( + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 1), (2, 3, 1)), + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 2), (2, 3, 2)) ) } @@ -112,11 +217,11 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext { .agg(count("*")) .as[(Int, Long)] - testStream(aggregated)( - StartStream, + testStream(aggregated, Update)( + StartStream(), AddData(inputData, 1, 2, 3, 4), ExpectFailure[SparkException](), - StartStream, + StartStream(), CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) ) } @@ -125,7 +230,7 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext { val inputData = MemoryStream[(String, Int)] val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) - testStream(aggregated)( + testStream(aggregated, Update)( AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala new file mode 100644 index 0000000000000..ff843865a017e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalactic.TolerantNumerics +import org.scalatest.BeforeAndAfter +import org.scalatest.PrivateMethodTester._ + +import org.apache.spark.SparkException +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions._ +import org.apache.spark.util.{JsonProtocol, ManualClock} + + +class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + import StreamingQueryListenerSuite._ + + // To make === between double tolerate inexact values + implicit val doubleEquality = TolerantNumerics.tolerantDoubleEquality(0.01) + + after { + spark.streams.active.foreach(_.stop()) + assert(spark.streams.active.isEmpty) + assert(addedListeners.isEmpty) + // Make sure we don't leak any events to the next test + } + + test("single listener, check trigger statuses") { + import StreamingQueryListenerSuite._ + clock = new StreamManualClock + + /** Custom MemoryStream that waits for manual clock to reach a time */ + val inputData = new MemoryStream[Int](0, sqlContext) { + // Wait for manual clock to be 100 first time there is data + override def getOffset: Option[Offset] = { + val offset = super.getOffset + if (offset.nonEmpty) { + clock.waitTillTime(100) + } + offset + } + + // Wait for manual clock to be 300 first time there is data + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + clock.waitTillTime(300) + super.getBatch(start, end) + } + } + + // This is to make sure thatquery waits for manual clock to be 600 first time there is data + val mapped = inputData.toDS().agg(count("*")).as[Long].coalesce(1).map { x => + clock.waitTillTime(600) + x + } + + testStream(mapped, OutputMode.Complete)( + StartStream(triggerClock = clock), + AddData(inputData, 1, 2), + AdvanceManualClock(100), // unblock getOffset, will block on getBatch + AdvanceManualClock(200), // unblock getBatch, will block on computation + AdvanceManualClock(300), // unblock computation + AssertOnQuery { _ => clock.getTimeMillis() === 600 }, + AssertOnLastQueryStatus { status: StreamingQueryStatus => + // Check the correctness of the trigger info of the last completed batch reported by + // onQueryProgress + assert(status.triggerDetails.containsKey("triggerId")) + assert(status.triggerDetails.get("isTriggerActive") === "false") + assert(status.triggerDetails.get("isDataPresentInTrigger") === "true") + + assert(status.triggerDetails.get("timestamp.triggerStart") === "0") + assert(status.triggerDetails.get("timestamp.afterGetOffset") === "100") + assert(status.triggerDetails.get("timestamp.afterGetBatch") === "300") + assert(status.triggerDetails.get("timestamp.triggerFinish") === "600") + + assert(status.triggerDetails.get("latency.getOffset.total") === "100") + assert(status.triggerDetails.get("latency.getBatch.total") === "200") + assert(status.triggerDetails.get("latency.optimizer") === "0") + assert(status.triggerDetails.get("latency.offsetLogWrite") === "0") + assert(status.triggerDetails.get("latency.fullTrigger") === "600") + + assert(status.triggerDetails.get("numRows.input.total") === "2") + assert(status.triggerDetails.get("numRows.state.aggregation1.total") === "1") + assert(status.triggerDetails.get("numRows.state.aggregation1.updated") === "1") + + assert(status.sourceStatuses.length === 1) + assert(status.sourceStatuses(0).triggerDetails.containsKey("triggerId")) + assert(status.sourceStatuses(0).triggerDetails.get("latency.getOffset.source") === "100") + assert(status.sourceStatuses(0).triggerDetails.get("latency.getBatch.source") === "200") + assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "2") + }, + CheckAnswer(2) + ) + } + + test("adding and removing listener") { + def isListenerActive(listener: QueryStatusCollector): Boolean = { + listener.reset() + testStream(MemoryStream[Int].toDS)( + StartStream(), + StopStream + ) + listener.startStatus != null + } + + try { + val listener1 = new QueryStatusCollector + val listener2 = new QueryStatusCollector + + spark.streams.addListener(listener1) + assert(isListenerActive(listener1) === true) + assert(isListenerActive(listener2) === false) + spark.streams.addListener(listener2) + assert(isListenerActive(listener1) === true) + assert(isListenerActive(listener2) === true) + spark.streams.removeListener(listener1) + assert(isListenerActive(listener1) === false) + assert(isListenerActive(listener2) === true) + } finally { + addedListeners.foreach(spark.streams.removeListener) + } + } + + test("event ordering") { + val listener = new QueryStatusCollector + withListenerAdded(listener) { + for (i <- 1 to 100) { + listener.reset() + require(listener.startStatus === null) + testStream(MemoryStream[Int].toDS)( + StartStream(), + Assert(listener.startStatus !== null, "onQueryStarted not called before query returned"), + StopStream, + Assert { listener.checkAsyncErrors() } + ) + } + } + } + + testQuietly("exception should be reported in QueryTerminated") { + val listener = new QueryStatusCollector + withListenerAdded(listener) { + val input = MemoryStream[Int] + testStream(input.toDS.map(_ / 0))( + StartStream(), + AddData(input, 1), + ExpectFailure[SparkException](), + Assert { + spark.sparkContext.listenerBus.waitUntilEmpty(10000) + assert(listener.terminationStatus !== null) + assert(listener.terminationException.isDefined) + // Make sure that the exception message reported through listener + // contains the actual exception and relevant stack trace + assert(!listener.terminationException.get.contains("StreamingQueryException")) + assert(listener.terminationException.get.contains("java.lang.ArithmeticException")) + assert(listener.terminationException.get.contains("StreamingQueryListenerSuite")) + } + ) + } + } + + test("QueryStarted serialization") { + val queryStarted = new StreamingQueryListener.QueryStartedEvent(StreamingQueryStatus.testStatus) + val json = JsonProtocol.sparkEventToJson(queryStarted) + val newQueryStarted = JsonProtocol.sparkEventFromJson(json) + .asInstanceOf[StreamingQueryListener.QueryStartedEvent] + assertStreamingQueryInfoEquals(queryStarted.queryStatus, newQueryStarted.queryStatus) + } + + test("QueryProgress serialization") { + val queryProcess = new StreamingQueryListener.QueryProgressEvent( + StreamingQueryStatus.testStatus) + val json = JsonProtocol.sparkEventToJson(queryProcess) + val newQueryProcess = JsonProtocol.sparkEventFromJson(json) + .asInstanceOf[StreamingQueryListener.QueryProgressEvent] + assertStreamingQueryInfoEquals(queryProcess.queryStatus, newQueryProcess.queryStatus) + } + + test("QueryTerminated serialization") { + val exception = new RuntimeException("exception") + val queryQueryTerminated = new StreamingQueryListener.QueryTerminatedEvent( + StreamingQueryStatus.testStatus, + Some(exception.getMessage)) + val json = + JsonProtocol.sparkEventToJson(queryQueryTerminated) + val newQueryTerminated = JsonProtocol.sparkEventFromJson(json) + .asInstanceOf[StreamingQueryListener.QueryTerminatedEvent] + assertStreamingQueryInfoEquals(queryQueryTerminated.queryStatus, newQueryTerminated.queryStatus) + assert(queryQueryTerminated.exception === newQueryTerminated.exception) + } + + private def assertStreamingQueryInfoEquals( + expected: StreamingQueryStatus, + actual: StreamingQueryStatus): Unit = { + assert(expected.name === actual.name) + assert(expected.sourceStatuses.size === actual.sourceStatuses.size) + expected.sourceStatuses.zip(actual.sourceStatuses).foreach { + case (expectedSource, actualSource) => + assertSourceStatus(expectedSource, actualSource) + } + assertSinkStatus(expected.sinkStatus, actual.sinkStatus) + } + + private def assertSourceStatus(expected: SourceStatus, actual: SourceStatus): Unit = { + assert(expected.description === actual.description) + assert(expected.offsetDesc === actual.offsetDesc) + } + + private def assertSinkStatus(expected: SinkStatus, actual: SinkStatus): Unit = { + assert(expected.description === actual.description) + assert(expected.offsetDesc === actual.offsetDesc) + } + + private def withListenerAdded(listener: StreamingQueryListener)(body: => Unit): Unit = { + try { + failAfter(streamingTimeout) { + spark.streams.addListener(listener) + body + } + } finally { + spark.streams.removeListener(listener) + } + } + + private def addedListeners(): Array[StreamingQueryListener] = { + val listenerBusMethod = + PrivateMethod[StreamingQueryListenerBus]('listenerBus) + val listenerBus = spark.streams invokePrivate listenerBusMethod() + listenerBus.listeners.toArray.map(_.asInstanceOf[StreamingQueryListener]) + } +} + +object StreamingQueryListenerSuite { + // Singleton reference to clock that does not get serialized in task closures + @volatile var clock: ManualClock = null +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala similarity index 79% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index 3d69c8a18711b..41ffd56cf1290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -28,12 +28,11 @@ import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException -import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils -class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { +class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { import AwaitTerminationTester._ import testImplicits._ @@ -41,13 +40,13 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with override val streamingTimeout = 20.seconds before { - assert(sqlContext.streams.active.isEmpty) - sqlContext.streams.resetTerminated() + assert(spark.streams.active.isEmpty) + spark.streams.resetTerminated() } after { - assert(sqlContext.streams.active.isEmpty) - sqlContext.streams.resetTerminated() + assert(spark.streams.active.isEmpty) + spark.streams.resetTerminated() } testQuietly("listing") { @@ -57,26 +56,18 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with withQueriesOn(ds1, ds2, ds3) { queries => require(queries.size === 3) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) val (q1, q2, q3) = (queries(0), queries(1), queries(2)) - assert(sqlContext.streams.get(q1.name).eq(q1)) - assert(sqlContext.streams.get(q2.name).eq(q2)) - assert(sqlContext.streams.get(q3.name).eq(q3)) - intercept[IllegalArgumentException] { - sqlContext.streams.get("non-existent-name") - } - + assert(spark.streams.get(q1.id).eq(q1)) + assert(spark.streams.get(q2.id).eq(q2)) + assert(spark.streams.get(q3.id).eq(q3)) + assert(spark.streams.get(-1) === null) // non-existent id q1.stop() - assert(sqlContext.streams.active.toSet === Set(q2, q3)) - val ex1 = withClue("no error while getting non-active query") { - intercept[IllegalArgumentException] { - sqlContext.streams.get(q1.name) - } - } - assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched") - assert(sqlContext.streams.get(q2.name).eq(q2)) + assert(spark.streams.active.toSet === Set(q2, q3)) + assert(spark.streams.get(q1.id) === null) + assert(spark.streams.get(q2.id).eq(q2)) m2.addData(0) // q2 should terminate with error @@ -84,13 +75,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with require(!q2.isActive) require(q2.exception.isDefined) } - withClue("no error while getting non-active query") { - intercept[IllegalArgumentException] { - sqlContext.streams.get(q2.name).eq(q2) - } - } - - assert(sqlContext.streams.active.toSet === Set(q3)) + assert(spark.streams.get(q2.id) === null) + assert(spark.streams.active.toSet === Set(q3)) } } @@ -98,7 +84,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val datasets = Seq.fill(5)(makeDataset._2) withQueriesOn(datasets: _*) { queries => require(queries.size === datasets.size) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) // awaitAnyTermination should be blocking testAwaitAnyTermination(ExpectBlocked) @@ -112,7 +98,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testAwaitAnyTermination(ExpectNotBlocked) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination(ExpectBlocked) // Terminate a query asynchronously with exception and see awaitAnyTermination throws @@ -125,7 +111,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testAwaitAnyTermination(ExpectException[SparkException]) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination(ExpectBlocked) // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws @@ -144,7 +130,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val datasets = Seq.fill(6)(makeDataset._2) withQueriesOn(datasets: _*) { queries => require(queries.size === datasets.size) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) // awaitAnyTermination should be blocking or non-blocking depending on timeout values testAwaitAnyTermination( @@ -173,7 +159,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with ExpectNotBlocked, awaitTimeout = 4 seconds, expectedReturnedValue = true) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination( ExpectBlocked, awaitTimeout = 4 seconds, @@ -196,7 +182,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testBehaviorFor = 4 seconds) // Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() val q3 = stopRandomQueryAsync(2 seconds, withError = true) testAwaitAnyTermination( ExpectNotBlocked, @@ -214,7 +200,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws // the exception - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() val q4 = stopRandomQueryAsync(10 milliseconds, withError = false) testAwaitAnyTermination( @@ -228,24 +214,24 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with } - /** Run a body of code by defining a query each on multiple datasets */ - private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[ContinuousQuery] => Unit): Unit = { + /** Run a body of code by defining a query on each dataset */ + private def withQueriesOn(datasets: Dataset[_]*)(body: Seq[StreamingQuery] => Unit): Unit = { failAfter(streamingTimeout) { val queries = withClue("Error starting queries") { - datasets.map { ds => + datasets.zipWithIndex.map { case (ds, i) => @volatile var query: StreamExecution = null try { val df = ds.toDF val metadataRoot = - Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - query = sqlContext - .streams - .startQuery( - StreamExecution.nextName, - metadataRoot, - df, - new MemorySink(df.schema)) - .asInstanceOf[StreamExecution] + Utils.createTempDir(namePrefix = "streaming.checkpoint").getCanonicalPath + query = + df.writeStream + .format("memory") + .queryName(s"query$i") + .option("checkpointLocation", metadataRoot) + .outputMode("append") + .start() + .asInstanceOf[StreamExecution] } catch { case NonFatal(e) => if (query != null) query.stop() @@ -272,10 +258,10 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with def awaitTermFunc(): Unit = { if (awaitTimeout != null && awaitTimeout.toMillis > 0) { - val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis) + val returnedValue = spark.streams.awaitAnyTermination(awaitTimeout.toMillis) assert(returnedValue === expectedReturnedValue, "Returned value does not match expected") } else { - sqlContext.streams.awaitAnyTermination() + spark.streams.awaitAnyTermination() } } @@ -283,11 +269,11 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with } /** Stop a random active query either with `stop()` or with an error */ - private def stopRandomQueryAsync(stopAfter: Span, withError: Boolean): ContinuousQuery = { + private def stopRandomQueryAsync(stopAfter: Span, withError: Boolean): StreamingQuery = { import scala.concurrent.ExecutionContext.Implicits.global - val activeQueries = sqlContext.streams.active + val activeQueries = spark.streams.active val queryToStop = activeQueries(Random.nextInt(activeQueries.length)) Future { Thread.sleep(stopAfter.toMillis) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala new file mode 100644 index 0000000000000..1a98cf2ba74e6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkFunSuite + +class StreamingQueryStatusSuite extends SparkFunSuite { + test("toString") { + assert(StreamingQueryStatus.testStatus.sourceStatuses(0).toString === + """ + |Status of source MySource1 + | Available offset: #0 + | Input rate: 15.5 rows/sec + | Processing rate: 23.5 rows/sec + | Trigger details: + | numRows.input.source: 100 + | latency.getOffset.source: 10 + | latency.getBatch.source: 20 + """.stripMargin.trim, "SourceStatus.toString does not match") + + assert(StreamingQueryStatus.testStatus.sinkStatus.toString === + """ + |Status of sink MySink + | Committed offsets: [#1, -] + """.stripMargin.trim, "SinkStatus.toString does not match") + + assert(StreamingQueryStatus.testStatus.toString === + """ + |Status of query 'query' + | Query id: 1 + | Status timestamp: 123 + | Input rate: 15.5 rows/sec + | Processing rate 23.5 rows/sec + | Latency: 345.0 ms + | Trigger details: + | isDataPresentInTrigger: true + | isTriggerActive: true + | latency.getBatch.total: 20 + | latency.getOffset.total: 10 + | numRows.input.total: 100 + | triggerId: 5 + | Source statuses [1 source]: + | Source 1 - MySource1 + | Available offset: #0 + | Input rate: 15.5 rows/sec + | Processing rate: 23.5 rows/sec + | Trigger details: + | numRows.input.source: 100 + | latency.getOffset.source: 10 + | latency.getBatch.source: 20 + | Sink status - MySink + | Committed offsets: [#1, -] + """.stripMargin.trim, "StreamingQueryStatus.toString does not match") + + } + + test("json") { + assert(StreamingQueryStatus.testStatus.json === + """ + |{"sourceStatuses":[{"description":"MySource1","offsetDesc":"#0","inputRate":15.5, + |"processingRate":23.5,"triggerDetails":{"numRows.input.source":"100", + |"latency.getOffset.source":"10","latency.getBatch.source":"20"}}], + |"sinkStatus":{"description":"MySink","offsetDesc":"[#1, -]"}} + """.stripMargin.replace("\n", "").trim) + } + + test("prettyJson") { + assert( + StreamingQueryStatus.testStatus.prettyJson === + """ + |{ + | "sourceStatuses" : [ { + | "description" : "MySource1", + | "offsetDesc" : "#0", + | "inputRate" : 15.5, + | "processingRate" : 23.5, + | "triggerDetails" : { + | "numRows.input.source" : "100", + | "latency.getOffset.source" : "10", + | "latency.getBatch.source" : "20" + | } + | } ], + | "sinkStatus" : { + | "description" : "MySink", + | "offsetDesc" : "[#1, -]" + | } + |} + """.stripMargin.trim) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala new file mode 100644 index 0000000000000..464c443beb6e7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalactic.TolerantNumerics +import org.scalatest.concurrent.Eventually._ +import org.scalatest.BeforeAndAfter + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.streaming.StreamingQueryListener._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.SparkException +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.util.Utils + + +class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { + + import AwaitTerminationTester._ + import testImplicits._ + + // To make === between double tolerate inexact values + implicit val doubleEquality = TolerantNumerics.tolerantDoubleEquality(0.01) + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + test("names unique across active queries, ids unique across all started queries") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map { 6 / _} + + def startQuery(queryName: String): StreamingQuery = { + val metadataRoot = Utils.createTempDir(namePrefix = "streaming.checkpoint").getCanonicalPath + val writer = mapped.writeStream + writer + .queryName(queryName) + .format("memory") + .option("checkpointLocation", metadataRoot) + .start() + } + + val q1 = startQuery("q1") + assert(q1.name === "q1") + + // Verify that another query with same name cannot be started + val e1 = intercept[IllegalArgumentException] { + startQuery("q1") + } + Seq("q1", "already active").foreach { s => assert(e1.getMessage.contains(s)) } + + // Verify q1 was unaffected by the above exception and stop it + assert(q1.isActive) + q1.stop() + + // Verify another query can be started with name q1, but will have different id + val q2 = startQuery("q1") + assert(q2.name === "q1") + assert(q2.id !== q1.id) + q2.stop() + } + + testQuietly("lifecycle states and awaitTermination") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map { 6 / _} + + testStream(mapped)( + AssertOnQuery(_.isActive === true), + AssertOnQuery(_.exception.isEmpty), + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + TestAwaitTermination(ExpectBlocked), + TestAwaitTermination(ExpectBlocked, timeoutMs = 2000), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = false), + StopStream, + AssertOnQuery(_.isActive === false), + AssertOnQuery(_.exception.isEmpty), + TestAwaitTermination(ExpectNotBlocked), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 2000, expectedReturnValue = true), + TestAwaitTermination(ExpectNotBlocked, timeoutMs = 10, expectedReturnValue = true), + StartStream(), + AssertOnQuery(_.isActive === true), + AddData(inputData, 0), + ExpectFailure[SparkException], + AssertOnQuery(_.isActive === false), + TestAwaitTermination(ExpectException[SparkException]), + TestAwaitTermination(ExpectException[SparkException], timeoutMs = 2000), + TestAwaitTermination(ExpectException[SparkException], timeoutMs = 10), + AssertOnQuery( + q => + q.exception.get.startOffset.get === q.committedOffsets.toCompositeOffset(Seq(inputData)), + "incorrect start offset on exception") + ) + } + + testQuietly("query statuses") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(6 / _) + testStream(mapped)( + AssertOnQuery(q => q.status.name === q.name), + AssertOnQuery(q => q.status.id === q.id), + AssertOnQuery(_.status.timestamp <= System.currentTimeMillis), + AssertOnQuery(_.status.inputRate === 0.0), + AssertOnQuery(_.status.processingRate === 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === "-"), + AssertOnQuery(_.status.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.status.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.status.sinkStatus.offsetDesc === CompositeOffset(None :: Nil).toString), + AssertOnQuery(_.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === "-"), + AssertOnQuery(_.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.sinkStatus.offsetDesc === new CompositeOffset(None :: Nil).toString), + + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + AssertOnQuery(_.status.timestamp <= System.currentTimeMillis), + AssertOnQuery(_.status.inputRate >= 0.0), + AssertOnQuery(_.status.processingRate >= 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).description.contains("Memory")), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(0).toString), + AssertOnQuery(_.status.sourceStatuses(0).inputRate >= 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate >= 0.0), + AssertOnQuery(_.status.sinkStatus.description.contains("Memory")), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(0)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(0).toString), + AssertOnQuery(_.sourceStatuses(0).inputRate >= 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate >= 0.0), + AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(0)).toString), + + AddData(inputData, 1, 2), + CheckAnswer(6, 3, 6, 3), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(1)).toString), + + StopStream, + AssertOnQuery(_.status.inputRate === 0.0), + AssertOnQuery(_.status.processingRate === 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.status.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(1).toString), + AssertOnQuery(_.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.status.triggerDetails.isEmpty), + + StartStream(), + AddData(inputData, 0), + ExpectFailure[SparkException], + AssertOnQuery(_.status.inputRate === 0.0), + AssertOnQuery(_.status.processingRate === 0.0), + AssertOnQuery(_.status.sourceStatuses.length === 1), + AssertOnQuery(_.status.sourceStatuses(0).offsetDesc === LongOffset(2).toString), + AssertOnQuery(_.status.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.status.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.status.sinkStatus.offsetDesc === + CompositeOffset.fill(LongOffset(1)).toString), + AssertOnQuery(_.sourceStatuses(0).offsetDesc === LongOffset(2).toString), + AssertOnQuery(_.sourceStatuses(0).inputRate === 0.0), + AssertOnQuery(_.sourceStatuses(0).processingRate === 0.0), + AssertOnQuery(_.sinkStatus.offsetDesc === CompositeOffset.fill(LongOffset(1)).toString) + ) + } + + test("codahale metrics") { + val inputData = MemoryStream[Int] + + /** Whether metrics of a query is registered for reporting */ + def isMetricsRegistered(query: StreamingQuery): Boolean = { + val sourceName = s"StructuredStreaming.${query.name}" + val sources = spark.sparkContext.env.metricsSystem.getSourcesByName(sourceName) + require(sources.size <= 1) + sources.nonEmpty + } + // Disabled by default + assert(spark.conf.get("spark.sql.streaming.metricsEnabled").toBoolean === false) + + withSQLConf("spark.sql.streaming.metricsEnabled" -> "false") { + testStream(inputData.toDF)( + AssertOnQuery { q => !isMetricsRegistered(q) }, + StopStream, + AssertOnQuery { q => !isMetricsRegistered(q) } + ) + } + + // Registered when enabled + withSQLConf("spark.sql.streaming.metricsEnabled" -> "true") { + testStream(inputData.toDF)( + AssertOnQuery { q => isMetricsRegistered(q) }, + StopStream, + AssertOnQuery { q => !isMetricsRegistered(q) } + ) + } + } + + test("input row calculation with mixed batch and streaming sources") { + val streamingTriggerDF = spark.createDataset(1 to 10).toDF + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF).toDF("value") + val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") + + // Trigger input has 10 rows, static input has 2 rows, + // therefore after the first trigger, the calculated input rows should be 10 + val status = getFirstTriggerStatus(streamingInputDF.join(staticInputDF, "value")) + assert(status.triggerDetails.get("numRows.input.total") === "10") + assert(status.sourceStatuses.size === 1) + assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "10") + } + + test("input row calculation with trigger DF having multiple leaves") { + val streamingTriggerDF = + spark.createDataset(1 to 5).toDF.union(spark.createDataset(6 to 10).toDF) + require(streamingTriggerDF.logicalPlan.collectLeaves().size > 1) + val streamingInputDF = createSingleTriggerStreamingDF(streamingTriggerDF) + + // After the first trigger, the calculated input rows should be 10 + val status = getFirstTriggerStatus(streamingInputDF) + assert(status.triggerDetails.get("numRows.input.total") === "10") + assert(status.sourceStatuses.size === 1) + assert(status.sourceStatuses(0).triggerDetails.get("numRows.input.source") === "10") + } + + testQuietly("StreamExecution metadata garbage collection") { + val inputData = MemoryStream[Int] + val mapped = inputData.toDS().map(6 / _) + + // Run 3 batches, and then assert that only 2 metadata files is are at the end + // since the first should have been purged. + testStream(mapped)( + AddData(inputData, 1, 2), + CheckAnswer(6, 3), + AddData(inputData, 1, 2), + CheckAnswer(6, 3, 6, 3), + AddData(inputData, 4, 6), + CheckAnswer(6, 3, 6, 3, 1, 1), + + AssertOnQuery("metadata log should contain only two files") { q => + val metadataLogDir = new java.io.File(q.offsetLog.metadataPath.toString) + val logFileNames = metadataLogDir.listFiles().toSeq.map(_.getName()) + val toTest = logFileNames.filter(! _.endsWith(".crc")).sorted // Workaround for SPARK-17475 + assert(toTest.size == 2 && toTest.head == "1") + true + } + ) + } + + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ + private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { + require(!triggerDF.isStreaming) + // A streaming Source that generate only on trigger and returns the given Dataframe as batch + val source = new Source() { + override def schema: StructType = triggerDF.schema + override def getOffset: Option[Offset] = Some(LongOffset(0)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = triggerDF + override def stop(): Unit = {} + } + StreamingExecutionRelation(source) + } + + /** Returns the query status at the end of the first trigger of streaming DF */ + private def getFirstTriggerStatus(streamingDF: DataFrame): StreamingQueryStatus = { + // A StreamingQueryListener that gets the query status after the first completed trigger + val listener = new StreamingQueryListener { + @volatile var firstStatus: StreamingQueryStatus = null + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { } + override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { + if (firstStatus == null) firstStatus = queryProgress.queryStatus + } + override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = { } + } + + try { + spark.streams.addListener(listener) + val q = streamingDF.writeStream.format("memory").queryName("test").start() + q.processAllAvailable() + eventually(timeout(streamingTimeout)) { + assert(listener.firstStatus != null) + } + listener.firstStatus + } finally { + spark.streams.active.map(_.stop()) + spark.streams.removeListener(listener) + } + } + + /** + * A [[StreamAction]] to test the behavior of `StreamingQuery.awaitTermination()`. + * + * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) + * @param timeoutMs Timeout in milliseconds + * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) + * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used + */ + case class TestAwaitTermination( + expectedBehavior: ExpectedBehavior, + timeoutMs: Int = -1, + expectedReturnValue: Boolean = false + ) extends AssertOnQuery( + TestAwaitTermination.assertOnQueryCondition(expectedBehavior, timeoutMs, expectedReturnValue), + "Error testing awaitTermination behavior" + ) { + override def toString(): String = { + s"TestAwaitTermination($expectedBehavior, timeoutMs = $timeoutMs, " + + s"expectedReturnValue = $expectedReturnValue)" + } + } + + object TestAwaitTermination { + + /** + * Tests the behavior of `StreamingQuery.awaitTermination`. + * + * @param expectedBehavior Expected behavior (not blocked, blocked, or exception thrown) + * @param timeoutMs Timeout in milliseconds + * When timeoutMs <= 0, awaitTermination() is tested (i.e. w/o timeout) + * When timeoutMs > 0, awaitTermination(timeoutMs) is tested + * @param expectedReturnValue Expected return value when awaitTermination(timeoutMs) is used + */ + def assertOnQueryCondition( + expectedBehavior: ExpectedBehavior, + timeoutMs: Int, + expectedReturnValue: Boolean + )(q: StreamExecution): Boolean = { + + def awaitTermFunc(): Unit = { + if (timeoutMs <= 0) { + q.awaitTermination() + } else { + val returnedValue = q.awaitTermination(timeoutMs) + assert(returnedValue === expectedReturnValue, "Returned value does not match expected") + } + } + AwaitTerminationTester.test(expectedBehavior, awaitTermFunc) + true // If the control reached here, then everything worked as expected + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala similarity index 65% rename from sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 00efe21d39de4..f0994395813e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, StreamingQuery, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -54,18 +54,18 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( - sqlContext: SQLContext, + spark: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = { LastOptions.parameters = parameters LastOptions.schema = schema - LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters) + LastOptions.mockStreamSourceProvider.sourceSchema(spark, schema, providerName, parameters) ("dummySource", fakeSchema) } override def createSource( - sqlContext: SQLContext, + spark: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, @@ -73,62 +73,77 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { LastOptions.parameters = parameters LastOptions.schema = schema LastOptions.mockStreamSourceProvider.createSource( - sqlContext, metadataPath, schema, providerName, parameters) + spark, metadataPath, schema, providerName, parameters) new Source { override def schema: StructType = fakeSchema override def getOffset: Option[Offset] = Some(new LongOffset(0)) override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - import sqlContext.implicits._ + import spark.implicits._ Seq[Int]().toDS().toDF() } + + override def stop() {} } } override def createSink( - sqlContext: SQLContext, + spark: SQLContext, parameters: Map[String, String], - partitionColumns: Seq[String]): Sink = { + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { LastOptions.parameters = parameters LastOptions.partitionColumns = partitionColumns - LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns) + LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns, outputMode) new Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } } } -class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { - import testImplicits._ +class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { private def newMetadataDir = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath after { - sqlContext.streams.active.foreach(_.stop()) + spark.streams.active.foreach(_.stop()) + } + + test("write cannot be called on streaming datasets") { + val e = intercept[AnalysisException] { + spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + .write + .save() + } + Seq("'write'", "not", "streaming Dataset/DataFrame").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } } test("resolve default source") { - sqlContext.read + spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() - .write + .load() + .writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream() + .start() .stop() } test("resolve full class") { - sqlContext.read + spark.readStream .format("org.apache.spark.sql.streaming.test.DefaultSource") - .stream() - .write + .load() + .writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream() + .start() .stop() } @@ -136,12 +151,12 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B val map = new java.util.HashMap[String, String] map.put("opt3", "3") - val df = sqlContext.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) - .stream() + .load() assert(LastOptions.parameters("opt1") == "1") assert(LastOptions.parameters("opt2") == "2") @@ -149,13 +164,13 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B LastOptions.clear() - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) .options(map) .option("checkpointLocation", newMetadataDir) - .startStream() + .start() .stop() assert(LastOptions.parameters("opt1") == "1") @@ -164,84 +179,84 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("partitioning") { - val df = sqlContext.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() + .load() - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream() + .start() .stop() assert(LastOptions.partitionColumns == Nil) - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .partitionBy("a") - .startStream() + .start() .stop() assert(LastOptions.partitionColumns == Seq("a")) withSQLConf("spark.sql.caseSensitive" -> "false") { - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .partitionBy("A") - .startStream() + .start() .stop() assert(LastOptions.partitionColumns == Seq("a")) } intercept[AnalysisException] { - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .partitionBy("b") - .startStream() + .start() .stop() } } test("stream paths") { - val df = sqlContext.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .stream("/test") + .load("/test") assert(LastOptions.parameters("path") == "/test") LastOptions.clear() - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream("/test") + .start("/test") .stop() assert(LastOptions.parameters("path") == "/test") } test("test different data types for options") { - val df = sqlContext.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") .option("intOpt", 56) .option("boolOpt", false) .option("doubleOpt", 6.7) - .stream("/test") + .load("/test") assert(LastOptions.parameters("intOpt") == "56") assert(LastOptions.parameters("boolOpt") == "false") assert(LastOptions.parameters("doubleOpt") == "6.7") LastOptions.clear() - df.write + df.writeStream .format("org.apache.spark.sql.streaming.test") .option("intOpt", 56) .option("boolOpt", false) .option("doubleOpt", 6.7) .option("checkpointLocation", newMetadataDir) - .startStream("/test") + .start("/test") .stop() assert(LastOptions.parameters("intOpt") == "56") @@ -252,31 +267,31 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B test("unique query names") { /** Start a query with a specific name */ - def startQueryWithName(name: String = ""): ContinuousQuery = { - sqlContext.read + def startQueryWithName(name: String = ""): StreamingQuery = { + spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream("/test") - .write + .load("/test") + .writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .queryName(name) - .startStream() + .start() } /** Start a query without specifying a name */ - def startQueryWithoutName(): ContinuousQuery = { - sqlContext.read + def startQueryWithoutName(): StreamingQuery = { + spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream("/test") - .write + .load("/test") + .writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) - .startStream() + .start() } /** Get the names of active streams */ def activeStreamNames: Set[String] = { - val streams = sqlContext.streams.active + val streams = spark.streams.active val names = streams.map(_.name).toSet assert(streams.length === names.size, s"names of active queries are not unique: $names") names @@ -307,28 +322,28 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B q1.stop() val q5 = startQueryWithName("name") assert(activeStreamNames.contains("name")) - sqlContext.streams.active.foreach(_.stop()) + spark.streams.active.foreach(_.stop()) } test("trigger") { - val df = sqlContext.read + val df = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream("/test") + .load("/test") - var q = df.write + var q = df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .trigger(ProcessingTime(10.seconds)) - .startStream() + .start() q.stop() assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000)) - q = df.write + q = df.writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) - .startStream() + .start() q.stop() assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) @@ -339,33 +354,117 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B val checkpointLocation = newMetadataDir - val df1 = sqlContext.read + val df1 = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() + .load() - val df2 = sqlContext.read + val df2 = spark.readStream .format("org.apache.spark.sql.streaming.test") - .stream() + .load() - val q = df1.union(df2).write + val q = df1.union(df2).writeStream .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", checkpointLocation) .trigger(ProcessingTime(10.seconds)) - .startStream() + .start() q.stop() verify(LastOptions.mockStreamSourceProvider).createSource( - sqlContext, + spark.sqlContext, checkpointLocation + "/sources/0", None, "org.apache.spark.sql.streaming.test", Map.empty) verify(LastOptions.mockStreamSourceProvider).createSource( - sqlContext, + spark.sqlContext, checkpointLocation + "/sources/1", None, "org.apache.spark.sql.streaming.test", Map.empty) } + + private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath + + test("check outputMode(string) throws exception on unsupported modes") { + def testError(outputMode: String): Unit = { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + val w = df.writeStream + val e = intercept[IllegalArgumentException](w.outputMode(outputMode)) + Seq("output mode", "unknown", outputMode).foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + testError("Update") + testError("Xyz") + } + + test("check foreach() catches null writers") { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + + var w = df.writeStream + var e = intercept[IllegalArgumentException](w.foreach(null)) + Seq("foreach", "null").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + + test("check foreach() does not support partitioning") { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + val foreachWriter = new ForeachWriter[Row] { + override def open(partitionId: Long, version: Long): Boolean = false + override def process(value: Row): Unit = {} + override def close(errorOrNull: Throwable): Unit = {} + } + var w = df.writeStream.partitionBy("value") + var e = intercept[AnalysisException](w.foreach(foreachWriter).start()) + Seq("foreach", "partitioning").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + test("ConsoleSink can be correctly loaded") { + LastOptions.clear() + val df = spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + + val sq = df.writeStream + .format("console") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime(2.seconds)) + .start() + + sq.awaitTermination(2000L) + } + + test("prevent all column partitioning") { + withTempDir { dir => + val path = dir.getCanonicalPath + intercept[AnalysisException] { + spark.range(10).writeStream + .outputMode("append") + .partitionBy("id") + .format("parquet") + .start(path) + } + } + } + + test("ConsoleSink should not require checkpointLocation") { + LastOptions.clear() + val df = spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + + val sq = df.writeStream.format("console").start() + sq.stop() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala new file mode 100644 index 0000000000000..bba265e9c9340 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -0,0 +1,518 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.util.Utils + + +object LastOptions { + + var parameters: Map[String, String] = null + var schema: Option[StructType] = null + var saveMode: SaveMode = null + + def clear(): Unit = { + parameters = null + schema = null + saveMode = null + } +} + + +/** Dummy provider. */ +class DefaultSource + extends RelationProvider + with SchemaRelationProvider + with CreatableRelationProvider { + + case class FakeRelation(sqlContext: SQLContext) extends BaseRelation { + override def schema: StructType = StructType(Seq(StructField("a", StringType))) + } + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType + ): BaseRelation = { + LastOptions.parameters = parameters + LastOptions.schema = Some(schema) + FakeRelation(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String] + ): BaseRelation = { + LastOptions.parameters = parameters + LastOptions.schema = None + FakeRelation(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + LastOptions.parameters = parameters + LastOptions.schema = None + LastOptions.saveMode = mode + FakeRelation(sqlContext) + } +} + +/** Dummy provider with only RelationProvider and CreatableRelationProvider. */ +class DefaultSourceWithoutUserSpecifiedSchema + extends RelationProvider + with CreatableRelationProvider { + + case class FakeRelation(sqlContext: SQLContext) extends BaseRelation { + override def schema: StructType = StructType(Seq(StructField("a", StringType))) + } + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String]): BaseRelation = { + FakeRelation(sqlContext) + } + + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + FakeRelation(sqlContext) + } +} + +class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { + + + private val userSchema = new StructType().add("s", StringType) + private val textSchema = new StructType().add("value", StringType) + private val data = Seq("1", "2", "3") + private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath + private implicit var enc: Encoder[String] = _ + + before { + enc = spark.implicits.newStringEncoder + Utils.deleteRecursively(new File(dir)) + } + + test("writeStream cannot be called on non-streaming datasets") { + val e = intercept[AnalysisException] { + spark.read + .format("org.apache.spark.sql.test") + .load() + .writeStream + .start() + } + Seq("'writeStream'", "only", "streaming Dataset/DataFrame").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + + test("resolve default source") { + spark.read + .format("org.apache.spark.sql.test") + .load() + .write + .format("org.apache.spark.sql.test") + .save() + } + + test("resolve default source without extending SchemaRelationProvider") { + spark.read + .format("org.apache.spark.sql.test.DefaultSourceWithoutUserSpecifiedSchema") + .load() + .write + .format("org.apache.spark.sql.test.DefaultSourceWithoutUserSpecifiedSchema") + .save() + } + + test("resolve full class") { + spark.read + .format("org.apache.spark.sql.test.DefaultSource") + .load() + .write + .format("org.apache.spark.sql.test") + .save() + } + + test("options") { + val map = new java.util.HashMap[String, String] + map.put("opt3", "3") + + val df = spark.read + .format("org.apache.spark.sql.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .load() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + + LastOptions.clear() + + df.write + .format("org.apache.spark.sql.test") + .option("opt1", "1") + .options(Map("opt2" -> "2")) + .options(map) + .save() + + assert(LastOptions.parameters("opt1") == "1") + assert(LastOptions.parameters("opt2") == "2") + assert(LastOptions.parameters("opt3") == "3") + } + + test("save mode") { + val df = spark.read + .format("org.apache.spark.sql.test") + .load() + + df.write + .format("org.apache.spark.sql.test") + .mode(SaveMode.ErrorIfExists) + .save() + assert(LastOptions.saveMode === SaveMode.ErrorIfExists) + } + + test("test path option in load") { + spark.read + .format("org.apache.spark.sql.test") + .option("intOpt", 56) + .load("/test") + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("path") == "/test") + + LastOptions.clear() + spark.read + .format("org.apache.spark.sql.test") + .option("intOpt", 55) + .load() + + assert(LastOptions.parameters("intOpt") == "55") + assert(!LastOptions.parameters.contains("path")) + + LastOptions.clear() + spark.read + .format("org.apache.spark.sql.test") + .option("intOpt", 54) + .load("/test", "/test1", "/test2") + + assert(LastOptions.parameters("intOpt") == "54") + assert(!LastOptions.parameters.contains("path")) + } + + test("test different data types for options") { + val df = spark.read + .format("org.apache.spark.sql.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .load("/test") + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + + LastOptions.clear() + df.write + .format("org.apache.spark.sql.test") + .option("intOpt", 56) + .option("boolOpt", false) + .option("doubleOpt", 6.7) + .save("/test") + + assert(LastOptions.parameters("intOpt") == "56") + assert(LastOptions.parameters("boolOpt") == "false") + assert(LastOptions.parameters("doubleOpt") == "6.7") + } + + test("check jdbc() does not support partitioning or bucketing") { + val df = spark.read.text(Utils.createTempDir(namePrefix = "text").getCanonicalPath) + + var w = df.write.partitionBy("value") + var e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "partitioning").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + + w = df.write.bucketBy(2, "value") + e = intercept[AnalysisException](w.jdbc(null, null, null)) + Seq("jdbc", "bucketing").foreach { s => + assert(e.getMessage.toLowerCase.contains(s.toLowerCase)) + } + } + + test("prevent all column partitioning") { + withTempDir { dir => + val path = dir.getCanonicalPath + intercept[AnalysisException] { + spark.range(10).write.format("parquet").mode("overwrite").partitionBy("id").save(path) + } + intercept[AnalysisException] { + spark.range(10).write.format("csv").mode("overwrite").partitionBy("id").save(path) + } + spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path) + } + } + + test("load API") { + spark.read.format("org.apache.spark.sql.test").load() + spark.read.format("org.apache.spark.sql.test").load(dir) + spark.read.format("org.apache.spark.sql.test").load(dir, dir, dir) + spark.read.format("org.apache.spark.sql.test").load(Seq(dir, dir): _*) + Option(dir).map(spark.read.format("org.apache.spark.sql.test").load) + } + + test("text - API and behavior regarding schema") { + // Writer + spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) + testRead(spark.read.text(dir), data, textSchema) + + // Reader, without user specified schema + testRead(spark.read.text(), Seq.empty, textSchema) + testRead(spark.read.text(dir, dir, dir), data ++ data ++ data, textSchema) + testRead(spark.read.text(Seq(dir, dir): _*), data ++ data, textSchema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.text).get, data, textSchema) + + // Reader, with user specified schema, should just apply user schema on the file data + testRead(spark.read.schema(userSchema).text(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchema).text(dir), data, userSchema) + testRead(spark.read.schema(userSchema).text(dir, dir), data ++ data, userSchema) + testRead(spark.read.schema(userSchema).text(Seq(dir, dir): _*), data ++ data, userSchema) + } + + test("textFile - API and behavior regarding schema") { + spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir) + + // Reader, without user specified schema + testRead(spark.read.textFile().toDF(), Seq.empty, textSchema) + testRead(spark.read.textFile(dir).toDF(), data, textSchema) + testRead(spark.read.textFile(dir, dir).toDF(), data ++ data, textSchema) + testRead(spark.read.textFile(Seq(dir, dir): _*).toDF(), data ++ data, textSchema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.text).get, data, textSchema) + + // Reader, with user specified schema, should just apply user schema on the file data + val e = intercept[AnalysisException] { spark.read.schema(userSchema).textFile() } + assert(e.getMessage.toLowerCase.contains("user specified schema not supported")) + intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir) } + intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir, dir) } + intercept[AnalysisException] { spark.read.schema(userSchema).textFile(Seq(dir, dir): _*) } + } + + test("csv - API and behavior regarding schema") { + // Writer + spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).csv(dir) + val df = spark.read.csv(dir) + checkAnswer(df, spark.createDataset(data).toDF()) + val schema = df.schema + + // Reader, without user specified schema + intercept[IllegalArgumentException] { + testRead(spark.read.csv(), Seq.empty, schema) + } + testRead(spark.read.csv(dir), data, schema) + testRead(spark.read.csv(dir, dir), data ++ data, schema) + testRead(spark.read.csv(Seq(dir, dir): _*), data ++ data, schema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.csv).get, data, schema) + + // Reader, with user specified schema, should just apply user schema on the file data + testRead(spark.read.schema(userSchema).csv(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchema).csv(dir), data, userSchema) + testRead(spark.read.schema(userSchema).csv(dir, dir), data ++ data, userSchema) + testRead(spark.read.schema(userSchema).csv(Seq(dir, dir): _*), data ++ data, userSchema) + } + + test("json - API and behavior regarding schema") { + // Writer + spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).json(dir) + val df = spark.read.json(dir) + checkAnswer(df, spark.createDataset(data).toDF()) + val schema = df.schema + + // Reader, without user specified schema + intercept[AnalysisException] { + testRead(spark.read.json(), Seq.empty, schema) + } + testRead(spark.read.json(dir), data, schema) + testRead(spark.read.json(dir, dir), data ++ data, schema) + testRead(spark.read.json(Seq(dir, dir): _*), data ++ data, schema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.json).get, data, schema) + + // Reader, with user specified schema, data should be nulls as schema in file different + // from user schema + val expData = Seq[String](null, null, null) + testRead(spark.read.schema(userSchema).json(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchema).json(dir), expData, userSchema) + testRead(spark.read.schema(userSchema).json(dir, dir), expData ++ expData, userSchema) + testRead(spark.read.schema(userSchema).json(Seq(dir, dir): _*), expData ++ expData, userSchema) + } + + test("parquet - API and behavior regarding schema") { + // Writer + spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).parquet(dir) + val df = spark.read.parquet(dir) + checkAnswer(df, spark.createDataset(data).toDF()) + val schema = df.schema + + // Reader, without user specified schema + intercept[AnalysisException] { + testRead(spark.read.parquet(), Seq.empty, schema) + } + testRead(spark.read.parquet(dir), data, schema) + testRead(spark.read.parquet(dir, dir), data ++ data, schema) + testRead(spark.read.parquet(Seq(dir, dir): _*), data ++ data, schema) + // Test explicit calls to single arg method - SPARK-16009 + testRead(Option(dir).map(spark.read.parquet).get, data, schema) + + // Reader, with user specified schema, data should be nulls as schema in file different + // from user schema + val expData = Seq[String](null, null, null) + testRead(spark.read.schema(userSchema).parquet(), Seq.empty, userSchema) + testRead(spark.read.schema(userSchema).parquet(dir), expData, userSchema) + testRead(spark.read.schema(userSchema).parquet(dir, dir), expData ++ expData, userSchema) + testRead( + spark.read.schema(userSchema).parquet(Seq(dir, dir): _*), expData ++ expData, userSchema) + } + + /** + * This only tests whether API compiles, but does not run it as orc() + * cannot be run without Hive classes. + */ + ignore("orc - API") { + // Reader, with user specified schema + // Refer to csv-specific test suites for behavior without user specified schema + spark.read.schema(userSchema).orc() + spark.read.schema(userSchema).orc(dir) + spark.read.schema(userSchema).orc(dir, dir, dir) + spark.read.schema(userSchema).orc(Seq(dir, dir): _*) + Option(dir).map(spark.read.schema(userSchema).orc) + + // Writer + spark.range(10).write.orc(dir) + } + + test("SPARK-17230: write out results of decimal calculation") { + val df = spark.range(99, 101) + .selectExpr("id", "cast(id as long) * cast('1.0' as decimal(38, 18)) as num") + df.write.mode(SaveMode.Overwrite).parquet(dir) + val df2 = spark.read.parquet(dir) + checkAnswer(df2, df) + } + + private def testRead( + df: => DataFrame, + expectedResult: Seq[String], + expectedSchema: StructType): Unit = { + checkAnswer(df, spark.createDataset(expectedResult).toDF()) + assert(df.schema === expectedSchema) + } + + test("saveAsTable with mode Append should not fail if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Append should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + sql("CREATE TABLE same_name(id LONG) USING parquet") + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Append).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } + + test("saveAsTable with mode ErrorIfExists should not fail if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.ErrorIfExists).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Overwrite should not drop the temp view if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Overwrite).saveAsTable("same_name") + assert(spark.sessionState.catalog.getTempView("same_name").isDefined) + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Overwrite should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + sql("CREATE TABLE same_name(id LONG) USING parquet") + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Overwrite).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } + + test("saveAsTable with mode Ignore should create the table if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.mode(SaveMode.Ignore).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 7fa6760b71c8b..0cfe260e52152 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -20,17 +20,17 @@ package org.apache.spark.sql.test import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} /** * A collection of sample data used in SQL tests. */ private[sql] trait SQLTestData { self => - protected def sqlContext: SQLContext + protected def spark: SparkSession // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.sqlContext + protected override def _sqlContext: SQLContext = self.spark.sqlContext } import internalImplicits._ @@ -39,173 +39,173 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. protected lazy val emptyTestData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() - df.registerTempTable("emptyTestData") + df.createOrReplaceTempView("emptyTestData") df } protected lazy val testData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() - df.registerTempTable("testData") + df.createOrReplaceTempView("testData") df } protected lazy val testData2: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: TestData2(3, 2) :: Nil, 2).toDF() - df.registerTempTable("testData2") + df.createOrReplaceTempView("testData2") df } protected lazy val testData3: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( TestData3(1, None) :: TestData3(2, Some(2)) :: Nil).toDF() - df.registerTempTable("testData3") + df.createOrReplaceTempView("testData3") df } protected lazy val negativeData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() - df.registerTempTable("negativeData") + df.createOrReplaceTempView("negativeData") df } protected lazy val largeAndSmallInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: LargeAndSmallInts(3, 2) :: Nil).toDF() - df.registerTempTable("largeAndSmallInts") + df.createOrReplaceTempView("largeAndSmallInts") df } protected lazy val decimalData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, 2) :: DecimalData(2, 1) :: DecimalData(2, 2) :: DecimalData(3, 1) :: DecimalData(3, 2) :: Nil).toDF() - df.registerTempTable("decimalData") + df.createOrReplaceTempView("decimalData") df } protected lazy val binaryData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) :: BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) :: BinaryData("121".getBytes(StandardCharsets.UTF_8), 2) :: BinaryData("123".getBytes(StandardCharsets.UTF_8), 4) :: Nil).toDF() - df.registerTempTable("binaryData") + df.createOrReplaceTempView("binaryData") df } protected lazy val upperCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( UpperCaseData(1, "A") :: UpperCaseData(2, "B") :: UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: UpperCaseData(6, "F") :: Nil).toDF() - df.registerTempTable("upperCaseData") + df.createOrReplaceTempView("upperCaseData") df } protected lazy val lowerCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: LowerCaseData(4, "d") :: Nil).toDF() - df.registerTempTable("lowerCaseData") + df.createOrReplaceTempView("lowerCaseData") df } protected lazy val arrayData: RDD[ArrayData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) - rdd.toDF().registerTempTable("arrayData") + rdd.toDF().createOrReplaceTempView("arrayData") rdd } protected lazy val mapData: RDD[MapData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: MapData(Map(1 -> "a4", 2 -> "b4")) :: MapData(Map(1 -> "a5")) :: Nil) - rdd.toDF().registerTempTable("mapData") + rdd.toDF().createOrReplaceTempView("mapData") rdd } protected lazy val repeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - rdd.toDF().registerTempTable("repeatedData") + val rdd = spark.sparkContext.parallelize(List.fill(2)(StringData("test"))) + rdd.toDF().createOrReplaceTempView("repeatedData") rdd } protected lazy val nullableRepeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) - rdd.toDF().registerTempTable("nullableRepeatedData") + rdd.toDF().createOrReplaceTempView("nullableRepeatedData") rdd } protected lazy val nullInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullInts(1) :: NullInts(2) :: NullInts(3) :: NullInts(null) :: Nil).toDF() - df.registerTempTable("nullInts") + df.createOrReplaceTempView("nullInts") df } protected lazy val allNulls: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullInts(null) :: NullInts(null) :: NullInts(null) :: NullInts(null) :: Nil).toDF() - df.registerTempTable("allNulls") + df.createOrReplaceTempView("allNulls") df } protected lazy val nullStrings: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil).toDF() - df.registerTempTable("nullStrings") + df.createOrReplaceTempView("nullStrings") df } protected lazy val tableName: DataFrame = { - val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() - df.registerTempTable("tableName") + val df = spark.sparkContext.parallelize(TableName("test") :: Nil).toDF() + df.createOrReplaceTempView("tableName") df } protected lazy val unparsedStrings: RDD[String] = { - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( "1, A1, true, null" :: "2, B2, false, null" :: "3, C3, true, null" :: @@ -214,44 +214,44 @@ private[sql] trait SQLTestData { self => // An RDD with 4 elements and 8 partitions protected lazy val withEmptyParts: RDD[IntField] = { - val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) - rdd.toDF().registerTempTable("withEmptyParts") + val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8) + rdd.toDF().createOrReplaceTempView("withEmptyParts") rdd } protected lazy val person: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Person(0, "mike", 30) :: Person(1, "jim", 20) :: Nil).toDF() - df.registerTempTable("person") + df.createOrReplaceTempView("person") df } protected lazy val salary: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil).toDF() - df.registerTempTable("salary") + df.createOrReplaceTempView("salary") df } protected lazy val complexData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: Nil).toDF() - df.registerTempTable("complexData") + df.createOrReplaceTempView("complexData") df } protected lazy val courseSales: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( CourseSales("dotNET", 2012, 10000) :: CourseSales("Java", 2012, 20000) :: CourseSales("dotNET", 2012, 5000) :: CourseSales("dotNET", 2013, 48000) :: CourseSales("Java", 2013, 30000) :: Nil).toDF() - df.registerTempTable("courseSales") + df.createOrReplaceTempView("courseSales") df } @@ -259,7 +259,7 @@ private[sql] trait SQLTestData { self => * Initialize all test data such that all temp tables are properly registered. */ def loadTestData(): Unit = { - assert(sqlContext != null, "attempted to initialize test data before SQLContext.") + assert(spark != null, "attempted to initialize test data before SparkSession.") emptyTestData testData testData2 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ffb206af0e906..d4d8e3e4e83d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -22,6 +22,7 @@ import java.util.UUID import scala.language.implicitConversions import scala.util.Try +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll @@ -29,11 +30,12 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec -import org.apache.spark.util.Utils +import org.apache.spark.util.{UninterruptibleThread, Utils} /** * Helper trait that should be extended by all SQL test suites. @@ -50,23 +52,23 @@ private[sql] trait SQLTestUtils with BeforeAndAfterAll with SQLTestData { self => - protected def sparkContext = sqlContext.sparkContext + protected def sparkContext = spark.sparkContext // Whether to materialize all test data before the first test is run private var loadTestDataBeforeTests = false // Shorthand for running a query using our SQLContext - protected lazy val sql = sqlContext.sql _ + protected lazy val sql = spark.sql _ /** * A helper object for importing SQL implicits. * - * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * Note that the alternative of importing `spark.implicits._` is not possible here. * This is because we create the [[SQLContext]] immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.sqlContext + protected override def _sqlContext: SQLContext = self.spark.sqlContext } /** @@ -92,12 +94,12 @@ private[sql] trait SQLTestUtils */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) } } } @@ -138,9 +140,9 @@ private[sql] trait SQLTestUtils // temp tables that never got created. functions.foreach { case (functionName, isTemporary) => val withTemporary = if (isTemporary) "TEMPORARY" else "" - sqlContext.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") assert( - !sqlContext.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), s"Function $functionName should have been dropped. But, it still exists.") } } @@ -149,11 +151,11 @@ private[sql] trait SQLTestUtils /** * Drops temporary table `tableName` after calling `f`. */ - protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { + protected def withTempView(tableNames: String*)(f: => Unit): Unit = { try f finally { // If the test failed part way, we don't want to mask the failure by failing to remove // temp tables that never got created. - try tableNames.foreach(sqlContext.dropTempTable) catch { + try tableNames.foreach(spark.catalog.dropTempView) catch { case _: NoSuchTableException => } } @@ -165,7 +167,7 @@ private[sql] trait SQLTestUtils protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - sqlContext.sql(s"DROP TABLE IF EXISTS $name") + spark.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -176,7 +178,7 @@ private[sql] trait SQLTestUtils protected def withView(viewNames: String*)(f: => Unit): Unit = { try f finally { viewNames.foreach { name => - sqlContext.sql(s"DROP VIEW IF EXISTS $name") + spark.sql(s"DROP VIEW IF EXISTS $name") } } } @@ -191,12 +193,17 @@ private[sql] trait SQLTestUtils val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - sqlContext.sql(s"CREATE DATABASE $dbName") + spark.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally { + if (spark.catalog.currentDatabase == dbName) { + spark.sql(s"USE ${DEFAULT_DATABASE}") + } + spark.sql(s"DROP DATABASE $dbName CASCADE") + } } /** @@ -204,8 +211,8 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sessionState.catalog.setCurrentDatabase(db) - try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default") + spark.sessionState.catalog.setCurrentDatabase(db) + try f finally spark.sessionState.catalog.setCurrentDatabase("default") } /** @@ -213,15 +220,11 @@ private[sql] trait SQLTestUtils */ protected def stripSparkFilter(df: DataFrame): DataFrame = { val schema = df.schema - val withoutFilters = df.queryExecution.sparkPlan transform { + val withoutFilters = df.queryExecution.sparkPlan.transform { case FilterExec(_, child) => child } - val childRDD = withoutFilters - .execute() - .map(row => Row.fromSeq(row.copy().toSeq(schema))) - - sqlContext.createDataFrame(childRDD, schema) + spark.internalCreateDataFrame(withoutFilters.execute(), schema) } /** @@ -229,7 +232,7 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, plan) + Dataset.ofRows(spark, plan) } /** @@ -245,6 +248,46 @@ private[sql] trait SQLTestUtils } } } + + /** Run a test on a separate [[UninterruptibleThread]]. */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } } private[sql] object SQLTestUtils { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 914c6a550900a..79c37faa4e9ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,37 +17,42 @@ package org.apache.spark.sql.test -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.SparkConf +import org.apache.spark.sql.{SparkSession, SQLContext} /** - * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ trait SharedSQLContext extends SQLTestUtils { protected val sparkConf = new SparkConf() /** - * The [[TestSQLContext]] to use for all tests in this suite. + * The [[TestSparkSession]] to use for all tests in this suite. * * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local * mode with the default test configurations. */ - private var _ctx: TestSQLContext = null + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected implicit def sqlContext: SQLContext = _ctx + protected implicit def sqlContext: SQLContext = _spark.sqlContext /** - * Initialize the [[TestSQLContext]]. + * Initialize the [[TestSparkSession]]. */ protected override def beforeAll(): Unit = { - SQLContext.clearSqlListener() - if (_ctx == null) { - _ctx = new TestSQLContext(sparkConf) + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = new TestSparkSession(sparkConf) } // Ensure we have initialized the context before calling parent code super.beforeAll() @@ -58,9 +63,9 @@ trait SharedSQLContext extends SQLTestUtils { */ protected override def afterAll(): Unit = { try { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = null + if (_spark != null) { + _spark.stop() + _spark = null } } finally { super.afterAll() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 5ef80b9aa35ab..2f247ca3e8b7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,21 +18,13 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.{SessionState, SQLConf} /** - * A special [[SQLContext]] prepared for testing. + * A special [[SparkSession]] prepared for testing. */ -private[sql] class TestSQLContext( - @transient override val sparkSession: SparkSession, - isRootContext: Boolean) - extends SQLContext(sparkSession, isRootContext) { self => - - def this(sc: SparkContext) { - this(new TestSparkSession(sc), true) - } - +private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { this(new SparkContext("local[2]", "test-sql-context", sparkConf.set("spark.sql.testkey", "true"))) @@ -42,20 +34,6 @@ private[sql] class TestSQLContext( this(new SparkConf) } - // Needed for Java tests - def loadTestData(): Unit = { - testData.loadTestData() - } - - private object testData extends SQLTestData { - protected override def sqlContext: SQLContext = self - } - -} - - -private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => - @transient protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { override lazy val conf: SQLConf = { @@ -70,6 +48,14 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } } + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() + } + + private object testData extends SQLTestData { + protected override def spark: SparkSession = self + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala deleted file mode 100644 index 3498fe83d02eb..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.util - -import java.util.concurrent.ConcurrentLinkedQueue - -import scala.util.control.NonFatal - -import org.scalatest.BeforeAndAfter -import org.scalatest.PrivateMethodTester._ -import org.scalatest.concurrent.AsyncAssertions.Waiter -import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.PatienceConfiguration.Timeout -import org.scalatest.time.SpanSugar._ - -import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.util.ContinuousQueryListener.{QueryProgress, QueryStarted, QueryTerminated} - -class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { - - import testImplicits._ - - after { - sqlContext.streams.active.foreach(_.stop()) - assert(sqlContext.streams.active.isEmpty) - assert(addedListeners.isEmpty) - // Make sure we don't leak any events to the next test - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) - } - - test("single listener") { - val listener = new QueryStatusCollector - val input = MemoryStream[Int] - withListenerAdded(listener) { - testStream(input.toDS)( - StartStream, - Assert("Incorrect query status in onQueryStarted") { - val status = listener.startStatus - assert(status != null) - assert(status.active == true) - assert(status.sourceStatuses.size === 1) - assert(status.sourceStatuses(0).description.contains("Memory")) - - // The source and sink offsets must be None as this must be called before the - // batches have started - assert(status.sourceStatuses(0).offset === None) - assert(status.sinkStatus.offset === CompositeOffset(None :: Nil)) - - // No progress events or termination events - assert(listener.progressStatuses.isEmpty) - assert(listener.terminationStatus === null) - }, - AddDataMemory(input, Seq(1, 2, 3)), - CheckAnswer(1, 2, 3), - Assert("Incorrect query status in onQueryProgress") { - eventually(Timeout(streamingTimeout)) { - - // There should be only on progress event as batch has been processed - assert(listener.progressStatuses.size === 1) - val status = listener.progressStatuses.peek() - assert(status != null) - assert(status.active == true) - assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) - assert(status.sinkStatus.offset === CompositeOffset.fill(LongOffset(0))) - - // No termination events - assert(listener.terminationStatus === null) - } - }, - StopStream, - Assert("Incorrect query status in onQueryTerminated") { - eventually(Timeout(streamingTimeout)) { - val status = listener.terminationStatus - assert(status != null) - - assert(status.active === false) // must be inactive by the time onQueryTerm is called - assert(status.sourceStatuses(0).offset === Some(LongOffset(0))) - assert(status.sinkStatus.offset === CompositeOffset.fill(LongOffset(0))) - } - listener.checkAsyncErrors() - } - ) - } - } - - test("adding and removing listener") { - def isListenerActive(listener: QueryStatusCollector): Boolean = { - listener.reset() - testStream(MemoryStream[Int].toDS)( - StartStream, - StopStream - ) - listener.startStatus != null - } - - try { - val listener1 = new QueryStatusCollector - val listener2 = new QueryStatusCollector - - sqlContext.streams.addListener(listener1) - assert(isListenerActive(listener1) === true) - assert(isListenerActive(listener2) === false) - sqlContext.streams.addListener(listener2) - assert(isListenerActive(listener1) === true) - assert(isListenerActive(listener2) === true) - sqlContext.streams.removeListener(listener1) - assert(isListenerActive(listener1) === false) - assert(isListenerActive(listener2) === true) - } finally { - addedListeners.foreach(sqlContext.streams.removeListener) - } - } - - test("event ordering") { - val listener = new QueryStatusCollector - withListenerAdded(listener) { - for (i <- 1 to 100) { - listener.reset() - require(listener.startStatus === null) - testStream(MemoryStream[Int].toDS)( - StartStream, - Assert(listener.startStatus !== null, "onQueryStarted not called before query returned"), - StopStream, - Assert { listener.checkAsyncErrors() } - ) - } - } - } - - - private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = { - try { - failAfter(1 minute) { - sqlContext.streams.addListener(listener) - body - } - } finally { - sqlContext.streams.removeListener(listener) - } - } - - private def addedListeners(): Array[ContinuousQueryListener] = { - val listenerBusMethod = - PrivateMethod[ContinuousQueryListenerBus]('listenerBus) - val listenerBus = sqlContext.streams invokePrivate listenerBusMethod() - listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener]) - } - - class QueryStatusCollector extends ContinuousQueryListener { - - private val asyncTestWaiter = new Waiter // to catch errors in the async listener events - - @volatile var startStatus: QueryStatus = null - @volatile var terminationStatus: QueryStatus = null - val progressStatuses = new ConcurrentLinkedQueue[QueryStatus] - - def reset(): Unit = { - startStatus = null - terminationStatus = null - progressStatuses.clear() - - // To reset the waiter - try asyncTestWaiter.await(timeout(1 milliseconds)) catch { - case NonFatal(e) => - } - } - - def checkAsyncErrors(): Unit = { - asyncTestWaiter.await(timeout(streamingTimeout)) - } - - - override def onQueryStarted(queryStarted: QueryStarted): Unit = { - asyncTestWaiter { - startStatus = QueryStatus(queryStarted.query) - } - } - - override def onQueryProgress(queryProgress: QueryProgress): Unit = { - asyncTestWaiter { - assert(startStatus != null, "onQueryProgress called before onQueryStarted") - progressStatuses.add(QueryStatus(queryProgress.query)) - } - } - - override def onQueryTerminated(queryTerminated: QueryTerminated): Unit = { - asyncTestWaiter { - assert(startStatus != null, "onQueryTerminated called before onQueryStarted") - terminationStatus = QueryStatus(queryTerminated.query) - } - asyncTestWaiter.dismiss() - } - } - - case class QueryStatus( - active: Boolean, - exception: Option[Exception], - sourceStatuses: Array[SourceStatus], - sinkStatus: SinkStatus) - - object QueryStatus { - def apply(query: ContinuousQuery): QueryStatus = { - QueryStatus(query.isActive, query.exception, query.sourceStatuses, query.sinkStatus) - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 8a0578c1ff537..3ae5ce610d2a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -39,7 +39,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += ((funcName, qe, duration)) } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j") df.select("i").collect() @@ -55,7 +55,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) assert(metrics(1)._3 > 0) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } test("execute callback functions when a DataFrame action failed") { @@ -68,7 +68,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { // Only test failed case here, so no need to implement `onSuccess` override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") @@ -82,7 +82,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(0)._2.analyzed.isInstanceOf[Project]) assert(metrics(0)._3.getMessage == e.getMessage) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } test("get numRows metrics by callback") { @@ -99,7 +99,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += metric.value } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() df.collect() @@ -111,7 +111,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(1) === 1) assert(metrics(2) === 2) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } // TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never @@ -131,10 +131,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += bottomAgg.longMetric("dataSize").value } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val sparkListener = new SaveInfoListener - sqlContext.sparkContext.addSparkListener(sparkListener) + spark.sparkContext.addSparkListener(sparkListener) val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j") df.groupBy("i").count().collect() @@ -157,6 +157,6 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(0) == topAggDataSize) assert(metrics(1) == bottomAggDataSize) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } } diff --git a/sql/hive-thriftserver/if/TCLIService.thrift b/sql/hive-thriftserver/if/TCLIService.thrift index baf583fb3ecd4..7cd6fa37cec37 100644 --- a/sql/hive-thriftserver/if/TCLIService.thrift +++ b/sql/hive-thriftserver/if/TCLIService.thrift @@ -661,7 +661,7 @@ union TGetInfoValue { // The function returns general information about the data source // using the same keys as ODBC. struct TGetInfoReq { - // The sesssion to run this request against + // The session to run this request against 1: required TSessionHandle sessionHandle 2: required TGetInfoType infoType @@ -1032,7 +1032,7 @@ enum TFetchOrientation { FETCH_PRIOR, // Return the rowset at the given fetch offset relative - // to the curren rowset. + // to the current rowset. // NOT SUPPORTED FETCH_RELATIVE, diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 987103b1a9536..c0d8ca53ed2da 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -64,24 +64,17 @@ ${hive.group} hive-beeline - - com.sun.jersey - jersey-core - - - com.sun.jersey - jersey-json - - - com.sun.jersey - jersey-server - org.seleniumhq.selenium selenium-java test + + org.seleniumhq.selenium + selenium-htmlunit-driver + test + org.apache.spark spark-sql_${scala.binary.version} @@ -91,7 +84,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} net.sf.jpam @@ -118,12 +111,6 @@ - - - - org.codehaus.mojo - build-helper-maven-plugin - add-source generate-sources diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java index 2111837cac8d1..b95077cd62186 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/Service.java @@ -29,7 +29,7 @@ public interface Service { /** * Service states */ - public enum STATE { + enum STATE { /** Constructed but not initialized */ NOTINITED, diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceStateChangeListener.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceStateChangeListener.java index 16ad9a991e921..d1aadad04cf67 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceStateChangeListener.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceStateChangeListener.java @@ -29,7 +29,7 @@ public interface ServiceStateChangeListener { * have changed state before this callback is invoked. * * This operation is invoked on the thread that initiated the state change, - * while the service itself in in a sychronized section. + * while the service itself in in a synchronized section. *
        *
      1. Any long-lived operation here will prevent the service state * change from completing in a timely manner.
      2. diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceUtils.java index e712aaf2348f8..edb5eff9615bf 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/ServiceUtils.java @@ -41,4 +41,4 @@ public static int indexOfDomainMatch(String userName) { } return endIdx; } -} \ No newline at end of file +} diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java index 3ef55779a6bde..5021528299682 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/HttpAuthUtils.java @@ -56,7 +56,7 @@ public final class HttpAuthUtils { private static final String COOKIE_CLIENT_USER_NAME = "cu"; private static final String COOKIE_CLIENT_RAND_NUMBER = "rn"; private static final String COOKIE_KEY_VALUE_SEPARATOR = "="; - private final static Set COOKIE_ATTRIBUTES = + private static final Set COOKIE_ATTRIBUTES = new HashSet(Arrays.asList(COOKIE_CLIENT_USER_NAME, COOKIE_CLIENT_RAND_NUMBER)); /** @@ -93,10 +93,10 @@ public static String getKerberosServiceTicket(String principal, String host, */ public static String createCookieToken(String clientUserName) { StringBuffer sb = new StringBuffer(); - sb.append(COOKIE_CLIENT_USER_NAME).append(COOKIE_KEY_VALUE_SEPARATOR).append(clientUserName). - append(COOKIE_ATTR_SEPARATOR); - sb.append(COOKIE_CLIENT_RAND_NUMBER).append(COOKIE_KEY_VALUE_SEPARATOR). - append((new Random(System.currentTimeMillis())).nextLong()); + sb.append(COOKIE_CLIENT_USER_NAME).append(COOKIE_KEY_VALUE_SEPARATOR).append(clientUserName) + .append(COOKIE_ATTR_SEPARATOR); + sb.append(COOKIE_CLIENT_RAND_NUMBER).append(COOKIE_KEY_VALUE_SEPARATOR) + .append((new Random(System.currentTimeMillis())).nextLong()); return sb.toString(); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/KerberosSaslHelper.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/KerberosSaslHelper.java index 11d26699fe78d..52eb752f1e026 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/KerberosSaslHelper.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/KerberosSaslHelper.java @@ -96,7 +96,7 @@ private static class CLIServiceProcessorFactory extends TProcessorFactory { private final ThriftCLIService service; private final Server saslServer; - public CLIServiceProcessorFactory(Server saslServer, ThriftCLIService service) { + CLIServiceProcessorFactory(Server saslServer, ThriftCLIService service) { super(null); this.service = service; this.saslServer = saslServer; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java index 479ebf32cec3f..ab3ac6285aa02 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/auth/SaslQOP.java @@ -25,10 +25,12 @@ * Possible values of SASL quality-of-protection value. */ public enum SaslQOP { - AUTH("auth"), // Authentication only. - AUTH_INT("auth-int"), // Authentication and integrity checking by using signatures. - AUTH_CONF("auth-conf"); // Authentication, integrity and confidentiality checking - // by using signatures and encryption. + // Authentication only. + AUTH("auth"), + // Authentication and integrity checking by using signatures. + AUTH_INT("auth-int"), + // Authentication, integrity and confidentiality checking by using signatures and encryption. + AUTH_CONF("auth-conf"); public final String saslQop; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java index a3af7b2d662b3..791ddcbd2c5b6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/CLIService.java @@ -481,8 +481,8 @@ public synchronized String getDelegationTokenFromMetaStore(String owner) @Override public String getDelegationToken(SessionHandle sessionHandle, HiveAuthFactory authFactory, String owner, String renewer) throws HiveSQLException { - String delegationToken = sessionManager.getSession(sessionHandle). - getDelegationToken(authFactory, owner, renewer); + String delegationToken = sessionManager.getSession(sessionHandle) + .getDelegationToken(authFactory, owner, renewer); LOG.info(sessionHandle + ": getDelegationToken()"); return delegationToken; } @@ -490,8 +490,7 @@ public String getDelegationToken(SessionHandle sessionHandle, HiveAuthFactory au @Override public void cancelDelegationToken(SessionHandle sessionHandle, HiveAuthFactory authFactory, String tokenStr) throws HiveSQLException { - sessionManager.getSession(sessionHandle). - cancelDelegationToken(authFactory, tokenStr); + sessionManager.getSession(sessionHandle).cancelDelegationToken(authFactory, tokenStr); LOG.info(sessionHandle + ": cancelDelegationToken()"); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/HiveSQLException.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/HiveSQLException.java index 1334dde66375b..86e57fbf31fe0 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/HiveSQLException.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/HiveSQLException.java @@ -111,7 +111,7 @@ public HiveSQLException(TStatus status) { /** * Converts current object to a {@link TStatus} object - * @return a {@link TStatus} object + * @return a {@link TStatus} object */ public TStatus toTStatus() { // TODO: convert sqlState, etc. @@ -125,8 +125,8 @@ public TStatus toTStatus() { /** * Converts the specified {@link Exception} object into a {@link TStatus} object - * @param e a {@link Exception} object - * @return a {@link TStatus} object + * @param e a {@link Exception} object + * @return a {@link TStatus} object */ public static TStatus toTStatus(Exception e) { if (e instanceof HiveSQLException) { @@ -155,7 +155,8 @@ private static List toString(Throwable cause, StackTraceElement[] parent if (parent != null) { int n = parent.length - 1; while (m >= 0 && n >= 0 && trace[m].equals(parent[n])) { - m--; n--; + m--; + n--; } } List detail = enroll(cause, trace, m); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/OperationState.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/OperationState.java index 51ffb40369b24..1165180118413 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/OperationState.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/OperationState.java @@ -85,6 +85,7 @@ public static void validateTransition(OperationState oldState, if (OperationState.CLOSED.equals(newState)) { return; } + break; default: // fall-through } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/RowBasedSet.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/RowBasedSet.java index a0ee2109dc540..7452137f077db 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/RowBasedSet.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/RowBasedSet.java @@ -130,8 +130,8 @@ public void remove() { } private static class RemovableList extends ArrayList { - public RemovableList() { super(); } - public RemovableList(List rows) { super(rows); } + RemovableList() { super(); } + RemovableList(List rows) { super(rows); } @Override public void removeRange(int fromIndex, int toIndex) { super.removeRange(fromIndex, toIndex); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java index 87ac39b05186a..05a6bf938404b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/ClassicTableTypeMapping.java @@ -42,7 +42,7 @@ public enum ClassicTableTypes { private final Map hiveToClientMap = new HashMap(); private final Map clientToHiveMap = new HashMap(); - public ClassicTableTypeMapping () { + public ClassicTableTypeMapping() { hiveToClientMap.put(TableType.MANAGED_TABLE.toString(), ClassicTableTypes.TABLE.toString()); hiveToClientMap.put(TableType.EXTERNAL_TABLE.toString(), diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetColumnsOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetColumnsOperation.java index 309f10f640f96..5efb0759383ac 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetColumnsOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetColumnsOperation.java @@ -27,10 +27,8 @@ import java.util.Map.Entry; import java.util.regex.Pattern; -import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.metastore.IMetaStoreClient; import org.apache.hadoop.hive.metastore.api.Table; -import org.apache.hadoop.hive.ql.plan.HiveOperation; import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType; import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObject; import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObject.HivePrivilegeObjectType; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetFunctionsOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetFunctionsOperation.java index 6df1e8a227f3e..5273c386b83d4 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetFunctionsOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetFunctionsOperation.java @@ -23,7 +23,6 @@ import java.util.Set; import org.apache.hadoop.hive.metastore.IMetaStoreClient; -import org.apache.hadoop.hive.metastore.api.MetaException; import org.apache.hadoop.hive.ql.exec.FunctionInfo; import org.apache.hadoop.hive.ql.exec.FunctionRegistry; import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType; @@ -103,7 +102,7 @@ public void runInternal() throws HiveSQLException { .getFunctionNames(CLIServiceUtils.patternToRegex(functionName)); for (String functionName : functionNames) { FunctionInfo functionInfo = FunctionRegistry.getFunctionInfo(functionName); - Object rowData[] = new Object[] { + Object[] rowData = new Object[] { null, // FUNCTION_CAT null, // FUNCTION_SCHEM functionInfo.getDisplayName(), // FUNCTION_NAME diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java index e56686abb7c5a..d6f6280f1c398 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetSchemasOperation.java @@ -18,16 +18,8 @@ package org.apache.hive.service.cli.operation; -import java.util.ArrayList; -import java.util.List; - import org.apache.hadoop.hive.metastore.IMetaStoreClient; -import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveAccessControlException; -import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveAuthzContext; -import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveAuthzPluginException; import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType; -import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObject; -import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hive.service.cli.FetchOrientation; import org.apache.hive.service.cli.HiveSQLException; import org.apache.hive.service.cli.OperationState; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTableTypesOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTableTypesOperation.java index a09b39a4e0855..3ae012a72764f 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTableTypesOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTableTypesOperation.java @@ -44,8 +44,8 @@ public class GetTableTypesOperation extends MetadataOperation { protected GetTableTypesOperation(HiveSession parentSession) { super(parentSession, OperationType.GET_TABLE_TYPES); - String tableMappingStr = getParentSession().getHiveConf(). - getVar(HiveConf.ConfVars.HIVE_SERVER2_TABLE_TYPE_MAPPING); + String tableMappingStr = getParentSession().getHiveConf() + .getVar(HiveConf.ConfVars.HIVE_SERVER2_TABLE_TYPE_MAPPING); tableTypeMapping = TableTypeMappingFactory.getTableTypeMapping(tableMappingStr); rowSet = RowSetFactory.create(RESULT_SET_SCHEMA, getProtocolVersion()); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java index 0e2fdc657c4fa..1a7ca79163d71 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java @@ -64,8 +64,8 @@ protected GetTablesOperation(HiveSession parentSession, this.catalogName = catalogName; this.schemaName = schemaName; this.tableName = tableName; - String tableMappingStr = getParentSession().getHiveConf(). - getVar(HiveConf.ConfVars.HIVE_SERVER2_TABLE_TYPE_MAPPING); + String tableMappingStr = getParentSession().getHiveConf() + .getVar(HiveConf.ConfVars.HIVE_SERVER2_TABLE_TYPE_MAPPING); tableTypeMapping = TableTypeMappingFactory.getTableTypeMapping(tableMappingStr); if (tableTypes != null) { diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java index 2a0fec27715da..0f72071d7e7d1 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTypeInfoOperation.java @@ -35,7 +35,7 @@ */ public class GetTypeInfoOperation extends MetadataOperation { - private final static TableSchema RESULT_SET_SCHEMA = new TableSchema() + private static final TableSchema RESULT_SET_SCHEMA = new TableSchema() .addPrimitiveColumn("TYPE_NAME", Type.STRING_TYPE, "Type name") .addPrimitiveColumn("DATA_TYPE", Type.INT_TYPE, diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java index 70340bd13cbc1..cb804318ace9c 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/LogDivertAppender.java @@ -60,15 +60,15 @@ private static class NameFilter extends Filter { /* Patterns that are excluded in verbose logging level. * Filter out messages coming from log processing classes, or we'll run an infinite loop. */ - private static final Pattern verboseExcludeNamePattern = Pattern.compile(Joiner.on("|"). - join(new String[] {LOG.getName(), OperationLog.class.getName(), + private static final Pattern verboseExcludeNamePattern = Pattern.compile(Joiner.on("|") + .join(new String[] {LOG.getName(), OperationLog.class.getName(), OperationManager.class.getName()})); /* Patterns that are included in execution logging level. * In execution mode, show only select logger messages. */ - private static final Pattern executionIncludeNamePattern = Pattern.compile(Joiner.on("|"). - join(new String[] {"org.apache.hadoop.mapreduce.JobSubmitter", + private static final Pattern executionIncludeNamePattern = Pattern.compile(Joiner.on("|") + .join(new String[] {"org.apache.hadoop.mapreduce.JobSubmitter", "org.apache.hadoop.mapreduce.Job", "SessionState", Task.class.getName(), "org.apache.hadoop.hive.ql.exec.spark.status.SparkJobMonitor"})); @@ -88,7 +88,7 @@ private void setCurrentNamePattern(OperationLog.LoggingLevel mode) { } } - public NameFilter( + NameFilter( OperationLog.LoggingLevel loggingMode, OperationManager op) { this.operationManager = op; this.loggingMode = loggingMode; @@ -131,7 +131,7 @@ public int decide(LoggingEvent ev) { /** This is where the log message will go to */ private final CharArrayWriter writer = new CharArrayWriter(); - private void setLayout (boolean isVerbose, Layout lo) { + private void setLayout(boolean isVerbose, Layout lo) { if (isVerbose) { if (lo == null) { lo = CLIServiceUtils.verboseLayout; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/MetadataOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/MetadataOperation.java index 4595ef56fcee4..6c819876a556d 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/MetadataOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/MetadataOperation.java @@ -18,7 +18,6 @@ package org.apache.hive.service.cli.operation; -import java.util.ArrayList; import java.util.List; import org.apache.hadoop.hive.conf.HiveConf; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java index 33ee16b80beb2..5014cedd870b6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/SQLOperation.java @@ -326,7 +326,7 @@ public TableSchema getResultSetSchema() throws HiveSQLException { return resultSchema; } - private transient final List convey = new ArrayList(); + private final transient List convey = new ArrayList(); @Override public RowSet getNextRowSet(FetchOrientation orientation, long maxRows) throws HiveSQLException { @@ -456,7 +456,7 @@ private SerDe getSerDe() throws SQLException { private HiveConf getConfigForOperation() throws HiveSQLException { HiveConf sqlOperationConf = getParentSession().getHiveConf(); if (!getConfOverlay().isEmpty() || shouldRunAsync()) { - // clone the partent session config for this query + // clone the parent session config for this query sqlOperationConf = new HiveConf(sqlOperationConf); // apply overlay query specific settings, if any diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java index 3a8a07f44f20d..e392c459cf586 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/TableTypeMapping.java @@ -27,18 +27,18 @@ public interface TableTypeMapping { * @param clientTypeName * @return */ - public String mapToHiveType (String clientTypeName); + String mapToHiveType(String clientTypeName); /** * Map hive's table type name to client's table type * @param clientTypeName * @return */ - public String mapToClientType (String hiveTypeName); + String mapToClientType(String hiveTypeName); /** * Get all the table types of this mapping * @return */ - public Set getTableTypeNames(); + Set getTableTypeNames(); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionBase.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionBase.java index 9b04d679df1c7..b72c18b2b2135 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionBase.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionBase.java @@ -18,8 +18,6 @@ package org.apache.hive.service.cli.session; -import java.util.Map; - import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.session.SessionState; import org.apache.hive.service.cli.SessionHandle; @@ -27,7 +25,6 @@ import org.apache.hive.service.cli.thrift.TProtocolVersion; import java.io.File; -import java.util.Map; /** * Methods that don't need to be executed under a doAs diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionHook.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionHook.java deleted file mode 100644 index 06388cc795b90..0000000000000 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionHook.java +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.hive.service.cli.session; - -import org.apache.hadoop.hive.ql.hooks.Hook; -import org.apache.hive.service.cli.HiveSQLException; - -/** - * HiveSessionHook. - * HiveServer2 session level Hook interface. The run method is executed - * when session manager starts a new session - * - */ -public interface HiveSessionHook extends Hook { - - /** - * @param sessionHookContext context - * @throws HiveSQLException - */ - public void run(HiveSessionHookContext sessionHookContext) throws HiveSQLException; -} diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionHookContext.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionHookContext.java index 156c8147f9d93..c56a107d42466 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionHookContext.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionHookContext.java @@ -22,7 +22,7 @@ /** * HiveSessionHookContext. * Interface passed to the HiveServer2 session hook execution. This enables - * the hook implementation to accesss session config, user and session handle + * the hook implementation to access session config, user and session handle */ public interface HiveSessionHookContext { @@ -30,17 +30,17 @@ public interface HiveSessionHookContext { * Retrieve session conf * @return */ - public HiveConf getSessionConf(); + HiveConf getSessionConf(); /** * The get the username starting the session * @return */ - public String getSessionUser(); + String getSessionUser(); /** * Retrieve handle for the session * @return */ - public String getSessionHandle(); + String getSessionHandle(); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java index cc3e807e7a840..47bfaa86021d6 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImpl.java @@ -29,7 +29,7 @@ import java.util.Set; import org.apache.commons.io.FileUtils; -import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.common.cli.HiveFileProcessor; diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImplwithUGI.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImplwithUGI.java index a29e5d1d81c17..762dbb2faadec 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImplwithUGI.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionImplwithUGI.java @@ -26,7 +26,6 @@ import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.ql.metadata.Hive; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.shims.ShimLoader; import org.apache.hadoop.hive.shims.Utils; import org.apache.hadoop.security.UserGroupInformation; import org.apache.hive.service.auth.HiveAuthFactory; @@ -83,7 +82,7 @@ public UserGroupInformation getSessionUgi() { return this.sessionUgi; } - public String getDelegationToken () { + public String getDelegationToken() { return this.delegationTokenStr; } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionProxy.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionProxy.java index 5b10521febfc2..8e539512f7410 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionProxy.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/HiveSessionProxy.java @@ -57,7 +57,7 @@ public Object invoke(Object arg0, final Method method, final Object[] args) return invoke(method, args); } return ugi.doAs( - new PrivilegedExceptionAction () { + new PrivilegedExceptionAction() { @Override public Object run() throws HiveSQLException { return invoke(method, args); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java index e31570bdfba57..de066dd406c7a 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/session/SessionManager.java @@ -22,7 +22,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Date; -import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Future; @@ -35,7 +34,6 @@ import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; -import org.apache.hadoop.hive.ql.hooks.HookUtils; import org.apache.hive.service.CompositeService; import org.apache.hive.service.cli.HiveSQLException; import org.apache.hive.service.cli.SessionHandle; @@ -151,7 +149,7 @@ public synchronized void start() { } private void startTimeoutChecker() { - final long interval = Math.max(checkInterval, 3000l); // minimum 3 seconds + final long interval = Math.max(checkInterval, 3000L); // minimum 3 seconds Runnable timeoutChecker = new Runnable() { @Override public void run() { @@ -268,17 +266,6 @@ public SessionHandle openSession(TProtocolVersion protocol, String username, Str if (isOperationLogEnabled) { session.setOperationLogSessionDir(operationLogRootDir); } - try { - executeSessionHooks(session); - } catch (Exception e) { - try { - session.close(); - } catch (Throwable t) { - LOG.warn("Error closing session", t); - } - session = null; - throw new HiveSQLException("Failed to execute session hooks", e); - } handleToSession.put(session.getSessionHandle(), session); return session.getSessionHandle(); } @@ -361,15 +348,6 @@ public static void clearProxyUserName() { threadLocalProxyUserName.remove(); } - // execute session hooks - private void executeSessionHooks(HiveSession session) throws Exception { - List sessionHooks = HookUtils.getHooks(hiveConf, - HiveConf.ConfVars.HIVE_SERVER2_SESSION_HOOK, HiveSessionHook.class); - for (HiveSessionHook sessionHook : sessionHooks) { - sessionHook.run(new HiveSessionHookContextImpl(session)); - } - } - public Future submitBackgroundOperation(Runnable r) { return backgroundOperationPool.submit(r); } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java index 5a0f1c83c70f3..ad7a9a238f8a9 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java @@ -18,6 +18,7 @@ package org.apache.hive.service.cli.thrift; +import javax.security.auth.login.LoginException; import java.io.IOException; import java.net.InetAddress; import java.net.UnknownHostException; @@ -25,8 +26,6 @@ import java.util.Map; import java.util.concurrent.TimeUnit; -import javax.security.auth.login.LoginException; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.conf.HiveConf; @@ -36,17 +35,7 @@ import org.apache.hive.service.ServiceUtils; import org.apache.hive.service.auth.HiveAuthFactory; import org.apache.hive.service.auth.TSetIpAddressProcessor; -import org.apache.hive.service.cli.CLIService; -import org.apache.hive.service.cli.FetchOrientation; -import org.apache.hive.service.cli.FetchType; -import org.apache.hive.service.cli.GetInfoType; -import org.apache.hive.service.cli.GetInfoValue; -import org.apache.hive.service.cli.HiveSQLException; -import org.apache.hive.service.cli.OperationHandle; -import org.apache.hive.service.cli.OperationStatus; -import org.apache.hive.service.cli.RowSet; -import org.apache.hive.service.cli.SessionHandle; -import org.apache.hive.service.cli.TableSchema; +import org.apache.hive.service.cli.*; import org.apache.hive.service.cli.session.SessionManager; import org.apache.hive.service.server.HiveServer2; import org.apache.thrift.TException; @@ -223,23 +212,7 @@ public InetAddress getServerIPAddress() { public TGetDelegationTokenResp GetDelegationToken(TGetDelegationTokenReq req) throws TException { TGetDelegationTokenResp resp = new TGetDelegationTokenResp(); - - if (hiveAuthFactory == null) { - resp.setStatus(unsecureTokenErrorStatus()); - } else { - try { - String token = cliService.getDelegationToken( - new SessionHandle(req.getSessionHandle()), - hiveAuthFactory, req.getOwner(), req.getRenewer()); - resp.setDelegationToken(token); - resp.setStatus(OK_STATUS); - } catch (HiveSQLException e) { - LOG.error("Error obtaining delegation token", e); - TStatus tokenErrorStatus = HiveSQLException.toTStatus(e); - tokenErrorStatus.setSqlState("42000"); - resp.setStatus(tokenErrorStatus); - } - } + resp.setStatus(notSupportTokenErrorStatus()); return resp; } @@ -247,19 +220,7 @@ public TGetDelegationTokenResp GetDelegationToken(TGetDelegationTokenReq req) public TCancelDelegationTokenResp CancelDelegationToken(TCancelDelegationTokenReq req) throws TException { TCancelDelegationTokenResp resp = new TCancelDelegationTokenResp(); - - if (hiveAuthFactory == null) { - resp.setStatus(unsecureTokenErrorStatus()); - } else { - try { - cliService.cancelDelegationToken(new SessionHandle(req.getSessionHandle()), - hiveAuthFactory, req.getDelegationToken()); - resp.setStatus(OK_STATUS); - } catch (HiveSQLException e) { - LOG.error("Error canceling delegation token", e); - resp.setStatus(HiveSQLException.toTStatus(e)); - } - } + resp.setStatus(notSupportTokenErrorStatus()); return resp; } @@ -267,25 +228,13 @@ public TCancelDelegationTokenResp CancelDelegationToken(TCancelDelegationTokenRe public TRenewDelegationTokenResp RenewDelegationToken(TRenewDelegationTokenReq req) throws TException { TRenewDelegationTokenResp resp = new TRenewDelegationTokenResp(); - if (hiveAuthFactory == null) { - resp.setStatus(unsecureTokenErrorStatus()); - } else { - try { - cliService.renewDelegationToken(new SessionHandle(req.getSessionHandle()), - hiveAuthFactory, req.getDelegationToken()); - resp.setStatus(OK_STATUS); - } catch (HiveSQLException e) { - LOG.error("Error obtaining renewing token", e); - resp.setStatus(HiveSQLException.toTStatus(e)); - } - } + resp.setStatus(notSupportTokenErrorStatus()); return resp; } - private TStatus unsecureTokenErrorStatus() { + private TStatus notSupportTokenErrorStatus() { TStatus errorStatus = new TStatus(TStatusCode.ERROR_STATUS); - errorStatus.setErrorMessage("Delegation token only supported over remote " + - "client with kerberos authentication"); + errorStatus.setErrorMessage("Delegation token is not supported"); return errorStatus; } @@ -722,8 +671,8 @@ private String getProxyUser(String realUser, Map sessionConf, } // If there's no authentication, then directly substitute the user - if (HiveAuthFactory.AuthTypes.NONE.toString(). - equalsIgnoreCase(hiveConf.getVar(ConfVars.HIVE_SERVER2_AUTHENTICATION))) { + if (HiveAuthFactory.AuthTypes.NONE.toString() + .equalsIgnoreCase(hiveConf.getVar(ConfVars.HIVE_SERVER2_AUTHENTICATION))) { return proxyUser; } diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java index 3b57efa38b588..341a7fdbb59b8 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpCLIService.java @@ -37,12 +37,15 @@ import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.server.TServlet; -import org.eclipse.jetty.server.nio.SelectChannelConnector; -import org.eclipse.jetty.server.ssl.SslSelectChannelConnector; +import org.eclipse.jetty.server.AbstractConnectionFactory; +import org.eclipse.jetty.server.ConnectionFactory; +import org.eclipse.jetty.server.HttpConnectionFactory; +import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.eclipse.jetty.util.ssl.SslContextFactory; import org.eclipse.jetty.util.thread.ExecutorThreadPool; +import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler; public class ThriftHttpCLIService extends ThriftCLIService { @@ -59,9 +62,6 @@ public ThriftHttpCLIService(CLIService cliService) { @Override public void run() { try { - // HTTP Server - httpServer = new org.eclipse.jetty.server.Server(); - // Server thread pool // Start with minWorkerThreads, expand till maxWorkerThreads and reject subsequent requests String threadPoolName = "HiveServer2-HttpHandler-Pool"; @@ -69,10 +69,13 @@ public void run() { workerKeepAliveTime, TimeUnit.SECONDS, new SynchronousQueue(), new ThreadFactoryWithGarbageCleanup(threadPoolName)); ExecutorThreadPool threadPool = new ExecutorThreadPool(executorService); - httpServer.setThreadPool(threadPool); + + // HTTP Server + httpServer = new org.eclipse.jetty.server.Server(threadPool); // Connector configs - SelectChannelConnector connector = new SelectChannelConnector(); + + ConnectionFactory[] connectionFactories; boolean useSsl = hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_USE_SSL); String schemeName = useSsl ? "https" : "http"; // Change connector if SSL is used @@ -92,14 +95,27 @@ public void run() { Arrays.toString(sslContextFactory.getExcludeProtocols())); sslContextFactory.setKeyStorePath(keyStorePath); sslContextFactory.setKeyStorePassword(keyStorePassword); - connector = new SslSelectChannelConnector(sslContextFactory); + connectionFactories = AbstractConnectionFactory.getFactories( + sslContextFactory, new HttpConnectionFactory()); + } else { + connectionFactories = new ConnectionFactory[] { new HttpConnectionFactory() }; } + ServerConnector connector = new ServerConnector( + httpServer, + null, + // Call this full constructor to set this, which forces daemon threads: + new ScheduledExecutorScheduler("HiveServer2-HttpHandler-JettyScheduler", true), + null, + -1, + -1, + connectionFactories); + connector.setPort(portNum); // Linux:yes, Windows:no connector.setReuseAddress(!Shell.WINDOWS); int maxIdleTime = (int) hiveConf.getTimeVar(ConfVars.HIVE_SERVER2_THRIFT_HTTP_MAX_IDLE_TIME, TimeUnit.MILLISECONDS); - connector.setMaxIdleTime(maxIdleTime); + connector.setIdleTimeout(maxIdleTime); httpServer.addConnector(connector); diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java index 56c8cb6e5459d..e15d2d0566d2b 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.io.UnsupportedEncodingException; import java.security.PrivilegedExceptionAction; -import java.util.Arrays; import java.util.Map; import java.util.Random; import java.util.Set; @@ -241,9 +240,9 @@ private String getClientNameFromCookie(Cookie[] cookies) { * Each cookie is of the format [key]=[value] */ private String toCookieStr(Cookie[] cookies) { - String cookieStr = ""; + String cookieStr = ""; - for (Cookie c : cookies) { + for (Cookie c : cookies) { cookieStr += c.getName() + "=" + c.getValue() + " ;\n"; } return cookieStr; @@ -458,7 +457,7 @@ private String getPrincipalWithoutRealmAndHost(String fullPrincipal) private String getUsername(HttpServletRequest request, String authType) throws HttpAuthenticationException { - String creds[] = getAuthHeaderTokens(request, authType); + String[] creds = getAuthHeaderTokens(request, authType); // Username must be present if (creds[0] == null || creds[0].isEmpty()) { throw new HttpAuthenticationException("Authorization header received " + @@ -469,7 +468,7 @@ private String getUsername(HttpServletRequest request, String authType) private String getPassword(HttpServletRequest request, String authType) throws HttpAuthenticationException { - String creds[] = getAuthHeaderTokens(request, authType); + String[] creds = getAuthHeaderTokens(request, authType); // Password must be present if (creds[1] == null || creds[1].isEmpty()) { throw new HttpAuthenticationException("Authorization header received " + diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java index 1500e537cef5b..9bf96cff572e8 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/server/HiveServer2.java @@ -236,8 +236,8 @@ ServerOptionsExecutor getServerOptionsExecutor() { /** * The executor interface for running the appropriate HiveServer2 command based on parsed options */ - static interface ServerOptionsExecutor { - public void execute(); + interface ServerOptionsExecutor { + void execute(); } /** diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 94b1ced9908d8..e3258d858f1cc 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -157,7 +157,7 @@ object HiveThriftServer2 extends Logging { /** - * A inner sparkListener called in sc.stop to clean up the HiveThriftServer2 + * An inner sparkListener called in sc.stop to clean up the HiveThriftServer2 */ private[thriftserver] class HiveThriftServer2Listener( val server: HiveServer2, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e8bcdd76efd7a..8a78523f56b45 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -51,18 +51,16 @@ private[hive] class SparkExecuteStatementOperation( private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ + private var iterHeader: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ private var statementId: String = _ private lazy val resultSchema: TableSchema = { - if (result == null || result.queryExecution.analyzed.output.size == 0) { + if (result == null || result.schema.isEmpty) { new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) } else { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, attr.dataType.catalogString, "") - } - new TableSchema(schema.asJava) + logInfo(s"Result Schema: ${result.schema}") + SparkExecuteStatementOperation.getTableSchema(result.schema) } } @@ -110,6 +108,14 @@ private[hive] class SparkExecuteStatementOperation( assertState(OperationState.FINISHED) setHasResultSet(true) val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) + + // Reset iter to header when fetching start from first row + if (order.equals(FetchOrientation.FETCH_FIRST)) { + val (ita, itb) = iterHeader.duplicate + iter = ita + iterHeader = itb + } + if (!iter.hasNext) { resultRowSet } else { @@ -228,6 +234,9 @@ private[hive] class SparkExecuteStatementOperation( result.collect().iterator } } + val (itra, itrb) = iter.duplicate + iterHeader = itra + iter = itrb dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray } catch { case e: HiveSQLException => @@ -269,3 +278,13 @@ private[hive] class SparkExecuteStatementOperation( } } } + +object SparkExecuteStatementOperation { + def getTableSchema(structType: StructType): TableSchema = { + val schema = structType.map { field => + val attrTypeString = if (field.dataType == NullType) "void" else field.dataType.catalogString + new FieldSchema(field.name, attrTypeString, "") + } + new TableSchema(schema.asJava) + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 1402e0a687290..5dafec1c3021b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -32,8 +32,7 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, CommandProcessor, - CommandProcessorFactory, SetProcessor} +import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket @@ -157,6 +156,14 @@ private[hive] object SparkSQLCLIDriver extends Logging { // Execute -i init files (always in silent mode) cli.processInitFiles(sessionState) + // Respect the configurations set by --hiveconf from the command line + // (based on Hive's CliDriver). + val it = sessionState.getOverriddenConfigurations.entrySet().iterator() + while (it.hasNext) { + val kv = it.next() + SparkSQLEnv.sqlContext.setConf(kv.getKey, kv.getValue) + } + if (sessionState.execString != null) { System.exit(cli.processLine(sessionState.execString)) } @@ -295,9 +302,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { System.exit(0) } if (tokens(0).toLowerCase(Locale.ENGLISH).equals("source") || - cmd_trimmed.startsWith("!") || - tokens(0).toLowerCase.equals("list") || - isRemoteMode) { + cmd_trimmed.startsWith("!") || isRemoteMode) { val start = System.currentTimeMillis() super.processCmd(cmd) val end = System.currentTimeMillis() @@ -312,7 +317,8 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { if (proc != null) { // scalastyle:off println if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor] || - proc.isInstanceOf[AddResourceProcessor]) { + proc.isInstanceOf[AddResourceProcessor] || proc.isInstanceOf[ListResourceProcessor] || + proc.isInstanceOf[ResetProcessor] ) { val driver = new SparkSQLDriver driver.init() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index c24e474d9ca46..0d5dc7af5f522 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -59,7 +59,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont // TODO unify the error code try { context.sparkContext.setJobDescription(command) - val execution = context.executePlan(context.sql(command).logicalPlan) + val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) hiveResponse = execution.hiveResultString() tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 665a44e51a0c7..638911599aad3 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -54,13 +54,15 @@ private[hive] object SparkSQLEnv extends Logging { "spark.kryo.referenceTracking", maybeKryoReferenceTracking.getOrElse("false")) - sparkContext = new SparkContext(sparkConf) - sqlContext = SparkSession.withHiveSupport(sparkContext).wrapped - val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] + val sparkSession = SparkSession.builder.config(sparkConf).enableHiveSupport().getOrCreate() + sparkContext = sparkSession.sparkContext + sqlContext = sparkSession.sqlContext + + val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) sessionState.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) sessionState.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) - sqlContext.setConf("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) + sparkSession.conf.set("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index 1e4c4790856be..fc5b221ee3f49 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -79,6 +79,9 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: sqlContext.newSession() } ctx.setConf("spark.sql.hive.version", HiveUtils.hiveExecutionVersion) + if (sessionConf != null && sessionConf.containsKey("use:database")) { + ctx.sql(s"use ${sessionConf.get("use:database")}") + } sparkSqlOperationManager.sessionToContexts += sessionHandle -> ctx sessionHandle } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index c82fa4eaaa4e5..2e0fa1ef77f88 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -30,7 +30,7 @@ import org.apache.spark.ui._ import org.apache.spark.ui.UIUtils._ -/** Page for Spark Web UI that shows statistics of a thrift server */ +/** Page for Spark Web UI that shows statistics of the thrift server */ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("") with Logging { private val listener = parent.listener diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 008108a5ce06d..f39e9dcd3a5bb 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, import org.apache.spark.ui._ import org.apache.spark.ui.UIUtils._ -/** Page for Spark Web UI that shows statistics of a streaming job */ +/** Page for Spark Web UI that shows statistics of jobs running in the thrift server */ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) extends WebUIPage("session") with Logging { @@ -60,7 +60,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } - /** Generate basic stats of the streaming program */ + /** Generate basic stats of the thrift server program */ private def generateBasicStats(): Seq[Node] = { val timeSinceStart = System.currentTimeMillis() - startTime.getTime
          diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 923ba8a30c5c5..db2066009b351 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab._ import org.apache.spark.ui.{SparkUI, SparkUITab} /** - * Spark Web UI tab that shows statistics of a streaming job. + * Spark Web UI tab that shows statistics of jobs running in the thrift server. * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) diff --git a/sql/hive-thriftserver/src/test/resources/TestUDTF.jar b/sql/hive-thriftserver/src/test/resources/TestUDTF.jar new file mode 100644 index 0000000000000..514f2d5d26fd3 Binary files /dev/null and b/sql/hive-thriftserver/src/test/resources/TestUDTF.jar differ diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 3fa2f884e2c4c..d3cec11bd7567 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -62,13 +62,13 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { /** * Run a CLI operation and expect all the queries and expected answers to be returned. + * * @param timeout maximum time for the commands to complete * @param extraArgs any extra arguments * @param errorResponses a sequence of strings whose presence in the stdout of the forked process * is taken as an immediate error condition. That is: if a line containing * with one of these strings is found, fail the test immediately. * The default value is `Seq("Error:")` - * * @param queriesAndExpectedAnswers one or more tuples of query + answer */ def runCliWithin( @@ -91,6 +91,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath + | --hiveconf conf1=conftest + | --hiveconf conf2=1 """.stripMargin.split("\\s+").toSeq ++ extraArgs } @@ -238,4 +240,47 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { runCliWithin(2.minute, Seq("-e", "!echo \"This is a test for Spark-11624\";"))( "" -> "This is a test for Spark-11624") } + + test("list jars") { + val jarFile = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar") + runCliWithin(2.minute)( + s"ADD JAR $jarFile;" -> "", + s"LIST JARS;" -> "TestUDTF.jar" + ) + } + + test("list jar ") { + val jarFile = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar") + runCliWithin(2.minute)( + s"ADD JAR $jarFile;" -> "", + s"List JAR $jarFile;" -> "TestUDTF.jar" + ) + } + + test("list files") { + val dataFilePath = Thread.currentThread(). + getContextClassLoader.getResource("data/files/small_kv.txt") + runCliWithin(2.minute)( + s"ADD FILE $dataFilePath;" -> "", + s"LIST FILES;" -> "small_kv.txt" + ) + } + + test("list file ") { + val dataFilePath = Thread.currentThread(). + getContextClassLoader.getResource("data/files/small_kv.txt") + runCliWithin(2.minute)( + s"ADD FILE $dataFilePath;" -> "", + s"LIST FILE $dataFilePath;" -> "small_kv.txt" + ) + } + + test("apply hiveconf from cli command") { + runCliWithin(2.minute)( + "SET conf1;" -> "conftest", + "SET conf2;" -> "1", + "SET conf3=${hiveconf:conf1};" -> "conftest", + "SET conf3;" -> "conftest" + ) + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 55a93ea06ba57..8f2c4fafa0b43 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -36,6 +36,8 @@ import org.apache.hive.service.auth.PlainSaslHelper import org.apache.hive.service.cli.GetInfoType import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient +import org.apache.hive.service.cli.FetchOrientation +import org.apache.hive.service.cli.FetchType import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll @@ -91,6 +93,52 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } + test("SPARK-16563 ThriftCLIService FetchResults repeat fetching result") { + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_16563", + "CREATE TABLE test_16563(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_16563") + + queries.foreach(statement.execute) + val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String] + val operationHandle = client.executeStatement( + sessionHandle, + "SELECT * FROM test_16563", + confOverlay) + + // Fetch result first time + assertResult(5, "Fetching result first time from next row") { + + val rows_next = client.fetchResults( + operationHandle, + FetchOrientation.FETCH_NEXT, + 1000, + FetchType.QUERY_OUTPUT) + + rows_next.numRows() + } + + // Fetch result second time from first row + assertResult(5, "Repeat fetching result from first row") { + + val rows_first = client.fetchResults( + operationHandle, + FetchOrientation.FETCH_FIRST, + 1000, + FetchType.QUERY_OUTPUT) + + rows_first.numRows() + } + statement.executeQuery("DROP TABLE IF EXISTS test_16563") + } + } + } + test("JDBC query execution") { withJdbcStatement { statement => val queries = Seq( @@ -533,7 +581,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } assert(rs1.next()) - assert(rs1.getString(1) === "Usage: To be added.") + assert(rs1.getString(1) === "Usage: N/A.") val dataPath = "../hive/src/test/resources/data/files/kv1.txt" @@ -607,7 +655,7 @@ class SingleSessionSuite extends HiveThriftJdbcTest { } assert(rs2.next()) - assert(rs2.getString(1) === "Usage: To be added.") + assert(rs2.getString(1) === "Usage: N/A.") } finally { statement.executeQuery("DROP TEMPORARY FUNCTION udtf_count2") } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala new file mode 100644 index 0000000000000..fb8a7e273ae44 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/JdbcConnectionUriSuite.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.sql.DriverManager + +import org.apache.hive.jdbc.HiveDriver + +import org.apache.spark.util.Utils + +class JdbcConnectionUriSuite extends HiveThriftServer2Test { + Utils.classForName(classOf[HiveDriver].getCanonicalName) + + override def mode: ServerMode.Value = ServerMode.binary + + val JDBC_TEST_DATABASE = "jdbc_test_database" + val USER = System.getProperty("user.name") + val PASSWORD = "" + + override protected def beforeAll(): Unit = { + super.beforeAll() + + val jdbcUri = s"jdbc:hive2://localhost:$serverPort/" + val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD) + val statement = connection.createStatement() + statement.execute(s"CREATE DATABASE $JDBC_TEST_DATABASE") + connection.close() + } + + override protected def afterAll(): Unit = { + try { + val jdbcUri = s"jdbc:hive2://localhost:$serverPort/" + val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD) + val statement = connection.createStatement() + statement.execute(s"DROP DATABASE $JDBC_TEST_DATABASE") + connection.close() + } finally { + super.afterAll() + } + } + + test("SPARK-17819 Support default database in connection URIs") { + val jdbcUri = s"jdbc:hive2://localhost:$serverPort/$JDBC_TEST_DATABASE" + val connection = DriverManager.getConnection(jdbcUri, USER, PASSWORD) + val statement = connection.createStatement() + try { + val resultSet = statement.executeQuery("select current_database()") + resultSet.next() + assert(resultSet.getString(1) === JDBC_TEST_DATABASE) + } finally { + statement.close() + connection.close() + } + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala new file mode 100644 index 0000000000000..32ded0d254ef8 --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperationSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{NullType, StructField, StructType} + +class SparkExecuteStatementOperationSuite extends SparkFunSuite { + test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") { + val field1 = StructField("NULL", NullType) + val field2 = StructField("(IF(true, NULL, NULL))", NullType) + val tableSchema = StructType(Seq(field1, field2)) + val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors() + assert(columns.size() == 2) + assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) + assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE) + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f082035852cc0..a54d234876256 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc + private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled def testCases: Seq[(String, File)] = { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) @@ -61,6 +62,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Ensures that the plans generation use metastore relation and not OrcRelation // Was done because SqlBuilder does not work with plans having logical relation TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false) + // Ensures that cross joins are enabled so that we can test them + TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) RuleExecutor.resetTime() } @@ -72,6 +75,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc) + TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) TestHive.sessionState.functionRegistry.restore() // For debugging dump some statistics about how much time was spent in various optimizer rules @@ -503,7 +507,53 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // We have converted the useful parts of these tests to tests // in org.apache.spark.sql.hive.execution.SQLQuerySuite. "drop_database_removes_partition_dirs", - "drop_table_removes_partition_dirs" + "drop_table_removes_partition_dirs", + + // These tests use EXPLAIN FORMATTED, which is not supported + "input4", + "join0", + "plan_json", + + // This test uses CREATE EXTERNAL TABLE without specifying LOCATION + "alter2", + + // [SPARK-16248][SQL] Whitelist the list of Hive fallback functions + "udf_field", + "udf_reflect2", + "udf_xpath", + "udf_xpath_boolean", + "udf_xpath_double", + "udf_xpath_float", + "udf_xpath_int", + "udf_xpath_long", + "udf_xpath_short", + "udf_xpath_string", + + // These tests DROP TABLE that don't exist (but do not specify IF EXISTS) + "alter_rename_partition1", + "date_1", + "date_4", + "date_join1", + "date_serde", + "insert_compressed", + "lateral_view_cp", + "leftsemijoin", + "mapjoin_subquery2", + "nomore_ambiguous_table_col", + "partition_date", + "partition_varchar1", + "ppd_repeated_alias", + "push_or", + "reducesink_dedup", + "subquery_in", + "subquery_notin_having", + "timestamp_3", + "timestamp_lazy", + "udaf_covar_pop", + "union31", + "union_date", + "varchar_2", + "varchar_join1" ) /** @@ -516,9 +566,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "add_partition_no_whitelist", "add_partition_with_whitelist", "alias_casted_column", - "alter2", "alter_partition_with_whitelist", - "alter_rename_partition", "ambiguous_col", "annotate_stats_join", "annotate_stats_limit", @@ -595,12 +643,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "database_drop", "database_location", "database_properties", - "date_1", "date_2", - "date_4", "date_comparison", - "date_join1", - "date_serde", "decimal_1", "decimal_4", "decimal_join", @@ -699,7 +743,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "input26", "input28", "input2_limit", - "input4", "input40", "input41", "input49", @@ -727,8 +770,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "insert1", "insert1_overwrite_partitions", "insert2_overwrite_partitions", - "insert_compressed", - "join0", "join1", "join10", "join11", @@ -784,10 +825,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_reorder4", "join_star", "lateral_view", - "lateral_view_cp", "lateral_view_noalias", "lateral_view_ppd", - "leftsemijoin", "leftsemijoin_mr", "limit_pushdown_negative", "lineage1", @@ -815,7 +854,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "mapjoin_filter_on_outerjoin", "mapjoin_mapjoin", "mapjoin_subquery", - "mapjoin_subquery2", "mapjoin_test_outer", "mapreduce1", "mapreduce2", @@ -837,7 +875,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "multi_join_union", "multigroupby_singlemr", "noalias_subq1", - "nomore_ambiguous_table_col", "nonblock_op_deduplicate", "notable_alias1", "notable_alias2", @@ -861,12 +898,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "part_inherit_tbl_props", "part_inherit_tbl_props_with_star", "partcols1", - "partition_date", "partition_serde_format", "partition_type_check", - "partition_varchar1", "partition_wise_fileformat9", - "plan_json", "ppd1", "ppd2", "ppd_clusterby", @@ -885,7 +919,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ppd_outer_join4", "ppd_outer_join5", "ppd_random", - "ppd_repeated_alias", "ppd_udf_col", "ppd_union", "ppr_allchildsarenull", @@ -893,7 +926,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ppr_pushdown2", "ppr_pushdown3", "progress_1", - "push_or", "query_with_semi", "quote1", "quote2", @@ -905,7 +937,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "reduce_deduplicate_exclude_gby", "reduce_deduplicate_exclude_join", "reduce_deduplicate_extended", - "reducesink_dedup", "router_join_ppr", "select_as_omitted", "select_unquote_and", @@ -924,17 +955,19 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "stats_aggregator_error_1", "stats_publisher_error_1", "subq2", + "subquery_exists", + "subquery_exists_having", + "subquery_notexists", + "subquery_notexists_having", + "subquery_in_having", "tablename_with_select", - "timestamp_3", "timestamp_comparison", - "timestamp_lazy", "timestamp_null", "transform_ppr1", "transform_ppr2", "type_cast_1", "type_widening", "udaf_collect_set", - "udaf_covar_pop", "udaf_histogram_numeric", "udf2", "udf5", @@ -946,8 +979,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_PI", "udf_acos", "udf_add", - "udf_array", - "udf_array_contains", + // "udf_array", -- done in array.sql + // "udf_array_contains", -- done in array.sql "udf_ascii", "udf_asin", "udf_atan", @@ -983,7 +1016,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_elt", "udf_equal", "udf_exp", - "udf_field", "udf_find_in_set", "udf_float", "udf_floor", @@ -1028,7 +1060,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_power", "udf_radians", "udf_rand", - "udf_reflect2", "udf_regexp", "udf_regexp_extract", "udf_regexp_replace", @@ -1069,14 +1100,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_variance", "udf_weekofyear", "udf_when", - "udf_xpath", - "udf_xpath_boolean", - "udf_xpath_double", - "udf_xpath_float", - "udf_xpath_int", - "udf_xpath_long", - "udf_xpath_short", - "udf_xpath_string", "union10", "union11", "union13", @@ -1098,7 +1121,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "union29", "union3", "union30", - "union31", "union33", "union34", "union4", @@ -1107,13 +1129,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "union7", "union8", "union9", - "union_date", "union_lateralview", "union_ppr", "union_remove_6", "union_script", - "varchar_2", - "varchar_join1", "varchar_union1", "view", "view_cast", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index de592f8d937dd..7ba5790c2979d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -534,31 +534,6 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte | rows between 2 preceding and 2 following); """.stripMargin, reset = false) - // collect_set() output array in an arbitrary order, hence causes different result - // when running this test suite under Java 7 and 8. - // We change the original sql query a little bit for making the test suite passed - // under different JDK - /* Disabled because: - - Spark uses a different default stddev. - - Tiny numerical differences in stddev results. - createQueryTest("windowing.q -- 20. testSTATs", - """ - |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp - |from ( - |select p_mfgr,p_name, p_size, - |stddev(p_retailprice) over w1 as sdev, - |stddev_pop(p_retailprice) over w1 as sdev_pop, - |collect_set(p_size) over w1 as uniq_size, - |variance(p_retailprice) over w1 as var, - |corr(p_size, p_retailprice) over w1 as cor, - |covar_pop(p_size, p_retailprice) over w1 as covarp - |from part - |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name - | rows between 2 preceding and 2 following) - |) t lateral view explode(uniq_size) d as uniq_data - |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp - """.stripMargin, reset = false) - */ createQueryTest("windowing.q -- 21. testDISTs", """ |select p_mfgr,p_name, p_size, @@ -826,15 +801,17 @@ class HiveWindowFunctionQueryFileSuite "windowing_ntile", "windowing_udaf", "windowing_windowspec", - "windowing_rank" - ) + "windowing_rank", - override def whiteList: Seq[String] = Seq( - "windowing_udaf2", + // These tests DROP TABLE that don't exist (but do not specify IF EXISTS) "windowing_columnPruning", "windowing_adjust_rowcontainer_sz" ) + override def whiteList: Seq[String] = Seq( + "windowing_udaf2" + ) + // Only run those query tests in the realWhileList (do not try other ignored query files). override def testCases: Seq[(String, File)] = super.testCases.filter { case (name, _) => realWhiteList.contains(name) diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 177b6884fa13b..68825c654c7d4 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../../pom.xml @@ -60,7 +60,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} + + + + hive.in.test + true + Internal marker for test. + + diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala index 10a017df831e0..4fbbbacb76081 100644 --- a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -15,7 +15,6 @@ * limitations under the License. */ -import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession /** @@ -33,15 +32,18 @@ object Main { def main(args: Array[String]) { // scalastyle:off println println("Running regression test for SPARK-8489.") - val sc = new SparkContext("local", "testing") - val sparkSession = SparkSession.withHiveSupport(sc) + val spark = SparkSession.builder + .master("local") + .appName("testing") + .enableHiveSupport() + .getOrCreate() // This line should not throw scala.reflect.internal.MissingRequirementError. // See SPARK-8470 for more detail. - val df = sparkSession.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) + val df = spark.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) df.collect() println("Regression test for SPARK-8489 success!") // scalastyle:on println - sc.stop() + spark.stop() } } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar index 26d410f33029b..3f28d37b93150 100644 Binary files a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar and b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.10.jar differ diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.11.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.11.jar index f34784752f69f..5e093697e219a 100644 Binary files a/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.11.jar and b/sql/hive/src/test/resources/regression-test-SPARK-8489/test-2.11.jar differ diff --git a/sql/hive/src/test/resources/sqlgen/agg1.sql b/sql/hive/src/test/resources/sqlgen/agg1.sql new file mode 100644 index 0000000000000..05403a9dd8927 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/agg1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT count(`gen_attr_3`) AS `gen_attr_0`, max(`gen_attr_2`) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_2` HAVING (`gen_attr_1` > CAST(0 AS BIGINT))) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/agg2.sql b/sql/hive/src/test/resources/sqlgen/agg2.sql new file mode 100644 index 0000000000000..65d71714fe850 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/agg2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY MAX(key) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT count(`gen_attr_3`) AS `gen_attr_0`, max(`gen_attr_2`) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_2` ORDER BY `gen_attr_1` ASC) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/agg3.sql b/sql/hive/src/test/resources/sqlgen/agg3.sql new file mode 100644 index 0000000000000..14b19392cdce3 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/agg3.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key, MAX(key) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT count(`gen_attr_4`) AS `gen_attr_0`, `gen_attr_3` AS `gen_attr_1`, max(`gen_attr_3`) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_3` ORDER BY `gen_attr_1` ASC, `gen_attr_2` ASC) AS gen_subquery_1) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/aggregate_functions_and_window.sql b/sql/hive/src/test/resources/sqlgen/aggregate_functions_and_window.sql new file mode 100644 index 0000000000000..e3e372d5eccdd --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/aggregate_functions_and_window.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT MAX(c) + COUNT(a) OVER () FROM parquet_t2 GROUP BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `(max(c) + count(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING))` FROM (SELECT (`gen_attr_1` + `gen_attr_2`) AS `gen_attr_0` FROM (SELECT gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, count(`gen_attr_3`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `gen_attr_2` FROM (SELECT max(`gen_attr_5`) AS `gen_attr_1`, `gen_attr_3` FROM (SELECT `a` AS `gen_attr_3`, `b` AS `gen_attr_4`, `c` AS `gen_attr_5`, `d` AS `gen_attr_6` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_3`, `gen_attr_4`) AS gen_subquery_1) AS gen_subquery_2) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql new file mode 100644 index 0000000000000..ec881a216e0b0 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/broadcast_join_subquery.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT /*+ MAPJOIN(srcpart) */ subq.key1, z.value +FROM (SELECT x.key as key1, x.value as value1, y.key as key2, y.value as value2 + FROM src1 x JOIN src y ON (x.key = y.key)) subq +JOIN srcpart z ON (subq.key1 = z.key and z.ds='2008-04-08' and z.hr=11) +ORDER BY subq.key1, z.value +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key1`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_7` AS `gen_attr_6`, `gen_attr_9` AS `gen_attr_8`, `gen_attr_11` AS `gen_attr_10` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_7` FROM `default`.`src1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`src`) AS gen_subquery_1 ON (`gen_attr_5` = `gen_attr_9`)) AS subq INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1`, `ds` AS `gen_attr_3`, `hr` AS `gen_attr_4` FROM `default`.`srcpart`) AS gen_subquery_2 ON (((`gen_attr_0` = `gen_attr_2`) AND (`gen_attr_3` = '2008-04-08')) AND (CAST(`gen_attr_4` AS DOUBLE) = CAST(11 AS DOUBLE))) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/case.sql b/sql/hive/src/test/resources/sqlgen/case.sql new file mode 100644 index 0000000000000..99630e88cff66 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/case.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `CASE WHEN ((id % CAST(2 AS BIGINT)) > CAST(0 AS BIGINT)) THEN 0 WHEN ((id % CAST(2 AS BIGINT)) = CAST(0 AS BIGINT)) THEN 1 END` FROM (SELECT CASE WHEN ((`gen_attr_1` % CAST(2 AS BIGINT)) > CAST(0 AS BIGINT)) THEN 0 WHEN ((`gen_attr_1` % CAST(2 AS BIGINT)) = CAST(0 AS BIGINT)) THEN 1 END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/case_with_else.sql b/sql/hive/src/test/resources/sqlgen/case_with_else.sql new file mode 100644 index 0000000000000..aed8f08804807 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/case_with_else.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `CASE WHEN ((id % CAST(2 AS BIGINT)) > CAST(0 AS BIGINT)) THEN 0 ELSE 1 END` FROM (SELECT CASE WHEN ((`gen_attr_1` % CAST(2 AS BIGINT)) > CAST(0 AS BIGINT)) THEN 0 ELSE 1 END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/case_with_key.sql b/sql/hive/src/test/resources/sqlgen/case_with_key.sql new file mode 100644 index 0000000000000..e991ebafdc90e --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/case_with_key.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN 'foo' WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN 'bar' END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql b/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql new file mode 100644 index 0000000000000..492777e376ecc --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/case_with_key_and_else.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `CASE WHEN (id = CAST(0 AS BIGINT)) THEN foo WHEN (id = CAST(1 AS BIGINT)) THEN bar ELSE baz END` FROM (SELECT CASE WHEN (`gen_attr_1` = CAST(0 AS BIGINT)) THEN 'foo' WHEN (`gen_attr_1` = CAST(1 AS BIGINT)) THEN 'bar' ELSE 'baz' END AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/cluster_by.sql b/sql/hive/src/test/resources/sqlgen/cluster_by.sql new file mode 100644 index 0000000000000..3154791c3c5fd --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/cluster_by.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM parquet_t0 CLUSTER BY id +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 CLUSTER BY `gen_attr_0`) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/data_source_json_parquet_t0.sql b/sql/hive/src/test/resources/sqlgen/data_source_json_parquet_t0.sql new file mode 100644 index 0000000000000..e41b645937d37 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/data_source_json_parquet_t0.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM json_parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`json_parquet_t0`) AS gen_subquery_0) AS json_parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/data_source_orc_parquet_t0.sql b/sql/hive/src/test/resources/sqlgen/data_source_orc_parquet_t0.sql new file mode 100644 index 0000000000000..f5ceccde8c65b --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/data_source_orc_parquet_t0.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM orc_parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`orc_parquet_t0`) AS gen_subquery_0) AS orc_parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/data_source_parquet_parquet_t0.sql b/sql/hive/src/test/resources/sqlgen/data_source_parquet_parquet_t0.sql new file mode 100644 index 0000000000000..2bccefe55e417 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/data_source_parquet_parquet_t0.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM parquet_parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_parquet_t0`) AS gen_subquery_0) AS parquet_parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/distinct_aggregation.sql b/sql/hive/src/test/resources/sqlgen/distinct_aggregation.sql new file mode 100644 index 0000000000000..bced711caedf4 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/distinct_aggregation.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT COUNT(DISTINCT id) FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `count(DISTINCT id)` FROM (SELECT count(DISTINCT `gen_attr_1`) AS `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/distribute_by.sql b/sql/hive/src/test/resources/sqlgen/distribute_by.sql new file mode 100644 index 0000000000000..72863dcaf5c9c --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/distribute_by.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM parquet_t0 DISTRIBUTE BY id +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 DISTRIBUTE BY `gen_attr_0`) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/distribute_by_with_sort_by.sql b/sql/hive/src/test/resources/sqlgen/distribute_by_with_sort_by.sql new file mode 100644 index 0000000000000..96b9b2dae87aa --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/distribute_by_with_sort_by.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM parquet_t0 DISTRIBUTE BY id SORT BY id +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 CLUSTER BY `gen_attr_0`) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/except.sql b/sql/hive/src/test/resources/sqlgen/except.sql new file mode 100644 index 0000000000000..7a7d27fcd6336 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/except.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM t0 EXCEPT SELECT * FROM t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM ((SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0`) AS gen_subquery_0 ) EXCEPT ( SELECT `gen_attr_1` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`t0`) AS gen_subquery_1)) AS t0 diff --git a/sql/hive/src/test/resources/sqlgen/filter_after_subquery.sql b/sql/hive/src/test/resources/sqlgen/filter_after_subquery.sql new file mode 100644 index 0000000000000..9cd6514d771ff --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/filter_after_subquery.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a FROM (SELECT key + 1 AS a FROM parquet_t1) t WHERE a > 5 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a` FROM (SELECT `gen_attr_0` FROM (SELECT (`gen_attr_1` + CAST(1 AS BIGINT)) AS `gen_attr_0` FROM (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS t WHERE (`gen_attr_0` > CAST(5 AS BIGINT))) AS t diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql new file mode 100644 index 0000000000000..805197a4ea11b --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_1.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT EXPLODE(arr) AS val, id +FROM parquet_t3 +WHERE id > 2 +ORDER BY val, id +LIMIT 5 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_0.`gen_attr_2`, gen_subquery_0.`gen_attr_3`, gen_subquery_0.`gen_attr_4`, gen_subquery_0.`gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 WHERE (`gen_attr_1` > CAST(2 AS BIGINT))) AS gen_subquery_1 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0` ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC LIMIT 5) AS parquet_t3 diff --git a/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql new file mode 100644 index 0000000000000..ef9a596197b8b --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generate_with_other_2.sql @@ -0,0 +1,10 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT val, id +FROM parquet_t3 +LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array +LATERAL VIEW EXPLODE(nested_array) exp1 AS val +WHERE val > 2 +ORDER BY val, id +LIMIT 5 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0` WHERE (`gen_attr_0` > CAST(2 AS BIGINT)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC LIMIT 5) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_in_lateral_view_1.sql b/sql/hive/src/test/resources/sqlgen/generator_in_lateral_view_1.sql new file mode 100644 index 0000000000000..2f6596ef422b0 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_in_lateral_view_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT val, id FROM parquet_t3 LATERAL VIEW EXPLODE(arr) exp AS val +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_in_lateral_view_2.sql b/sql/hive/src/test/resources/sqlgen/generator_in_lateral_view_2.sql new file mode 100644 index 0000000000000..239980dd80bda --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_in_lateral_view_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT val, id FROM parquet_t3 LATERAL VIEW OUTER EXPLODE(arr) exp AS val +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW OUTER explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_non_referenced_table_1.sql b/sql/hive/src/test/resources/sqlgen/generator_non_referenced_table_1.sql new file mode 100644 index 0000000000000..7fe0298c8e171 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_non_referenced_table_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT EXPLODE(ARRAY(1,2,3)) FROM t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `col` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`t0`) AS gen_subquery_0 LATERAL VIEW explode(array(1, 2, 3)) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_non_referenced_table_2.sql b/sql/hive/src/test/resources/sqlgen/generator_non_referenced_table_2.sql new file mode 100644 index 0000000000000..8db834acc73a1 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_non_referenced_table_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT EXPLODE(ARRAY(1,2,3)) AS val FROM t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`t0`) AS gen_subquery_0 LATERAL VIEW explode(array(1, 2, 3)) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_non_udtf_1.sql b/sql/hive/src/test/resources/sqlgen/generator_non_udtf_1.sql new file mode 100644 index 0000000000000..fef65e006867c --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_non_udtf_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT EXPLODE(arr), id FROM parquet_t3 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `col`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_1 AS `gen_attr_0`) AS parquet_t3 diff --git a/sql/hive/src/test/resources/sqlgen/generator_non_udtf_2.sql b/sql/hive/src/test/resources/sqlgen/generator_non_udtf_2.sql new file mode 100644 index 0000000000000..e0e310888f11f --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_non_udtf_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT EXPLODE(arr) AS val, id as a FROM parquet_t3 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `a` FROM (SELECT `gen_attr_0`, `gen_attr_2` AS `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_3`, `arr2` AS `gen_attr_4`, `json` AS `gen_attr_5`, `id` AS `gen_attr_2` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_referenced_table_1.sql b/sql/hive/src/test/resources/sqlgen/generator_referenced_table_1.sql new file mode 100644 index 0000000000000..ea5db850bef8a --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_referenced_table_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT EXPLODE(arr) FROM parquet_t3 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `col` FROM (SELECT `gen_attr_0` FROM (SELECT `arr` AS `gen_attr_1`, `arr2` AS `gen_attr_2`, `json` AS `gen_attr_3`, `id` AS `gen_attr_4` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_1`) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_referenced_table_2.sql b/sql/hive/src/test/resources/sqlgen/generator_referenced_table_2.sql new file mode 100644 index 0000000000000..8f75b825476e0 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_referenced_table_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT EXPLODE(arr) AS val FROM parquet_t3 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val` FROM (SELECT `gen_attr_0` FROM (SELECT `arr` AS `gen_attr_1`, `arr2` AS `gen_attr_2`, `json` AS `gen_attr_3`, `id` AS `gen_attr_4` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_1`) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_with_ambiguous_names_1.sql b/sql/hive/src/test/resources/sqlgen/generator_with_ambiguous_names_1.sql new file mode 100644 index 0000000000000..984cce8a2ca83 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_with_ambiguous_names_1.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT exp.id, parquet_t3.id +FROM parquet_t3 +LATERAL VIEW EXPLODE(arr) exp AS id +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_with_ambiguous_names_2.sql b/sql/hive/src/test/resources/sqlgen/generator_with_ambiguous_names_2.sql new file mode 100644 index 0000000000000..5c55b164c7feb --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_with_ambiguous_names_2.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT exp.id, parquet_t3.id +FROM parquet_t3 +LATERAL VIEW OUTER EXPLODE(arr) exp AS id +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_2`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_4`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW OUTER explode(`gen_attr_2`) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/generator_without_from_1.sql b/sql/hive/src/test/resources/sqlgen/generator_without_from_1.sql new file mode 100644 index 0000000000000..ee22fe8728995 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_without_from_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT EXPLODE(ARRAY(1,2,3)) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `col` FROM (SELECT `gen_attr_0` FROM (SELECT 1) gen_subquery_1 LATERAL VIEW explode(array(1, 2, 3)) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_0 diff --git a/sql/hive/src/test/resources/sqlgen/generator_without_from_2.sql b/sql/hive/src/test/resources/sqlgen/generator_without_from_2.sql new file mode 100644 index 0000000000000..0acded74b3eee --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/generator_without_from_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT EXPLODE(ARRAY(1,2,3)) AS val +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val` FROM (SELECT `gen_attr_0` FROM (SELECT 1) gen_subquery_1 LATERAL VIEW explode(array(1, 2, 3)) gen_subquery_2 AS `gen_attr_0`) AS gen_subquery_0 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_1.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_1.sql new file mode 100644 index 0000000000000..db2b2cc732889 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_1.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 +FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 +GROUPING SETS (key % 5, key - 5) +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `cnt`, `gen_attr_4` AS `k1`, `gen_attr_5` AS `k2`, `gen_attr_6` AS `k3` FROM (SELECT count(1) AS `gen_attr_3`, (`gen_attr_7` % CAST(5 AS BIGINT)) AS `gen_attr_4`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_5`, grouping_id() AS `gen_attr_6` FROM (SELECT `gen_attr_7`, (`gen_attr_7` % CAST(2 AS BIGINT)) AS `gen_attr_8`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_9` FROM (SELECT `key` AS `gen_attr_7`, `value` AS `gen_attr_12` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS t GROUP BY (`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_7` % CAST(5 AS BIGINT))), ((`gen_attr_7` - CAST(5 AS BIGINT))))) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_1.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_1.sql new file mode 100644 index 0000000000000..b2c426c660d80 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`), (`gen_attr_6`)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_2.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_2.sql new file mode 100644 index 0000000000000..96ee8e85951e8 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_3.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_3.sql new file mode 100644 index 0000000000000..9b8b230c879c2 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_3.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_6`)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_4.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_4.sql new file mode 100644 index 0000000000000..c35db74a5c5b5 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_4.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS(()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/grouping_sets_2_5.sql b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_5.sql new file mode 100644 index 0000000000000..e47f6d5dcf465 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/grouping_sets_2_5.sql @@ -0,0 +1,5 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b +GROUPING SETS ((), (a), (a, b)) ORDER BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((), (`gen_attr_5`), (`gen_attr_5`, `gen_attr_6`)) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/in.sql b/sql/hive/src/test/resources/sqlgen/in.sql new file mode 100644 index 0000000000000..7cff62b1af7df --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/in.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0 WHERE (CAST(`gen_attr_0` AS BIGINT) IN (CAST(1 AS BIGINT), CAST(2 AS BIGINT), CAST(3 AS BIGINT)))) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/inline_tables.sql b/sql/hive/src/test/resources/sqlgen/inline_tables.sql new file mode 100644 index 0000000000000..18803a3ee59b9 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/inline_tables.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) where b > 1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (VALUES ('one', 1), ('two', 2), ('three', CAST(NULL AS INT)) AS gen_subquery_0(gen_attr_0, gen_attr_1)) AS data WHERE (`gen_attr_1` > 1)) AS data diff --git a/sql/hive/src/test/resources/sqlgen/intersect.sql b/sql/hive/src/test/resources/sqlgen/intersect.sql new file mode 100644 index 0000000000000..4143a6208d4b5 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/intersect.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM t0 INTERSECT SELECT * FROM t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM ((SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0`) AS gen_subquery_0 ) INTERSECT ( SELECT `gen_attr_1` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`t0`) AS gen_subquery_1)) AS t0 diff --git a/sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql b/sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql new file mode 100644 index 0000000000000..31d00348769f5 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/interval_arithmetic.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select ts + interval 1 day, ts + interval 2 days, + ts - interval 1 day, ts - interval 2 days, + ts + interval '1' day, ts + interval '2' days, + ts - interval '1' day, ts - interval '2' days +from dates +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `CAST(ts + interval 1 days AS TIMESTAMP)`, `gen_attr_2` AS `CAST(ts + interval 2 days AS TIMESTAMP)`, `gen_attr_3` AS `CAST(ts - interval 1 days AS TIMESTAMP)`, `gen_attr_4` AS `CAST(ts - interval 2 days AS TIMESTAMP)`, `gen_attr_5` AS `CAST(ts + interval 1 days AS TIMESTAMP)`, `gen_attr_6` AS `CAST(ts + interval 2 days AS TIMESTAMP)`, `gen_attr_7` AS `CAST(ts - interval 1 days AS TIMESTAMP)`, `gen_attr_8` AS `CAST(ts - interval 2 days AS TIMESTAMP)` FROM (SELECT CAST(`gen_attr_1` + interval 1 days AS TIMESTAMP) AS `gen_attr_0`, CAST(`gen_attr_1` + interval 2 days AS TIMESTAMP) AS `gen_attr_2`, CAST(`gen_attr_1` - interval 1 days AS TIMESTAMP) AS `gen_attr_3`, CAST(`gen_attr_1` - interval 2 days AS TIMESTAMP) AS `gen_attr_4`, CAST(`gen_attr_1` + interval 1 days AS TIMESTAMP) AS `gen_attr_5`, CAST(`gen_attr_1` + interval 2 days AS TIMESTAMP) AS `gen_attr_6`, CAST(`gen_attr_1` - interval 1 days AS TIMESTAMP) AS `gen_attr_7`, CAST(`gen_attr_1` - interval 2 days AS TIMESTAMP) AS `gen_attr_8` FROM (SELECT `ts` AS `gen_attr_1` FROM `default`.`dates`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/join_2_tables.sql b/sql/hive/src/test/resources/sqlgen/join_2_tables.sql new file mode 100644 index 0000000000000..9dd200c3c0cfa --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/join_2_tables.sql @@ -0,0 +1,7 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT COUNT(a.value), b.KEY, a.KEY +FROM parquet_t1 a, parquet_t1 b +GROUP BY a.KEY, b.KEY +HAVING MAX(a.KEY) > 0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `count(value)`, `gen_attr_1` AS `KEY`, `gen_attr_2` AS `KEY` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT count(`gen_attr_4`) AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`, max(`gen_attr_2`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_5` FROM `default`.`parquet_t1`) AS gen_subquery_1 GROUP BY `gen_attr_2`, `gen_attr_1` HAVING (`gen_attr_3` > CAST(0 AS BIGINT))) AS gen_subquery_2) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql new file mode 100644 index 0000000000000..11e45a48f1b89 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_1.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT c0, c1, c2 +FROM parquet_t3 +LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `c0`, `gen_attr_1` AS `c1`, `gen_attr_2` AS `c2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, 'f1', 'f2', 'f3') gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt diff --git a/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql new file mode 100644 index 0000000000000..d86b39df57442 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/json_tuple_generator_2.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, c +FROM parquet_t3 +LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt AS a, b, c +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_2` AS `c` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_5`, `json` AS `gen_attr_3`, `id` AS `gen_attr_6` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW json_tuple(`gen_attr_3`, 'f1', 'f2', 'f3') gen_subquery_1 AS `gen_attr_0`, `gen_attr_1`, `gen_attr_2`) AS jt diff --git a/sql/hive/src/test/resources/sqlgen/multi_distinct.sql b/sql/hive/src/test/resources/sqlgen/multi_distinct.sql new file mode 100644 index 0000000000000..3ca526fcc4415 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/multi_distinct.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM parquet_t2 GROUP BY a +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `count(DISTINCT b)`, `gen_attr_3` AS `count(DISTINCT c)`, `gen_attr_5` AS `sum(d)` FROM (SELECT `gen_attr_0`, count(DISTINCT `gen_attr_2`) AS `gen_attr_1`, count(DISTINCT `gen_attr_4`) AS `gen_attr_3`, sum(`gen_attr_6`) AS `gen_attr_5` FROM (SELECT `a` AS `gen_attr_0`, `b` AS `gen_attr_2`, `c` AS `gen_attr_4`, `d` AS `gen_attr_6` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_0`) AS parquet_t2 diff --git a/sql/hive/src/test/resources/sqlgen/nested_generator_in_lateral_view_1.sql b/sql/hive/src/test/resources/sqlgen/nested_generator_in_lateral_view_1.sql new file mode 100644 index 0000000000000..e681c2b6354c0 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/nested_generator_in_lateral_view_1.sql @@ -0,0 +1,7 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT val, id +FROM parquet_t3 +LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array +LATERAL VIEW EXPLODE(nested_array) exp1 AS val +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/nested_generator_in_lateral_view_2.sql b/sql/hive/src/test/resources/sqlgen/nested_generator_in_lateral_view_2.sql new file mode 100644 index 0000000000000..e9d6522c91680 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/nested_generator_in_lateral_view_2.sql @@ -0,0 +1,7 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT val, id +FROM parquet_t3 +LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array +LATERAL VIEW OUTER EXPLODE(nested_array) exp1 AS val +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `val`, `gen_attr_1` AS `id` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `arr` AS `gen_attr_4`, `arr2` AS `gen_attr_3`, `json` AS `gen_attr_5`, `id` AS `gen_attr_1` FROM `default`.`parquet_t3`) AS gen_subquery_0 LATERAL VIEW explode(`gen_attr_3`) gen_subquery_2 AS `gen_attr_2` LATERAL VIEW OUTER explode(`gen_attr_2`) gen_subquery_3 AS `gen_attr_0`) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/not_in.sql b/sql/hive/src/test/resources/sqlgen/not_in.sql new file mode 100644 index 0000000000000..797d22e8e9154 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/not_in.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM t0 WHERE id NOT IN (1, 2, 3) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0`) AS gen_subquery_0 WHERE (NOT (CAST(`gen_attr_0` AS BIGINT) IN (CAST(1 AS BIGINT), CAST(2 AS BIGINT), CAST(3 AS BIGINT))))) AS t0 diff --git a/sql/hive/src/test/resources/sqlgen/not_like.sql b/sql/hive/src/test/resources/sqlgen/not_like.sql new file mode 100644 index 0000000000000..22485045e212e --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/not_like.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%' +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0`) AS gen_subquery_0 WHERE (NOT CAST((`gen_attr_0` + CAST(5 AS BIGINT)) AS STRING) LIKE '1%')) AS t0 diff --git a/sql/hive/src/test/resources/sqlgen/predicate_subquery.sql b/sql/hive/src/test/resources/sqlgen/predicate_subquery.sql new file mode 100644 index 0000000000000..6e5bd9860008c --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/predicate_subquery.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * from t1 b where exists (select * from t1 a) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a` FROM (SELECT `gen_attr_0` FROM (SELECT `a` AS `gen_attr_0` FROM `default`.`t1`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_1` AS `a` FROM ((SELECT `gen_attr_1` FROM (SELECT `a` AS `gen_attr_1` FROM `default`.`t1`) AS gen_subquery_2) AS gen_subquery_1) AS gen_subquery_1)) AS b diff --git a/sql/hive/src/test/resources/sqlgen/range.sql b/sql/hive/src/test/resources/sqlgen/range.sql new file mode 100644 index 0000000000000..53c72ea71e6ac --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/range.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * from range(100) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(0, 100, 1)) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/range_with_splits.sql b/sql/hive/src/test/resources/sqlgen/range_with_splits.sql new file mode 100644 index 0000000000000..83d637d54a302 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/range_with_splits.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * from range(1, 100, 20, 10) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT id AS `gen_attr_0` FROM range(1, 100, 20, 10)) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/regular_expressions_and_window.sql b/sql/hive/src/test/resources/sqlgen/regular_expressions_and_window.sql new file mode 100644 index 0000000000000..37cd5568baa7f --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/regular_expressions_and_window.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT MAX(key) OVER (PARTITION BY key % 3) + key FROM parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `(max(key) OVER (PARTITION BY (key % CAST(3 AS BIGINT)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) + key)` FROM (SELECT (`gen_attr_1` + `gen_attr_2`) AS `gen_attr_0` FROM (SELECT gen_subquery_1.`gen_attr_2`, gen_subquery_1.`gen_attr_3`, max(`gen_attr_2`) OVER (PARTITION BY `gen_attr_3` ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `gen_attr_1` FROM (SELECT `gen_attr_2`, (`gen_attr_2` % CAST(3 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_1_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_1_1.sql new file mode 100644 index 0000000000000..c54963ab5c550 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_1_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_2` AS `cnt`, `gen_attr_3` AS `(key % CAST(5 AS BIGINT))`, `gen_attr_4` AS `grouping_id()` FROM (SELECT count(1) AS `gen_attr_2`, (`gen_attr_5` % CAST(5 AS BIGINT)) AS `gen_attr_3`, grouping_id() AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_6` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY (`gen_attr_5` % CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_5` % CAST(5 AS BIGINT))), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_1_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_1_2.sql new file mode 100644 index 0000000000000..6c869063c3bec --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_1_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE +-------------------------------------------------------------------------------- +SELECT `gen_attr_2` AS `cnt`, `gen_attr_3` AS `(key % CAST(5 AS BIGINT))`, `gen_attr_4` AS `grouping_id()` FROM (SELECT count(1) AS `gen_attr_2`, (`gen_attr_5` % CAST(5 AS BIGINT)) AS `gen_attr_3`, grouping_id() AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_6` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY (`gen_attr_5` % CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_5` % CAST(5 AS BIGINT))), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_2_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_2_1.sql new file mode 100644 index 0000000000000..9628e38572940 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_2_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_3` AS `count(value)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_4` AS `gen_attr_1`, count(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_4` GROUPING SETS((`gen_attr_5`, `gen_attr_4`), (`gen_attr_5`), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_2_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_2_2.sql new file mode 100644 index 0000000000000..d6b61929df0ad --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_2_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_3` AS `count(value)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_4` AS `gen_attr_1`, count(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_4` GROUPING SETS((`gen_attr_5`, `gen_attr_4`), (`gen_attr_5`), (`gen_attr_4`), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_3_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_3_1.sql new file mode 100644 index 0000000000000..d04b6578fc1ce --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_3_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_3` AS `count(value)`, `gen_attr_5` AS `grouping_id()` FROM (SELECT `gen_attr_6` AS `gen_attr_0`, count(`gen_attr_4`) AS `gen_attr_3`, grouping_id() AS `gen_attr_5` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_6`, `gen_attr_4` GROUPING SETS((`gen_attr_6`, `gen_attr_4`), (`gen_attr_6`), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_3_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_3_2.sql new file mode 100644 index 0000000000000..80a5d93438f2a --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_3_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_3` AS `count(value)`, `gen_attr_5` AS `grouping_id()` FROM (SELECT `gen_attr_6` AS `gen_attr_0`, count(`gen_attr_4`) AS `gen_attr_3`, grouping_id() AS `gen_attr_5` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_6`, `gen_attr_4` GROUPING SETS((`gen_attr_6`, `gen_attr_4`), (`gen_attr_6`), (`gen_attr_4`), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_4_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_4_1.sql new file mode 100644 index 0000000000000..619a554875ff0 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_4_1.sql @@ -0,0 +1,5 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 +GROUP BY key % 5, key - 5 WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `cnt`, `gen_attr_4` AS `k1`, `gen_attr_5` AS `k2`, `gen_attr_6` AS `grouping_id()` FROM (SELECT count(1) AS `gen_attr_3`, (`gen_attr_7` % CAST(5 AS BIGINT)) AS `gen_attr_4`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_5`, grouping_id() AS `gen_attr_6` FROM (SELECT `key` AS `gen_attr_7`, `value` AS `gen_attr_8` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY (`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT))), ((`gen_attr_7` % CAST(5 AS BIGINT))), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_4_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_4_2.sql new file mode 100644 index 0000000000000..8bf164519165c --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_4_2.sql @@ -0,0 +1,5 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 +GROUP BY key % 5, key - 5 WITH CUBE +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `cnt`, `gen_attr_4` AS `k1`, `gen_attr_5` AS `k2`, `gen_attr_6` AS `grouping_id()` FROM (SELECT count(1) AS `gen_attr_3`, (`gen_attr_7` % CAST(5 AS BIGINT)) AS `gen_attr_4`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_5`, grouping_id() AS `gen_attr_6` FROM (SELECT `key` AS `gen_attr_7`, `value` AS `gen_attr_8` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY (`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT))), ((`gen_attr_7` % CAST(5 AS BIGINT))), ((`gen_attr_7` - CAST(5 AS BIGINT))), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_5_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_5_1.sql new file mode 100644 index 0000000000000..17e78a0a706a5 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_5_1.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 +FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5 +WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `cnt`, `gen_attr_4` AS `k1`, `gen_attr_5` AS `k2`, `gen_attr_6` AS `k3` FROM (SELECT count(1) AS `gen_attr_3`, (`gen_attr_7` % CAST(5 AS BIGINT)) AS `gen_attr_4`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_5`, grouping_id() AS `gen_attr_6` FROM (SELECT `gen_attr_7`, (`gen_attr_7` % CAST(2 AS BIGINT)) AS `gen_attr_8`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_9` FROM (SELECT `key` AS `gen_attr_7`, `value` AS `gen_attr_12` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS t GROUP BY (`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT))), ((`gen_attr_7` % CAST(5 AS BIGINT))), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_5_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_5_2.sql new file mode 100644 index 0000000000000..72506ef72aecd --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_5_2.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 +FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 +WITH CUBE +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `cnt`, `gen_attr_4` AS `k1`, `gen_attr_5` AS `k2`, `gen_attr_6` AS `k3` FROM (SELECT count(1) AS `gen_attr_3`, (`gen_attr_7` % CAST(5 AS BIGINT)) AS `gen_attr_4`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_5`, grouping_id() AS `gen_attr_6` FROM (SELECT `gen_attr_7`, (`gen_attr_7` % CAST(2 AS BIGINT)) AS `gen_attr_8`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_9` FROM (SELECT `key` AS `gen_attr_7`, `value` AS `gen_attr_12` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS t GROUP BY (`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT)) GROUPING SETS(((`gen_attr_7` % CAST(5 AS BIGINT)), (`gen_attr_7` - CAST(5 AS BIGINT))), ((`gen_attr_7` % CAST(5 AS BIGINT))), ((`gen_attr_7` - CAST(5 AS BIGINT))), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_1.sql new file mode 100644 index 0000000000000..22df578518ef3 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`, `gen_attr_6`), (`gen_attr_5`), ()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_2.sql new file mode 100644 index 0000000000000..f44b652343acb --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(c)` FROM (SELECT `gen_attr_5` AS `gen_attr_0`, `gen_attr_6` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_4`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_5`, `gen_attr_6` GROUPING SETS((`gen_attr_5`, `gen_attr_6`), (`gen_attr_5`), (`gen_attr_6`), ()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_3.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_3.sql new file mode 100644 index 0000000000000..40f6924913765 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_3.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(a)` FROM (SELECT `gen_attr_4` AS `gen_attr_0`, `gen_attr_5` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_4`, `b` AS `gen_attr_5`, `c` AS `gen_attr_6`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_4`, `gen_attr_5` GROUPING SETS((`gen_attr_4`, `gen_attr_5`), (`gen_attr_4`), ()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_4.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_4.sql new file mode 100644 index 0000000000000..608e644dee6d0 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_4.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `sum(a)` FROM (SELECT `gen_attr_4` AS `gen_attr_0`, `gen_attr_5` AS `gen_attr_1`, sum(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_4`, `b` AS `gen_attr_5`, `c` AS `gen_attr_6`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_4`, `gen_attr_5` GROUPING SETS((`gen_attr_4`, `gen_attr_5`), (`gen_attr_4`), (`gen_attr_5`), ()) ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_5.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_5.sql new file mode 100644 index 0000000000000..26885a26e2b96 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_5.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `(a + b)`, `gen_attr_1` AS `b`, `gen_attr_4` AS `sum((a - b))` FROM (SELECT (`gen_attr_5` + `gen_attr_6`) AS `gen_attr_3`, `gen_attr_6` AS `gen_attr_1`, sum((`gen_attr_5` - `gen_attr_6`)) AS `gen_attr_4` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_7`, `d` AS `gen_attr_8` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY (`gen_attr_5` + `gen_attr_6`), `gen_attr_6` GROUPING SETS(((`gen_attr_5` + `gen_attr_6`), `gen_attr_6`), ((`gen_attr_5` + `gen_attr_6`)), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_6_6.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_6.sql new file mode 100644 index 0000000000000..dd97c976afe61 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_6_6.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `(a + b)`, `gen_attr_1` AS `b`, `gen_attr_4` AS `sum((a - b))` FROM (SELECT (`gen_attr_5` + `gen_attr_6`) AS `gen_attr_3`, `gen_attr_6` AS `gen_attr_1`, sum((`gen_attr_5` - `gen_attr_6`)) AS `gen_attr_4` FROM (SELECT `a` AS `gen_attr_5`, `b` AS `gen_attr_6`, `c` AS `gen_attr_7`, `d` AS `gen_attr_8` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY (`gen_attr_5` + `gen_attr_6`), `gen_attr_6` GROUPING SETS(((`gen_attr_5` + `gen_attr_6`), `gen_attr_6`), ((`gen_attr_5` + `gen_attr_6`)), (`gen_attr_6`), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_7_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_7_1.sql new file mode 100644 index 0000000000000..aae2d75d794be --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_7_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `grouping_id(a, b)` FROM (SELECT `gen_attr_4` AS `gen_attr_0`, `gen_attr_5` AS `gen_attr_1`, grouping_id() AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_4`, `b` AS `gen_attr_5`, `c` AS `gen_attr_6`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_4`, `gen_attr_5` GROUPING SETS((`gen_attr_4`, `gen_attr_5`), (`gen_attr_4`), (`gen_attr_5`), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_7_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_7_2.sql new file mode 100644 index 0000000000000..9958c8f38bc87 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_7_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `grouping(b)` FROM (SELECT `gen_attr_4` AS `gen_attr_0`, `gen_attr_5` AS `gen_attr_1`, grouping(`gen_attr_5`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_4`, `b` AS `gen_attr_5`, `c` AS `gen_attr_6`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_4`, `gen_attr_5` GROUPING SETS((`gen_attr_4`, `gen_attr_5`), (`gen_attr_4`), (`gen_attr_5`), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_7_3.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_7_3.sql new file mode 100644 index 0000000000000..fd012043cf6cb --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_7_3.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `a`, `gen_attr_1` AS `b`, `gen_attr_3` AS `grouping(a)` FROM (SELECT `gen_attr_4` AS `gen_attr_0`, `gen_attr_5` AS `gen_attr_1`, grouping(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `a` AS `gen_attr_4`, `b` AS `gen_attr_5`, `c` AS `gen_attr_6`, `d` AS `gen_attr_7` FROM `default`.`parquet_t2`) AS gen_subquery_0 GROUP BY `gen_attr_4`, `gen_attr_5` GROUPING SETS((`gen_attr_4`, `gen_attr_5`), (`gen_attr_4`), (`gen_attr_5`), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_8_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_8_1.sql new file mode 100644 index 0000000000000..61c27067e1521 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_8_1.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid +FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 +WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `k1`, `gen_attr_4` AS `k2`, `gen_attr_5` AS `hgid` FROM (SELECT `gen_attr_6` AS `gen_attr_3`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_4`, hash(grouping_id()) AS `gen_attr_5` FROM (SELECT hash(`gen_attr_10`) AS `gen_attr_6`, `gen_attr_10` AS `gen_attr_7` FROM (SELECT `key` AS `gen_attr_10`, `value` AS `gen_attr_11` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS t GROUP BY `gen_attr_6`, (`gen_attr_7` - CAST(5 AS BIGINT)) GROUPING SETS((`gen_attr_6`, (`gen_attr_7` - CAST(5 AS BIGINT))), (`gen_attr_6`), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_8_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_8_2.sql new file mode 100644 index 0000000000000..16f254fa41f78 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_8_2.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid +FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 +WITH CUBE +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `k1`, `gen_attr_4` AS `k2`, `gen_attr_5` AS `hgid` FROM (SELECT `gen_attr_6` AS `gen_attr_3`, (`gen_attr_7` - CAST(5 AS BIGINT)) AS `gen_attr_4`, hash(grouping_id()) AS `gen_attr_5` FROM (SELECT hash(`gen_attr_10`) AS `gen_attr_6`, `gen_attr_10` AS `gen_attr_7` FROM (SELECT `key` AS `gen_attr_10`, `value` AS `gen_attr_11` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS t GROUP BY `gen_attr_6`, (`gen_attr_7` - CAST(5 AS BIGINT)) GROUPING SETS((`gen_attr_6`, (`gen_attr_7` - CAST(5 AS BIGINT))), (`gen_attr_6`), ((`gen_attr_7` - CAST(5 AS BIGINT))), ())) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_9_1.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_9_1.sql new file mode 100644 index 0000000000000..cfce1758434de --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_9_1.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT t.key - 5, cnt, SUM(cnt) +FROM (SELECT x.key, COUNT(*) as cnt +FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t +GROUP BY cnt, t.key - 5 +WITH ROLLUP +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `(key - CAST(5 AS BIGINT))`, `gen_attr_0` AS `cnt`, `gen_attr_4` AS `sum(cnt)` FROM (SELECT (`gen_attr_6` - CAST(5 AS BIGINT)) AS `gen_attr_3`, `gen_attr_5` AS `gen_attr_0`, sum(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `gen_attr_6`, count(1) AS `gen_attr_5` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_10` FROM `default`.`parquet_t1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`parquet_t1`) AS gen_subquery_1 ON (`gen_attr_6` = `gen_attr_9`) GROUP BY `gen_attr_6`) AS t GROUP BY `gen_attr_5`, (`gen_attr_6` - CAST(5 AS BIGINT)) GROUPING SETS((`gen_attr_5`, (`gen_attr_6` - CAST(5 AS BIGINT))), (`gen_attr_5`), ())) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/rollup_cube_9_2.sql b/sql/hive/src/test/resources/sqlgen/rollup_cube_9_2.sql new file mode 100644 index 0000000000000..d950674b74c19 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/rollup_cube_9_2.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT t.key - 5, cnt, SUM(cnt) +FROM (SELECT x.key, COUNT(*) as cnt +FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t +GROUP BY cnt, t.key - 5 +WITH CUBE +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `(key - CAST(5 AS BIGINT))`, `gen_attr_0` AS `cnt`, `gen_attr_4` AS `sum(cnt)` FROM (SELECT (`gen_attr_6` - CAST(5 AS BIGINT)) AS `gen_attr_3`, `gen_attr_5` AS `gen_attr_0`, sum(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `gen_attr_6`, count(1) AS `gen_attr_5` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_10` FROM `default`.`parquet_t1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_9`, `value` AS `gen_attr_11` FROM `default`.`parquet_t1`) AS gen_subquery_1 ON (`gen_attr_6` = `gen_attr_9`) GROUP BY `gen_attr_6`) AS t GROUP BY `gen_attr_5`, (`gen_attr_6` - CAST(5 AS BIGINT)) GROUPING SETS((`gen_attr_5`, (`gen_attr_6` - CAST(5 AS BIGINT))), (`gen_attr_5`), ((`gen_attr_6` - CAST(5 AS BIGINT))), ())) AS gen_subquery_2 diff --git a/sql/hive/src/test/resources/sqlgen/script_transformation_1.sql b/sql/hive/src/test/resources/sqlgen/script_transformation_1.sql new file mode 100644 index 0000000000000..1736d74b0cfa9 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/script_transformation_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT TRANSFORM (a, b, c, d) USING 'cat' FROM parquet_t2 +-------------------------------------------------------------------------------- +SELECT `gen_attr_4` AS `key`, `gen_attr_5` AS `value` FROM (SELECT TRANSFORM (`gen_attr_0`, `gen_attr_1`, `gen_attr_2`, `gen_attr_3`) USING 'cat' AS (`gen_attr_4` string, `gen_attr_5` string) FROM (SELECT `a` AS `gen_attr_0`, `b` AS `gen_attr_1`, `c` AS `gen_attr_2`, `d` AS `gen_attr_3` FROM `default`.`parquet_t2`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/script_transformation_2.sql b/sql/hive/src/test/resources/sqlgen/script_transformation_2.sql new file mode 100644 index 0000000000000..07f59d6bffddc --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/script_transformation_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT TRANSFORM (*) USING 'cat' FROM parquet_t2 +-------------------------------------------------------------------------------- +SELECT `gen_attr_4` AS `key`, `gen_attr_5` AS `value` FROM (SELECT TRANSFORM (`gen_attr_0`, `gen_attr_1`, `gen_attr_2`, `gen_attr_3`) USING 'cat' AS (`gen_attr_4` string, `gen_attr_5` string) FROM (SELECT `a` AS `gen_attr_0`, `b` AS `gen_attr_1`, `c` AS `gen_attr_2`, `d` AS `gen_attr_3` FROM `default`.`parquet_t2`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/script_transformation_alias_list.sql b/sql/hive/src/test/resources/sqlgen/script_transformation_alias_list.sql new file mode 100644 index 0000000000000..fc0cabec237bc --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/script_transformation_alias_list.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT TRANSFORM (a, b, c, d) USING 'cat' AS (d1, d2, d3, d4) FROM parquet_t2 +-------------------------------------------------------------------------------- +SELECT `gen_attr_4` AS `d1`, `gen_attr_5` AS `d2`, `gen_attr_6` AS `d3`, `gen_attr_7` AS `d4` FROM (SELECT TRANSFORM (`gen_attr_0`, `gen_attr_1`, `gen_attr_2`, `gen_attr_3`) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES('field.delim' = ' ') USING 'cat' AS (`gen_attr_4` string, `gen_attr_5` string, `gen_attr_6` string, `gen_attr_7` string) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES('field.delim' = ' ') FROM (SELECT `a` AS `gen_attr_0`, `b` AS `gen_attr_1`, `c` AS `gen_attr_2`, `d` AS `gen_attr_3` FROM `default`.`parquet_t2`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/script_transformation_alias_list_with_type.sql b/sql/hive/src/test/resources/sqlgen/script_transformation_alias_list_with_type.sql new file mode 100644 index 0000000000000..a45f9a2c625f6 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/script_transformation_alias_list_with_type.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +FROM +(FROM parquet_t1 SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t +SELECT thing1 + 1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `(thing1 + 1)` FROM (SELECT (`gen_attr_1` + 1) AS `gen_attr_0` FROM (SELECT TRANSFORM (`gen_attr_2`, `gen_attr_3`) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES('field.delim' = ' ') USING 'cat' AS (`gen_attr_1` int, `gen_attr_4` string) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES('field.delim' = ' ') FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS t) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_multiple.sql b/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_multiple.sql new file mode 100644 index 0000000000000..30d37c78b58e1 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_multiple.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT TRANSFORM (key) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t' +USING 'cat' AS (tKey) +ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t' +FROM parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_1` AS `tKey` FROM (SELECT TRANSFORM (`gen_attr_0`) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t' USING 'cat' AS (`gen_attr_1` string) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t' FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_one.sql b/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_one.sql new file mode 100644 index 0000000000000..0b694e0d6dafa --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_one.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT TRANSFORM (key) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +USING 'cat' AS (tKey) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' +FROM parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_1` AS `tKey` FROM (SELECT TRANSFORM (`gen_attr_0`) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' USING 'cat' AS (`gen_attr_1` string) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_serde.sql b/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_serde.sql new file mode 100644 index 0000000000000..14cff373852dd --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_serde.sql @@ -0,0 +1,10 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT TRANSFORM (key, value) +ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +WITH SERDEPROPERTIES('field.delim' = '|') +USING 'cat' AS (tKey, tValue) +ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +WITH SERDEPROPERTIES('field.delim' = '|') +FROM parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_2` AS `tKey`, `gen_attr_3` AS `tValue` FROM (SELECT TRANSFORM (`gen_attr_0`, `gen_attr_1`) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES('field.delim' = '|') USING 'cat' AS (`gen_attr_2` string, `gen_attr_3` string) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES('field.delim' = '|') FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_without_serde.sql b/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_without_serde.sql new file mode 100644 index 0000000000000..d20caf7afcf0f --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/script_transformation_row_format_without_serde.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT TRANSFORM (key, value) +ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +USING 'cat' AS (tKey, tValue) +ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +FROM parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_2` AS `tKey`, `gen_attr_3` AS `tValue` FROM (SELECT TRANSFORM (`gen_attr_0`, `gen_attr_1`) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' USING 'cat' AS (`gen_attr_2` string, `gen_attr_3` string) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1 diff --git a/sql/hive/src/test/resources/sqlgen/select_distinct.sql b/sql/hive/src/test/resources/sqlgen/select_distinct.sql new file mode 100644 index 0000000000000..09d93cac8e5fd --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/select_distinct.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT DISTINCT id FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT DISTINCT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/select_orc_table.sql b/sql/hive/src/test/resources/sqlgen/select_orc_table.sql new file mode 100644 index 0000000000000..18ff021798972 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/select_orc_table.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * from orc_t +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `c1`, `gen_attr_1` AS `c2` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `c1` AS `gen_attr_0`, `c2` AS `gen_attr_1` FROM `default`.`orc_t`) AS gen_subquery_0) AS orc_t diff --git a/sql/hive/src/test/resources/sqlgen/select_parquet_table.sql b/sql/hive/src/test/resources/sqlgen/select_parquet_table.sql new file mode 100644 index 0000000000000..d2eac9c08f56c --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/select_parquet_table.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * from parquet_t +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `c1`, `gen_attr_1` AS `c2` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `c1` AS `gen_attr_0`, `c2` AS `gen_attr_1` FROM `default`.`parquet_t`) AS gen_subquery_0) AS parquet_t diff --git a/sql/hive/src/test/resources/sqlgen/self_join.sql b/sql/hive/src/test/resources/sqlgen/self_join.sql new file mode 100644 index 0000000000000..d6dcee2f67dbd --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/self_join.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT x.key FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key` FROM (SELECT `gen_attr_0` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_2` FROM `default`.`parquet_t1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_1 ON (`gen_attr_0` = `gen_attr_1`)) AS x diff --git a/sql/hive/src/test/resources/sqlgen/self_join_with_group_by.sql b/sql/hive/src/test/resources/sqlgen/self_join_with_group_by.sql new file mode 100644 index 0000000000000..1dedb44dbff65 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/self_join_with_group_by.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT x.key, COUNT(*) FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key group by x.key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_1 ON (`gen_attr_0` = `gen_attr_2`) GROUP BY `gen_attr_0`) AS x diff --git a/sql/hive/src/test/resources/sqlgen/sort_by_after_having.sql b/sql/hive/src/test/resources/sqlgen/sort_by_after_having.sql new file mode 100644 index 0000000000000..da60204297a21 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/sort_by_after_having.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0 SORT BY key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `count(value)` FROM (SELECT `gen_attr_0` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT count(`gen_attr_3`) AS `gen_attr_0`, max(`gen_attr_1`) AS `gen_attr_2`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_3` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_1` HAVING (`gen_attr_2` > CAST(0 AS BIGINT))) AS gen_subquery_1 SORT BY `gen_attr_1` ASC) AS gen_subquery_2) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/subq2.sql b/sql/hive/src/test/resources/sqlgen/subq2.sql new file mode 100644 index 0000000000000..ee7e80c1fc9e2 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subq2.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT a.k, a.c +FROM (SELECT b.key as k, count(1) as c + FROM src b + GROUP BY b.key) a +WHERE a.k >= 90 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `k`, `gen_attr_1` AS `c` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_2` AS `gen_attr_0`, count(1) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_2`) AS a WHERE (`gen_attr_0` >= 90)) AS a diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql new file mode 100644 index 0000000000000..bd28d8dca94c2 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_1.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * +from src b +where exists (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_9') +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_9')) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql new file mode 100644 index 0000000000000..d2965fc0b9b77 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_2.sql @@ -0,0 +1,9 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * +from (select * + from src b + where exists (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_9')) a +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_9')) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3)) AS a) AS a diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql new file mode 100644 index 0000000000000..93ce902b75994 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_1.sql @@ -0,0 +1,9 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select b.key, count(*) +from src b +group by b.key +having exists (select a.key + from src a + where a.key = b.key and a.value > 'val_9') +-------------------------------------------------------------------------------- +SELECT `gen_attr_1` AS `key`, `gen_attr_2` AS `count(1)` FROM (SELECT `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_1`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_1` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_0` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > 'val_9')) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3)) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql new file mode 100644 index 0000000000000..411e073f0d280 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_2.sql @@ -0,0 +1,10 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * +from (select b.key, count(*) + from src b + group by b.key + having exists (select a.key + from src a + where a.key = b.key and a.value > 'val_9')) a +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_5` > 'val_9')) AS gen_subquery_1 WHERE (`gen_attr_2` = `gen_attr_0`)) AS gen_subquery_3)) AS a) AS a diff --git a/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql new file mode 100644 index 0000000000000..b2ed0b0557aff --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_exists_having_3.sql @@ -0,0 +1,9 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select b.key, min(b.value) +from src b +group by b.key +having exists (select a.key + from src a + where a.value > 'val_9' and a.value = min(b.value)) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_4`) AS `gen_attr_1`, min(`gen_attr_4`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING EXISTS(SELECT `gen_attr_5` AS `1` FROM (SELECT 1 AS `gen_attr_5` FROM (SELECT `gen_attr_6`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_2` > 'val_9')) AS gen_subquery_2 WHERE (`gen_attr_2` = `gen_attr_3`)) AS gen_subquery_4)) AS gen_subquery_1) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in.sql b/sql/hive/src/test/resources/sqlgen/subquery_in.sql new file mode 100644 index 0000000000000..0fe62248dbfec --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_in.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key +FROM src +WHERE key in (SELECT max(key) FROM src) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key` FROM (SELECT `gen_attr_0` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_0 WHERE (`gen_attr_0` IN (SELECT `gen_attr_3` AS `_c0` FROM (SELECT `gen_attr_1` AS `gen_attr_3` FROM (SELECT max(`gen_attr_4`) AS `gen_attr_1` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_2) AS gen_subquery_1) AS gen_subquery_3))) AS src diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql new file mode 100644 index 0000000000000..9894f5ab39c76 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_1.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select key, count(*) +from src +group by key +having count(*) in (select count(*) from src s1 where s1.key = '90' group by s1.key) +order by key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `count(1)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, count(1) AS `gen_attr_1`, count(1) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (`gen_attr_2` IN (SELECT `gen_attr_5` AS `_c0` FROM (SELECT `gen_attr_3` AS `gen_attr_5` FROM (SELECT count(1) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_6`, `value` AS `gen_attr_7` FROM `default`.`src`) AS gen_subquery_3 WHERE (CAST(`gen_attr_6` AS DOUBLE) = CAST('90' AS DOUBLE)) GROUP BY `gen_attr_6`) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS src diff --git a/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql new file mode 100644 index 0000000000000..c3a122aa889b9 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_in_having_2.sql @@ -0,0 +1,10 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select b.key, min(b.value) +from src b +group by b.key +having b.key in (select a.key + from src a + where a.value > 'val_9' and a.value = min(b.value)) +order by b.key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `min(value)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `gen_attr_0`, min(`gen_attr_5`) AS `gen_attr_1`, min(`gen_attr_5`) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_5` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_0` HAVING (struct(`gen_attr_0`, `gen_attr_4`) IN (SELECT `gen_attr_6` AS `_c0`, `gen_attr_7` AS `_c1` FROM (SELECT `gen_attr_2` AS `gen_attr_6`, `gen_attr_3` AS `gen_attr_7` FROM (SELECT `gen_attr_2`, `gen_attr_3` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_3` FROM `default`.`src`) AS gen_subquery_3 WHERE (`gen_attr_3` > 'val_9')) AS gen_subquery_2) AS gen_subquery_4))) AS gen_subquery_1 ORDER BY `gen_attr_0` ASC) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql new file mode 100644 index 0000000000000..eed20a5d311f3 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_1.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * +from src b +where not exists (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_2') +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_3`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_2')) AS gen_subquery_1 WHERE ((`gen_attr_1` = `gen_attr_2`) AND (`gen_attr_3` = `gen_attr_0`))) AS gen_subquery_3))) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql new file mode 100644 index 0000000000000..7040e106e7ba2 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_2.sql @@ -0,0 +1,8 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * +from src b +where not exists (select a.key + from src a + where b.value = a.value and a.value > 'val_2') +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_0 WHERE (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT `gen_attr_4`, `gen_attr_2` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_2` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_2` > 'val_2')) AS gen_subquery_1 WHERE (`gen_attr_1` = `gen_attr_2`)) AS gen_subquery_3))) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql new file mode 100644 index 0000000000000..3c0e90ed42223 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_1.sql @@ -0,0 +1,9 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * +from src b +group by key, value +having not exists (select a.key + from src a + where b.value = a.value and a.key = b.key and a.value > 'val_12') +-------------------------------------------------------------------------------- +SELECT `gen_attr_3` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_3`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_3`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_4` AS `1` FROM (SELECT 1 AS `gen_attr_4` FROM (SELECT `gen_attr_2`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > 'val_12')) AS gen_subquery_1 WHERE ((`gen_attr_0` = `gen_attr_1`) AND (`gen_attr_2` = `gen_attr_3`))) AS gen_subquery_3))) AS b diff --git a/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql new file mode 100644 index 0000000000000..0c16f9e58b9b9 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/subquery_not_exists_having_2.sql @@ -0,0 +1,9 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +select * +from src b +group by key, value +having not exists (select distinct a.key + from src a + where b.value = a.value and a.value > 'val_12') +-------------------------------------------------------------------------------- +SELECT `gen_attr_2` AS `key`, `gen_attr_0` AS `value` FROM (SELECT `gen_attr_2`, `gen_attr_0` FROM (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_0` FROM `default`.`src`) AS gen_subquery_0 GROUP BY `gen_attr_2`, `gen_attr_0` HAVING (NOT EXISTS(SELECT `gen_attr_3` AS `1` FROM (SELECT 1 AS `gen_attr_3` FROM (SELECT DISTINCT `gen_attr_4`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_4`, `value` AS `gen_attr_1` FROM `default`.`src`) AS gen_subquery_2 WHERE (`gen_attr_1` > 'val_12')) AS gen_subquery_1 WHERE (`gen_attr_0` = `gen_attr_1`)) AS gen_subquery_3))) AS b diff --git a/sql/hive/src/test/resources/sqlgen/tablesample_1.sql b/sql/hive/src/test/resources/sqlgen/tablesample_1.sql new file mode 100644 index 0000000000000..291f2f59d7378 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/tablesample_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT s.id FROM parquet_t0 TABLESAMPLE(100 PERCENT) s +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0` TABLESAMPLE(100.0 PERCENT)) AS gen_subquery_0) AS s diff --git a/sql/hive/src/test/resources/sqlgen/tablesample_2.sql b/sql/hive/src/test/resources/sqlgen/tablesample_2.sql new file mode 100644 index 0000000000000..6a92d7aef72f1 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/tablesample_2.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM parquet_t0 TABLESAMPLE(100 PERCENT) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0` TABLESAMPLE(100.0 PERCENT)) AS gen_subquery_0) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/tablesample_3.sql b/sql/hive/src/test/resources/sqlgen/tablesample_3.sql new file mode 100644 index 0000000000000..4a17d7105eec6 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/tablesample_3.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT s.id FROM t0 TABLESAMPLE(100 PERCENT) s +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0` TABLESAMPLE(100.0 PERCENT)) AS gen_subquery_0) AS s diff --git a/sql/hive/src/test/resources/sqlgen/tablesample_4.sql b/sql/hive/src/test/resources/sqlgen/tablesample_4.sql new file mode 100644 index 0000000000000..873de051a6bd5 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/tablesample_4.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM t0 TABLESAMPLE(100 PERCENT) +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0` TABLESAMPLE(100.0 PERCENT)) AS gen_subquery_0) AS t0 diff --git a/sql/hive/src/test/resources/sqlgen/tablesample_5.sql b/sql/hive/src/test/resources/sqlgen/tablesample_5.sql new file mode 100644 index 0000000000000..f958b2f111ba2 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/tablesample_5.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT s.id FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) s WHERE 1=0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0` TABLESAMPLE(0.1 PERCENT)) AS gen_subquery_0 WHERE (1 = 0)) AS s diff --git a/sql/hive/src/test/resources/sqlgen/tablesample_6.sql b/sql/hive/src/test/resources/sqlgen/tablesample_6.sql new file mode 100644 index 0000000000000..688a102d1da4e --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/tablesample_6.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) WHERE 1=0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0` TABLESAMPLE(0.1 PERCENT)) AS gen_subquery_0 WHERE (1 = 0)) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/three_child_union.sql b/sql/hive/src/test/resources/sqlgen/three_child_union.sql new file mode 100644 index 0000000000000..713c7502f5a1a --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/three_child_union.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM parquet_t0 +UNION ALL SELECT id FROM parquet_t0 +UNION ALL SELECT id FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM ((SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0) UNION ALL (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_1) UNION ALL (SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_2)) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/type_widening.sql b/sql/hive/src/test/resources/sqlgen/type_widening.sql new file mode 100644 index 0000000000000..ebb8a92afd345 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/type_widening.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT id FROM parquet_t0 UNION ALL SELECT CAST(id AS INT) AS id FROM parquet_t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM ((SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_0) UNION ALL (SELECT CAST(CAST(`gen_attr_0` AS INT) AS BIGINT) AS `gen_attr_1` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`parquet_t0`) AS gen_subquery_1)) AS parquet_t0 diff --git a/sql/hive/src/test/resources/sqlgen/union_distinct.sql b/sql/hive/src/test/resources/sqlgen/union_distinct.sql new file mode 100644 index 0000000000000..46644b89ebb04 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/union_distinct.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT * FROM t0 UNION SELECT * FROM t0 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `id` FROM ((SELECT `gen_attr_0` FROM (SELECT `id` AS `gen_attr_0` FROM `default`.`t0`) AS gen_subquery_0) UNION DISTINCT (SELECT `gen_attr_1` FROM (SELECT `id` AS `gen_attr_1` FROM `default`.`t0`) AS gen_subquery_1)) AS t0 diff --git a/sql/hive/src/test/resources/sqlgen/window_basic_1.sql b/sql/hive/src/test/resources/sqlgen/window_basic_1.sql new file mode 100644 index 0000000000000..000c4e735ac6e --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_basic_1.sql @@ -0,0 +1,4 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT MAX(value) OVER (PARTITION BY key % 3) FROM parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `max(value) OVER (PARTITION BY (key % CAST(3 AS BIGINT)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)` FROM (SELECT `gen_attr_0` FROM (SELECT gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_2`, max(`gen_attr_1`) OVER (PARTITION BY `gen_attr_2` ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `gen_attr_0` FROM (SELECT `gen_attr_1`, (`gen_attr_3` % CAST(3 AS BIGINT)) AS `gen_attr_2` FROM (SELECT `key` AS `gen_attr_3`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2) AS gen_subquery_3 diff --git a/sql/hive/src/test/resources/sqlgen/window_basic_2.sql b/sql/hive/src/test/resources/sqlgen/window_basic_2.sql new file mode 100644 index 0000000000000..ec55d4b7146f2 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_basic_2.sql @@ -0,0 +1,5 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, value, ROUND(AVG(key) OVER (), 2) +FROM parquet_t1 ORDER BY key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `round(avg(key) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 2)` FROM (SELECT `gen_attr_0`, `gen_attr_1`, round(`gen_attr_3`, 2) AS `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, avg(`gen_attr_0`) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `gen_attr_3` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2 ORDER BY `gen_attr_0` ASC) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_basic_3.sql b/sql/hive/src/test/resources/sqlgen/window_basic_3.sql new file mode 100644 index 0000000000000..c0ac9541e67ee --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_basic_3.sql @@ -0,0 +1,5 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT value, MAX(key + 1) OVER (PARTITION BY key % 5 ORDER BY key % 7) AS max +FROM parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `value`, `gen_attr_1` AS `max` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_2`, gen_subquery_1.`gen_attr_3`, gen_subquery_1.`gen_attr_4`, max(`gen_attr_2`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_4` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_1` FROM (SELECT `gen_attr_0`, (`gen_attr_5` + CAST(1 AS BIGINT)) AS `gen_attr_2`, (`gen_attr_5` % CAST(5 AS BIGINT)) AS `gen_attr_3`, (`gen_attr_5` % CAST(7 AS BIGINT)) AS `gen_attr_4` FROM (SELECT `key` AS `gen_attr_5`, `value` AS `gen_attr_0` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_with_join.sql b/sql/hive/src/test/resources/sqlgen/window_with_join.sql new file mode 100644 index 0000000000000..030a4c0907a1c --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_with_join.sql @@ -0,0 +1,5 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key) +FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `max(key) OVER (PARTITION BY (key % CAST(5 AS BIGINT)) ORDER BY key ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT gen_subquery_2.`gen_attr_0`, gen_subquery_2.`gen_attr_2`, gen_subquery_2.`gen_attr_3`, max(`gen_attr_2`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_1` FROM (SELECT `gen_attr_0`, `gen_attr_2`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_4` FROM `default`.`parquet_t1`) AS gen_subquery_0 INNER JOIN (SELECT `key` AS `gen_attr_2`, `value` AS `gen_attr_5` FROM `default`.`parquet_t1`) AS gen_subquery_1 ON (`gen_attr_0` = `gen_attr_2`)) AS gen_subquery_2) AS gen_subquery_3) AS x diff --git a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg.sql b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg.sql new file mode 100644 index 0000000000000..7b99539a05480 --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg.sql @@ -0,0 +1,7 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, value, +DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, +COUNT(key) +FROM parquet_t1 GROUP BY key, value +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `dr`, `gen_attr_3` AS `count(key)` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2`, `gen_attr_3` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, DENSE_RANK() OVER (PARTITION BY `gen_attr_0` ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, count(`gen_attr_0`) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_0`, `gen_attr_1`) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_filter.sql b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_filter.sql new file mode 100644 index 0000000000000..591a654a3888e --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_filter.sql @@ -0,0 +1,7 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, value, +DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, +COUNT(key) OVER(DISTRIBUTE BY key SORT BY key, value) AS ca +FROM parquet_t1 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `dr`, `gen_attr_3` AS `ca` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2`, `gen_attr_3` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, DENSE_RANK() OVER (PARTITION BY `gen_attr_0` ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2`, count(`gen_attr_0`) OVER (PARTITION BY `gen_attr_0` ORDER BY `gen_attr_0` ASC, `gen_attr_1` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_3` FROM (SELECT `gen_attr_0`, `gen_attr_1` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_functions.sql b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_functions.sql new file mode 100644 index 0000000000000..d9169eab6e46a --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_functions.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, value, +MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max +FROM parquet_t1 GROUP BY key, value +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `max` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, max(`gen_attr_1`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_0`, `gen_attr_1`) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_having.sql b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_having.sql new file mode 100644 index 0000000000000..f0a820811ee0a --- /dev/null +++ b/sql/hive/src/test/resources/sqlgen/window_with_the_same_window_with_agg_having.sql @@ -0,0 +1,6 @@ +-- This file is automatically generated by LogicalPlanToSQLSuite. +SELECT key, value, +MAX(value) OVER (PARTITION BY key % 5 ORDER BY key DESC) AS max +FROM parquet_t1 GROUP BY key, value HAVING key > 5 +-------------------------------------------------------------------------------- +SELECT `gen_attr_0` AS `key`, `gen_attr_1` AS `value`, `gen_attr_2` AS `max` FROM (SELECT `gen_attr_0`, `gen_attr_1`, `gen_attr_2` FROM (SELECT gen_subquery_1.`gen_attr_0`, gen_subquery_1.`gen_attr_1`, gen_subquery_1.`gen_attr_3`, max(`gen_attr_1`) OVER (PARTITION BY `gen_attr_3` ORDER BY `gen_attr_0` DESC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS `gen_attr_2` FROM (SELECT `gen_attr_0`, `gen_attr_1`, (`gen_attr_0` % CAST(5 AS BIGINT)) AS `gen_attr_3` FROM (SELECT `key` AS `gen_attr_0`, `value` AS `gen_attr_1` FROM `default`.`parquet_t1`) AS gen_subquery_0 GROUP BY `gen_attr_0`, `gen_attr_1` HAVING (`gen_attr_0` > CAST(5 AS BIGINT))) AS gen_subquery_1) AS gen_subquery_2) AS parquet_t1 diff --git a/sql/hive/src/test/resources/test_script.sh b/sql/hive/src/test/resources/test_script.sh new file mode 100755 index 0000000000000..eb0c50e98292c --- /dev/null +++ b/sql/hive/src/test/resources/test_script.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +while read line +do + echo "$line" | sed $'s/\t/_/' +done < /dev/stdin diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index c8bf20d13bdba..56c0736e8e462 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -20,19 +20,28 @@ package org.apache.spark.sql.catalyst import java.sql.Timestamp import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{If, Literal} - +import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFrame, TimeAdd, + TimeSub, WindowSpecDefinition} +import org.apache.spark.unsafe.types.CalendarInterval class ExpressionSQLBuilderSuite extends SQLBuilderTest { test("literal") { - checkSQL(Literal("foo"), "\"foo\"") - checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") + checkSQL(Literal("foo"), "'foo'") + checkSQL(Literal("\"foo\""), "'\"foo\"'") + checkSQL(Literal("'foo'"), "'\\'foo\\''") checkSQL(Literal(1: Byte), "1Y") checkSQL(Literal(2: Short), "2S") checkSQL(Literal(4: Int), "4") checkSQL(Literal(8: Long), "8L") checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") + checkSQL(Literal(Float.PositiveInfinity), "CAST('Infinity' AS FLOAT)") + checkSQL(Literal(Float.NegativeInfinity), "CAST('-Infinity' AS FLOAT)") + checkSQL(Literal(Float.NaN), "CAST('NaN' AS FLOAT)") checkSQL(Literal(2.5D), "2.5D") + checkSQL(Literal(Double.PositiveInfinity), "CAST('Infinity' AS DOUBLE)") + checkSQL(Literal(Double.NegativeInfinity), "CAST('-Infinity' AS DOUBLE)") + checkSQL(Literal(Double.NaN), "CAST('NaN' AS DOUBLE)") + checkSQL(Literal(BigDecimal("10.0000000").underlying), "10.0000000BD") checkSQL( Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')") // TODO tests for decimals @@ -76,7 +85,53 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL('a.int / 'b.int, "(`a` / `b`)") checkSQL('a.int % 'b.int, "(`a` % `b`)") - checkSQL(-'a.int, "(-`a`)") - checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))") + checkSQL(-'a.int, "(- `a`)") + checkSQL(-('a.int + 'b.int), "(- (`a` + `b`))") + } + + test("window specification") { + val frame = SpecifiedWindowFrame.defaultWindowFrame( + hasOrderSpecification = true, + acceptWindowFrame = true + ) + + checkSQL( + WindowSpecDefinition('a.int :: Nil, Nil, frame), + s"(PARTITION BY `a` $frame)" + ) + + checkSQL( + WindowSpecDefinition('a.int :: 'b.string :: Nil, Nil, frame), + s"(PARTITION BY `a`, `b` $frame)" + ) + + checkSQL( + WindowSpecDefinition(Nil, 'a.int.asc :: Nil, frame), + s"(ORDER BY `a` ASC $frame)" + ) + + checkSQL( + WindowSpecDefinition(Nil, 'a.int.asc :: 'b.string.desc :: Nil, frame), + s"(ORDER BY `a` ASC, `b` DESC $frame)" + ) + + checkSQL( + WindowSpecDefinition('a.int :: 'b.string :: Nil, 'c.int.asc :: 'd.string.desc :: Nil, frame), + s"(PARTITION BY `a`, `b` ORDER BY `c` ASC, `d` DESC $frame)" + ) + } + + test("interval arithmetic") { + val interval = Literal(new CalendarInterval(0, CalendarInterval.MICROS_PER_DAY)) + + checkSQL( + TimeAdd('a, interval), + "`a` + interval 1 days" + ) + + checkSQL( + TimeSub('a, interval), + "`a` - interval 1 days" + ) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index a7782abc39876..b4eb50e331cf9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -34,13 +34,13 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { val bytes = Array[Byte](1, 2, 3, 4) Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0") - sqlContext + spark .range(10) .select('id as 'key, concat(lit("val_"), 'id) as 'value) .write .saveAsTable("t1") - sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") } override protected def afterAll(): Unit = { @@ -102,7 +102,6 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT map(1, 'a', 2, 'b')") checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)") checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2") - checkSqlGeneration("SELECT nvl(null, 1, 2)") checkSqlGeneration("SELECT rand(1)") checkSqlGeneration("SELECT randn(3)") checkSqlGeneration("SELECT struct(1,2,3)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 34c2773581ff8..3794f63a6a361 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -17,32 +17,53 @@ package org.apache.spark.sql.catalyst +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, NoSuchFileException, Paths} + import scala.util.control.NonFatal import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +/** + * A test suite for LogicalPlan-to-SQL conversion. + * + * Each query has a golden generated SQL file in test/resources/sqlgen. The test suite also has + * built-in functionality to automatically generate these golden files. + * + * To re-generate golden files, run: + * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "hive/test-only *LogicalPlanToSQLSuite" + */ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { import testImplicits._ + // Used for generating new query answer files by saving + private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" + private val goldenSQLPath = "src/test/resources/sqlgen/" + protected override def beforeAll(): Unit = { super.beforeAll() - sql("DROP TABLE IF EXISTS parquet_t0") - sql("DROP TABLE IF EXISTS parquet_t1") - sql("DROP TABLE IF EXISTS parquet_t2") + (0 to 3).foreach { i => + sql(s"DROP TABLE IF EXISTS parquet_t$i") + } sql("DROP TABLE IF EXISTS t0") - sqlContext.range(10).write.saveAsTable("parquet_t0") + spark.range(10).write.saveAsTable("parquet_t0") sql("CREATE TABLE t0 AS SELECT * FROM parquet_t0") - sqlContext + spark .range(10) .select('id as 'key, concat(lit("val_"), 'id) as 'value) .write .saveAsTable("parquet_t1") - sqlContext + spark .range(10) .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) .write @@ -52,7 +73,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { when(id % 3 === 0, lit(null)).otherwise(array('id, 'id + 1)) } - sqlContext + spark .range(10) .select( createArray('id).as("arr"), @@ -66,32 +87,66 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { override protected def afterAll(): Unit = { try { - sql("DROP TABLE IF EXISTS parquet_t0") - sql("DROP TABLE IF EXISTS parquet_t1") - sql("DROP TABLE IF EXISTS parquet_t2") - sql("DROP TABLE IF EXISTS parquet_t3") + (0 to 3).foreach { i => + sql(s"DROP TABLE IF EXISTS parquet_t$i") + } sql("DROP TABLE IF EXISTS t0") } finally { super.afterAll() } } - private def checkHiveQl(hiveQl: String): Unit = { - val df = sql(hiveQl) + /** + * Compare the generated SQL with the expected answer string. + */ + private def checkSQLStructure(originalSQL: String, convertedSQL: String, answerFile: String) = { + if (answerFile != null) { + val separator = "-" * 80 + if (regenerateGoldenFiles) { + val path = Paths.get(s"$goldenSQLPath/$answerFile.sql") + val header = "-- This file is automatically generated by LogicalPlanToSQLSuite." + val answerText = s"$header\n${originalSQL.trim()}\n${separator}\n$convertedSQL\n" + Files.write(path, answerText.getBytes(StandardCharsets.UTF_8)) + } else { + val goldenFileName = s"sqlgen/$answerFile.sql" + val resourceFile = getClass.getClassLoader.getResource(goldenFileName) + if (resourceFile == null) { + throw new NoSuchFileException(goldenFileName) + } + val path = resourceFile.getPath + val answerText = new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8) + val sqls = answerText.split(separator) + assert(sqls.length == 2, "Golden sql files should have a separator.") + val expectedSQL = sqls(1).trim() + assert(convertedSQL == expectedSQL) + } + } + } + + /** + * 1. Checks if SQL parsing succeeds. + * 2. Checks if SQL generation succeeds. + * 3. Checks the generated SQL against golden files. + * 4. Verifies the execution result stays the same. + */ + private def checkSQL(sqlString: String, answerFile: String = null): Unit = { + val df = sql(sqlString) val convertedSQL = try new SQLBuilder(df).toSQL catch { case NonFatal(e) => fail( - s"""Cannot convert the following HiveQL query plan back to SQL query string: + s"""Cannot convert the following SQL query plan back to SQL query string: | - |# Original HiveQL query string: - |$hiveQl + |# Original SQL query string: + |$sqlString | |# Resolved query plan: |${df.queryExecution.analyzed.treeString} """.stripMargin, e) } + checkSQLStructure(sqlString, convertedSQL, answerFile) + try { checkAnswer(sql(convertedSQL), df) } catch { case cause: Throwable => @@ -101,8 +156,8 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { |# Converted SQL query string: |$convertedSQL | - |# Original HiveQL query string: - |$hiveQl + |# Original SQL query string: + |$sqlString | |# Resolved query plan: |${df.queryExecution.analyzed.treeString} @@ -110,24 +165,66 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } } + // When saving golden files, these tests should be ignored to prevent making files. + if (!regenerateGoldenFiles) { + test("Test should fail if the SQL query cannot be parsed") { + val m = intercept[ParseException] { + checkSQL("SELE", "NOT_A_FILE") + }.getMessage + assert(m.contains("mismatched input")) + } + + test("Test should fail if the golden file cannot be found") { + val m2 = intercept[NoSuchFileException] { + checkSQL("SELECT 1", "NOT_A_FILE") + }.getMessage + assert(m2.contains("NOT_A_FILE")) + } + + test("Test should fail if the SQL query cannot be regenerated") { + case class Unsupported() extends LeafNode with MultiInstanceRelation { + override def newInstance(): Unsupported = copy() + override def output: Seq[Attribute] = Nil + } + Unsupported().createOrReplaceTempView("not_sql_gen_supported_table_so_far") + sql("select * from not_sql_gen_supported_table_so_far") + val m3 = intercept[org.scalatest.exceptions.TestFailedException] { + checkSQL("select * from not_sql_gen_supported_table_so_far", "in") + }.getMessage + assert(m3.contains("Cannot convert the following SQL query plan back to SQL query string")) + } + + test("Test should fail if the SQL query did not equal to the golden SQL") { + val m4 = intercept[org.scalatest.exceptions.TestFailedException] { + checkSQL("SELECT 1", "in") + }.getMessage + assert(m4.contains("did not equal")) + } + } + + test("range") { + checkSQL("select * from range(100)", "range") + checkSQL("select * from range(1, 100, 20, 10)", "range_with_splits") + } + test("in") { - checkHiveQl("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)") + checkSQL("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)", "in") } test("not in") { - checkHiveQl("SELECT id FROM t0 WHERE id NOT IN (1, 2, 3)") + checkSQL("SELECT id FROM t0 WHERE id NOT IN (1, 2, 3)", "not_in") } test("not like") { - checkHiveQl("SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%'") + checkSQL("SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%'", "not_like") } test("aggregate function in having clause") { - checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0") + checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0", "agg1") } test("aggregate function in order by clause") { - checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY MAX(key)") + checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY MAX(key)", "agg2") } // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into @@ -135,61 +232,67 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // execution since these aliases have different expression ID. But this introduces name collision // when converting resolved plans back to SQL query strings as expression IDs are stripped. test("aggregate function in order by clause with multiple order keys") { - checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key, MAX(key)") + checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key, MAX(key)", "agg3") } test("type widening in union") { - checkHiveQl("SELECT id FROM parquet_t0 UNION ALL SELECT CAST(id AS INT) AS id FROM parquet_t0") + checkSQL("SELECT id FROM parquet_t0 UNION ALL SELECT CAST(id AS INT) AS id FROM parquet_t0", + "type_widening") } test("union distinct") { - checkHiveQl("SELECT * FROM t0 UNION SELECT * FROM t0") + checkSQL("SELECT * FROM t0 UNION SELECT * FROM t0", "union_distinct") } test("three-child union") { - checkHiveQl( + checkSQL( """ |SELECT id FROM parquet_t0 |UNION ALL SELECT id FROM parquet_t0 |UNION ALL SELECT id FROM parquet_t0 - """.stripMargin) + """.stripMargin, + "three_child_union") } test("intersect") { - checkHiveQl("SELECT * FROM t0 INTERSECT SELECT * FROM t0") + checkSQL("SELECT * FROM t0 INTERSECT SELECT * FROM t0", "intersect") } test("except") { - checkHiveQl("SELECT * FROM t0 EXCEPT SELECT * FROM t0") + checkSQL("SELECT * FROM t0 EXCEPT SELECT * FROM t0", "except") } test("self join") { - checkHiveQl("SELECT x.key FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key") + checkSQL("SELECT x.key FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key", "self_join") } test("self join with group by") { - checkHiveQl( - "SELECT x.key, COUNT(*) FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key group by x.key") + checkSQL( + "SELECT x.key, COUNT(*) FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key group by x.key", + "self_join_with_group_by") } test("case") { - checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM parquet_t0") + checkSQL("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM parquet_t0", + "case") } test("case with else") { - checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM parquet_t0") + checkSQL("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM parquet_t0", "case_with_else") } test("case with key") { - checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM parquet_t0") + checkSQL("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM parquet_t0", + "case_with_key") } test("case with key and else") { - checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM parquet_t0") + checkSQL("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM parquet_t0", + "case_with_key_and_else") } test("select distinct without aggregate functions") { - checkHiveQl("SELECT DISTINCT id FROM parquet_t0") + checkSQL("SELECT DISTINCT id FROM parquet_t0", "select_distinct") } test("rollup/cube #1") { @@ -213,146 +316,195 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // FROM `default`.`t1` // GROUP BY (`t1`.`key` % CAST(5 AS BIGINT)) // GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ()) - checkHiveQl( - "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP") - checkHiveQl( - "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE") + checkSQL( + "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP", + "rollup_cube_1_1") + + checkSQL( + "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE", + "rollup_cube_1_2") } test("rollup/cube #2") { - checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP") - checkHiveQl("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE") + checkSQL("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP", + "rollup_cube_2_1") + + checkSQL("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE", + "rollup_cube_2_2") } test("rollup/cube #3") { - checkHiveQl( - "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP") - checkHiveQl( - "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE") + checkSQL( + "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP", + "rollup_cube_3_1") + + checkSQL( + "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE", + "rollup_cube_3_2") } test("rollup/cube #4") { - checkHiveQl( + checkSQL( s""" |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 |GROUP BY key % 5, key - 5 WITH ROLLUP - """.stripMargin) - checkHiveQl( + """.stripMargin, + "rollup_cube_4_1") + + checkSQL( s""" |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1 |GROUP BY key % 5, key - 5 WITH CUBE - """.stripMargin) + """.stripMargin, + "rollup_cube_4_2") } test("rollup/cube #5") { - checkHiveQl( + checkSQL( s""" |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 |FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5 |WITH ROLLUP - """.stripMargin) - checkHiveQl( + """.stripMargin, + "rollup_cube_5_1") + + checkSQL( s""" |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3 |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 |WITH CUBE - """.stripMargin) + """.stripMargin, + "rollup_cube_5_2") } test("rollup/cube #6") { - checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") - checkHiveQl("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b") - checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b") - checkHiveQl("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b") - checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP") - checkHiveQl("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE") + checkSQL("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b", + "rollup_cube_6_1") + + checkSQL("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b", + "rollup_cube_6_2") + + checkSQL("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b", + "rollup_cube_6_3") + + checkSQL("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b", + "rollup_cube_6_4") + + checkSQL("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP", + "rollup_cube_6_5") + + checkSQL("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE", + "rollup_cube_6_6") } test("rollup/cube #7") { - checkHiveQl("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)") - checkHiveQl("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)") - checkHiveQl("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)") + checkSQL("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)", + "rollup_cube_7_1") + + checkSQL("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)", + "rollup_cube_7_2") + + checkSQL("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)", + "rollup_cube_7_3") } test("rollup/cube #8") { // grouping_id() is part of another expression - checkHiveQl( + checkSQL( s""" |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 |WITH ROLLUP - """.stripMargin) - checkHiveQl( + """.stripMargin, + "rollup_cube_8_1") + + checkSQL( s""" |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5 |WITH CUBE - """.stripMargin) + """.stripMargin, + "rollup_cube_8_2") } test("rollup/cube #9") { // self join is used as the child node of ROLLUP/CUBE with replaced quantifiers - checkHiveQl( + checkSQL( s""" |SELECT t.key - 5, cnt, SUM(cnt) |FROM (SELECT x.key, COUNT(*) as cnt |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t |GROUP BY cnt, t.key - 5 |WITH ROLLUP - """.stripMargin) - checkHiveQl( + """.stripMargin, + "rollup_cube_9_1") + + checkSQL( s""" |SELECT t.key - 5, cnt, SUM(cnt) |FROM (SELECT x.key, COUNT(*) as cnt |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t |GROUP BY cnt, t.key - 5 |WITH CUBE - """.stripMargin) + """.stripMargin, + "rollup_cube_9_2") } test("grouping sets #1") { - checkHiveQl( + checkSQL( s""" |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3 |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5 |GROUPING SETS (key % 5, key - 5) - """.stripMargin) + """.stripMargin, + "grouping_sets_1") } test("grouping sets #2") { - checkHiveQl( - "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b") - checkHiveQl( - "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b") - checkHiveQl( - "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b") - checkHiveQl( - "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b") - checkHiveQl( + checkSQL( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b", + "grouping_sets_2_1") + + checkSQL( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b", + "grouping_sets_2_2") + + checkSQL( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b", + "grouping_sets_2_3") + + checkSQL( + "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b", + "grouping_sets_2_4") + + checkSQL( s""" |SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b |GROUPING SETS ((), (a), (a, b)) ORDER BY a, b - """.stripMargin) + """.stripMargin, + "grouping_sets_2_5") } test("cluster by") { - checkHiveQl("SELECT id FROM parquet_t0 CLUSTER BY id") + checkSQL("SELECT id FROM parquet_t0 CLUSTER BY id", "cluster_by") } test("distribute by") { - checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id") + checkSQL("SELECT id FROM parquet_t0 DISTRIBUTE BY id", "distribute_by") } test("distribute by with sort by") { - checkHiveQl("SELECT id FROM parquet_t0 DISTRIBUTE BY id SORT BY id") + checkSQL("SELECT id FROM parquet_t0 DISTRIBUTE BY id SORT BY id", + "distribute_by_with_sort_by") } test("SPARK-13720: sort by after having") { - checkHiveQl("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0 SORT BY key") + checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0 SORT BY key", + "sort_by_after_having") } test("distinct aggregation") { - checkHiveQl("SELECT COUNT(DISTINCT id) FROM parquet_t0") + checkSQL("SELECT COUNT(DISTINCT id) FROM parquet_t0", "distinct_aggregation") } test("TABLESAMPLE") { @@ -361,82 +513,89 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { // +- Subquery s // +- Subquery parquet_t0 // +- Relation[id#2L] ParquetRelation - checkHiveQl("SELECT s.id FROM parquet_t0 TABLESAMPLE(100 PERCENT) s") + checkSQL("SELECT s.id FROM parquet_t0 TABLESAMPLE(100 PERCENT) s", "tablesample_1") // Project [id#2L] // +- Sample 0.0, 1.0, false, ... // +- Subquery parquet_t0 // +- Relation[id#2L] ParquetRelation - checkHiveQl("SELECT * FROM parquet_t0 TABLESAMPLE(100 PERCENT)") + checkSQL("SELECT * FROM parquet_t0 TABLESAMPLE(100 PERCENT)", "tablesample_2") // Project [id#21L] // +- Sample 0.0, 1.0, false, ... // +- MetastoreRelation default, t0, Some(s) - checkHiveQl("SELECT s.id FROM t0 TABLESAMPLE(100 PERCENT) s") + checkSQL("SELECT s.id FROM t0 TABLESAMPLE(100 PERCENT) s", "tablesample_3") // Project [id#24L] // +- Sample 0.0, 1.0, false, ... // +- MetastoreRelation default, t0, None - checkHiveQl("SELECT * FROM t0 TABLESAMPLE(100 PERCENT)") + checkSQL("SELECT * FROM t0 TABLESAMPLE(100 PERCENT)", "tablesample_4") // When a sampling fraction is not 100%, the returned results are random. // Thus, added an always-false filter here to check if the generated plan can be successfully // executed. - checkHiveQl("SELECT s.id FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) s WHERE 1=0") - checkHiveQl("SELECT * FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) WHERE 1=0") + checkSQL("SELECT s.id FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) s WHERE 1=0", "tablesample_5") + checkSQL("SELECT * FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) WHERE 1=0", "tablesample_6") } test("multi-distinct columns") { - checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM parquet_t2 GROUP BY a") + checkSQL("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM parquet_t2 GROUP BY a", + "multi_distinct") } test("persisted data source relations") { Seq("orc", "json", "parquet").foreach { format => val tableName = s"${format}_parquet_t0" withTable(tableName) { - sqlContext.range(10).write.format(format).saveAsTable(tableName) - checkHiveQl(s"SELECT id FROM $tableName") + spark.range(10).write.format(format).saveAsTable(tableName) + checkSQL(s"SELECT id FROM $tableName", s"data_source_$tableName") } } } test("script transformation - schemaless") { - checkHiveQl("SELECT TRANSFORM (a, b, c, d) USING 'cat' FROM parquet_t2") - checkHiveQl("SELECT TRANSFORM (*) USING 'cat' FROM parquet_t2") + checkSQL("SELECT TRANSFORM (a, b, c, d) USING 'cat' FROM parquet_t2", + "script_transformation_1") + checkSQL("SELECT TRANSFORM (*) USING 'cat' FROM parquet_t2", + "script_transformation_2") } test("script transformation - alias list") { - checkHiveQl("SELECT TRANSFORM (a, b, c, d) USING 'cat' AS (d1, d2, d3, d4) FROM parquet_t2") + checkSQL("SELECT TRANSFORM (a, b, c, d) USING 'cat' AS (d1, d2, d3, d4) FROM parquet_t2", + "script_transformation_alias_list") } test("script transformation - alias list with type") { - checkHiveQl( + checkSQL( """FROM |(FROM parquet_t1 SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t |SELECT thing1 + 1 - """.stripMargin) + """.stripMargin, + "script_transformation_alias_list_with_type") } test("script transformation - row format delimited clause with only one format property") { - checkHiveQl( + checkSQL( """SELECT TRANSFORM (key) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' |USING 'cat' AS (tKey) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' |FROM parquet_t1 - """.stripMargin) + """.stripMargin, + "script_transformation_row_format_one") } test("script transformation - row format delimited clause with multiple format properties") { - checkHiveQl( + checkSQL( """SELECT TRANSFORM (key) |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t' |USING 'cat' AS (tKey) |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t' |FROM parquet_t1 - """.stripMargin) + """.stripMargin, + "script_transformation_row_format_multiple") } test("script transformation - row format serde clauses with SERDEPROPERTIES") { - checkHiveQl( + checkSQL( """SELECT TRANSFORM (key, value) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES('field.delim' = '|') @@ -444,27 +603,29 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |WITH SERDEPROPERTIES('field.delim' = '|') |FROM parquet_t1 - """.stripMargin) + """.stripMargin, + "script_transformation_row_format_serde") } test("script transformation - row format serde clauses without SERDEPROPERTIES") { - checkHiveQl( + checkSQL( """SELECT TRANSFORM (key, value) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |USING 'cat' AS (tKey, tValue) |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' |FROM parquet_t1 - """.stripMargin) + """.stripMargin, + "script_transformation_row_format_without_serde") } test("plans with non-SQL expressions") { - sqlContext.udf.register("foo", (_: Int) * 2) + spark.udf.register("foo", (_: Int) * 2) intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL) } test("named expression in column names shouldn't be quoted") { def checkColumnNames(query: String, expectedColNames: String*): Unit = { - checkHiveQl(query) + checkSQL(query) assert(sql(query).columns === expectedColNames) } @@ -521,21 +682,25 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("window basic") { - checkHiveQl("SELECT MAX(value) OVER (PARTITION BY key % 3) FROM parquet_t1") - checkHiveQl( + checkSQL("SELECT MAX(value) OVER (PARTITION BY key % 3) FROM parquet_t1", "window_basic_1") + + checkSQL( """ |SELECT key, value, ROUND(AVG(key) OVER (), 2) |FROM parquet_t1 ORDER BY key - """.stripMargin) - checkHiveQl( + """.stripMargin, + "window_basic_2") + + checkSQL( """ |SELECT value, MAX(key + 1) OVER (PARTITION BY key % 5 ORDER BY key % 7) AS max |FROM parquet_t1 - """.stripMargin) + """.stripMargin, + "window_basic_3") } test("multiple window functions in one expression") { - checkHiveQl( + checkSQL( """ |SELECT | MAX(key) OVER (ORDER BY key DESC, value) / MIN(key) OVER (PARTITION BY key % 3) @@ -544,15 +709,17 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("regular expressions and window functions in one expression") { - checkHiveQl("SELECT MAX(key) OVER (PARTITION BY key % 3) + key FROM parquet_t1") + checkSQL("SELECT MAX(key) OVER (PARTITION BY key % 3) + key FROM parquet_t1", + "regular_expressions_and_window") } test("aggregate functions and window functions in one expression") { - checkHiveQl("SELECT MAX(c) + COUNT(a) OVER () FROM parquet_t2 GROUP BY a, b") + checkSQL("SELECT MAX(c) + COUNT(a) OVER () FROM parquet_t2 GROUP BY a, b", + "aggregate_functions_and_window") } test("window with different window specification") { - checkHiveQl( + checkSQL( """ |SELECT key, value, |DENSE_RANK() OVER (ORDER BY key, value) AS dr, @@ -562,45 +729,49 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("window with the same window specification with aggregate + having") { - checkHiveQl( + checkSQL( """ |SELECT key, value, |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key DESC) AS max |FROM parquet_t1 GROUP BY key, value HAVING key > 5 - """.stripMargin) + """.stripMargin, + "window_with_the_same_window_with_agg_having") } test("window with the same window specification with aggregate functions") { - checkHiveQl( + checkSQL( """ |SELECT key, value, |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max |FROM parquet_t1 GROUP BY key, value - """.stripMargin) + """.stripMargin, + "window_with_the_same_window_with_agg_functions") } test("window with the same window specification with aggregate") { - checkHiveQl( + checkSQL( """ |SELECT key, value, |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, |COUNT(key) |FROM parquet_t1 GROUP BY key, value - """.stripMargin) + """.stripMargin, + "window_with_the_same_window_with_agg") } test("window with the same window specification without aggregate and filter") { - checkHiveQl( + checkSQL( """ |SELECT key, value, |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr, |COUNT(key) OVER(DISTRIBUTE BY key SORT BY key, value) AS ca |FROM parquet_t1 - """.stripMargin) + """.stripMargin, + "window_with_the_same_window_with_agg_filter") } test("window clause") { - checkHiveQl( + checkSQL( """ |SELECT key, MAX(value) OVER w1 AS MAX, MIN(value) OVER w2 AS min |FROM parquet_t1 @@ -609,7 +780,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("special window functions") { - checkHiveQl( + checkSQL( """ |SELECT | RANK() OVER w, @@ -626,107 +797,120 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("window with join") { - checkHiveQl( + checkSQL( """ |SELECT x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key) |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key - """.stripMargin) + """.stripMargin, + "window_with_join") } test("join 2 tables and aggregate function in having clause") { - checkHiveQl( + checkSQL( """ |SELECT COUNT(a.value), b.KEY, a.KEY |FROM parquet_t1 a, parquet_t1 b |GROUP BY a.KEY, b.KEY |HAVING MAX(a.KEY) > 0 - """.stripMargin) + """.stripMargin, + "join_2_tables") } test("generator in project list without FROM clause") { - checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3))") - checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3)) AS val") + checkSQL("SELECT EXPLODE(ARRAY(1,2,3))", "generator_without_from_1") + checkSQL("SELECT EXPLODE(ARRAY(1,2,3)) AS val", "generator_without_from_2") } test("generator in project list with non-referenced table") { - checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3)) FROM t0") - checkHiveQl("SELECT EXPLODE(ARRAY(1,2,3)) AS val FROM t0") + checkSQL("SELECT EXPLODE(ARRAY(1,2,3)) FROM t0", "generator_non_referenced_table_1") + checkSQL("SELECT EXPLODE(ARRAY(1,2,3)) AS val FROM t0", "generator_non_referenced_table_2") } test("generator in project list with referenced table") { - checkHiveQl("SELECT EXPLODE(arr) FROM parquet_t3") - checkHiveQl("SELECT EXPLODE(arr) AS val FROM parquet_t3") + checkSQL("SELECT EXPLODE(arr) FROM parquet_t3", "generator_referenced_table_1") + checkSQL("SELECT EXPLODE(arr) AS val FROM parquet_t3", "generator_referenced_table_2") } test("generator in project list with non-UDTF expressions") { - checkHiveQl("SELECT EXPLODE(arr), id FROM parquet_t3") - checkHiveQl("SELECT EXPLODE(arr) AS val, id as a FROM parquet_t3") + checkSQL("SELECT EXPLODE(arr), id FROM parquet_t3", "generator_non_udtf_1") + checkSQL("SELECT EXPLODE(arr) AS val, id as a FROM parquet_t3", "generator_non_udtf_2") } test("generator in lateral view") { - checkHiveQl("SELECT val, id FROM parquet_t3 LATERAL VIEW EXPLODE(arr) exp AS val") - checkHiveQl("SELECT val, id FROM parquet_t3 LATERAL VIEW OUTER EXPLODE(arr) exp AS val") + checkSQL("SELECT val, id FROM parquet_t3 LATERAL VIEW EXPLODE(arr) exp AS val", + "generator_in_lateral_view_1") + checkSQL("SELECT val, id FROM parquet_t3 LATERAL VIEW OUTER EXPLODE(arr) exp AS val", + "generator_in_lateral_view_2") } test("generator in lateral view with ambiguous names") { - checkHiveQl( + checkSQL( """ |SELECT exp.id, parquet_t3.id |FROM parquet_t3 |LATERAL VIEW EXPLODE(arr) exp AS id - """.stripMargin) - checkHiveQl( + """.stripMargin, + "generator_with_ambiguous_names_1") + + checkSQL( """ |SELECT exp.id, parquet_t3.id |FROM parquet_t3 |LATERAL VIEW OUTER EXPLODE(arr) exp AS id - """.stripMargin) + """.stripMargin, + "generator_with_ambiguous_names_2") } test("use JSON_TUPLE as generator") { - checkHiveQl( + checkSQL( """ |SELECT c0, c1, c2 |FROM parquet_t3 |LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt - """.stripMargin) - checkHiveQl( + """.stripMargin, + "json_tuple_generator_1") + + checkSQL( """ |SELECT a, b, c |FROM parquet_t3 |LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt AS a, b, c - """.stripMargin) + """.stripMargin, + "json_tuple_generator_2") } test("nested generator in lateral view") { - checkHiveQl( + checkSQL( """ |SELECT val, id |FROM parquet_t3 |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array |LATERAL VIEW EXPLODE(nested_array) exp1 AS val - """.stripMargin) + """.stripMargin, + "nested_generator_in_lateral_view_1") - checkHiveQl( + checkSQL( """ |SELECT val, id |FROM parquet_t3 |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array |LATERAL VIEW OUTER EXPLODE(nested_array) exp1 AS val - """.stripMargin) + """.stripMargin, + "nested_generator_in_lateral_view_2") } test("generate with other operators") { - checkHiveQl( + checkSQL( """ |SELECT EXPLODE(arr) AS val, id |FROM parquet_t3 |WHERE id > 2 |ORDER BY val, id |LIMIT 5 - """.stripMargin) + """.stripMargin, + "generate_with_other_1") - checkHiveQl( + checkSQL( """ |SELECT val, id |FROM parquet_t3 @@ -735,10 +919,222 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { |WHERE val > 2 |ORDER BY val, id |LIMIT 5 - """.stripMargin) + """.stripMargin, + "generate_with_other_2") } test("filter after subquery") { - checkHiveQl("SELECT a FROM (SELECT key + 1 AS a FROM parquet_t1) t WHERE a > 5") + checkSQL("SELECT a FROM (SELECT key + 1 AS a FROM parquet_t1) t WHERE a > 5", + "filter_after_subquery") + } + + test("SPARK-14933 - select parquet table") { + withTable("parquet_t") { + sql("create table parquet_t stored as parquet as select 1 as c1, 'abc' as c2") + checkSQL("select * from parquet_t", "select_parquet_table") + } + } + + test("predicate subquery") { + withTable("t1") { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + sql("CREATE TABLE t1(a int)") + checkSQL("select * from t1 b where exists (select * from t1 a)", "predicate_subquery") + } + } + } + + test("broadcast join") { + checkSQL( + """ + |SELECT /*+ MAPJOIN(srcpart) */ subq.key1, z.value + |FROM (SELECT x.key as key1, x.value as value1, y.key as key2, y.value as value2 + | FROM src1 x JOIN src y ON (x.key = y.key)) subq + |JOIN srcpart z ON (subq.key1 = z.key and z.ds='2008-04-08' and z.hr=11) + |ORDER BY subq.key1, z.value + """.stripMargin, + "broadcast_join_subquery") + } + + test("subquery using single table") { + checkSQL( + """ + |SELECT a.k, a.c + |FROM (SELECT b.key as k, count(1) as c + | FROM src b + | GROUP BY b.key) a + |WHERE a.k >= 90 + """.stripMargin, + "subq2") + } + + test("correlated subqueries using EXISTS on where clause") { + checkSQL( + """ + |select * + |from src b + |where exists (select a.key + | from src a + | where b.value = a.value and a.key = b.key and a.value > 'val_9') + """.stripMargin, + "subquery_exists_1") + + checkSQL( + """ + |select * + |from (select * + | from src b + | where exists (select a.key + | from src a + | where b.value = a.value and a.key = b.key and a.value > 'val_9')) a + """.stripMargin, + "subquery_exists_2") + } + + test("correlated subqueries using EXISTS on having clause") { + checkSQL( + """ + |select b.key, count(*) + |from src b + |group by b.key + |having exists (select a.key + | from src a + | where a.key = b.key and a.value > 'val_9') + """.stripMargin, + "subquery_exists_having_1") + + checkSQL( + """ + |select * + |from (select b.key, count(*) + | from src b + | group by b.key + | having exists (select a.key + | from src a + | where a.key = b.key and a.value > 'val_9')) a + """.stripMargin, + "subquery_exists_having_2") + + checkSQL( + """ + |select b.key, min(b.value) + |from src b + |group by b.key + |having exists (select a.key + | from src a + | where a.value > 'val_9' and a.value = min(b.value)) + """.stripMargin, + "subquery_exists_having_3") + } + + test("correlated subqueries using NOT EXISTS on where clause") { + checkSQL( + """ + |select * + |from src b + |where not exists (select a.key + | from src a + | where b.value = a.value and a.key = b.key and a.value > 'val_2') + """.stripMargin, + "subquery_not_exists_1") + + checkSQL( + """ + |select * + |from src b + |where not exists (select a.key + | from src a + | where b.value = a.value and a.value > 'val_2') + """.stripMargin, + "subquery_not_exists_2") + } + + test("correlated subqueries using NOT EXISTS on having clause") { + checkSQL( + """ + |select * + |from src b + |group by key, value + |having not exists (select a.key + | from src a + | where b.value = a.value and a.key = b.key and a.value > 'val_12') + """.stripMargin, + "subquery_not_exists_having_1") + + checkSQL( + """ + |select * + |from src b + |group by key, value + |having not exists (select distinct a.key + | from src a + | where b.value = a.value and a.value > 'val_12') + """.stripMargin, + "subquery_not_exists_having_2") + } + + test("subquery using IN on where clause") { + checkSQL( + """ + |SELECT key + |FROM src + |WHERE key in (SELECT max(key) FROM src) + """.stripMargin, + "subquery_in") + } + + test("subquery using IN on having clause") { + checkSQL( + """ + |select key, count(*) + |from src + |group by key + |having count(*) in (select count(*) from src s1 where s1.key = '90' group by s1.key) + |order by key + """.stripMargin, + "subquery_in_having_1") + + checkSQL( + """ + |select b.key, min(b.value) + |from src b + |group by b.key + |having b.key in (select a.key + | from src a + | where a.value > 'val_9' and a.value = min(b.value)) + |order by b.key + """.stripMargin, + "subquery_in_having_2") + } + + test("SPARK-14933 - select orc table") { + withTable("orc_t") { + sql("create table orc_t stored as orc as select 1 as c1, 'abc' as c2") + checkSQL("select * from orc_t", "select_orc_table") + } + } + + test("inline tables") { + checkSQL( + """ + |select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) where b > 1 + """.stripMargin, + "inline_tables") + } + + test("SPARK-17750 - interval arithmetic") { + withTable("dates") { + sql("create table dates (ts timestamp)") + checkSQL( + """ + |select ts + interval 1 day, ts + interval 2 days, + | ts - interval 1 day, ts - interval 2 days, + | ts + interval '1' day, ts + interval '2' days, + | ts - interval '1' day, ts - interval '2' days + |from dates + """.stripMargin, + "interval_arithmetic" + ) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala index 27c9e992de022..31755f56ece63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala @@ -64,7 +64,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext.sparkSession, plan)) + checkAnswer(spark.sql(generatedSQL), Dataset.ofRows(spark, plan)) } protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 4ca5619603220..7d4ef6f26a600 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -20,12 +20,15 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils -class CachedTableSuite extends QueryTest with TestHiveSingleton { +class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext._ def rddIdOf(tableName: String): Int = { @@ -95,46 +98,64 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { sql("DROP TABLE IF EXISTS nonexistantTable") } - test("correct error on uncache of non-cached table") { - intercept[IllegalArgumentException] { - hiveContext.uncacheTable("src") + test("correct error on uncache of nonexistant tables") { + intercept[NoSuchTableException] { + spark.catalog.uncacheTable("nonexistantTable") + } + intercept[NoSuchTableException] { + sql("UNCACHE TABLE nonexistantTable") + } + } + + test("no error on uncache of non-cached table") { + val tableName = "newTable" + withTable(tableName) { + sql(s"CREATE TABLE $tableName(a INT)") + // no error will be reported in the following three ways to uncache a table. + spark.catalog.uncacheTable(tableName) + sql("UNCACHE TABLE newTable") + sparkSession.table(tableName).unpersist() } } test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") { sql("CACHE TABLE src") assertCached(table("src")) - assert(hiveContext.isCached("src"), "Table 'src' should be cached") + assert(spark.catalog.isCached("src"), "Table 'src' should be cached") sql("UNCACHE TABLE src") assertCached(table("src"), 0) - assert(!hiveContext.isCached("src"), "Table 'src' should not be cached") + assert(!spark.catalog.isCached("src"), "Table 'src' should not be cached") } test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { - sql("CACHE TABLE testCacheTable AS SELECT * FROM src") - assertCached(table("testCacheTable")) + withTempView("testCacheTable") { + sql("CACHE TABLE testCacheTable AS SELECT * FROM src") + assertCached(table("testCacheTable")) - val rddId = rddIdOf("testCacheTable") - assert( - isMaterialized(rddId), - "Eagerly cached in-memory table should have already been materialized") + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE TABLE tableName AS SELECT ...") { - sql("CACHE TABLE testCacheTable AS SELECT key FROM src LIMIT 10") - assertCached(table("testCacheTable")) + withTempView("testCacheTable") { + sql("CACHE TABLE testCacheTable AS SELECT key FROM src LIMIT 10") + assertCached(table("testCacheTable")) - val rddId = rddIdOf("testCacheTable") - assert( - isMaterialized(rddId), - "Eagerly cached in-memory table should have already been materialized") + val rddId = rddIdOf("testCacheTable") + assert( + isMaterialized(rddId), + "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") - assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + uncacheTable("testCacheTable") + assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") + } } test("CACHE LAZY TABLE tableName") { @@ -156,9 +177,11 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { } test("CACHE TABLE with Hive UDF") { - sql("CACHE TABLE udfTest AS SELECT * FROM src WHERE floor(key) = 1") - assertCached(table("udfTest")) - uncacheTable("udfTest") + withTempView("udfTest") { + sql("CACHE TABLE udfTest AS SELECT * FROM src WHERE floor(key) = 1") + assertCached(table("udfTest")) + uncacheTable("udfTest") + } } test("REFRESH TABLE also needs to recache the data (data source tables)") { @@ -206,6 +229,85 @@ class CachedTableSuite extends QueryTest with TestHiveSingleton { Utils.deleteRecursively(tempPath) } + test("SPARK-15678: REFRESH PATH") { + val tempPath: File = Utils.createTempDir() + tempPath.delete() + table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) + sql("DROP TABLE IF EXISTS refreshTable") + sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet") + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Cache the table. + sql("CACHE TABLE refreshTable") + assertCached(table("refreshTable")) + // Append new data. + table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) + // We are still using the old data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Refresh the table. + sql(s"REFRESH ${tempPath.toString}") + // We are using the new data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").union(table("src")).collect()) + + // Drop the table and create it again. + sql("DROP TABLE refreshTable") + sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet") + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + // Refresh the table. REFRESH command should not make a uncached + // table cached. + sql(s"REFRESH ${tempPath.toString}") + checkAnswer( + table("refreshTable"), + table("src").union(table("src")).collect()) + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + + sql("DROP TABLE refreshTable") + Utils.deleteRecursively(tempPath) + } + + test("Cache/Uncache Qualified Tables") { + withTempDatabase { db => + withTempView("cachedTable") { + sql(s"CREATE TABLE $db.cachedTable STORED AS PARQUET AS SELECT 1") + sql(s"CACHE TABLE $db.cachedTable") + assertCached(spark.table(s"$db.cachedTable")) + + activateDatabase(db) { + assertCached(spark.table("cachedTable")) + sql("UNCACHE TABLE cachedTable") + assert(!spark.catalog.isCached("cachedTable"), "Table 'cachedTable' should not be cached") + sql(s"CACHE TABLE cachedTable") + assert(spark.catalog.isCached("cachedTable"), "Table 'cachedTable' should be cached") + } + + sql(s"UNCACHE TABLE $db.cachedTable") + assert(!spark.catalog.isCached(s"$db.cachedTable"), + "Table 'cachedTable' should not be cached") + } + } + } + + test("Cache Table As Select - having database name") { + withTempDatabase { db => + withTempView("cachedTable") { + val e = intercept[ParseException] { + sql(s"CACHE TABLE $db.cachedTable AS SELECT 1") + }.getMessage + assert(e.contains("It is not allowed to add database prefix ") && + e.contains("to the table name in CACHE TABLE AS SELECT")) + } + } + } + test("SPARK-11246 cache parquet table") { sql("CREATE TABLE cachedTable STORED AS PARQUET AS SELECT 1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 61910b8e6b51d..aa1973de7f678 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -26,21 +26,21 @@ import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.test.TestHiveSingleton class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterEach { - import hiveContext.implicits._ + import spark.implicits._ override protected def beforeEach(): Unit = { super.beforeEach() - if (sqlContext.tableNames().contains("src")) { - sqlContext.dropTempTable("src") + if (spark.catalog.listTables().collect().map(_.name).contains("src")) { + spark.catalog.dropTempView("src") } - Seq((1, "")).toDF("key", "value").registerTempTable("src") - Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") + Seq((1, "")).toDF("key", "value").createOrReplaceTempView("src") + Seq((1, 1, 1)).toDF("a", "a", "b").createOrReplaceTempView("dupAttributes") } override protected def afterEach(): Unit = { try { - sqlContext.dropTempTable("src") - sqlContext.dropTempTable("dupAttributes") + spark.catalog.dropTempView("src") + spark.catalog.dropTempView("dupAttributes") } finally { super.afterEach() } @@ -130,12 +130,12 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { - def ast = hiveContext.parseSql(query) + def ast = spark.sessionState.sqlParser.parsePlan(query) def parseTree = Try(quietly(ast.treeString)).getOrElse("") test(name) { val error = intercept[AnalysisException] { - quietly(hiveContext.sql(query)) + quietly(spark.sql(query)) } assert(!error.getMessage.contains("Seq(")) diff --git a/sql/hivecontext-compatibility/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala similarity index 92% rename from sql/hivecontext-compatibility/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 5df674d60e9c4..57363b7259c61 100644 --- a/sql/hivecontext-compatibility/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import org.scalatest.BeforeAndAfterEach -import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEach { @@ -29,7 +29,7 @@ class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEac override def beforeAll(): Unit = { super.beforeAll() - sc = new SparkContext("local[4]", "test") + sc = SparkContext.getOrCreate(new SparkConf().setMaster("local").setAppName("test")) HiveUtils.newTemporaryConfiguration(useInMemoryDerby = true).foreach { case (k, v) => sc.hadoopConfiguration.set(k, v) } @@ -47,7 +47,6 @@ class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEac override def afterAll(): Unit = { try { - sc.stop() sc = null hc = null } finally { @@ -66,7 +65,7 @@ class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEac val res = df3.collect() val expected = Seq((18, 18, 8)).toDF("a", "x", "b").collect() assert(res.toSeq == expected.toSeq) - df3.registerTempTable("mai_table") + df3.createOrReplaceTempView("mai_table") val df4 = hc.table("mai_table") val res2 = df4.collect() assert(res2.toSeq == expected.toSeq) @@ -82,7 +81,7 @@ class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEac val databases2 = hc.sql("SHOW DATABASES").collect().map(_.getString(0)) assert(databases2.toSet == Set("default", "mee_db")) val df = (1 to 10).map { i => ("bob" + i.toString, i) }.toDF("name", "age") - df.registerTempTable("mee_table") + df.createOrReplaceTempView("mee_table") hc.sql("CREATE TABLE moo_table (name string, age int)") hc.sql("INSERT INTO moo_table SELECT * FROM mee_table") assert( @@ -94,6 +93,7 @@ class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEac hc.sql("DROP TABLE mee_table") val tables2 = hc.sql("SHOW TABLES IN mee_db").collect().map(_.getString(0)) assert(tables2.isEmpty) + hc.sql("USE default") hc.sql("DROP DATABASE mee_db CASCADE") val databases3 = hc.sql("SHOW DATABASES").collect().map(_.getString(0)) assert(databases3.toSeq == Seq("default")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index c97c28c40c964..54009d4b4130a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.hive.serde.serdeConstants - import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -29,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} -import org.apache.spark.sql.execution.command.{CreateTable, CreateTableAsSelectLogicalPlan, CreateTableLike, CreateViewCommand, LoadData} +import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.hive.test.TestHive class HiveDDLCommandSuite extends PlanTest { @@ -37,9 +36,9 @@ class HiveDDLCommandSuite extends PlanTest { private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parser.parsePlan(sql).collect { - case CreateTable(desc, allowExisting) => (desc, allowExisting) - case CreateTableAsSelectLogicalPlan(desc, _, allowExisting) => (desc, allowExisting) - case CreateViewCommand(desc, _, allowExisting, _, _) => (desc, allowExisting) + case c: CreateTableCommand => (c.table, c.ifNotExists) + case c: CreateHiveTableAsSelectLogicalPlan => (c.tableDesc, c.allowExisting) + case c: CreateViewCommand => (c.tableDesc, c.allowExisting) }.head } @@ -53,15 +52,8 @@ class HiveDDLCommandSuite extends PlanTest { test("Test CTAS #1") { val s1 = """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |(viewTime INT, - |userid BIGINT, - |page_url STRING, - |referrer_url STRING, - |ip STRING COMMENT 'IP Address of the User', - |country STRING COMMENT 'country of origination') |COMMENT 'This is the staging page view table' - |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') - |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\054' STORED AS RCFILE + |STORED AS RCFILE |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') |AS SELECT * FROM src""".stripMargin @@ -72,24 +64,12 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL) assert(desc.storage.locationUri == Some("/user/external/page_view")) - assert(desc.schema == - CatalogColumn("viewtime", "int") :: - CatalogColumn("userid", "bigint") :: - CatalogColumn("page_url", "string") :: - CatalogColumn("referrer_url", "string") :: - CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: - CatalogColumn("dt", "string", comment = Some("date type")) :: - CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created assert(desc.comment == Some("This is the staging page view table")) // TODO will be SQLText assert(desc.viewText.isEmpty) assert(desc.viewOriginalText.isEmpty) - assert(desc.partitionColumns == - CatalogColumn("dt", "string", comment = Some("date type")) :: - CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) - assert(desc.storage.serdeProperties == - Map((serdeConstants.SERIALIZATION_FORMAT, "\u002C"), (serdeConstants.FIELD_DELIM, "\u002C"))) + assert(desc.partitionColumns == Seq.empty[CatalogColumn]) assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) assert(desc.storage.serde == @@ -100,14 +80,7 @@ class HiveDDLCommandSuite extends PlanTest { test("Test CTAS #2") { val s2 = """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |(viewTime INT, - |userid BIGINT, - |page_url STRING, - |referrer_url STRING, - |ip STRING COMMENT 'IP Address of the User', - |country STRING COMMENT 'country of origination') |COMMENT 'This is the staging page view table' - |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' | STORED AS | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' @@ -122,22 +95,12 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL) assert(desc.storage.locationUri == Some("/user/external/page_view")) - assert(desc.schema == - CatalogColumn("viewtime", "int") :: - CatalogColumn("userid", "bigint") :: - CatalogColumn("page_url", "string") :: - CatalogColumn("referrer_url", "string") :: - CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: - CatalogColumn("dt", "string", comment = Some("date type")) :: - CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created // TODO will be SQLText assert(desc.comment == Some("This is the staging page view table")) assert(desc.viewText.isEmpty) assert(desc.viewOriginalText.isEmpty) - assert(desc.partitionColumns == - CatalogColumn("dt", "string", comment = Some("date type")) :: - CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.partitionColumns == Seq.empty[CatalogColumn]) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) @@ -199,6 +162,16 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) } + test("CTAS statement with a PARTITIONED BY clause is not allowed") { + assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" + + " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp") + } + + test("CTAS statement with schema") { + assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src") + assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'") + } + test("unsupported operations") { intercept[ParseException] { parser.parsePlan( @@ -254,12 +227,13 @@ class HiveDDLCommandSuite extends PlanTest { } test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { - val plan = parser.parsePlan( + val analyzer = TestHive.sparkSession.sessionState.analyzer + val plan = analyzer.execute(parser.parsePlan( """ |SELECT * |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b - """.stripMargin) + """.stripMargin)) assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) } @@ -347,13 +321,14 @@ class HiveDDLCommandSuite extends PlanTest { test("create table - temporary") { val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" val e = intercept[ParseException] { parser.parsePlan(query) } - assert(e.message.contains("registerTempTable")) + assert(e.message.contains("CREATE TEMPORARY TABLE is not supported yet")) } test("create table - external") { - val query = "CREATE EXTERNAL TABLE tab1 (id int, name string)" + val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" val (desc, _) = extractTableDesc(query) assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some("/path/to/nowhere")) } test("create table - if not exists") { @@ -453,12 +428,6 @@ class HiveDDLCommandSuite extends PlanTest { assert(e2.getMessage.contains("Operation not allowed")) } - test("create table - location") { - val query = "CREATE TABLE my_table (id int, name string) LOCATION '/path/to/mars'" - val (desc, _) = extractTableDesc(query) - assert(desc.storage.locationUri == Some("/path/to/mars")) - } - test("create table - properties") { val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" val (desc, _) = extractTableDesc(query) @@ -552,14 +521,19 @@ class HiveDDLCommandSuite extends PlanTest { } } - test("MSCK repair table (not supported)") { - assertUnsupported("MSCK REPAIR TABLE tab1") + test("MSCK REPAIR table") { + val sql = "MSCK REPAIR TABLE tab1" + val parsed = parser.parsePlan(sql) + val expected = AlterTableRecoverPartitionsCommand( + TableIdentifier("tab1", None), + "MSCK REPAIR TABLE") + comparePlans(parsed, expected) } test("create table like") { val v1 = "CREATE TABLE table1 LIKE table2" val (target, source, exists) = parser.parsePlan(v1).collect { - case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + case CreateTableLikeCommand(t, s, allowExisting) => (t, s, allowExisting) }.head assert(exists == false) assert(target.database.isEmpty) @@ -569,7 +543,7 @@ class HiveDDLCommandSuite extends PlanTest { val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2" val (target2, source2, exists2) = parser.parsePlan(v2).collect { - case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + case CreateTableLikeCommand(t, s, allowExisting) => (t, s, allowExisting) }.head assert(exists2) assert(target2.database.isEmpty) @@ -578,10 +552,10 @@ class HiveDDLCommandSuite extends PlanTest { assert(source2.table == "table2") } - test("load data") { + test("load data") { val v1 = "LOAD DATA INPATH 'path' INTO TABLE table1" val (table, path, isLocal, isOverwrite, partition) = parser.parsePlan(v1).collect { - case LoadData(t, path, l, o, partition) => (t, path, l, o, partition) + case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) }.head assert(table.database.isEmpty) assert(table.table == "table1") @@ -592,7 +566,7 @@ class HiveDDLCommandSuite extends PlanTest { val v2 = "LOAD DATA LOCAL INPATH 'path' OVERWRITE INTO TABLE table1 PARTITION(c='1', d='2')" val (table2, path2, isLocal2, isOverwrite2, partition2) = parser.parsePlan(v2).collect { - case LoadData(t, path, l, o, partition) => (t, path, l, o, partition) + case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition) }.head assert(table2.database.isEmpty) assert(table2.table == "table1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 57f96e725a044..6477974fe713a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -27,20 +27,20 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't // support the `cube` or `rollup` yet. class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { - import hiveContext.implicits._ - import hiveContext.sql + import spark.implicits._ + import spark.sql private var testData: DataFrame = _ override def beforeAll() { super.beforeAll() testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") - hiveContext.registerDataFrameAsTable(testData, "mytable") + testData.createOrReplaceTempView("mytable") } override def afterAll(): Unit = { try { - hiveContext.dropTempTable("mytable") + spark.catalog.dropTempView("mytable") } finally { super.afterAll() } @@ -58,17 +58,6 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with ) } - test("collect functions") { - checkAnswer( - testData.select(collect_list($"a"), collect_list($"b")), - Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) - ) - checkAnswer( - testData.select(collect_set($"a"), collect_set($"b")), - Seq(Row(Seq(1, 2, 3), Seq(2, 4))) - ) - } - test("cube") { checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index 63cf5030ab8b6..cdc259d75b139 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { - import hiveContext.implicits._ + import spark.implicits._ // We should move this into SQL package if we make case sensitivity configurable in SQL. test("join - self join auto resolve ambiguity with case insensitivity") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala index 7fdc5d71937ff..23798431e697f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameSuite.scala @@ -23,10 +23,15 @@ import org.apache.spark.sql.QueryTest class HiveDataFrameSuite extends QueryTest with TestHiveSingleton { test("table name with schema") { // regression test for SPARK-11778 - hiveContext.sql("create schema usrdb") - hiveContext.sql("create table usrdb.test(c int)") - hiveContext.read.table("usrdb.test") - hiveContext.sql("drop table usrdb.test") - hiveContext.sql("drop schema usrdb") + spark.sql("create schema usrdb") + spark.sql("create table usrdb.test(c int)") + spark.read.table("usrdb.test") + spark.sql("drop table usrdb.test") + spark.sql("drop schema usrdb") + } + + test("SPARK-15887: hive-site.xml should be loaded") { + val hiveClient = spark.sharedState.asInstanceOf[HiveSharedState].metadataHive + assert(hiveClient.getConf("hive.in.test", "") == "true") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index bf9935ae41b30..175889b08b49f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -37,7 +37,8 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { protected override val utils: CatalogTestUtils = new CatalogTestUtils { override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat" override val tableOutputFormat: String = "org.apache.hadoop.mapred.SequenceFileOutputFormat" - override def newEmptyCatalog(): ExternalCatalog = new HiveExternalCatalog(client) + override def newEmptyCatalog(): ExternalCatalog = + new HiveExternalCatalog(client, new Configuration()) } protected override def resetState(): Unit = client.reset() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 3b867bbfa1817..5fda367bd600d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -75,6 +75,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val data = Literal(true) :: + Literal(null) :: Literal(0.asInstanceOf[Byte]) :: Literal(0.asInstanceOf[Short]) :: Literal(0) :: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala new file mode 100644 index 0000000000000..3414f5e0409a1 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +/** + * Test suite to handle metadata cache related. + */ +class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + test("SPARK-16337 temporary view refresh") { + withTempView("view_refresh") { + withTable("view_table") { + // Create a Parquet directory + spark.range(start = 0, end = 100, step = 1, numPartitions = 3) + .write.saveAsTable("view_table") + + // Read the table in + spark.table("view_table").filter("id > -1").createOrReplaceTempView("view_refresh") + assert(sql("select count(*) from view_refresh").first().getLong(0) == 100) + + // Delete a file using the Hadoop file system interface since the path returned by + // inputFiles is not recognizable by Java IO. + val p = new Path(spark.table("view_table").inputFiles.head) + assert(p.getFileSystem(hiveContext.sessionState.newHadoopConf()).delete(p, false)) + + // Read it again and now we should see a FileNotFoundException + val e = intercept[SparkException] { + sql("select count(*) from view_refresh").first() + } + assert(e.getMessage.contains("FileNotFoundException")) + assert(e.getMessage.contains("REFRESH")) + + // Refresh and we should be able to read it again. + spark.catalog.refreshTable("view_refresh") + val newCount = sql("select count(*) from view_refresh").first().getLong(0) + assert(newCount > 0 && newCount < 100) + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index b043d291aa65f..754aabb5ac936 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -26,15 +26,15 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} -import org.apache.spark.sql.types.{DecimalType, StringType, StructType} +import org.apache.spark.sql.types.{DecimalType, StringType, StructField, StructType} class HiveMetastoreCatalogSuite extends TestHiveSingleton { - import hiveContext.implicits._ + import spark.implicits._ test("struct field should accept underscore in sub-column name") { val hiveTypeStr = "struct" - val dateType = CatalystSqlParser.parseDataType(hiveTypeStr) - assert(dateType.isInstanceOf[StructType]) + val dataType = CatalystSqlParser.parseDataType(hiveTypeStr) + assert(dataType.isInstanceOf[StructType]) } test("udt to metastore type conversion") { @@ -45,10 +45,18 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton { } test("duplicated metastore relations") { - val df = hiveContext.sql("SELECT * FROM src") + val df = spark.sql("SELECT * FROM src") logInfo(df.queryExecution.toString) df.as('a).join(df.as('b), $"a.key" === $"b.key") } + + test("should not truncate struct type catalog string") { + def field(n: Int): StructField = { + StructField("col" + n, StringType) + } + val dataType = StructType((1 to 100).map(field)) + assert(CatalystSqlParser.parseDataType(dataType.catalogString) == dataType) + } } class DataSourceWithHiveMetastoreCatalogSuite diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index b5af758a65b1c..09c15473b21c1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -51,8 +51,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton test("Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - hiveContext.read.parquet(dir.getCanonicalPath).registerTempTable("p") - withTempTable("p") { + spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("p") + withTempView("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), sql("SELECT * from p ORDER BY key").collect().toSeq) @@ -65,8 +65,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t", false) { withTempPath { file => sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - hiveContext.read.parquet(file.getCanonicalPath).registerTempTable("p") - withTempTable("p") { + spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("p") + withTempView("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 77a6a94a6719b..dd8fec0c15ffa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import java.io.File +import java.io.{BufferedWriter, File, FileWriter} import java.sql.Timestamp import java.util.Date @@ -31,9 +31,9 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext} +import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource, JarResource} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer @@ -142,8 +142,7 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - // TODO: re-enable this after rebuilding the jar (HiveContext was removed) - ignore("SPARK-8489: MissingRequirementError during reflection") { + test("SPARK-8489: MissingRequirementError during reflection") { // This test uses a pre-built jar to test SPARK-8489. In a nutshell, this test creates // a HiveContext and uses it to create a data frame from an RDD using reflection. // Before the fix in SPARK-8470, this results in a MissingRequirementError because @@ -206,7 +205,7 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SetWarehouseLocationTest.getClass.getName.stripSuffix("$"), - "--name", "SetWarehouseLocationTest", + "--name", "SetSparkWarehouseLocationTest", "--master", "local-cluster[2,1,1024]", "--conf", "spark.ui.enabled=false", "--conf", "spark.master.rest.enabled=false", @@ -215,6 +214,86 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("set hive.metastore.warehouse.dir") { + // In this test, we set hive.metastore.warehouse.dir in hive-site.xml but + // not set spark.sql.warehouse.dir. So, the warehouse dir should be + // the value of hive.metastore.warehouse.dir. Also, the value of + // spark.sql.warehouse.dir should be set to the value of hive.metastore.warehouse.dir. + + val hiveWarehouseLocation = Utils.createTempDir() + hiveWarehouseLocation.delete() + val hiveSiteXmlContent = + s""" + | + | + | hive.metastore.warehouse.dir + | $hiveWarehouseLocation + | + | + """.stripMargin + + // Write a hive-site.xml containing a setting of hive.metastore.warehouse.dir. + val hiveSiteDir = Utils.createTempDir() + val file = new File(hiveSiteDir.getCanonicalPath, "hive-site.xml") + val bw = new BufferedWriter(new FileWriter(file)) + bw.write(hiveSiteXmlContent) + bw.close() + + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SetWarehouseLocationTest.getClass.getName.stripSuffix("$"), + "--name", "SetHiveWarehouseLocationTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.sql.test.expectedWarehouseDir=$hiveWarehouseLocation", + "--conf", s"spark.driver.extraClassPath=${hiveSiteDir.getCanonicalPath}", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + + test("SPARK-16901: set javax.jdo.option.ConnectionURL") { + // In this test, we set javax.jdo.option.ConnectionURL and set metastore version to + // 0.13. This test will make sure that javax.jdo.option.ConnectionURL will not be + // overridden by hive's default settings when we create a HiveConf object inside + // HiveClientImpl. Please see SPARK-16901 for more details. + + val metastoreLocation = Utils.createTempDir() + metastoreLocation.delete() + val metastoreURL = + s"jdbc:derby:memory:;databaseName=${metastoreLocation.getAbsolutePath};create=true" + val hiveSiteXmlContent = + s""" + | + | + | javax.jdo.option.ConnectionURL + | $metastoreURL + | + | + """.stripMargin + + // Write a hive-site.xml containing a setting of hive.metastore.warehouse.dir. + val hiveSiteDir = Utils.createTempDir() + val file = new File(hiveSiteDir.getCanonicalPath, "hive-site.xml") + val bw = new BufferedWriter(new FileWriter(file)) + bw.write(hiveSiteXmlContent) + bw.close() + + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SetMetastoreURLTest.getClass.getName.stripSuffix("$"), + "--name", "SetMetastoreURLTest", + "--master", "local[1]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", s"spark.sql.test.expectedMetastoreURL=$metastoreURL", + "--conf", s"spark.driver.extraClassPath=${hiveSiteDir.getCanonicalPath}", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -275,23 +354,86 @@ class HiveSparkSubmitSuite } } +object SetMetastoreURLTest extends Logging { + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkConf = new SparkConf(loadDefaults = true) + val builder = SparkSession.builder() + .config(sparkConf) + .config("spark.ui.enabled", "false") + .config("spark.sql.hive.metastore.version", "0.13.1") + // The issue described in SPARK-16901 only appear when + // spark.sql.hive.metastore.jars is not set to builtin. + .config("spark.sql.hive.metastore.jars", "maven") + .enableHiveSupport() + + val spark = builder.getOrCreate() + val expectedMetastoreURL = + spark.conf.get("spark.sql.test.expectedMetastoreURL") + logInfo(s"spark.sql.test.expectedMetastoreURL is $expectedMetastoreURL") + + if (expectedMetastoreURL == null) { + throw new Exception( + s"spark.sql.test.expectedMetastoreURL should be set.") + } + + // HiveSharedState is used when Hive support is enabled. + val actualMetastoreURL = + spark.sharedState.asInstanceOf[HiveSharedState] + .metadataHive + .getConf("javax.jdo.option.ConnectionURL", "this_is_a_wrong_URL") + logInfo(s"javax.jdo.option.ConnectionURL is $actualMetastoreURL") + + if (actualMetastoreURL != expectedMetastoreURL) { + throw new Exception( + s"Expected value of javax.jdo.option.ConnectionURL is $expectedMetastoreURL. But, " + + s"the actual value is $actualMetastoreURL") + } + } +} + object SetWarehouseLocationTest extends Logging { def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") - val warehouseLocation = Utils.createTempDir() - warehouseLocation.delete() - val hiveWarehouseLocation = Utils.createTempDir() - hiveWarehouseLocation.delete() - val conf = new SparkConf() - conf.set("spark.ui.enabled", "false") - // We will use the value of spark.sql.warehouse.dir override the - // value of hive.metastore.warehouse.dir. - conf.set("spark.sql.warehouse.dir", warehouseLocation.toString) - conf.set("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString) + val sparkConf = new SparkConf(loadDefaults = true) + val builder = SparkSession.builder() + .config(sparkConf) + .config("spark.ui.enabled", "false") + .enableHiveSupport() + val providedExpectedWarehouseLocation = + sparkConf.getOption("spark.sql.test.expectedWarehouseDir") + + val (sparkSession, expectedWarehouseLocation) = providedExpectedWarehouseLocation match { + case Some(warehouseDir) => + // If spark.sql.test.expectedWarehouseDir is set, the warehouse dir is set + // through spark-summit. So, neither spark.sql.warehouse.dir nor + // hive.metastore.warehouse.dir is set at here. + (builder.getOrCreate(), warehouseDir) + case None => + val warehouseLocation = Utils.createTempDir() + warehouseLocation.delete() + val hiveWarehouseLocation = Utils.createTempDir() + hiveWarehouseLocation.delete() + // If spark.sql.test.expectedWarehouseDir is not set, we will set + // spark.sql.warehouse.dir and hive.metastore.warehouse.dir. + // We are expecting that the value of spark.sql.warehouse.dir will override the + // value of hive.metastore.warehouse.dir. + val session = builder + .config("spark.sql.warehouse.dir", warehouseLocation.toString) + .config("hive.metastore.warehouse.dir", hiveWarehouseLocation.toString) + .getOrCreate() + (session, warehouseLocation.toString) + + } + + if (sparkSession.conf.get("spark.sql.warehouse.dir") != expectedWarehouseLocation) { + throw new Exception( + "spark.sql.warehouse.dir is not set to the expected warehouse location " + + s"$expectedWarehouseLocation.") + } - val sc = new SparkContext(conf) - val sparkSession = SparkSession.withHiveSupport(sc) val catalog = sparkSession.sessionState.catalog sparkSession.sql("drop table if exists testLocation") @@ -302,7 +444,7 @@ object SetWarehouseLocationTest extends Logging { val tableMetadata = catalog.getTableMetadata(TableIdentifier("testLocation", Some("default"))) val expectedLocation = - "file:" + warehouseLocation.toString + "/testlocation" + "file:" + expectedWarehouseLocation.toString + "/testlocation" val actualLocation = tableMetadata.storage.locationUri.get if (actualLocation != expectedLocation) { throw new Exception( @@ -318,7 +460,7 @@ object SetWarehouseLocationTest extends Logging { val tableMetadata = catalog.getTableMetadata(TableIdentifier("testLocation", Some("testLocationDB"))) val expectedLocation = - "file:" + warehouseLocation.toString + "/testlocationdb.db/testlocation" + "file:" + expectedWarehouseLocation.toString + "/testlocationdb.db/testlocation" val actualLocation = tableMetadata.storage.locationUri.get if (actualLocation != expectedLocation) { throw new Exception( @@ -353,7 +495,7 @@ object TemporaryHiveUDFTest extends Logging { """.stripMargin) val source = hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") - source.registerTempTable("sourceTable") + source.createOrReplaceTempView("sourceTable") // Actually use the loaded UDF. logInfo("Using the UDF.") val result = hiveContext.sql( @@ -391,7 +533,7 @@ object PermanentHiveUDFTest1 extends Logging { """.stripMargin) val source = hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") - source.registerTempTable("sourceTable") + source.createOrReplaceTempView("sourceTable") // Actually use the loaded UDF. logInfo("Using the UDF.") val result = hiveContext.sql( @@ -423,11 +565,11 @@ object PermanentHiveUDFTest2 extends Logging { val function = CatalogFunction( FunctionIdentifier("example_max"), "org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax", - ("JAR" -> jar) :: Nil) + FunctionResource(JarResource, jar) :: Nil) hiveContext.sessionState.catalog.createFunction(function, ignoreIfExists = false) val source = hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") - source.registerTempTable("sourceTable") + source.createOrReplaceTempView("sourceTable") // Actually use the loaded UDF. logInfo("Using the UDF.") val result = hiveContext.sql( @@ -489,7 +631,7 @@ object SparkSubmitClassLoaderTest extends Logging { """.stripMargin) val source = hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") - source.registerTempTable("sourceTable") + source.createOrReplaceTempView("sourceTable") // Load a Hive SerDe from the jar. logInfo("Creating a Hive table with a SerDe provided in a jar.") hiveContext.sql( @@ -553,7 +695,7 @@ object SparkSQLConfTest extends Logging { object SPARK_9757 extends QueryTest { import org.apache.spark.sql.functions._ - protected var sqlContext: SQLContext = _ + protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -565,7 +707,7 @@ object SPARK_9757 extends QueryTest { .set("spark.ui.enabled", "false")) val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext + spark = hiveContext.sparkSession import hiveContext.implicits._ val dir = Utils.createTempDir() @@ -600,7 +742,7 @@ object SPARK_9757 extends QueryTest { object SPARK_11009 extends QueryTest { import org.apache.spark.sql.functions._ - protected var sqlContext: SQLContext = _ + protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -611,10 +753,10 @@ object SPARK_11009 extends QueryTest { .set("spark.sql.shuffle.partitions", "100")) val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext + spark = hiveContext.sparkSession try { - val df = sqlContext.range(1 << 20) + val df = spark.range(1 << 20) val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) val df3 = df2.select(df2("A"), df2("B"), row_number().over(ws).alias("rn")).filter("rn < 0") @@ -631,7 +773,7 @@ object SPARK_14244 extends QueryTest { import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ - protected var sqlContext: SQLContext = _ + protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -642,13 +784,13 @@ object SPARK_14244 extends QueryTest { .set("spark.sql.shuffle.partitions", "100")) val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext + spark = hiveContext.sparkSession import hiveContext.implicits._ try { val window = Window.orderBy('id) - val df = sqlContext.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist) + val df = spark.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist) checkAnswer(df, Seq(Row(0.5D), Row(1.0D))) } finally { sparkContext.stop() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala new file mode 100644 index 0000000000000..667a7ddd8bb61 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.QueryTest + +class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + test("newTemporaryConfiguration overwrites listener configurations") { + Seq(true, false).foreach { useInMemoryDerby => + val conf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby) + assert(conf(ConfVars.METASTORE_PRE_EVENT_LISTENERS.varname) === "") + assert(conf(ConfVars.METASTORE_EVENT_LISTENERS.varname) === "") + assert(conf(ConfVars.METASTORE_END_FUNCTION_LISTENERS.varname) === "") + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index baf34d1cf0aaf..d9ce1c3dc18ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkException import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -32,19 +34,19 @@ case class TestData(key: Int, value: String) case class ThreeCloumntable(key: Int, value: String, key1: String) -class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { - import hiveContext.implicits._ - import hiveContext.sql +class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter + with SQLTestUtils { + import spark.implicits._ - val testData = hiveContext.sparkContext.parallelize( + override lazy val testData = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() before { // Since every we are doing tests for DDL statements, // it is better to reset before every test. hiveContext.reset() - // Register the testData, which will be used in every test. - testData.registerTempTable("testData") + // Creates a temporary view with testData, which will be used in all tests. + testData.createOrReplaceTempView("testData") } test("insertInto() HiveTable") { @@ -93,10 +95,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef test("SPARK-4052: scala.collection.Map as value type of MapType") { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) - val rowRDD = hiveContext.sparkContext.parallelize( + val rowRDD = spark.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val df = hiveContext.createDataFrame(rowRDD, schema) - df.registerTempTable("tableWithMapValue") + val df = spark.createDataFrame(rowRDD, schema) + df.createOrReplaceTempView("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -164,12 +166,80 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE tmp_table") } + test("INSERT OVERWRITE - partition IF NOT EXISTS") { + withTempDir { tmpDir => + val table = "table_with_partition" + withTable(table) { + val selQuery = s"select c1, p1, p2 from $table" + sql( + s""" + |CREATE TABLE $table(c1 string) + |PARTITIONED by (p1 string,p2 string) + |location '${tmpDir.toURI.toString}' + """.stripMargin) + sql( + s""" + |INSERT OVERWRITE TABLE $table + |partition (p1='a',p2='b') + |SELECT 'blarr' + """.stripMargin) + checkAnswer( + sql(selQuery), + Row("blarr", "a", "b")) + + sql( + s""" + |INSERT OVERWRITE TABLE $table + |partition (p1='a',p2='b') + |SELECT 'blarr2' + """.stripMargin) + checkAnswer( + sql(selQuery), + Row("blarr2", "a", "b")) + + var e = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE $table + |partition (p1='a',p2) IF NOT EXISTS + |SELECT 'blarr3', 'newPartition' + """.stripMargin) + } + assert(e.getMessage.contains( + "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]")) + + e = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE $table + |partition (p1='a',p2) IF NOT EXISTS + |SELECT 'blarr3', 'b' + """.stripMargin) + } + assert(e.getMessage.contains( + "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]")) + + // If the partition already exists, the insert will overwrite the data + // unless users specify IF NOT EXISTS + sql( + s""" + |INSERT OVERWRITE TABLE $table + |partition (p1='a',p2='b') IF NOT EXISTS + |SELECT 'blarr3' + """.stripMargin) + checkAnswer( + sql(selQuery), + Row("blarr2", "a", "b")) + } + } + } + test("Insert ArrayType.containsNull == false") { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) - val rowRDD = hiveContext.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val df = hiveContext.createDataFrame(rowRDD, schema) - df.registerTempTable("tableWithArrayValue") + val rowRDD = spark.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) + val df = spark.createDataFrame(rowRDD, schema) + df.createOrReplaceTempView("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -183,10 +253,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef test("Insert MapType.valueContainsNull == false") { val schema = StructType(Seq( StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) - val rowRDD = hiveContext.sparkContext.parallelize( + val rowRDD = spark.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val df = hiveContext.createDataFrame(rowRDD, schema) - df.registerTempTable("tableWithMapValue") + val df = spark.createDataFrame(rowRDD, schema) + df.createOrReplaceTempView("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -200,10 +270,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef test("Insert StructType.fields.exists(_.nullable == false)") { val schema = StructType(Seq( StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) - val rowRDD = hiveContext.sparkContext.parallelize( + val rowRDD = spark.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val df = hiveContext.createDataFrame(rowRDD, schema) - df.registerTempTable("tableWithStructValue") + val df = spark.createDataFrame(rowRDD, schema) + df.createOrReplaceTempView("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") @@ -213,4 +283,238 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE hiveTableWithStructValue") } + + test("Reject partitioning that does not match table") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + + intercept[AnalysisException] { + // cannot partition by 2 fields when there is only one in the table definition + data.write.partitionBy("part", "data").insertInto("partitioned") + } + } + } + + test("Test partition mode = strict") { + withSQLConf(("hive.exec.dynamic.partition.mode", "strict")) { + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + + intercept[SparkException] { + data.write.insertInto("partitioned") + } + } + } + + test("Detect table partitioning") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, data string, part string)") + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")).toDF() + + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + // this will pick up the output partitioning from the table definition + spark.table("source").write.insertInto("partitioned") + + checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq) + } + } + + private def testPartitionedHiveSerDeTable(testName: String)(f: String => Unit): Unit = { + test(s"Hive SerDe table - $testName") { + val hiveTable = "hive_table" + + withTable(hiveTable) { + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + sql( + s""" + |CREATE TABLE $hiveTable (a INT, d INT) + |PARTITIONED BY (b INT, c INT) STORED AS TEXTFILE + """.stripMargin) + f(hiveTable) + } + } + } + } + + private def testPartitionedDataSourceTable(testName: String)(f: String => Unit): Unit = { + test(s"Data source table - $testName") { + val dsTable = "ds_table" + + withTable(dsTable) { + sql( + s""" + |CREATE TABLE $dsTable (a INT, b INT, c INT, d INT) + |USING PARQUET PARTITIONED BY (b, c) + """.stripMargin) + f(dsTable) + } + } + } + + private def testPartitionedTable(testName: String)(f: String => Unit): Unit = { + testPartitionedHiveSerDeTable(testName)(f) + testPartitionedDataSourceTable(testName)(f) + } + + testPartitionedTable("partitionBy() can't be used together with insertInto()") { tableName => + val cause = intercept[AnalysisException] { + Seq((1, 2, 3, 4)).toDF("a", "b", "c", "d").write.partitionBy("b", "c").insertInto(tableName) + } + + assert(cause.getMessage.contains("insertInto() can't be used together with partitionBy().")) + } + + test("InsertIntoTable#resolved should include dynamic partitions") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data") + + val logical = InsertIntoTable(spark.table("partitioned").logicalPlan, + Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false) + assert(!logical.resolved, "Should not resolve: missing partition data") + } + } + + testPartitionedTable( + "SPARK-16036: better error message when insert into a table with mismatch schema") { + tableName => + val e = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName PARTITION(b=1, c=2) SELECT 1, 2, 3") + } + assert(e.message.contains("the number of columns are different")) + } + + testPartitionedTable("SPARK-16037: INSERT statement should match columns by position") { + tableName => + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + sql(s"INSERT INTO TABLE $tableName SELECT 1, 4, 2 AS c, 3 AS b") + checkAnswer(sql(s"SELECT a, b, c, d FROM $tableName"), Row(1, 2, 3, 4)) + sql(s"INSERT OVERWRITE TABLE $tableName SELECT 1, 4, 2, 3") + checkAnswer(sql(s"SELECT a, b, c, 4 FROM $tableName"), Row(1, 2, 3, 4)) + } + } + + testPartitionedTable("INSERT INTO a partitioned table (semantic and error handling)") { + tableName => + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql(s"INSERT INTO TABLE $tableName PARTITION (b=2, c=3) SELECT 1, 4") + + sql(s"INSERT INTO TABLE $tableName PARTITION (b=6, c=7) SELECT 5, 8") + + sql(s"INSERT INTO TABLE $tableName PARTITION (c=11, b=10) SELECT 9, 12") + + // c is defined twice. Parser will complain. + intercept[ParseException] { + sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, c=16) SELECT 13") + } + + // d is not a partitioning column. + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, d=16) SELECT 13, 14") + } + + // d is not a partitioning column. The total number of columns is correct. + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, d=16) SELECT 13") + } + + // The data is missing a column. + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName PARTITION (c=15, b=16) SELECT 13") + } + + // d is not a partitioning column. + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName PARTITION (b=15, d=15) SELECT 13, 14") + } + + // The statement is missing a column. + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName PARTITION (b=15) SELECT 13, 14") + } + + // The statement is missing a column. + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName PARTITION (b=15) SELECT 13, 14, 16") + } + + sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c) SELECT 13, 16, 15") + + // Dynamic partitioning columns need to be after static partitioning columns. + intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName PARTITION (b, c=19) SELECT 17, 20, 18") + } + + sql(s"INSERT INTO TABLE $tableName PARTITION (b, c) SELECT 17, 20, 18, 19") + + sql(s"INSERT INTO TABLE $tableName PARTITION (c, b) SELECT 21, 24, 22, 23") + + sql(s"INSERT INTO TABLE $tableName SELECT 25, 28, 26, 27") + + checkAnswer( + sql(s"SELECT a, b, c, d FROM $tableName"), + Row(1, 2, 3, 4) :: + Row(5, 6, 7, 8) :: + Row(9, 10, 11, 12) :: + Row(13, 14, 15, 16) :: + Row(17, 18, 19, 20) :: + Row(21, 22, 23, 24) :: + Row(25, 26, 27, 28) :: Nil + ) + } + } + + testPartitionedTable("insertInto() should match columns by position and ignore column names") { + tableName => + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + // Columns `df.c` and `df.d` are resolved by position, and thus mapped to partition columns + // `b` and `c` of the target table. + val df = Seq((1, 2, 3, 4)).toDF("a", "b", "c", "d") + df.write.insertInto(tableName) + + checkAnswer( + sql(s"SELECT a, b, c, d FROM $tableName"), + Row(1, 3, 4, 2) + ) + } + } + + testPartitionedTable("insertInto() should match unnamed columns by position") { + tableName => + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + // Columns `c + 1` and `d + 1` are resolved by position, and thus mapped to partition + // columns `b` and `c` of the target table. + val df = Seq((1, 2, 3, 4)).toDF("a", "b", "c", "d") + df.select('a + 1, 'b + 1, 'c + 1, 'd + 1).write.insertInto(tableName) + + checkAnswer( + sql(s"SELECT a, b, c, d FROM $tableName"), + Row(2, 4, 5, 3) + ) + } + } + + testPartitionedTable("insertInto() should reject missing columns") { + tableName => + sql("CREATE TABLE t (a INT, b INT)") + + intercept[AnalysisException] { + spark.table("t").write.insertInto(tableName) + } + } + + testPartitionedTable("insertInto() should reject extra columns") { + tableName => + sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT)") + + intercept[AnalysisException] { + spark.table("t").write.insertInto(tableName) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index e8188e5f02f28..8dc756b9380c3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -33,7 +33,7 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft override def beforeAll(): Unit = { super.beforeAll() // The catalog in HiveContext is a case insensitive one. - sessionState.catalog.createTempTable( + sessionState.catalog.createTempView( "ListTablesSuiteTable", df.logicalPlan, overrideIfExists = true) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index c3a9f2479ce74..28c8139072ead 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils */ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext._ - import hiveContext.implicits._ + import spark.implicits._ var jsonFilePath: String = _ @@ -79,8 +79,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv |) """.stripMargin) - withTempTable("expectedJsonTable") { - read.json(jsonFilePath).registerTempTable("expectedJsonTable") + withTempView("expectedJsonTable") { + read.json(jsonFilePath).createOrReplaceTempView("expectedJsonTable") checkAnswer( sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) @@ -109,8 +109,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(expectedSchema === table("jsonTable").schema) - withTempTable("expectedJsonTable") { - read.json(jsonFilePath).registerTempTable("expectedJsonTable") + withTempView("expectedJsonTable") { + read.json(jsonFilePath).createOrReplaceTempView("expectedJsonTable") checkAnswer( sql("SELECT b, ``.`=` FROM jsonTable"), sql("SELECT b, ``.`=` FROM expectedJsonTable")) @@ -247,21 +247,21 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv |) """.stripMargin) - withTempTable("expectedJsonTable") { - read.json(jsonFilePath).registerTempTable("expectedJsonTable") + withTempView("expectedJsonTable") { + read.json(jsonFilePath).createOrReplaceTempView("expectedJsonTable") checkAnswer( sql("SELECT * FROM jsonTable"), sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) // Discard the cached relation. - sessionState.invalidateTable("jsonTable") + sessionState.refreshTable("jsonTable") checkAnswer( sql("SELECT * FROM jsonTable"), sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - sessionState.invalidateTable("jsonTable") + sessionState.refreshTable("jsonTable") val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) assert(expectedSchema === table("jsonTable").schema) @@ -333,7 +333,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv }.getMessage assert( - message.contains("Table ctasJsonTable already exists."), + message.contains("Table default.ctasJsonTable already exists."), "We should complain that ctasJsonTable already exists") // The following statement should be fine if it has IF NOT EXISTS. @@ -349,7 +349,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv """.stripMargin) // Discard the cached relation. - sessionState.invalidateTable("ctasJsonTable") + sessionState.refreshTable("ctasJsonTable") // Schema should not be changed. assert(table("ctasJsonTable").schema === table("jsonTable").schema) @@ -374,7 +374,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val expectedPath = sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) // It is a managed table when we do not specify the location. @@ -424,7 +424,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), (6 to 10).map(i => Row(i, s"str$i"))) - sessionState.invalidateTable("savedJsonTable") + sessionState.refreshTable("savedJsonTable") checkAnswer( sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), @@ -548,13 +548,13 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv }.getMessage.contains("Unable to infer schema"), "We should complain that path is not specified.") - sql("DROP TABLE createdJsonTable") + sql("DROP TABLE IF EXISTS createdJsonTable") } test("scan a parquet table created through a CTAS statement") { withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true") { - withTempTable("jt") { - (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") + withTempView("jt") { + (1 to 10).map(i => i -> s"str$i").toDF("a", "b").createOrReplaceTempView("jt") withTable("test_parquet_ctas") { sql( @@ -622,7 +622,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .mode(SaveMode.Append) .saveAsTable("arrayInParquet") - sessionState.refreshTable("arrayInParquet") + sparkSession.catalog.refreshTable("arrayInParquet") checkAnswer( sql("SELECT a FROM arrayInParquet"), @@ -681,7 +681,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .mode(SaveMode.Append) .saveAsTable("mapInParquet") - sessionState.refreshTable("mapInParquet") + sparkSession.catalog.refreshTable("mapInParquet") checkAnswer( sql("SELECT a FROM mapInParquet"), @@ -700,8 +700,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) // Manually create a metastore data source table. - CreateDataSourceTableUtils.createDataSourceTable( - sparkSession = sqlContext.sparkSession, + createDataSourceTable( + sparkSession = spark, tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -710,7 +710,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv options = Map("path" -> tempDir.getCanonicalPath), isExternal = false) - sessionState.invalidateTable("wide_schema") + sessionState.refreshTable("wide_schema") val actualSchema = table("wide_schema").schema assert(schema === actualSchema) @@ -732,19 +732,23 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv inputFormat = None, outputFormat = None, serde = None, + compressed = false, serdeProperties = Map( "path" -> sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier(tableName))) ), properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema" -> schema.json, + DATASOURCE_PROVIDER -> "json", + DATASOURCE_SCHEMA -> schema.json, "EXTERNAL" -> "FALSE")) sharedState.externalCatalog.createTable("default", hiveTable, ignoreIfExists = false) - sessionState.invalidateTable(tableName) + sessionState.refreshTable(tableName) val actualSchema = table(tableName).schema assert(schema === actualSchema) + + // Checks the DESCRIBE output. + checkAnswer(sql("DESCRIBE spark6655"), Row("int", "int", null) :: Nil) } } @@ -754,17 +758,17 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv withTable(tableName) { df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) - sessionState.invalidateTable(tableName) + sessionState.refreshTable(tableName) val metastoreTable = sharedState.externalCatalog.getTable("default", tableName) val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) - val numPartCols = metastoreTable.properties("spark.sql.sources.schema.numPartCols").toInt + val numPartCols = metastoreTable.properties(DATASOURCE_SCHEMA_NUMPARTCOLS).toInt assert(numPartCols == 2) val actualPartitionColumns = StructType( (0 until numPartCols).map { index => - df.schema(metastoreTable.properties(s"spark.sql.sources.schema.partCol.$index")) + df.schema(metastoreTable.properties(s"$DATASOURCE_SCHEMA_PARTCOL_PREFIX$index")) }) // Make sure partition columns are correctly stored in metastore. assert( @@ -789,24 +793,24 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .bucketBy(8, "d", "b") .sortBy("c") .saveAsTable(tableName) - sessionState.invalidateTable(tableName) + sessionState.refreshTable(tableName) val metastoreTable = sharedState.externalCatalog.getTable("default", tableName) val expectedBucketByColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) val expectedSortByColumns = StructType(df.schema("c") :: Nil) - val numBuckets = metastoreTable.properties("spark.sql.sources.schema.numBuckets").toInt + val numBuckets = metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETS).toInt assert(numBuckets == 8) - val numBucketCols = metastoreTable.properties("spark.sql.sources.schema.numBucketCols").toInt + val numBucketCols = metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETCOLS).toInt assert(numBucketCols == 2) - val numSortCols = metastoreTable.properties("spark.sql.sources.schema.numSortCols").toInt + val numSortCols = metastoreTable.properties(DATASOURCE_SCHEMA_NUMSORTCOLS).toInt assert(numSortCols == 1) val actualBucketByColumns = StructType( (0 until numBucketCols).map { index => - df.schema(metastoreTable.properties(s"spark.sql.sources.schema.bucketCol.$index")) + df.schema(metastoreTable.properties(s"$DATASOURCE_SCHEMA_BUCKETCOL_PREFIX$index")) }) // Make sure bucketBy columns are correctly stored in metastore. assert( @@ -817,7 +821,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val actualSortByColumns = StructType( (0 until numSortCols).map { index => - df.schema(metastoreTable.properties(s"spark.sql.sources.schema.sortCol.$index")) + df.schema(metastoreTable.properties(s"$DATASOURCE_SCHEMA_SORTCOL_PREFIX$index")) }) // Make sure sortBy columns are correctly stored in metastore. assert( @@ -887,21 +891,96 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + test("append table using different formats") { + def createDF(from: Int, to: Int): DataFrame = { + (from to to).map(i => i -> s"str$i").toDF("c1", "c2") + } + + withTable("appendOrcToParquet") { + createDF(0, 9).write.format("parquet").saveAsTable("appendOrcToParquet") + val e = intercept[AnalysisException] { + createDF(10, 19).write.mode(SaveMode.Append).format("orc").saveAsTable("appendOrcToParquet") + } + assert(e.getMessage.contains( + "The file format of the existing table default.appendOrcToParquet " + + "is `org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat`. " + + "It doesn't match the specified format `orc`")) + } + + withTable("appendParquetToJson") { + createDF(0, 9).write.format("json").saveAsTable("appendParquetToJson") + val e = intercept[AnalysisException] { + createDF(10, 19).write.mode(SaveMode.Append).format("parquet") + .saveAsTable("appendParquetToJson") + } + assert(e.getMessage.contains( + "The file format of the existing table default.appendParquetToJson " + + "is `org.apache.spark.sql.execution.datasources.json.JsonFileFormat`. " + + "It doesn't match the specified format `parquet`")) + } + + withTable("appendTextToJson") { + createDF(0, 9).write.format("json").saveAsTable("appendTextToJson") + val e = intercept[AnalysisException] { + createDF(10, 19).write.mode(SaveMode.Append).format("text") + .saveAsTable("appendTextToJson") + } + assert(e.getMessage.contains( + "The file format of the existing table default.appendTextToJson is " + + "`org.apache.spark.sql.execution.datasources.json.JsonFileFormat`. " + + "It doesn't match the specified format `text`")) + } + } + + test("append a table using the same formats but different names") { + def createDF(from: Int, to: Int): DataFrame = { + (from to to).map(i => i -> s"str$i").toDF("c1", "c2") + } + + withTable("appendParquet") { + createDF(0, 9).write.format("parquet").saveAsTable("appendParquet") + createDF(10, 19).write.mode(SaveMode.Append).format("org.apache.spark.sql.parquet") + .saveAsTable("appendParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM appendParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) + } + + withTable("appendParquet") { + createDF(0, 9).write.format("org.apache.spark.sql.parquet").saveAsTable("appendParquet") + createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("appendParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM appendParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) + } + + withTable("appendParquet") { + createDF(0, 9).write.format("org.apache.spark.sql.parquet.DefaultSource") + .saveAsTable("appendParquet") + createDF(10, 19).write.mode(SaveMode.Append) + .format("org.apache.spark.sql.execution.datasources.parquet.DefaultSource") + .saveAsTable("appendParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM appendParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) + } + } + test("SPARK-8156:create table to specific database by 'use dbname' ") { val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") - sqlContext.sql("""create database if not exists testdb8156""") - sqlContext.sql("""use testdb8156""") + spark.sql("""create database if not exists testdb8156""") + spark.sql("""use testdb8156""") df.write .format("parquet") .mode(SaveMode.Overwrite) .saveAsTable("ttt3") checkAnswer( - sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), + spark.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), Row("ttt3", false)) - sqlContext.sql("""use default""") - sqlContext.sql("""drop database if exists testdb8156 CASCADE""") + spark.sql("""use default""") + spark.sql("""drop database if exists testdb8156 CASCADE""") } @@ -909,8 +988,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv withTempDir { tempPath => val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) - CreateDataSourceTableUtils.createDataSourceTable( - sparkSession = sqlContext.sparkSession, + createDataSourceTable( + sparkSession = spark, tableIdent = TableIdentifier("not_skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -924,8 +1003,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv assert(sharedState.externalCatalog.getTable("default", "not_skip_hive_metadata").schema .forall(column => CatalystSqlParser.parseDataType(column.dataType) == StringType)) - CreateDataSourceTableUtils.createDataSourceTable( - sparkSession = sqlContext.sparkSession, + createDataSourceTable( + sparkSession = spark, tableIdent = TableIdentifier("skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -943,7 +1022,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("CTAS: persisted partitioned data source table") { - withTempDir { dir => + withTempPath { dir => withTable("t") { val path = dir.getCanonicalPath @@ -956,10 +1035,10 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv ) val metastoreTable = sharedState.externalCatalog.getTable("default", "t") - assert(metastoreTable.properties("spark.sql.sources.schema.numPartCols").toInt === 1) - assert(!metastoreTable.properties.contains("spark.sql.sources.schema.numBuckets")) - assert(!metastoreTable.properties.contains("spark.sql.sources.schema.numBucketCols")) - assert(!metastoreTable.properties.contains("spark.sql.sources.schema.numSortCols")) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMPARTCOLS).toInt === 1) + assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMBUCKETS)) + assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMBUCKETCOLS)) + assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMSORTCOLS)) checkAnswer(table("t"), Row(2, 1)) } @@ -967,7 +1046,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("CTAS: persisted bucketed data source table") { - withTempDir { dir => + withTempPath { dir => withTable("t") { val path = dir.getCanonicalPath @@ -980,14 +1059,16 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv ) val metastoreTable = sharedState.externalCatalog.getTable("default", "t") - assert(!metastoreTable.properties.contains("spark.sql.sources.schema.numPartCols")) - assert(metastoreTable.properties("spark.sql.sources.schema.numBuckets").toInt === 2) - assert(metastoreTable.properties("spark.sql.sources.schema.numBucketCols").toInt === 1) - assert(metastoreTable.properties("spark.sql.sources.schema.numSortCols").toInt === 1) + assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMPARTCOLS)) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETS).toInt === 2) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETCOLS).toInt === 1) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMSORTCOLS).toInt === 1) checkAnswer(table("t"), Row(1, 2)) } + } + withTempPath { dir => withTable("t") { val path = dir.getCanonicalPath @@ -1000,10 +1081,10 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv ) val metastoreTable = sharedState.externalCatalog.getTable("default", "t") - assert(!metastoreTable.properties.contains("spark.sql.sources.schema.numPartCols")) - assert(metastoreTable.properties("spark.sql.sources.schema.numBuckets").toInt === 2) - assert(metastoreTable.properties("spark.sql.sources.schema.numBucketCols").toInt === 1) - assert(!metastoreTable.properties.contains("spark.sql.sources.schema.numSortCols")) + assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMPARTCOLS)) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETS).toInt === 2) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETCOLS).toInt === 1) + assert(!metastoreTable.properties.contains(DATASOURCE_SCHEMA_NUMSORTCOLS)) checkAnswer(table("t"), Row(1, 2)) } @@ -1011,7 +1092,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } test("CTAS: persisted partitioned bucketed data source table") { - withTempDir { dir => + withTempPath { dir => withTable("t") { val path = dir.getCanonicalPath @@ -1025,13 +1106,90 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv ) val metastoreTable = sharedState.externalCatalog.getTable("default", "t") - assert(metastoreTable.properties("spark.sql.sources.schema.numPartCols").toInt === 1) - assert(metastoreTable.properties("spark.sql.sources.schema.numBuckets").toInt === 2) - assert(metastoreTable.properties("spark.sql.sources.schema.numBucketCols").toInt === 1) - assert(metastoreTable.properties("spark.sql.sources.schema.numSortCols").toInt === 1) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMPARTCOLS).toInt === 1) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETS).toInt === 2) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMBUCKETCOLS).toInt === 1) + assert(metastoreTable.properties(DATASOURCE_SCHEMA_NUMSORTCOLS).toInt === 1) checkAnswer(table("t"), Row(2, 3, 1)) } } } + + test("saveAsTable[append]: the column order doesn't matter") { + withTable("saveAsTable_column_order") { + Seq((1, 2)).toDF("i", "j").write.saveAsTable("saveAsTable_column_order") + Seq((3, 4)).toDF("j", "i").write.mode("append").saveAsTable("saveAsTable_column_order") + checkAnswer( + table("saveAsTable_column_order"), + Seq((1, 2), (4, 3)).toDF("i", "j")) + } + } + + test("saveAsTable[append]: mismatch column names") { + withTable("saveAsTable_mismatch_column_names") { + Seq((1, 2)).toDF("i", "j").write.saveAsTable("saveAsTable_mismatch_column_names") + val e = intercept[AnalysisException] { + Seq((3, 4)).toDF("i", "k") + .write.mode("append").saveAsTable("saveAsTable_mismatch_column_names") + } + assert(e.getMessage.contains("cannot resolve")) + } + } + + test("saveAsTable[append]: too many columns") { + withTable("saveAsTable_too_many_columns") { + Seq((1, 2)).toDF("i", "j").write.saveAsTable("saveAsTable_too_many_columns") + val e = intercept[AnalysisException] { + Seq((3, 4, 5)).toDF("i", "j", "k") + .write.mode("append").saveAsTable("saveAsTable_too_many_columns") + } + assert(e.getMessage.contains("doesn't match")) + } + } + + test("saveAsTable[append]: less columns") { + withTable("saveAsTable_less_columns") { + Seq((1, 2)).toDF("i", "j").write.saveAsTable("saveAsTable_less_columns") + val e = intercept[AnalysisException] { + Seq((4)).toDF("j") + .write.mode("append").saveAsTable("saveAsTable_less_columns") + } + assert(e.getMessage.contains("doesn't match")) + } + } + + test("SPARK-15025: create datasource table with path with select") { + withTempPath { dir => + withTable("t") { + val path = dir.getCanonicalPath + + sql( + s"""CREATE TABLE t USING PARQUET + |OPTIONS (PATH '$path') + |AS SELECT 1 AS a, 2 AS b, 3 AS c + """.stripMargin + ) + sql("insert into t values (2, 3, 4)") + checkAnswer(table("t"), Seq(Row(1, 2, 3), Row(2, 3, 4))) + val catalogTable = sharedState.externalCatalog.getTable("default", "t") + // there should not be a lowercase key 'path' now + assert(catalogTable.storage.serdeProperties.get("path").isEmpty) + assert(catalogTable.storage.serdeProperties.get("PATH").isDefined) + } + } + } + + test("SPARK-15269 external data source table creation") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(1).write.json(path) + + withTable("t") { + sql(s"CREATE TABLE t USING json OPTIONS (PATH '$path')") + sql("DROP TABLE t") + sql(s"CREATE TABLE t USING json AS SELECT 1 AS c") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala new file mode 100644 index 0000000000000..c6711c3236250 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala @@ -0,0 +1,56 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + test("makeCopy and toJSON should work") { + val table = CatalogTable( + identifier = TableIdentifier("test", Some("db")), + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = Seq.empty[CatalogColumn]) + val relation = MetastoreRelation("db", "test", None)(table, null, null) + + // No exception should be thrown + relation.makeCopy(Array("db", "test", None)) + // No exception should be thrown + relation.toJSON + } + + test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { + withTable("bar") { + withTempView("foo") { + sql("select 0 as id").createOrReplaceTempView("foo") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + sql("CREATE TABLE bar AS SELECT * FROM foo group by id") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(!DDLUtils.isDatasourceTable(tableMetadata), + "the expected table is a Hive serde table") + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 850cb1eda5809..83f1b192f7c94 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -22,26 +22,33 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - private lazy val df = sqlContext.range(10).coalesce(1).toDF() + private lazy val df = spark.range(10).coalesce(1).toDF() private def checkTablePath(dbName: String, tableName: String): Unit = { - val metastoreTable = hiveContext.sharedState.externalCatalog.getTable(dbName, tableName) + val metastoreTable = spark.sharedState.externalCatalog.getTable(dbName, tableName) val expectedPath = - hiveContext.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName + spark.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName assert(metastoreTable.storage.serdeProperties("path") === expectedPath) } + private def getTableNames(dbName: Option[String] = None): Array[String] = { + dbName match { + case Some(db) => spark.catalog.listTables(db).collect().map(_.name) + case None => spark.catalog.listTables().collect().map(_.name) + } + } + test(s"saveAsTable() to non-default database - with USE - Overwrite") { withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) - checkAnswer(sqlContext.table("t"), df) + assert(getTableNames().contains("t")) + checkAnswer(spark.table("t"), df) } - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(getTableNames(Option(db)).contains("t")) + checkAnswer(spark.table(s"$db.t"), df) checkTablePath(db, "t") } @@ -50,8 +57,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle test(s"saveAsTable() to non-default database - without USE - Overwrite") { withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(getTableNames(Option(db)).contains("t")) + checkAnswer(spark.table(s"$db.t"), df) checkTablePath(db, "t") } @@ -64,9 +71,9 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle val path = dir.getCanonicalPath df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - sqlContext.createExternalTable("t", path, "parquet") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table("t"), df) + spark.catalog.createExternalTable("t", path, "parquet") + assert(getTableNames(Option(db)).contains("t")) + checkAnswer(spark.table("t"), df) sql( s""" @@ -76,8 +83,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle | path '$path' |) """.stripMargin) - assert(sqlContext.tableNames(db).contains("t1")) - checkAnswer(sqlContext.table("t1"), df) + assert(getTableNames(Option(db)).contains("t1")) + checkAnswer(spark.table("t1"), df) } } } @@ -88,10 +95,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempPath { dir => val path = dir.getCanonicalPath df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - sqlContext.createExternalTable(s"$db.t", path, "parquet") + spark.catalog.createExternalTable(s"$db.t", path, "parquet") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(getTableNames(Option(db)).contains("t")) + checkAnswer(spark.table(s"$db.t"), df) sql( s""" @@ -101,8 +108,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle | path '$path' |) """.stripMargin) - assert(sqlContext.tableNames(db).contains("t1")) - checkAnswer(sqlContext.table(s"$db.t1"), df) + assert(getTableNames(Option(db)).contains("t1")) + checkAnswer(spark.table(s"$db.t1"), df) } } } @@ -112,12 +119,12 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") df.write.mode(SaveMode.Append).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) - checkAnswer(sqlContext.table("t"), df.union(df)) + assert(getTableNames().contains("t")) + checkAnswer(spark.table("t"), df.union(df)) } - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + assert(getTableNames(Option(db)).contains("t")) + checkAnswer(spark.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") } @@ -127,8 +134,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + assert(getTableNames(Option(db)).contains("t")) + checkAnswer(spark.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") } @@ -138,10 +145,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) + assert(getTableNames().contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + checkAnswer(spark.table(s"$db.t"), df.union(df)) } } } @@ -150,13 +157,13 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) + assert(getTableNames().contains("t")) } - assert(sqlContext.tableNames(db).contains("t")) + assert(getTableNames(Option(db)).contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + checkAnswer(spark.table(s"$db.t"), df.union(df)) } } @@ -164,10 +171,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { sql("CREATE TABLE t (key INT)") - checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table("t"), spark.emptyDataFrame) } - checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table(s"$db.t"), spark.emptyDataFrame) } } @@ -175,21 +182,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { sql(s"CREATE TABLE t (key INT)") - assert(sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames("default").contains("t")) + assert(getTableNames().contains("t")) + assert(!getTableNames(Option("default")).contains("t")) } - assert(!sqlContext.tableNames().contains("t")) - assert(sqlContext.tableNames(db).contains("t")) + assert(!getTableNames().contains("t")) + assert(getTableNames(Option(db)).contains("t")) activateDatabase(db) { sql(s"DROP TABLE t") - assert(!sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames("default").contains("t")) + assert(!getTableNames().contains("t")) + assert(!getTableNames(Option("default")).contains("t")) } - assert(!sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames(db).contains("t")) + assert(!getTableNames().contains("t")) + assert(!getTableNames(Option(db)).contains("t")) } } @@ -208,18 +215,18 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle |LOCATION '$path' """.stripMargin) - checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table("t"), spark.emptyDataFrame) df.write.parquet(s"$path/p=1") sql("ALTER TABLE t ADD PARTITION (p=1)") sql("REFRESH TABLE t") - checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + checkAnswer(spark.table("t"), df.withColumn("p", lit(1))) df.write.parquet(s"$path/p=2") sql("ALTER TABLE t ADD PARTITION (p=2)") - hiveContext.sessionState.refreshTable("t") + spark.catalog.refreshTable("t") checkAnswer( - sqlContext.table("t"), + spark.table("t"), df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) } } @@ -240,18 +247,18 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle |LOCATION '$path' """.stripMargin) - checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table(s"$db.t"), spark.emptyDataFrame) df.write.parquet(s"$path/p=1") sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)") sql(s"REFRESH TABLE $db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1))) + checkAnswer(spark.table(s"$db.t"), df.withColumn("p", lit(1))) df.write.parquet(s"$path/p=2") sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") - hiveContext.sessionState.refreshTable(s"$db.t") + spark.catalog.refreshTable(s"$db.t") checkAnswer( - sqlContext.table(s"$db.t"), + spark.table(s"$db.t"), df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index af4dc1beec279..7d429f4723f51 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -54,7 +54,7 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi // Don't convert Hive metastore Parquet tables to let Hive write those Parquet files. withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") { - withTempTable("data") { + withTempView("data") { val fields = hiveTypes.zipWithIndex.map { case (typ, index) => s" col_$index $typ" } val ddl = @@ -70,12 +70,12 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi |$ddl """.stripMargin) - sqlContext.sql(ddl) + spark.sql(ddl) - val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(rows).coalesce(1) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") - sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + val schema = spark.table("parquet_compat").schema + val rowRDD = spark.sparkContext.parallelize(rows).coalesce(1) + spark.createDataFrame(rowRDD, schema).createOrReplaceTempView("data") + spark.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") } } @@ -84,7 +84,7 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. // Have to assume all BINARY values are strings here. withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(path), rows) + checkAnswer(spark.read.parquet(path), rows) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 78569c58085cd..feeaade561441 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.hive +import java.io.File + import com.google.common.io.Files +import org.apache.hadoop.fs.FileSystem import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -26,13 +29,13 @@ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - import hiveContext.implicits._ + import spark.implicits._ test("SPARK-5068: query data when path doesn't exist") { withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) { val testData = sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() - testData.registerTempTable("testData") + testData.createOrReplaceTempView("testData") val tmpDir = Files.createTempDir() // create the table for test @@ -61,8 +64,86 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl checkAnswer(sql("select key,value from table_with_partition"), testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) - sql("DROP TABLE table_with_partition") - sql("DROP TABLE createAndInsertTest") + sql("DROP TABLE IF EXISTS table_with_partition") + sql("DROP TABLE IF EXISTS createAndInsertTest") + } + } + + test("SPARK-13709: reading partitioned Avro table with nested schema") { + withTempDir { dir => + val path = dir.getCanonicalPath + val tableName = "spark_13709" + val tempTableName = "spark_13709_temp" + + new File(path, tableName).mkdir() + new File(path, tempTableName).mkdir() + + val avroSchema = + """{ + | "name": "test_record", + | "type": "record", + | "fields": [ { + | "name": "f0", + | "type": "int" + | }, { + | "name": "f1", + | "type": { + | "type": "record", + | "name": "inner", + | "fields": [ { + | "name": "f10", + | "type": "int" + | }, { + | "name": "f11", + | "type": "double" + | } ] + | } + | } ] + |} + """.stripMargin + + withTable(tableName, tempTableName) { + // Creates the external partitioned Avro table to be tested. + sql( + s"""CREATE EXTERNAL TABLE $tableName + |PARTITIONED BY (ds STRING) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$path/$tableName' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + + // Creates an temporary Avro table used to prepare testing Avro file. + sql( + s"""CREATE EXTERNAL TABLE $tempTableName + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat' + |LOCATION '$path/$tempTableName' + |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema') + """.stripMargin + ) + + // Generates Avro data. + sql(s"INSERT OVERWRITE TABLE $tempTableName SELECT 1, STRUCT(2, 2.5)") + + // Adds generated Avro data as a new partition to the testing table. + sql(s"ALTER TABLE $tableName ADD PARTITION (ds = 'foo') LOCATION '$path/$tempTableName'") + + // The following query fails before SPARK-13709 is fixed. This is because when reading data + // from table partitions, Avro deserializer needs the Avro schema, which is defined in + // table property "avro.schema.literal". However, we only initializes the deserializer using + // partition properties, which doesn't include the wanted property entry. Merging two sets + // of properties solves the problem. + checkAnswer( + sql(s"SELECT * FROM $tableName"), + Row(1, Row(2, 2.5D), "foo") + ) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala new file mode 100644 index 0000000000000..3f3dc122093b5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + +class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + test("data source table with user specified schema") { + withTable("ddl_test") { + val jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile + + sql( + s"""CREATE TABLE ddl_test ( + | a STRING, + | b STRING, + | `extra col` ARRAY, + | `` STRUCT> + |) + |USING json + |OPTIONS ( + | PATH '$jsonFilePath' + |) + """.stripMargin + ) + + checkCreateTable("ddl_test") + } + } + + test("data source table CTAS") { + withTable("ddl_test") { + sql( + s"""CREATE TABLE ddl_test + |USING json + |AS SELECT 1 AS a, "foo" AS b + """.stripMargin + ) + + checkCreateTable("ddl_test") + } + } + + test("partitioned data source table") { + withTable("ddl_test") { + sql( + s"""CREATE TABLE ddl_test + |USING json + |PARTITIONED BY (b) + |AS SELECT 1 AS a, "foo" AS b + """.stripMargin + ) + + checkCreateTable("ddl_test") + } + } + + test("bucketed data source table") { + withTable("ddl_test") { + sql( + s"""CREATE TABLE ddl_test + |USING json + |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS + |AS SELECT 1 AS a, "foo" AS b + """.stripMargin + ) + + checkCreateTable("ddl_test") + } + } + + test("partitioned bucketed data source table") { + withTable("ddl_test") { + sql( + s"""CREATE TABLE ddl_test + |USING json + |PARTITIONED BY (c) + |CLUSTERED BY (a) SORTED BY (b) INTO 2 BUCKETS + |AS SELECT 1 AS a, "foo" AS b, 2.5 AS c + """.stripMargin + ) + + checkCreateTable("ddl_test") + } + } + + test("data source table using Dataset API") { + withTable("ddl_test") { + spark + .range(3) + .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd, 'id as 'e) + .write + .mode("overwrite") + .partitionBy("a", "b") + .bucketBy(2, "c", "d") + .saveAsTable("ddl_test") + + checkCreateTable("ddl_test") + } + } + + test("simple hive table") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |TBLPROPERTIES ( + | 'prop1' = 'value1', + | 'prop2' = 'value2' + |) + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("simple external hive table") { + withTempDir { dir => + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |LOCATION '$dir' + |TBLPROPERTIES ( + | 'prop1' = 'value1', + | 'prop2' = 'value2' + |) + """.stripMargin + ) + + checkCreateTable("t1") + } + } + } + + test("partitioned hive table") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |COMMENT 'bla' + |PARTITIONED BY ( + | p1 BIGINT COMMENT 'bla', + | p2 STRING + |) + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("hive table with explicit storage info") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' + |COLLECTION ITEMS TERMINATED BY '@' + |MAP KEYS TERMINATED BY '#' + |NULL DEFINED AS 'NaN' + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("hive table with STORED AS clause") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |STORED AS PARQUET + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("hive table with serde info") { + withTable("t1") { + sql( + s"""CREATE TABLE t1 ( + | c1 INT COMMENT 'bla', + | c2 STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |WITH SERDEPROPERTIES ( + | 'mapkey.delim' = ',', + | 'field.delim' = ',' + |) + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin + ) + + checkCreateTable("t1") + } + } + + test("hive view") { + withView("v1") { + sql("CREATE VIEW v1 AS SELECT 1 AS a") + checkCreateView("v1") + } + } + + test("hive view with output columns") { + withView("v1") { + sql("CREATE VIEW v1 (b) AS SELECT 1 AS a") + checkCreateView("v1") + } + } + + test("hive bucketing is not supported") { + withTable("t1") { + createRawHiveTable( + s"""CREATE TABLE t1 (a INT, b STRING) + |CLUSTERED BY (a) + |SORTED BY (b) + |INTO 2 BUCKETS + """.stripMargin + ) + + val cause = intercept[AnalysisException] { + sql("SHOW CREATE TABLE t1") + } + + assert(cause.getMessage.contains(" - bucketing")) + } + } + + private def createRawHiveTable(ddl: String): Unit = { + hiveContext.sharedState.metadataHive.runSqlHive(ddl) + } + + private def checkCreateTable(table: String): Unit = { + checkCreateTableOrView(TableIdentifier(table, Some("default")), "TABLE") + } + + private def checkCreateView(table: String): Unit = { + checkCreateTableOrView(TableIdentifier(table, Some("default")), "VIEW") + } + + private def checkCreateTableOrView(table: TableIdentifier, checkType: String): Unit = { + val db = table.database.getOrElse("default") + val expected = spark.sharedState.externalCatalog.getTable(db, table.table) + val shownDDL = sql(s"SHOW CREATE TABLE ${table.quotedString}").head().getString(0) + sql(s"DROP $checkType ${table.quotedString}") + + try { + sql(shownDDL) + val actual = spark.sharedState.externalCatalog.getTable(db, table.table) + checkCatalogTables(expected, actual) + } finally { + sql(s"DROP $checkType IF EXISTS ${table.table}") + } + } + + private def checkCatalogTables(expected: CatalogTable, actual: CatalogTable): Unit = { + def normalize(table: CatalogTable): CatalogTable = { + val nondeterministicProps = Set( + "CreateTime", + "transient_lastDdlTime", + "grantTime", + "lastUpdateTime", + "last_modified_by", + "last_modified_time", + "Owner:", + "COLUMN_STATS_ACCURATE", + // The following are hive specific schema parameters which we do not need to match exactly. + "numFiles", + "numRows", + "rawDataSize", + "totalSize", + "totalNumberFiles", + "maxFileSize", + "minFileSize", + // EXTERNAL is not non-deterministic, but it is filtered out for external tables. + "EXTERNAL" + ) + + table.copy( + createTime = 0L, + lastAccessTime = 0L, + properties = table.properties.filterKeys(!nondeterministicProps.contains(_)), + // View texts are checked separately + viewOriginalText = None, + viewText = None + ) + } + + // Normalizes attributes auto-generated by Spark SQL for views + def normalizeGeneratedAttributes(str: String): String = { + str.replaceAll("gen_attr_[0-9]+", "gen_attr_0") + } + + // We use expanded canonical view text as original view text of the new table + assertResult(expected.viewText.map(normalizeGeneratedAttributes)) { + actual.viewOriginalText.map(normalizeGeneratedAttributes) + } + + assert(normalize(actual) == normalize(expected)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 8060ef77e7586..a5975cf483c10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.sql.hive +import java.io.{File, PrintWriter} + import scala.reflect.ClassTag -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.command.AnalyzeTable +import org.apache.spark.sql.execution.command.AnalyzeTableCommand import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils -class StatisticsSuite extends QueryTest with TestHiveSingleton { - import hiveContext.sql +class StatisticsSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = hiveContext.parseSql(analyzeCommand) + val parsed = spark.sessionState.sqlParser.parsePlan(analyzeCommand) val operators = parsed.collect { - case a: AnalyzeTable => a + case a: AnalyzeTableCommand => a case o => o } @@ -49,29 +51,73 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { assertAnalyzeCommand( "ANALYZE TABLE Table1 COMPUTE STATISTICS", - classOf[AnalyzeTable]) + classOf[AnalyzeTableCommand]) assertAnalyzeCommand( "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", - classOf[AnalyzeTable]) + classOf[AnalyzeTableCommand]) assertAnalyzeCommand( "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", - classOf[AnalyzeTable]) + classOf[AnalyzeTableCommand]) assertAnalyzeCommand( "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS", - classOf[AnalyzeTable]) + classOf[AnalyzeTableCommand]) assertAnalyzeCommand( "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan", - classOf[AnalyzeTable]) + classOf[AnalyzeTableCommand]) assertAnalyzeCommand( "ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn", - classOf[AnalyzeTable]) + classOf[AnalyzeTableCommand]) + } + + test("MetastoreRelations fallback to HDFS for size estimation") { + val enableFallBackToHdfsForStats = spark.sessionState.conf.fallBackToHdfsForStatsEnabled + try { + withTempDir { tempDir => + + // EXTERNAL OpenCSVSerde table pointing to LOCATION + + val file1 = new File(tempDir + "/data1") + val writer1 = new PrintWriter(file1) + writer1.write("1,2") + writer1.close() + + val file2 = new File(tempDir + "/data2") + val writer2 = new PrintWriter(file2) + writer2.write("1,2") + writer2.close() + + sql( + s"""CREATE EXTERNAL TABLE csv_table(page_id INT, impressions INT) + ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + WITH SERDEPROPERTIES ( + \"separatorChar\" = \",\", + \"quoteChar\" = \"\\\"\", + \"escapeChar\" = \"\\\\\") + LOCATION '$tempDir' + """) + + spark.conf.set(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key, true) + + val relation = spark.sessionState.catalog.lookupRelation(TableIdentifier("csv_table")) + .asInstanceOf[MetastoreRelation] + + val properties = relation.hiveQlTable.getParameters + assert(properties.get("totalSize").toLong <= 0, "external table totalSize must be <= 0") + assert(properties.get("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") + + val sizeInBytes = relation.statistics.sizeInBytes + assert(sizeInBytes === BigInt(file1.length() + file2.length())) + } + } finally { + spark.conf.set(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key, enableFallBackToHdfsForStats) + sql("DROP TABLE csv_table ") + } } - ignore("analyze MetastoreRelations") { + test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - hiveContext.sessionState.catalog.lookupRelation( - TableIdentifier(tableName)).statistics.sizeInBytes + spark.sessionState.catalog.lookupRelation(TableIdentifier(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -105,21 +151,20 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === hiveContext.conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === spark.sessionState.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") - // This seems to be flaky. - // assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) + assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) sql("DROP TABLE analyzeTable_part").collect() // Try to analyze a temp table - sql("""SELECT * FROM src""").registerTempTable("tempTable") - intercept[UnsupportedOperationException] { - hiveContext.sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") + sql("""SELECT * FROM src""").createOrReplaceTempView("tempTable") + intercept[AnalysisException] { + sql("ANALYZE TABLE tempTable COMPUTE STATISTICS") } - hiveContext.sessionState.catalog.dropTable( + spark.sessionState.catalog.dropTable( TableIdentifier("tempTable"), ignoreIfNotExists = true) } @@ -148,8 +193,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold - && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold + && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -160,8 +205,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { checkAnswer(df, expectedAnswer) // check correctness of output - hiveContext.conf.settings.synchronized { - val tmp = hiveContext.conf.autoBroadcastJoinThreshold + spark.sessionState.conf.settings.synchronized { + val tmp = spark.sessionState.conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") df = sql(query) @@ -204,8 +249,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= hiveContext.conf.autoBroadcastJoinThreshold - && sizes(0) <= hiveContext.conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold + && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -218,8 +263,8 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton { checkAnswer(df, answer) // check correctness of output - hiveContext.conf.settings.synchronized { - val tmp = hiveContext.conf.autoBroadcastJoinThreshold + spark.sessionState.conf.settings.synchronized { + val tmp = spark.sessionState.conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index d1aa5aa931947..88cc42efd0fe3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -36,7 +36,7 @@ class UDFSuite with TestHiveSingleton with BeforeAndAfterEach { - import hiveContext.implicits._ + import spark.implicits._ private[this] val functionName = "myUPper" private[this] val functionNameUpper = "MYUPPER" @@ -53,7 +53,7 @@ class UDFSuite sql("USE default") testDF = (1 to 10).map(i => s"sTr$i").toDF("value") - testDF.registerTempTable(testTableName) + testDF.createOrReplaceTempView(testTableName) expectedDF = (1 to 10).map(i => s"STR$i").toDF("value") super.beforeAll() } @@ -64,12 +64,12 @@ class UDFSuite } test("UDF case insensitive") { - hiveContext.udf.register("random0", () => { Math.random() }) - hiveContext.udf.register("RANDOM1", () => { Math.random() }) - hiveContext.udf.register("strlenScala", (_: String).length + (_: Int)) - assert(hiveContext.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + spark.udf.register("random0", () => { Math.random() }) + spark.udf.register("RANDOM1", () => { Math.random() }) + spark.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } test("temporary function: create and drop") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9341b3816fea7..5b209acf0f212 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,21 +17,27 @@ package org.apache.spark.sql.hive.client -import java.io.File +import java.io.{ByteArrayOutputStream, File, PrintStream} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal, NamedExpression} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.types.IntegerType import org.apache.spark.tags.ExtendedHiveTest -import org.apache.spark.util.Utils +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** * A simple set of tests that call the methods of a [[HiveClient]], loading different version @@ -97,12 +103,6 @@ class VersionsSuite extends SparkFunSuite with Logging { private val emptyDir = Utils.createTempDir().getCanonicalPath - private def partSpec = { - val hashMap = new java.util.LinkedHashMap[String, String] - hashMap.put("key", "1") - hashMap - } - // Its actually pretty easy to mess things up and have all of your tests "pass" by accidentally // connecting to an auto-populated, in-process metastore. Let's make sure we are getting the // versions right by forcing a known compatibility failure. @@ -122,7 +122,7 @@ class VersionsSuite extends SparkFunSuite with Logging { assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") + private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2") private var client: HiveClient = null @@ -130,109 +130,406 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: create client") { client = null System.gc() // Hack to avoid SEGV on some JVM versions. + val hadoopConf = new Configuration(); + hadoopConf.set("test", "success") client = IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, sparkConf = sparkConf, - hadoopConf = new Configuration(), + hadoopConf = hadoopConf, config = buildConf(), ivyPath = ivyPath).createClient() } + def table(database: String, tableName: String): CatalogTable = { + CatalogTable( + identifier = TableIdentifier(tableName, Some(database)), + tableType = CatalogTableType.MANAGED, + schema = Seq(CatalogColumn("key", "int")), + storage = CatalogStorageFormat( + locationUri = None, + inputFormat = Some(classOf[TextInputFormat].getName), + outputFormat = Some(classOf[HiveIgnoreKeyTextOutputFormat[_, _]].getName), + serde = Some(classOf[LazySimpleSerDe].getName()), + compressed = false, + serdeProperties = Map.empty + )) + } + + /////////////////////////////////////////////////////////////////////////// + // Database related API + /////////////////////////////////////////////////////////////////////////// + + val tempDatabasePath = Utils.createTempDir().getCanonicalPath + test(s"$version: createDatabase") { - val db = CatalogDatabase("default", "desc", "loc", Map()) - client.createDatabase(db, ignoreIfExists = true) + val defaultDB = CatalogDatabase("default", "desc", "loc", Map()) + client.createDatabase(defaultDB, ignoreIfExists = true) + val tempDB = CatalogDatabase( + "temporary", description = "test create", tempDatabasePath, Map()) + client.createDatabase(tempDB, ignoreIfExists = true) + } + + test(s"$version: setCurrentDatabase") { + client.setCurrentDatabase("default") + } + + test(s"$version: getDatabase") { + // No exception should be thrown + client.getDatabase("default") + } + + test(s"$version: getDatabaseOption") { + assert(client.getDatabaseOption("default").isDefined) + assert(client.getDatabaseOption("nonexist") == None) } + test(s"$version: listDatabases") { + assert(client.listDatabases("defau.*") == Seq("default")) + } + + test(s"$version: alterDatabase") { + val database = client.getDatabase("temporary").copy(properties = Map("flag" -> "true")) + client.alterDatabase(database) + assert(client.getDatabase("temporary").properties.contains("flag")) + } + + test(s"$version: dropDatabase") { + assert(client.getDatabaseOption("temporary").isDefined) + client.dropDatabase("temporary", ignoreIfNotExists = false, cascade = true) + assert(client.getDatabaseOption("temporary").isEmpty) + } + + /////////////////////////////////////////////////////////////////////////// + // Table related API + /////////////////////////////////////////////////////////////////////////// + test(s"$version: createTable") { - val table = - CatalogTable( - identifier = TableIdentifier("src", Some("default")), - tableType = CatalogTableType.MANAGED, - schema = Seq(CatalogColumn("key", "int")), - storage = CatalogStorageFormat( - locationUri = None, - inputFormat = Some(classOf[org.apache.hadoop.mapred.TextInputFormat].getName), - outputFormat = Some( - classOf[org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat[_, _]].getName), - serde = Some(classOf[org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe].getName()), - serdeProperties = Map.empty - )) - - client.createTable(table, ignoreIfExists = false) + client.createTable(table("default", tableName = "src"), ignoreIfExists = false) + client.createTable(table("default", "temporary"), ignoreIfExists = false) + } + + test(s"$version: loadTable") { + client.loadTable( + emptyDir, + tableName = "src", + replace = false, + holdDDLTime = false) } test(s"$version: getTable") { + // No exception should be thrown client.getTable("default", "src") } - test(s"$version: listTables") { - assert(client.listTables("default") === Seq("src")) + test(s"$version: getTableOption") { + assert(client.getTableOption("default", "src").isDefined) } - test(s"$version: getDatabase") { - client.getDatabase("default") + test(s"$version: alterTable(table: CatalogTable)") { + val newTable = client.getTable("default", "src").copy(properties = Map("changed" -> "")) + client.alterTable(newTable) + assert(client.getTable("default", "src").properties.contains("changed")) } - test(s"$version: alterTable") { - client.alterTable(client.getTable("default", "src")) + test(s"$version: alterTable(tableName: String, table: CatalogTable)") { + val newTable = client.getTable("default", "src").copy(properties = Map("changedAgain" -> "")) + client.alterTable("src", newTable) + assert(client.getTable("default", "src").properties.contains("changedAgain")) } - test(s"$version: set command") { - client.runSqlHive("SET spark.sql.test.key=1") + test(s"$version: listTables(database)") { + assert(client.listTables("default") === Seq("src", "temporary")) + } + + test(s"$version: listTables(database, pattern)") { + assert(client.listTables("default", pattern = "src") === Seq("src")) + assert(client.listTables("default", pattern = "nonexist").isEmpty) + } + + test(s"$version: dropTable") { + client.dropTable("default", tableName = "temporary", ignoreIfNotExists = false) + assert(client.listTables("default") === Seq("src")) } - test(s"$version: create partitioned table DDL") { - client.runSqlHive("CREATE TABLE src_part (value INT) PARTITIONED BY (key INT)") - client.runSqlHive("ALTER TABLE src_part ADD PARTITION (key = '1')") + /////////////////////////////////////////////////////////////////////////// + // Partition related API + /////////////////////////////////////////////////////////////////////////// + + val storageFormat = CatalogStorageFormat( + locationUri = None, + inputFormat = None, + outputFormat = None, + serde = None, + compressed = false, + serdeProperties = Map.empty) + + test(s"$version: sql create partitioned table") { + client.runSqlHive("CREATE TABLE src_part (value INT) PARTITIONED BY (key1 INT, key2 INT)") } - test(s"$version: getPartitions") { - client.getPartitions(client.getTable("default", "src_part")) + test(s"$version: createPartitions") { + val partition1 = CatalogTablePartition(Map("key1" -> "1", "key2" -> "1"), storageFormat) + val partition2 = CatalogTablePartition(Map("key1" -> "1", "key2" -> "2"), storageFormat) + client.createPartitions( + "default", "src_part", Seq(partition1, partition2), ignoreIfExists = true) + } + + test(s"$version: getPartitions(catalogTable)") { + assert(2 == client.getPartitions(client.getTable("default", "src_part")).size) } test(s"$version: getPartitionsByFilter") { - client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo( - AttributeReference("key", IntegerType, false)(NamedExpression.newExprId), - Literal(1)))) + // Only one partition [1, 1] for key2 == 1 + val result = client.getPartitionsByFilter(client.getTable("default", "src_part"), + Seq(EqualTo(AttributeReference("key2", IntegerType)(), Literal(1)))) + + // Hive 0.12 doesn't support getPartitionsByFilter, it ignores the filter condition. + if (version != "0.12") { + assert(result.size == 1) + } + } + + test(s"$version: getPartition") { + // No exception should be thrown + client.getPartition("default", "src_part", Map("key1" -> "1", "key2" -> "2")) + } + + test(s"$version: getPartitionOption(db: String, table: String, spec: TablePartitionSpec)") { + val partition = client.getPartitionOption( + "default", "src_part", Map("key1" -> "1", "key2" -> "2")) + assert(partition.isDefined) + } + + test(s"$version: getPartitionOption(table: CatalogTable, spec: TablePartitionSpec)") { + val partition = client.getPartitionOption( + client.getTable("default", "src_part"), Map("key1" -> "1", "key2" -> "2")) + assert(partition.isDefined) + } + + test(s"$version: getPartitions(db: String, table: String)") { + assert(2 == client.getPartitions("default", "src_part", None).size) } test(s"$version: loadPartition") { + val partSpec = new java.util.LinkedHashMap[String, String] + partSpec.put("key1", "1") + partSpec.put("key2", "2") + client.loadPartition( emptyDir, "default.src_part", partSpec, - false, - false, - false, - false) - } - - test(s"$version: loadTable") { - client.loadTable( - emptyDir, - "src", - false, - false) + replace = false, + holdDDLTime = false, + inheritTableSpecs = false, + isSkewedStoreAsSubdir = false) } test(s"$version: loadDynamicPartitions") { + val partSpec = new java.util.LinkedHashMap[String, String] + partSpec.put("key1", "1") + partSpec.put("key2", "") // Dynamic partition + client.loadDynamicPartitions( emptyDir, "default.src_part", partSpec, - false, - 1, + replace = false, + numDP = 1, false, false) } - test(s"$version: create index and reset") { + test(s"$version: renamePartitions") { + val oldSpec = Map("key1" -> "1", "key2" -> "1") + val newSpec = Map("key1" -> "1", "key2" -> "3") + client.renamePartitions("default", "src_part", Seq(oldSpec), Seq(newSpec)) + + // Checks the existence of the new partition (key1 = 1, key2 = 3) + assert(client.getPartitionOption("default", "src_part", newSpec).isDefined) + } + + test(s"$version: alterPartitions") { + val spec = Map("key1" -> "1", "key2" -> "2") + val newLocation = Utils.createTempDir().getPath() + val storage = storageFormat.copy( + locationUri = Some(newLocation), + // needed for 0.12 alter partitions + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val partition = CatalogTablePartition(spec, storage) + client.alterPartitions("default", "src_part", Seq(partition)) + assert(client.getPartition("default", "src_part", spec) + .storage.locationUri == Some(newLocation)) + } + + test(s"$version: dropPartitions") { + val spec = Map("key1" -> "1", "key2" -> "3") + client.dropPartitions("default", "src_part", Seq(spec), ignoreIfNotExists = true) + assert(client.getPartitionOption("default", "src_part", spec).isEmpty) + } + + /////////////////////////////////////////////////////////////////////////// + // Function related API + /////////////////////////////////////////////////////////////////////////// + + def function(name: String, className: String): CatalogFunction = { + CatalogFunction( + FunctionIdentifier(name, Some("default")), className, Seq.empty[FunctionResource]) + } + + test(s"$version: createFunction") { + val functionClass = "org.apache.spark.MyFunc1" + if (version == "0.12") { + // Hive 0.12 doesn't support creating permanent functions + intercept[AnalysisException] { + client.createFunction("default", function("func1", functionClass)) + } + } else { + client.createFunction("default", function("func1", functionClass)) + } + } + + test(s"$version: functionExists") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + assert(client.functionExists("default", "func1") == false) + } else { + assert(client.functionExists("default", "func1") == true) + } + } + + test(s"$version: renameFunction") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + intercept[NoSuchPermanentFunctionException] { + client.renameFunction("default", "func1", "func2") + } + } else { + client.renameFunction("default", "func1", "func2") + assert(client.functionExists("default", "func2") == true) + } + } + + test(s"$version: alterFunction") { + val functionClass = "org.apache.spark.MyFunc2" + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + intercept[NoSuchPermanentFunctionException] { + client.alterFunction("default", function("func2", functionClass)) + } + } else { + client.alterFunction("default", function("func2", functionClass)) + } + } + + test(s"$version: getFunction") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + intercept[NoSuchPermanentFunctionException] { + client.getFunction("default", "func2") + } + } else { + // No exception should be thrown + val func = client.getFunction("default", "func2") + assert(func.className == "org.apache.spark.MyFunc2") + } + } + + test(s"$version: getFunctionOption") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + assert(client.getFunctionOption("default", "func2").isEmpty) + } else { + assert(client.getFunctionOption("default", "func2").isDefined) + assert(client.getFunctionOption("default", "the_func_not_exists").isEmpty) + } + } + + test(s"$version: listFunctions") { + if (version == "0.12") { + // Hive 0.12 doesn't allow customized permanent functions + assert(client.listFunctions("default", "fun.*").isEmpty) + } else { + assert(client.listFunctions("default", "fun.*").size == 1) + } + } + + test(s"$version: dropFunction") { + if (version == "0.12") { + // Hive 0.12 doesn't support creating permanent functions + intercept[NoSuchPermanentFunctionException] { + client.dropFunction("default", "func2") + } + } else { + // No exception should be thrown + client.dropFunction("default", "func2") + assert(client.listFunctions("default", "fun.*").size == 0) + } + } + + /////////////////////////////////////////////////////////////////////////// + // SQL related API + /////////////////////////////////////////////////////////////////////////// + + test(s"$version: sql set command") { + client.runSqlHive("SET spark.sql.test.key=1") + } + + test(s"$version: sql create index and reset") { client.runSqlHive("CREATE TABLE indexed_table (key INT)") client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + "as 'COMPACT' WITH DEFERRED REBUILD") + } + + /////////////////////////////////////////////////////////////////////////// + // Miscellaneous API + /////////////////////////////////////////////////////////////////////////// + + test(s"$version: version") { + assert(client.version.fullVersion.startsWith(version)) + } + + test(s"$version: getConf") { + assert("success" === client.getConf("test", null)) + } + + test(s"$version: setOut") { + client.setOut(new PrintStream(new ByteArrayOutputStream())) + } + + test(s"$version: setInfo") { + client.setInfo(new PrintStream(new ByteArrayOutputStream())) + } + + test(s"$version: setError") { + client.setError(new PrintStream(new ByteArrayOutputStream())) + } + + test(s"$version: newSession") { + val newClient = client.newSession() + assert(newClient != null) + } + + test(s"$version: withHiveState and addJar") { + val newClassPath = "." + client.addJar(newClassPath) + client.withHiveState { + // No exception should be thrown. + // withHiveState changes the classloader to MutableURLClassLoader + val classLoader = Thread.currentThread().getContextClassLoader + .asInstanceOf[MutableURLClassLoader] + + val urls = classLoader.getURLs() + urls.contains(new File(newClassPath).toURI.toURL) + } + } + + test(s"$version: reset") { + // Clears all database, tables, functions... client.reset() + assert(client.listTables("default").isEmpty) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 0ba72b033f725..2dcf13c02a466 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -177,30 +177,30 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") data3.write.saveAsTable("agg3") - val emptyDF = sqlContext.createDataFrame( + val emptyDF = spark.createDataFrame( sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) - emptyDF.registerTempTable("emptyTable") + emptyDF.createOrReplaceTempView("emptyTable") // Register UDAFs - sqlContext.udf.register("mydoublesum", new MyDoubleSum) - sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) - sqlContext.udf.register("longProductSum", new LongProductSum) + spark.udf.register("mydoublesum", new MyDoubleSum) + spark.udf.register("mydoubleavg", new MyDoubleAvg) + spark.udf.register("longProductSum", new LongProductSum) } override def afterAll(): Unit = { try { - sqlContext.sql("DROP TABLE IF EXISTS agg1") - sqlContext.sql("DROP TABLE IF EXISTS agg2") - sqlContext.sql("DROP TABLE IF EXISTS agg3") - sqlContext.dropTempTable("emptyTable") + spark.sql("DROP TABLE IF EXISTS agg1") + spark.sql("DROP TABLE IF EXISTS agg2") + spark.sql("DROP TABLE IF EXISTS agg3") + spark.catalog.dropTempView("emptyTable") } finally { super.afterAll() } } test("group by function") { - Seq((1, 2)).toDF("a", "b").registerTempTable("data") + Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("data") checkAnswer( sql("SELECT floor(a) AS a, collect_set(b) FROM data GROUP BY floor(a) ORDER BY a"), @@ -210,7 +210,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("empty table") { // If there is no GROUP BY clause and the table is empty, we will generate a single row. checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | AVG(value), @@ -227,7 +227,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, 0, 0, 0, null, null, null, null, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | AVG(value), @@ -246,7 +246,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // If there is a GROUP BY clause and the table is empty, there is no output. checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | AVG(value), @@ -266,7 +266,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("null literal") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | AVG(null), @@ -282,7 +282,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("only do grouping") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT key |FROM agg1 @@ -291,7 +291,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT DISTINCT value1, key |FROM agg2 @@ -308,7 +308,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT value1, key |FROM agg2 @@ -326,7 +326,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT DISTINCT key |FROM agg3 @@ -341,7 +341,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(Seq[Integer](3)) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT value1, key |FROM agg3 @@ -363,7 +363,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("case in-sensitive resolution") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value), kEY - 100 |FROM agg1 @@ -372,7 +372,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT sum(distinct value1), kEY - 100, count(distinct value1) |FROM agg2 @@ -381,7 +381,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT valUe * key - 100 |FROM agg1 @@ -397,7 +397,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("test average no key in output") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value) |FROM agg1 @@ -408,7 +408,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("test average") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT key, avg(value) |FROM agg1 @@ -417,7 +417,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT key, mean(value) |FROM agg1 @@ -426,7 +426,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value), key |FROM agg1 @@ -435,7 +435,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value) + 1.5, key + 10 |FROM agg1 @@ -444,7 +444,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT avg(value) FROM agg1 """.stripMargin), @@ -456,7 +456,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // deterministic. withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | first_valUE(key), @@ -472,7 +472,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | first_valUE(key), @@ -491,7 +491,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("udaf") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | key, @@ -511,7 +511,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("interpreted aggregate function") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT mydoublesum(value), key |FROM agg1 @@ -520,14 +520,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT mydoublesum(value) FROM agg1 """.stripMargin), Row(89.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT mydoublesum(null) """.stripMargin), @@ -536,7 +536,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("interpreted and expression-based aggregation functions") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT mydoublesum(value), key, avg(value) |FROM agg1 @@ -548,7 +548,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(30.0, null, 10.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | mydoublesum(value + 1.5 * key), @@ -568,7 +568,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("single distinct column set") { // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | min(distinct value1), @@ -581,7 +581,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(-60, 70.0, 101.0/9.0, 5.6, 100)) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | mydoubleavg(distinct value1), @@ -600,7 +600,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | key, @@ -618,7 +618,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | count(value1), @@ -637,7 +637,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("single distinct multiple columns set") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | key, @@ -653,7 +653,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("multiple distinct multiple columns sets") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | key, @@ -681,7 +681,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te test("test count") { checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | count(value2), @@ -704,7 +704,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(0, null, 1, 1, null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT | count(value2), @@ -783,31 +783,31 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te (5, 8, 17), (6, 2, 11)).toDF("a", "b", "c") - covar_tab.registerTempTable("covar_tab") + covar_tab.createOrReplaceTempView("covar_tab") checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT corr(b, c) FROM covar_tab WHERE a < 1 """.stripMargin), Row(null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT corr(b, c) FROM covar_tab WHERE a < 3 """.stripMargin), Row(null) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT corr(b, c) FROM covar_tab WHERE a = 3 """.stripMargin), Row(Double.NaN) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT a, corr(b, c) FROM covar_tab GROUP BY a ORDER BY a """.stripMargin), @@ -818,7 +818,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(5, Double.NaN) :: Row(6, Double.NaN) :: Nil) - val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) + val corr7 = spark.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) } @@ -852,7 +852,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("no aggregation function (SPARK-11486)") { - val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") + val df = spark.range(20).selectExpr("id", "repeat(id, 1) as s") .groupBy("s").count() .groupBy().count() checkAnswer(df, Row(20) :: Nil) @@ -869,10 +869,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, new UDT.MyDenseVectorUDT()) - // Right now, we will use SortBasedAggregate to handle UDAFs. - // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use + // Right now, we will use SortAggregate to handle UDAFs. + // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortAggregate to use // UnsafeRow as the aggregation buffer. While, dataTypes will trigger - // SortBasedAggregate to use a safe row as the aggregation buffer. + // SortAggregate to use a safe row as the aggregation buffer. Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes => val fields = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, nullable = true) @@ -906,8 +906,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } // Create a DF for the schema with random data. - val rdd = sqlContext.sparkContext.parallelize(data, 1) - val df = sqlContext.createDataFrame(rdd, schema) + val rdd = spark.sparkContext.parallelize(data, 1) + val df = spark.createDataFrame(rdd, schema) val allColumns = df.schema.fields.map(f => col(f.name)) val expectedAnswer = @@ -923,8 +923,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("udaf without specifying inputSchema") { - withTempTable("noInputSchemaUDAF") { - sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema) + withTempView("noInputSchemaUDAF") { + spark.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema) val data = Row(1, Seq(Row(1), Row(2), Row(3))) :: @@ -935,13 +935,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te StructField("key", IntegerType) :: StructField("myArray", ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil) - sqlContext.createDataFrame( + spark.createDataFrame( sparkContext.parallelize(data, 2), schema) - .registerTempTable("noInputSchemaUDAF") + .createOrReplaceTempView("noInputSchemaUDAF") checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT key, noInputSchema(myArray) |FROM noInputSchemaUDAF @@ -950,7 +950,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(1, 21) :: Row(2, -10) :: Nil) checkAnswer( - sqlContext.sql( + spark.sql( """ |SELECT noInputSchema(myArray) |FROM noInputSchemaUDAF @@ -958,13 +958,44 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(11) :: Nil) } } + + test("SPARK-15206: single distinct aggregate function in having clause") { + checkAnswer( + sql( + """ + |select key, count(distinct value1) + |from agg2 group by key + |having count(distinct value1) > 0 + """.stripMargin), + Seq( + Row(null, 3), + Row(1, 2), + Row(2, 2) + ) + ) + } + + test("SPARK-15206: multiple distinct aggregate function in having clause") { + checkAnswer( + sql( + """ + |select key, count(distinct value1), count(distinct value2) + |from agg2 group by key + |having count(distinct value1) > 0 and count(distinct value2) = 3 + """.stripMargin), + Seq( + Row(null, 3, 3), + Row(1, 2, 3) + ) + ) + } } -class TungstenAggregationQuerySuite extends AggregationQuerySuite +class HashAggregationQuerySuite extends AggregationQuerySuite -class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { +class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite { override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { Seq(0, 10).foreach { maxColumnarHashMapColumns => @@ -976,13 +1007,13 @@ class TungstenAggregationQueryWithControlledFallbackSuite extends AggregationQue // Create a new df to make sure its physical operator picks up // spark.sql.TungstenAggregate.testFallbackStartsAt. // todo: remove it? - val newActual = Dataset.ofRows(sqlContext.sparkSession, actual.logicalPlan) + val newActual = Dataset.ofRows(spark, actual.logicalPlan) QueryTest.checkAnswer(newActual, expectedAnswer) match { case Some(errorMessage) => val newErrorMessage = s""" - |The following aggregation query failed when using TungstenAggregate with + |The following aggregation query failed when using HashAggregate with |controlled fallback (it falls back to bytes to bytes map once it has processed |${fallbackStartsAt - 1} input rows and to sort-based aggregation once it has |processed $fallbackStartsAt input rows). The query is ${actual.queryExecution} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index f5cd73d45ed75..1583a448efaf7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -30,9 +30,9 @@ class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ui.enabled", "false") val ts = new TestHiveContext(new SparkContext("local", s"TestSQLContext$i", conf)) - ts.executeSql("SHOW TABLES").toRdd.collect() - ts.executeSql("SELECT * FROM src").toRdd.collect() - ts.executeSql("SHOW TABLES").toRdd.collect() + ts.sessionState.executeSql("SHOW TABLES").toRdd.collect() + ts.sessionState.executeSql("SELECT * FROM src").toRdd.collect() + ts.sessionState.executeSql("SHOW TABLES").toRdd.collect() } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index 8b3f2d1a0cd07..446029fdc6e43 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -35,7 +36,7 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto sql( """ - |CREATE EXTERNAL TABLE parquet_tab2 (c1 INT, c2 STRING) + |CREATE TABLE parquet_tab2 (c1 INT, c2 STRING) |STORED AS PARQUET |TBLPROPERTIES('prop1Key'="prop1Val", '`prop2Key`'="prop2Val") """.stripMargin) @@ -100,32 +101,30 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto test("show tblproperties of data source tables - basic") { checkAnswer( - sql("SHOW TBLPROPERTIES parquet_tab1") - .filter(s"key = 'spark.sql.sources.provider'"), - Row("spark.sql.sources.provider", "org.apache.spark.sql.parquet.DefaultSource") :: Nil + sql("SHOW TBLPROPERTIES parquet_tab1").filter(s"key = '$DATASOURCE_PROVIDER'"), + Row(DATASOURCE_PROVIDER, "org.apache.spark.sql.parquet.DefaultSource") :: Nil ) checkAnswer( - sql("SHOW TBLPROPERTIES parquet_tab1(spark.sql.sources.provider)"), + sql(s"SHOW TBLPROPERTIES parquet_tab1($DATASOURCE_PROVIDER)"), Row("org.apache.spark.sql.parquet.DefaultSource") :: Nil ) checkAnswer( - sql("SHOW TBLPROPERTIES parquet_tab1") - .filter(s"key = 'spark.sql.sources.schema.numParts'"), - Row("spark.sql.sources.schema.numParts", "1") :: Nil + sql("SHOW TBLPROPERTIES parquet_tab1").filter(s"key = '$DATASOURCE_SCHEMA_NUMPARTS'"), + Row(DATASOURCE_SCHEMA_NUMPARTS, "1") :: Nil ) checkAnswer( - sql("SHOW TBLPROPERTIES parquet_tab1('spark.sql.sources.schema.numParts')"), + sql(s"SHOW TBLPROPERTIES parquet_tab1('$DATASOURCE_SCHEMA_NUMPARTS')"), Row("1")) } test("show tblproperties for datasource table - errors") { - val message1 = intercept[AnalysisException] { + val message1 = intercept[NoSuchTableException] { sql("SHOW TBLPROPERTIES badtable") }.getMessage - assert(message1.contains("Table or View badtable not found in database default")) + assert(message1.contains("Table or view 'badtable' not found in database 'default'")) // When key is not found, a row containing the error is returned. checkAnswer( @@ -140,7 +139,7 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } test("show tblproperties for spark temporary table - empty row") { - withTempTable("parquet_temp") { + withTempView("parquet_temp") { sql( """ |CREATE TEMPORARY TABLE parquet_temp (c1 INT, c2 STRING) @@ -269,6 +268,72 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } } + test("Truncate Table") { + withTable("non_part_table", "part_table") { + sql( + """ + |CREATE TABLE non_part_table (employeeID INT, employeeName STRING) + |ROW FORMAT DELIMITED + |FIELDS TERMINATED BY '|' + |LINES TERMINATED BY '\n' + """.stripMargin) + + val testData = hiveContext.getHiveFile("data/files/employee.dat").getCanonicalPath + + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE non_part_table""") + checkAnswer( + sql("SELECT * FROM non_part_table WHERE employeeID = 16"), + Row(16, "john") :: Nil) + + val testResults = sql("SELECT * FROM non_part_table").collect() + + sql("TRUNCATE TABLE non_part_table") + checkAnswer(sql("SELECT * FROM non_part_table"), Seq.empty[Row]) + + sql( + """ + |CREATE TABLE part_table (employeeID INT, employeeName STRING) + |PARTITIONED BY (c STRING, d STRING) + |ROW FORMAT DELIMITED + |FIELDS TERMINATED BY '|' + |LINES TERMINATED BY '\n' + """.stripMargin) + + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE part_table PARTITION(c="1", d="1")""") + checkAnswer( + sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1' AND d = '1'"), + testResults) + + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE part_table PARTITION(c="1", d="2")""") + checkAnswer( + sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1' AND d = '2'"), + testResults) + + sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE part_table PARTITION(c="2", d="2")""") + checkAnswer( + sql("SELECT employeeID, employeeName FROM part_table WHERE c = '2' AND d = '2'"), + testResults) + + sql("TRUNCATE TABLE part_table PARTITION(c='1', d='1')") + checkAnswer( + sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1' AND d = '1'"), + Seq.empty[Row]) + checkAnswer( + sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1' AND d = '2'"), + testResults) + + sql("TRUNCATE TABLE part_table PARTITION(c='1')") + checkAnswer( + sql("SELECT employeeID, employeeName FROM part_table WHERE c = '1'"), + Seq.empty[Row]) + + sql("TRUNCATE TABLE part_table") + checkAnswer( + sql("SELECT employeeID, employeeName FROM part_table"), + Seq.empty[Row]) + } + } + test("show columns") { checkAnswer( sql("SHOW COLUMNS IN parquet_tab3"), @@ -289,7 +354,7 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val message = intercept[NoSuchTableException] { sql("SHOW COLUMNS IN badtable FROM default") }.getMessage - assert(message.contains("badtable not found in database")) + assert(message.contains("'badtable' not found in database")) } test("show partitions - show everything") { @@ -332,32 +397,31 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto } test("show partitions - empty row") { - withTempTable("parquet_temp") { + withTempView("parquet_temp") { sql( """ |CREATE TEMPORARY TABLE parquet_temp (c1 INT, c2 STRING) |USING org.apache.spark.sql.parquet.DefaultSource """.stripMargin) // An empty sequence of row is returned for session temporary table. - val message1 = intercept[AnalysisException] { + intercept[NoSuchTableException] { sql("SHOW PARTITIONS parquet_temp") - }.getMessage - assert(message1.contains("is not allowed on a temporary table")) + } - val message2 = intercept[AnalysisException] { + val message1 = intercept[AnalysisException] { sql("SHOW PARTITIONS parquet_tab3") }.getMessage - assert(message2.contains("not allowed on a table that is not partitioned")) + assert(message1.contains("not allowed on a table that is not partitioned")) - val message3 = intercept[AnalysisException] { + val message2 = intercept[AnalysisException] { sql("SHOW PARTITIONS parquet_tab4 PARTITION(abcd=2015, xyz=1)") }.getMessage - assert(message3.contains("Non-partitioning column(s) [abcd, xyz] are specified")) + assert(message2.contains("Non-partitioning column(s) [abcd, xyz] are specified")) - val message4 = intercept[AnalysisException] { + val message3 = intercept[AnalysisException] { sql("SHOW PARTITIONS parquet_view1") }.getMessage - assert(message4.contains("is not allowed on a view or index table")) + assert(message3.contains("is not allowed on a view")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index b12f3aafefb8b..f5d2f02d512be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import java.io._ import java.nio.charset.StandardCharsets +import java.util import scala.util.control.NonFatal @@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.SQLBuilder import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExplainCommand, SetCommand, ShowColumnsCommand} +import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable} import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} @@ -37,7 +38,7 @@ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} * Allows the creations of tests that execute the same query against both hive * and catalyst, comparing the results. * - * The "golden" results from Hive are cached in an retrieved both from the classpath and + * The "golden" results from Hive are cached in and retrieved both from the classpath and * [[answerCache]] to speed up testing. * * See the documentation of public vals in this class for information on how test execution can be @@ -347,6 +348,7 @@ abstract class HiveComparisonTest queryString.replace("../../data", testDataPath)) val containsCommands = originalQuery.analyzed.collectFirst { case _: Command => () + case _: InsertIntoTable => () case _: LogicalInsertIntoHiveTable => () }.nonEmpty @@ -414,8 +416,8 @@ abstract class HiveComparisonTest // We will ignore the ExplainCommand, ShowFunctions, DescribeFunction if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && - (!hiveQuery.logical.isInstanceOf[ShowFunctions]) && - (!hiveQuery.logical.isInstanceOf[DescribeFunction]) && + (!hiveQuery.logical.isInstanceOf[ShowFunctionsCommand]) && + (!hiveQuery.logical.isInstanceOf[DescribeFunctionCommand]) && preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive @@ -497,6 +499,8 @@ abstract class HiveComparisonTest } } + val savedSettings = new util.HashMap[String, String] + savedSettings.putAll(TestHive.conf.settings) try { try { if (tryWithoutResettingFirst && canSpeculativelyTryWithoutReset) { @@ -515,6 +519,9 @@ abstract class HiveComparisonTest } } catch { case tf: org.scalatest.exceptions.TestFailedException => throw tf + } finally { + TestHive.conf.settings.clear() + TestHive.conf.settings.putAll(savedSettings) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 687a4a7e512ad..0a6ccbed84930 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -20,23 +20,43 @@ package org.apache.spark.sql.hive.execution import java.io.File import org.apache.hadoop.fs.Path +import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} -import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.internal.config._ +import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.command.{CreateDataSourceTableUtils, DDLUtils} +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ +import org.apache.spark.sql.execution.datasources.CaseInsensitiveMap import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils -class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - import hiveContext.implicits._ +class HiveDDLSuite + extends QueryTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { + import spark.implicits._ + override def afterEach(): Unit = { + try { + // drop all databases, tables and functions after each test + spark.sessionState.catalog.reset() + } finally { + super.afterEach() + } + } // check if the directory for recording the data of the table exists. - private def tableDirectoryExists(tableIdentifier: TableIdentifier): Boolean = { + private def tableDirectoryExists( + tableIdentifier: TableIdentifier, + dbPath: Option[String] = None): Boolean = { val expectedTablePath = - hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdentifier) + if (dbPath.isEmpty) { + hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdentifier) + } else { + new Path(new Path(dbPath.get), tableIdentifier.table).toString + } val filesystemPath = new Path(expectedTablePath) - val fs = filesystemPath.getFileSystem(hiveContext.sessionState.newHadoopConf()) + val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) fs.exists(filesystemPath) } @@ -56,7 +76,7 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("drop managed tables") { + test("drop external tables in default database") { withTempDir { tmpDir => val tabName = "tab1" withTable(tabName) { @@ -70,20 +90,17 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin) val hiveTable = - hiveContext.sessionState.catalog - .getTableMetadata(TableIdentifier(tabName, Some("default"))) - // It is a managed table, although it uses external in SQL - assert(hiveTable.tableType == CatalogTableType.MANAGED) + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) + assert(hiveTable.tableType == CatalogTableType.EXTERNAL) assert(tmpDir.listFiles.nonEmpty) sql(s"DROP TABLE $tabName") - // The data are deleted since the table type is not EXTERNAL - assert(tmpDir.listFiles == null) + assert(tmpDir.listFiles.nonEmpty) } } } - test("drop external data source table") { + test("drop external data source table in default database") { withTempDir { tmpDir => val tabName = "tab1" withTable(tabName) { @@ -99,8 +116,7 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } val hiveTable = - hiveContext.sessionState.catalog - .getTableMetadata(TableIdentifier(tabName, Some("default"))) + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) // This data source table is external table assert(hiveTable.tableType == CatalogTableType.EXTERNAL) @@ -113,7 +129,7 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("create table and view with comment") { - val catalog = hiveContext.sessionState.catalog + val catalog = spark.sessionState.catalog val tabName = "tab1" withTable(tabName) { sql(s"CREATE TABLE $tabName(c1 int) COMMENT 'BLABLA'") @@ -122,14 +138,17 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql(s"CREATE VIEW $viewName COMMENT 'no comment' AS SELECT * FROM $tabName") val tableMetadata = catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) val viewMetadata = catalog.getTableMetadata(TableIdentifier(viewName, Some("default"))) - assert(tableMetadata.properties.get("comment") == Option("BLABLA")) - assert(viewMetadata.properties.get("comment") == Option("no comment")) + assert(tableMetadata.comment == Option("BLABLA")) + assert(viewMetadata.comment == Option("no comment")) + // Ensure that `comment` is removed from the table property + assert(tableMetadata.properties.get("comment").isEmpty) + assert(viewMetadata.properties.get("comment").isEmpty) } } } test("add/drop partitions - external table") { - val catalog = hiveContext.sessionState.catalog + val catalog = spark.sessionState.catalog withTempDir { tmpDir => val basePath = tmpDir.getCanonicalPath val partitionPath_1stCol_part1 = new File(basePath + "/ds=2008-04-08") @@ -170,10 +189,17 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // After data insertion, all the directory are not empty assert(dirSet.forall(dir => dir.listFiles.nonEmpty)) + val message = intercept[AnalysisException] { + sql(s"ALTER TABLE $externalTab DROP PARTITION (ds='2008-04-09', unknownCol='12')") + } + assert(message.getMessage.contains( + "Partition spec is invalid. The spec (ds, unknowncol) must be contained within the " + + "partition spec (ds, hr) defined in table '`default`.`exttable_with_partitions`'")) + sql( s""" |ALTER TABLE $externalTab DROP PARTITION (ds='2008-04-08'), - |PARTITION (ds='2008-04-09', hr='12') + |PARTITION (hr='12') """.stripMargin) assert(catalog.listPartitions(TableIdentifier(externalTab)).map(_.spec).toSet == Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) @@ -196,7 +222,7 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("drop views") { withTable("tab1") { val tabName = "tab1" - sqlContext.range(10).write.saveAsTable("tab1") + spark.range(10).write.saveAsTable("tab1") withView("view1") { val viewName = "view1" @@ -217,11 +243,11 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("alter views - rename") { val tabName = "tab1" withTable(tabName) { - sqlContext.range(10).write.saveAsTable(tabName) + spark.range(10).write.saveAsTable(tabName) val oldViewName = "view1" val newViewName = "view2" withView(oldViewName, newViewName) { - val catalog = hiveContext.sessionState.catalog + val catalog = spark.sessionState.catalog sql(s"CREATE VIEW $oldViewName AS SELECT * FROM $tabName") assert(catalog.tableExists(TableIdentifier(oldViewName))) @@ -236,10 +262,10 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("alter views - set/unset tblproperties") { val tabName = "tab1" withTable(tabName) { - sqlContext.range(10).write.saveAsTable(tabName) + spark.range(10).write.saveAsTable(tabName) val viewName = "view1" withView(viewName) { - val catalog = hiveContext.sessionState.catalog + val catalog = spark.sessionState.catalog sql(s"CREATE VIEW $viewName AS SELECT * FROM $tabName") assert(catalog.getTableMetadata(TableIdentifier(viewName)) @@ -266,7 +292,7 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')") }.getMessage assert(message.contains( - "attempted to unset non-existent property 'p' in table '`view1`'")) + "Attempted to unset non-existent property 'p' in table '`default`.`view1`'")) } } } @@ -274,11 +300,11 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("alter views and alter table - misuse") { val tabName = "tab1" withTable(tabName) { - sqlContext.range(10).write.saveAsTable(tabName) + spark.range(10).write.saveAsTable(tabName) val oldViewName = "view1" val newViewName = "view2" withView(oldViewName, newViewName) { - val catalog = hiveContext.sessionState.catalog + val catalog = spark.sessionState.catalog sql(s"CREATE VIEW $oldViewName AS SELECT * FROM $tabName") assert(catalog.tableExists(TableIdentifier(tabName))) @@ -326,6 +352,66 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("alter table partition - storage information") { + sql("CREATE TABLE boxes (height INT, length INT) PARTITIONED BY (width INT)") + sql("INSERT OVERWRITE TABLE boxes PARTITION (width=4) SELECT 4, 4") + val catalog = spark.sessionState.catalog + val expectedSerde = "com.sparkbricks.serde.ColumnarSerDe" + val expectedSerdeProps = Map("compress" -> "true") + val expectedSerdePropsString = + expectedSerdeProps.map { case (k, v) => s"'$k'='$v'" }.mkString(", ") + val oldPart = catalog.getPartition(TableIdentifier("boxes"), Map("width" -> "4")) + assume(oldPart.storage.serde != Some(expectedSerde), "bad test: serde was already set") + assume(oldPart.storage.serdeProperties.filterKeys(expectedSerdeProps.contains) != + expectedSerdeProps, "bad test: serde properties were already set") + sql(s"""ALTER TABLE boxes PARTITION (width=4) + | SET SERDE '$expectedSerde' + | WITH SERDEPROPERTIES ($expectedSerdePropsString) + |""".stripMargin) + val newPart = catalog.getPartition(TableIdentifier("boxes"), Map("width" -> "4")) + assert(newPart.storage.serde == Some(expectedSerde)) + assume(newPart.storage.serdeProperties.filterKeys(expectedSerdeProps.contains) == + expectedSerdeProps) + } + + test("MSCK REPAIR RABLE") { + val catalog = spark.sessionState.catalog + val tableIdent = TableIdentifier("tab1") + sql("CREATE TABLE tab1 (height INT, length INT) PARTITIONED BY (a INT, b INT)") + val part1 = Map("a" -> "1", "b" -> "5") + val part2 = Map("a" -> "2", "b" -> "6") + val root = new Path(catalog.getTableMetadata(tableIdent).storage.locationUri.get) + val fs = root.getFileSystem(spark.sparkContext.hadoopConfiguration) + // valid + fs.mkdirs(new Path(new Path(root, "a=1"), "b=5")) + fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "a.csv")) // file + fs.createNewFile(new Path(new Path(root, "a=1/b=5"), "_SUCCESS")) // file + fs.mkdirs(new Path(new Path(root, "A=2"), "B=6")) + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "b.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), "c.csv")) // file + fs.createNewFile(new Path(new Path(root, "A=2/B=6"), ".hiddenFile")) // file + fs.mkdirs(new Path(new Path(root, "A=2/B=6"), "_temporary")) + + // invalid + fs.mkdirs(new Path(new Path(root, "a"), "b")) // bad name + fs.mkdirs(new Path(new Path(root, "b=1"), "a=1")) // wrong order + fs.mkdirs(new Path(root, "a=4")) // not enough columns + fs.createNewFile(new Path(new Path(root, "a=1"), "b=4")) // file + fs.createNewFile(new Path(new Path(root, "a=1"), "_SUCCESS")) // _SUCCESS + fs.mkdirs(new Path(new Path(root, "a=1"), "_temporary")) // _temporary + fs.mkdirs(new Path(new Path(root, "a=1"), ".b=4")) // start with . + + try { + sql("MSCK REPAIR TABLE tab1") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2)) + assert(catalog.getPartition(tableIdent, part1).parameters("numFiles") == "1") + assert(catalog.getPartition(tableIdent, part2).parameters("numFiles") == "2") + } finally { + fs.delete(root, true) + } + } + test("drop table using drop view") { withTable("tab1") { sql("CREATE TABLE tab1(c1 int)") @@ -338,7 +424,7 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("drop view using drop table") { withTable("tab1") { - sqlContext.range(10).write.saveAsTable("tab1") + spark.range(10).write.saveAsTable("tab1") withView("view1") { sql("CREATE VIEW view1 AS SELECT * FROM tab1") val message = intercept[AnalysisException] { @@ -348,4 +434,538 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("create view with mismatched schema") { + withTable("tab1") { + spark.range(10).write.saveAsTable("tab1") + withView("view1") { + val e = intercept[AnalysisException] { + sql("CREATE VIEW view1 (col1, col3) AS SELECT * FROM tab1") + }.getMessage + assert(e.contains("the SELECT clause (num: `1`) does not match") + && e.contains("CREATE VIEW (num: `2`)")) + } + } + } + + test("create view with specified schema") { + withView("view1") { + sql("CREATE VIEW view1 (col1, col2) AS SELECT 1, 2") + checkAnswer( + sql("SELECT * FROM view1"), + Row(1, 2) :: Nil + ) + } + } + + test("desc table for Hive table") { + withTable("tab1") { + val tabName = "tab1" + sql(s"CREATE TABLE $tabName(c1 int)") + + assert(sql(s"DESC $tabName").collect().length == 1) + + assert( + sql(s"DESC FORMATTED $tabName").collect() + .exists(_.getString(0) == "# Storage Information")) + + assert( + sql(s"DESC EXTENDED $tabName").collect() + .exists(_.getString(0) == "# Detailed Table Information")) + } + } + + test("desc table for Hive table - partitioned table") { + withTable("tbl") { + sql("CREATE TABLE tbl(a int) PARTITIONED BY (b int)") + + assert(sql("DESC tbl").collect().containsSlice( + Seq( + Row("a", "int", null), + Row("b", "int", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("b", "int", null) + ) + )) + } + } + + test("desc table for data source table using Hive Metastore") { + assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive") + val tabName = "tab1" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(a int comment 'test') USING parquet ") + + checkAnswer( + sql(s"DESC $tabName").select("col_name", "data_type", "comment"), + Row("a", "int", "test") + ) + } + } + + private def createDatabaseWithLocation(tmpDir: File, dirExists: Boolean): Unit = { + val catalog = spark.sessionState.catalog + val dbName = "db1" + val tabName = "tab1" + val fs = new Path(tmpDir.toString).getFileSystem(spark.sessionState.newHadoopConf()) + withTable(tabName) { + if (dirExists) { + assert(tmpDir.listFiles.isEmpty) + } else { + assert(!fs.exists(new Path(tmpDir.toString))) + } + sql(s"CREATE DATABASE $dbName Location '$tmpDir'") + val db1 = catalog.getDatabaseMetadata(dbName) + val dbPath = "file:" + tmpDir + assert(db1 == CatalogDatabase( + dbName, + "", + if (dbPath.endsWith(File.separator)) dbPath.dropRight(1) else dbPath, + Map.empty)) + sql("USE db1") + + sql(s"CREATE TABLE $tabName as SELECT 1") + assert(tableDirectoryExists(TableIdentifier(tabName), Option(tmpDir.toString))) + + assert(tmpDir.listFiles.nonEmpty) + sql(s"DROP TABLE $tabName") + + assert(tmpDir.listFiles.isEmpty) + sql("USE default") + sql(s"DROP DATABASE $dbName") + assert(!fs.exists(new Path(tmpDir.toString))) + } + } + + test("create/drop database - location without pre-created directory") { + withTempPath { tmpDir => + createDatabaseWithLocation(tmpDir, dirExists = false) + } + } + + test("create/drop database - location with pre-created directory") { + withTempDir { tmpDir => + createDatabaseWithLocation(tmpDir, dirExists = true) + } + } + + private def appendTrailingSlash(path: String): String = { + if (!path.endsWith(File.separator)) path + File.separator else path + } + + private def dropDatabase(cascade: Boolean, tableExists: Boolean): Unit = { + withTempPath { tmpDir => + val path = tmpDir.toString + withSQLConf(SQLConf.WAREHOUSE_PATH.key -> path) { + val dbName = "db1" + val fs = new Path(path).getFileSystem(spark.sessionState.newHadoopConf()) + val dbPath = new Path(path) + // the database directory does not exist + assert(!fs.exists(dbPath)) + + sql(s"CREATE DATABASE $dbName") + val catalog = spark.sessionState.catalog + val expectedDBLocation = "file:" + appendTrailingSlash(dbPath.toString) + s"$dbName.db" + val db1 = catalog.getDatabaseMetadata(dbName) + assert(db1 == CatalogDatabase( + dbName, + "", + expectedDBLocation, + Map.empty)) + // the database directory was created + assert(fs.exists(dbPath) && fs.isDirectory(dbPath)) + sql(s"USE $dbName") + + val tabName = "tab1" + assert(!tableDirectoryExists(TableIdentifier(tabName), Option(expectedDBLocation))) + sql(s"CREATE TABLE $tabName as SELECT 1") + assert(tableDirectoryExists(TableIdentifier(tabName), Option(expectedDBLocation))) + + if (!tableExists) { + sql(s"DROP TABLE $tabName") + assert(!tableDirectoryExists(TableIdentifier(tabName), Option(expectedDBLocation))) + } + + sql(s"USE default") + val sqlDropDatabase = s"DROP DATABASE $dbName ${if (cascade) "CASCADE" else "RESTRICT"}" + if (tableExists && !cascade) { + val message = intercept[AnalysisException] { + sql(sqlDropDatabase) + }.getMessage + assert(message.contains(s"Database $dbName is not empty. One or more tables exist.")) + // the database directory was not removed + assert(fs.exists(new Path(expectedDBLocation))) + } else { + sql(sqlDropDatabase) + // the database directory was removed and the inclusive table directories are also removed + assert(!fs.exists(new Path(expectedDBLocation))) + } + } + } + } + + test("drop database containing tables - CASCADE") { + dropDatabase(cascade = true, tableExists = true) + } + + test("drop an empty database - CASCADE") { + dropDatabase(cascade = true, tableExists = false) + } + + test("drop database containing tables - RESTRICT") { + dropDatabase(cascade = false, tableExists = true) + } + + test("drop an empty database - RESTRICT") { + dropDatabase(cascade = false, tableExists = false) + } + + test("drop default database") { + Seq("true", "false").foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { + var message = intercept[AnalysisException] { + sql("DROP DATABASE default") + }.getMessage + assert(message.contains("Can not drop default database")) + + // SQLConf.CASE_SENSITIVE does not affect the result + // because the Hive metastore is not case sensitive. + message = intercept[AnalysisException] { + sql("DROP DATABASE DeFault") + }.getMessage + assert(message.contains("Can not drop default database")) + } + } + } + + test("Create Cataloged Table As Select - Drop Table After Runtime Exception") { + withTable("tab") { + intercept[RuntimeException] { + sql( + """ + |CREATE TABLE tab + |STORED AS TEXTFILE + |SELECT 1 AS a, (SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t) AS b + """.stripMargin) + } + // After hitting runtime exception, we should drop the created table. + assert(!spark.sessionState.catalog.tableExists(TableIdentifier("tab"))) + } + } + + + test("CREATE TABLE LIKE a temporary view") { + val sourceViewName = "tab1" + val targetTabName = "tab2" + withTempView(sourceViewName) { + withTable(targetTabName) { + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .createTempView(sourceViewName) + sql(s"CREATE TABLE $targetTabName LIKE $sourceViewName") + + val sourceTable = spark.sessionState.catalog.getTempViewOrPermanentTableMetadata( + TableIdentifier(sourceViewName)) + val targetTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(targetTabName, Some("default"))) + + checkCreateTableLike(sourceTable, targetTable) + } + } + } + + test("CREATE TABLE LIKE a data source table") { + val sourceTabName = "tab1" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .write.format("json").saveAsTable(sourceTabName) + sql(s"CREATE TABLE $targetTabName LIKE $sourceTabName") + + val sourceTable = + spark.sessionState.catalog.getTableMetadata(TableIdentifier(sourceTabName, Some("default"))) + val targetTable = + spark.sessionState.catalog.getTableMetadata(TableIdentifier(targetTabName, Some("default"))) + // The table type of the source table should be a Hive-managed data source table + assert(DDLUtils.isDatasourceTable(sourceTable)) + assert(sourceTable.tableType == CatalogTableType.MANAGED) + + checkCreateTableLike(sourceTable, targetTable) + } + } + + test("CREATE TABLE LIKE an external data source table") { + val sourceTabName = "tab1" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .write.format("parquet").save(path) + sql(s"CREATE TABLE $sourceTabName USING parquet OPTIONS (PATH '$path')") + sql(s"CREATE TABLE $targetTabName LIKE $sourceTabName") + + // The source table should be an external data source table + val sourceTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(sourceTabName, Some("default"))) + val targetTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(targetTabName, Some("default"))) + // The table type of the source table should be an external data source table + assert(DDLUtils.isDatasourceTable(sourceTable)) + assert(sourceTable.tableType == CatalogTableType.EXTERNAL) + + checkCreateTableLike(sourceTable, targetTable) + } + } + } + + test("CREATE TABLE LIKE a managed Hive serde table") { + val catalog = spark.sessionState.catalog + val sourceTabName = "tab1" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + sql(s"CREATE TABLE $sourceTabName TBLPROPERTIES('prop1'='value1') AS SELECT 1 key, 'a'") + sql(s"CREATE TABLE $targetTabName LIKE $sourceTabName") + + val sourceTable = catalog.getTableMetadata(TableIdentifier(sourceTabName, Some("default"))) + assert(sourceTable.tableType == CatalogTableType.MANAGED) + assert(sourceTable.properties.get("prop1").nonEmpty) + val targetTable = catalog.getTableMetadata(TableIdentifier(targetTabName, Some("default"))) + + checkCreateTableLike(sourceTable, targetTable) + } + } + + test("CREATE TABLE LIKE an external Hive serde table") { + val catalog = spark.sessionState.catalog + withTempDir { tmpDir => + val basePath = tmpDir.getCanonicalPath + val sourceTabName = "tab1" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + assert(tmpDir.listFiles.isEmpty) + sql( + s""" + |CREATE EXTERNAL TABLE $sourceTabName (key INT comment 'test', value STRING) + |COMMENT 'Apache Spark' + |PARTITIONED BY (ds STRING, hr STRING) + |LOCATION '$basePath' + """.stripMargin) + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + sql( + s""" + |INSERT OVERWRITE TABLE $sourceTabName + |partition (ds='$ds',hr='$hr') + |SELECT 1, 'a' + """.stripMargin) + } + sql(s"CREATE TABLE $targetTabName LIKE $sourceTabName") + + val sourceTable = catalog.getTableMetadata(TableIdentifier(sourceTabName, Some("default"))) + assert(sourceTable.tableType == CatalogTableType.EXTERNAL) + assert(sourceTable.comment == Option("Apache Spark")) + val targetTable = catalog.getTableMetadata(TableIdentifier(targetTabName, Some("default"))) + + checkCreateTableLike(sourceTable, targetTable) + } + } + } + + test("CREATE TABLE LIKE a view") { + val sourceTabName = "tab1" + val sourceViewName = "view" + val targetTabName = "tab2" + withTable(sourceTabName, targetTabName) { + withView(sourceViewName) { + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) + .write.format("json").saveAsTable(sourceTabName) + sql(s"CREATE VIEW $sourceViewName AS SELECT * FROM $sourceTabName") + sql(s"CREATE TABLE $targetTabName LIKE $sourceViewName") + + val sourceView = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(sourceViewName, Some("default"))) + // The original source should be a VIEW with an empty path + assert(sourceView.tableType == CatalogTableType.VIEW) + assert(sourceView.viewText.nonEmpty && sourceView.viewOriginalText.nonEmpty) + val targetTable = spark.sessionState.catalog.getTableMetadata( + TableIdentifier(targetTabName, Some("default"))) + + checkCreateTableLike(sourceView, targetTable) + } + } + } + + private def getTablePath(table: CatalogTable): Option[String] = { + if (DDLUtils.isDatasourceTable(table)) { + new CaseInsensitiveMap(table.storage.serdeProperties).get("path") + } else { + table.storage.locationUri + } + } + + private def checkCreateTableLike(sourceTable: CatalogTable, targetTable: CatalogTable): Unit = { + // The created table should be a MANAGED table with empty view text and original text. + assert(targetTable.tableType == CatalogTableType.MANAGED, + "the created table must be a Hive managed table") + assert(targetTable.viewText.isEmpty && targetTable.viewOriginalText.isEmpty, + "the view text and original text in the created table must be empty") + assert(targetTable.comment.isEmpty, + "the comment in the created table must be empty") + assert(targetTable.unsupportedFeatures.isEmpty, + "the unsupportedFeatures in the create table must be empty") + + val metastoreGeneratedProperties = Seq( + "CreateTime", + "transient_lastDdlTime", + "grantTime", + "lastUpdateTime", + "last_modified_by", + "last_modified_time", + "Owner:", + "COLUMN_STATS_ACCURATE", + "numFiles", + "numRows", + "rawDataSize", + "totalSize", + "totalNumberFiles", + "maxFileSize", + "minFileSize" + ) + assert(targetTable.properties.filterKeys { key => + !metastoreGeneratedProperties.contains(key) && !key.startsWith(DATASOURCE_PREFIX) + }.isEmpty, + "the table properties of source tables should not be copied in the created table") + + if (DDLUtils.isDatasourceTable(sourceTable) || + sourceTable.tableType == CatalogTableType.VIEW) { + assert(DDLUtils.isDatasourceTable(targetTable), + "the target table should be a data source table") + } else { + assert(!DDLUtils.isDatasourceTable(targetTable), + "the target table should be a Hive serde table") + } + + if (sourceTable.tableType == CatalogTableType.VIEW) { + // Source table is a temporary/permanent view, which does not have a provider. The created + // target table uses the default data source format + assert(targetTable.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER) == + spark.sessionState.conf.defaultDataSourceName) + } else if (DDLUtils.isDatasourceTable(sourceTable)) { + assert(targetTable.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER) == + sourceTable.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER)) + } + + val sourceTablePath = getTablePath(sourceTable) + val targetTablePath = getTablePath(targetTable) + assert(targetTablePath.nonEmpty, "target table path should not be empty") + assert(sourceTablePath != targetTablePath, + "source table/view path should be different from target table path") + + // The source table contents should not been seen in the target table. + assert(spark.table(sourceTable.identifier).count() != 0, "the source table should be nonempty") + assert(spark.table(targetTable.identifier).count() == 0, "the target table should be empty") + + // Their schema should be identical + checkAnswer( + sql(s"DESC ${sourceTable.identifier}").select("col_name", "data_type"), + sql(s"DESC ${targetTable.identifier}").select("col_name", "data_type")) + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + // Check whether the new table can be inserted using the data from the original table + sql(s"INSERT INTO TABLE ${targetTable.identifier} SELECT * FROM ${sourceTable.identifier}") + } + + // After insertion, the data should be identical + checkAnswer( + sql(s"SELECT * FROM ${sourceTable.identifier}"), + sql(s"SELECT * FROM ${targetTable.identifier}")) + } + + test("Analyze data source tables(LogicalRelation)") { + withTable("t1") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(1).write.format("parquet").save(path) + sql(s"CREATE TABLE t1 USING parquet OPTIONS (PATH '$path')") + val e = intercept[AnalysisException] { + sql("ANALYZE TABLE t1 COMPUTE STATISTICS") + }.getMessage + assert(e.contains("ANALYZE TABLE is only supported for Hive tables, " + + "but 't1' is a LogicalRelation")) + } + } + } + + test("desc table for data source table") { + withTable("tab1") { + val tabName = "tab1" + spark.range(1).write.format("json").saveAsTable(tabName) + + assert(sql(s"DESC $tabName").collect().length == 1) + + assert( + sql(s"DESC FORMATTED $tabName").collect() + .exists(_.getString(0) == "# Storage Information")) + + assert( + sql(s"DESC EXTENDED $tabName").collect() + .exists(_.getString(0) == "# Detailed Table Information")) + } + } + + test("desc table for data source table - no user-defined schema") { + Seq("parquet", "json", "orc").foreach { fileFormat => + withTable("t1") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.range(1).write.format(fileFormat).save(path) + sql(s"CREATE TABLE t1 USING $fileFormat OPTIONS (PATH '$path')") + + val desc = sql("DESC FORMATTED t1").collect().toSeq + + assert(desc.contains(Row("id", "bigint", null))) + } + } + } + } + + test("desc table for data source table - partitioned bucketed table") { + withTable("t1") { + spark + .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write + .bucketBy(2, "b").sortBy("c").partitionBy("d") + .saveAsTable("t1") + + val formattedDesc = sql("DESC FORMATTED t1").collect() + + assert(formattedDesc.containsSlice( + Seq( + Row("a", "bigint", null), + Row("b", "bigint", null), + Row("c", "bigint", null), + Row("d", "bigint", null), + Row("# Partition Information", "", ""), + Row("# col_name", "data_type", "comment"), + Row("d", "bigint", null), + Row("", "", ""), + Row("# Detailed Table Information", "", ""), + Row("Database:", "default", "") + ) + )) + + assert(formattedDesc.containsSlice( + Seq( + Row("Table Type:", "MANAGED", "") + ) + )) + + assert(formattedDesc.containsSlice( + Seq( + Row("Num Buckets:", "2", ""), + Row("Bucket Columns:", "[b]", ""), + Row("Sort Columns:", "[c]", "") + ) + )) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 542de724ccbf7..ec3328c0ae2d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -52,7 +53,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", "== Physical Plan ==", - "CreateTableAsSelect", + "CreateHiveTableAsSelect", "InsertIntoHiveTable", "Limit", "src") @@ -70,35 +71,35 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", "== Physical Plan ==", - "CreateTableAsSelect", + "CreateHiveTableAsSelect", "InsertIntoHiveTable", "Limit", "src") } - test("SPARK-6212: The EXPLAIN output of CTAS only shows the analyzed plan") { - withTempTable("jt") { + test("SPARK-17230: The EXPLAIN output of CTAS only shows the analyzed plan") { + withTempView("jt") { val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - hiveContext.read.json(rdd).registerTempTable("jt") + spark.read.json(rdd).createOrReplaceTempView("jt") val outputs = sql( s""" |EXPLAIN EXTENDED |CREATE TABLE t1 |AS |SELECT * FROM jt - """.stripMargin).collect().map(_.mkString).mkString + """.stripMargin).collect().map(_.mkString).mkString val shouldContain = "== Parsed Logical Plan ==" :: "== Analyzed Logical Plan ==" :: "Subquery" :: "== Optimized Logical Plan ==" :: "== Physical Plan ==" :: - "CreateTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil + "CreateHiveTableAsSelect" :: "InsertIntoHiveTable" :: "jt" :: Nil for (key <- shouldContain) { assert(outputs.contains(key), s"$key doesn't exist in result") } val physicalIndex = outputs.indexOf("== Physical Plan ==") - assert(!outputs.substring(physicalIndex).contains("Subquery"), - "Physical Plan should not contain Subquery since it's eliminated by optimizer") + assert(outputs.substring(physicalIndex).contains("SubqueryAlias"), + "Physical Plan should contain SubqueryAlias since the query should not be optimized") } } @@ -115,19 +116,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "== Physical Plan ==" ) - checkKeywordsExist(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), - "WholeStageCodegen", - "Generated code:", - "/* 001 */ public Object generate(Object[] references) {", - "/* 002 */ return new GeneratedIterator(references);", - "/* 003 */ }" - ) - - checkKeywordsNotExist(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), - "== Parsed Logical Plan ==", - "== Analyzed Logical Plan ==", - "== Optimized Logical Plan ==", - "== Physical Plan ==" - ) + intercept[ParseException] { + sql("EXPLAIN EXTENDED CODEGEN SELECT 1") + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index b252c6ee2faae..0e89e990e564e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton * A set of tests that validates commands can also be queried by like a table */ class HiveOperatorQueryableSuite extends QueryTest with TestHiveSingleton { - import hiveContext._ + import spark._ test("SPARK-5324 query result of describe command") { hiveContext.loadTestTable("src") - // register a describe command to be a temp table - sql("desc src").registerTempTable("mydesc") + // Creates a temporary view with the output of a describe command + sql("desc src").createOrReplaceTempView("mydesc") checkAnswer( sql("desc mydesc"), Seq( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index d8d3448adde0b..89e6edb6b1577 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.TestHiveSingleton class HivePlanTest extends QueryTest with TestHiveSingleton { - import hiveContext.sql - import hiveContext.implicits._ + import spark.sql + import spark.implicits._ test("udf constant folding") { - Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") + Seq.empty[Tuple1[Int]].toDF("a").createOrReplaceTempView("t") val optimized = sql("SELECT cos(null) AS c FROM t").queryExecution.optimizedPlan val correctAnswer = sql("SELECT cast(null as double) AS c FROM t").queryExecution.optimizedPlan diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 3bf0e84267419..183a57f22cb72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -26,15 +26,17 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkException, SparkFiles} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.SparkFiles +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils case class TestData(a: Int, b: String) @@ -42,12 +44,16 @@ case class TestData(a: Int, b: String) * A set of test cases expressed in Hive QL that are not covered by the tests * included in the hive distribution. */ -class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { +class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault import org.apache.spark.sql.hive.test.TestHive.implicits._ + private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + + def spark: SparkSession = sparkSession + override def beforeAll() { super.beforeAll() TestHive.setCacheTables(true) @@ -55,6 +61,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) + // Ensures that cross joins are enabled so that we can test them + TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) } override def afterAll() { @@ -63,6 +71,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2") + TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) } finally { super.afterAll() } @@ -210,15 +219,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(new Timestamp(1000) == r1.getTimestamp(0)) } - createQueryTest("constant array", - """ - |SELECT sort_array( - | sort_array( - | array("hadoop distributed file system", - | "enterprise databases", "hadoop map-reduce"))) - |FROM src LIMIT 1; - """.stripMargin) - createQueryTest("null case", "SELECT case when(true) then 1 else null end FROM src LIMIT 1") @@ -685,12 +685,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("case sensitivity when query Hive table", "SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15") - test("case sensitivity: registered table") { + test("case sensitivity: created temporary view") { val testData = TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) - testData.toDF().registerTempTable("REGisteredTABle") + testData.toDF().createOrReplaceTempView("REGisteredTABle") assertResult(Array(Row(2, "str2"))) { sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + @@ -715,7 +715,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case ((value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test") + TestHive.sparkContext.parallelize(fixture).toDF().createOrReplaceTempView("having_test") val results = sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() @@ -819,17 +819,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { .collect() } - // Describe a registered temporary table. + // Describe a temporary view. val testData = TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) - testData.toDF().registerTempTable("test_describe_commands2") + testData.toDF().createOrReplaceTempView("test_describe_commands2") assertResult( Array( - Row("a", "int", ""), - Row("b", "string", "")) + Row("a", "int", null), + Row("b", "string", null)) ) { sql("DESCRIBE test_describe_commands2") .select('col_name, 'data_type, 'comment) @@ -870,6 +870,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE t1""") sql("select * from src join t1 on src.key = t1.a") sql("DROP TABLE t1") + assert(sql("list jars"). + filter(_.getString(0).contains("hive-hcatalog-core-0.13.1.jar")).count() > 0) + assert(sql("list jar"). + filter(_.getString(0).contains("hive-hcatalog-core-0.13.1.jar")).count() > 0) + val testJar2 = TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath + sql(s"ADD JAR $testJar2") + assert(sql(s"list jar $testJar").count() == 1) } test("CREATE TEMPORARY FUNCTION") { @@ -893,6 +900,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } assert(checkAddFileRDD.first()) + assert(sql("list files"). + filter(_.getString(0).contains("data/files/v1.txt")).count() > 0) + assert(sql("list file"). + filter(_.getString(0).contains("data/files/v1.txt")).count() > 0) + assert(sql(s"list file $testFile").count() == 1) } createQueryTest("dynamic_partition", @@ -978,7 +990,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SET hive.exec.dynamic.partition.mode=strict") // Should throw when using strict dynamic partition mode without any static partition - intercept[SparkException] { + intercept[AnalysisException] { sql( """INSERT INTO TABLE dp_test PARTITION(dp) |SELECT key, value, key % 5 FROM src @@ -988,7 +1000,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SET hive.exec.dynamic.partition.mode=nonstrict") // Should throw when a static partition appears after a dynamic partition - intercept[SparkException] { + intercept[AnalysisException] { sql( """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) |SELECT key, value, key % 5 FROM src @@ -996,9 +1008,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { - sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().registerTempTable("rawLogs") - sparkContext.makeRDD(Seq.empty[LogFile]).toDF().registerTempTable("logFiles") + test("SPARK-3414 regression: should store analyzed logical plan when creating a temporary view") { + sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().createOrReplaceTempView("rawLogs") + sparkContext.makeRDD(Seq.empty[LogFile]).toDF().createOrReplaceTempView("logFiles") sql( """ @@ -1009,20 +1021,20 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { FROM logFiles ) files ON rawLogs.filename = files.name - """).registerTempTable("boom") + """).createOrReplaceTempView("boom") // This should be successfully analyzed sql("SELECT * FROM boom").queryExecution.analyzed } - test("SPARK-3810: PreInsertionCasts static partitioning support") { + test("SPARK-3810: PreprocessTableInsertion static partitioning support") { val analyzedPlan = { loadTestTable("srcpart") sql("DROP TABLE IF EXISTS withparts") sql("CREATE TABLE withparts LIKE srcpart") sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") .queryExecution.analyzed - } + } assertResult(1, "Duplicated project detected\n" + analyzedPlan) { analyzedPlan.collect { @@ -1031,7 +1043,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - test("SPARK-3810: PreInsertionCasts dynamic partitioning support") { + test("SPARK-3810: PreprocessTableInsertion dynamic partitioning support") { val analyzedPlan = { loadTestTable("srcpart") sql("DROP TABLE IF EXISTS withparts") @@ -1039,11 +1051,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SET hive.exec.dynamic.partition.mode=nonstrict") sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") - sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value, '1', '2' FROM src") .queryExecution.analyzed } - assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + assertResult(2, "Duplicated project detected\n" + analyzedPlan) { analyzedPlan.collect { case _: Project => () }.size @@ -1166,7 +1178,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } test("some show commands are not supported") { - assertUnsupportedFeature { sql("SHOW CREATE TABLE my_table") } assertUnsupportedFeature { sql("SHOW COMPACTIONS") } assertUnsupportedFeature { sql("SHOW TRANSACTIONS") } assertUnsupportedFeature { sql("SHOW INDEXES ON my_table") } @@ -1195,6 +1206,27 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } assertUnsupportedFeature { sql("DROP TEMPORARY MACRO SIGMOID") } } + + test("dynamic partitioning is allowed when hive.exec.dynamic.partition.mode is nonstrict") { + val modeConfKey = "hive.exec.dynamic.partition.mode" + withTable("with_parts") { + sql("CREATE TABLE with_parts(key INT) PARTITIONED BY (p INT)") + + withSQLConf(modeConfKey -> "nonstrict") { + sql("INSERT OVERWRITE TABLE with_parts partition(p) select 1, 2") + assert(spark.table("with_parts").filter($"p" === 2).collect().head == Row(1, 2)) + } + + val originalValue = spark.sparkContext.hadoopConfiguration.get(modeConfKey, "nonstrict") + try { + spark.sparkContext.hadoopConfiguration.set(modeConfKey, "nonstrict") + sql("INSERT OVERWRITE TABLE with_parts partition(p) select 3, 4") + assert(spark.table("with_parts").filter($"p" === 4).collect().head == Row(3, 4)) + } finally { + spark.sparkContext.hadoopConfiguration.set(modeConfKey, originalValue) + } + } + } } // for SPARK-2180 test diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index dd13b8392880a..b2f19d7753956 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -32,14 +32,14 @@ class HiveResolutionSuite extends HiveComparisonTest { test("SPARK-3698: case insensitive test for nested data") { read.json(sparkContext.makeRDD( - """{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested") + """{"a": [{"a": {"a": 1}}]}""" :: Nil)).createOrReplaceTempView("nested") // This should be successfully analyzed sql("SELECT a[0].A.A from nested").queryExecution.analyzed } test("SPARK-5278: check ambiguous reference to fields") { read.json(sparkContext.makeRDD( - """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested") + """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).createOrReplaceTempView("nested") // there are 2 filed matching field name "b", we should report Ambiguous reference error val exception = intercept[AnalysisException] { @@ -78,7 +78,7 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) - .toDF().registerTempTable("caseSensitivityTest") + .toDF().createOrReplaceTempView("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), @@ -89,14 +89,14 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) - .toDF().registerTempTable("caseSensitivityTest") + .toDF().createOrReplaceTempView("caseSensitivityTest") sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) - .toDF().registerTempTable("nestedRepeatedTest") + .toDF().createOrReplaceTempView("nestedRepeatedTest") assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index b0c0dcbe5c25c..5b464764f0a99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils -class HiveTableScanSuite extends HiveComparisonTest { +class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestHiveSingleton { createQueryTest("partition_based_table_scan_with_different_serde", """ @@ -65,7 +66,7 @@ class HiveTableScanSuite extends HiveComparisonTest { TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") TestHive.sql( """ - CREATE EXTERNAL TABLE timestamp_query_null (time TIMESTAMP,id INT) + CREATE TABLE timestamp_query_null (time TIMESTAMP,id INT) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' @@ -84,9 +85,62 @@ class HiveTableScanSuite extends HiveComparisonTest { sql("""insert into table spark_4959 select "hi" from src limit 1""") table("spark_4959").select( 'col1.as("CaseSensitiveColName"), - 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2") + 'col1.as("CaseSensitiveColName2")).createOrReplaceTempView("spark_4959_2") assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } + + private def checkNumScannedPartitions(stmt: String, expectedNumParts: Int): Unit = { + val plan = sql(stmt).queryExecution.sparkPlan + val numPartitions = plan.collectFirst { + case p: HiveTableScanExec => + p.relation.getHiveQlPartitions(p.partitionPruningPred).length + }.getOrElse(0) + assert(numPartitions == expectedNumParts) + } + + test("Verify SQLConf HIVE_METASTORE_PARTITION_PRUNING") { + val view = "src" + withTempView(view) { + spark.range(1, 5).createOrReplaceTempView(view) + val table = "table_with_partition" + withTable(table) { + sql( + s""" + |CREATE TABLE $table(id string) + |PARTITIONED BY (p1 string,p2 string,p3 string,p4 string,p5 string) + """.stripMargin) + sql( + s""" + |FROM $view v + |INSERT INTO TABLE $table + |PARTITION (p1='a',p2='b',p3='c',p4='d',p5='e') + |SELECT v.id + |INSERT INTO TABLE $table + |PARTITION (p1='a',p2='c',p3='c',p4='d',p5='e') + |SELECT v.id + """.stripMargin) + + Seq("true", "false").foreach { hivePruning => + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> hivePruning) { + // If the pruning predicate is used, getHiveQlPartitions should only return the + // qualified partition; Otherwise, it return all the partitions. + val expectedNumPartitions = if (hivePruning == "true") 1 else 2 + checkNumScannedPartitions( + stmt = s"SELECT id, p2 FROM $table WHERE p2 <= 'b'", expectedNumPartitions) + } + } + + Seq("true", "false").foreach { hivePruning => + withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> hivePruning) { + // If the pruning predicate does not exist, getHiveQlPartitions should always + // return all the partitions. + checkNumScannedPartitions( + stmt = s"SELECT id, p2 FROM $table WHERE id <= 3", expectedNumParts = 2) + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index d07ac56586744..f690035c845f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -47,8 +47,8 @@ case class ListStringCaseClass(l: Seq[String]) */ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { - import hiveContext.udf - import hiveContext.implicits._ + import spark.udf + import spark.implicits._ test("spark sql udf test that returns a struct") { udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) @@ -72,7 +72,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { test("hive struct udf") { sql( """ - |CREATE EXTERNAL TABLE hiveUDFTestTable ( + |CREATE TABLE hiveUDFTestTable ( | pair STRUCT |) |PARTITIONED BY (partition STRING) @@ -142,6 +142,13 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) } + test("SPARK-16228 Percentile needs explicit cast to double") { + sql("select percentile(value, cast(0.5 as double)) from values 1,2,3 T(value)") + sql("select percentile_approx(value, cast(0.5 as double)) from values 1.0,2.0,3.0 T(value)") + sql("select percentile(value, 0.5) from values 1,2,3 T(value)") + sql("select percentile_approx(value, 0.5) from values 1.0,2.0,3.0 T(value)") + } + test("Generic UDAF aggregates") { checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"), sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) @@ -151,9 +158,9 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("UDFIntegerToString") { - val testData = hiveContext.sparkContext.parallelize( + val testData = spark.sparkContext.parallelize( IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() - testData.registerTempTable("integerTable") + testData.createOrReplaceTempView("integerTable") val udfName = classOf[UDFIntegerToString].getName sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") @@ -166,8 +173,8 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("UDFToListString") { - val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() - testData.registerTempTable("inputTable") + val testData = spark.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.createOrReplaceTempView("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListString AS '${classOf[UDFToListString].getName}'") val errMsg = intercept[AnalysisException] { @@ -181,8 +188,8 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("UDFToListInt") { - val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() - testData.registerTempTable("inputTable") + val testData = spark.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.createOrReplaceTempView("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'") val errMsg = intercept[AnalysisException] { @@ -196,8 +203,8 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("UDFToStringIntMap") { - val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() - testData.registerTempTable("inputTable") + val testData = spark.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.createOrReplaceTempView("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToStringIntMap " + s"AS '${classOf[UDFToStringIntMap].getName}'") @@ -212,8 +219,8 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("UDFToIntIntMap") { - val testData = hiveContext.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() - testData.registerTempTable("inputTable") + val testData = spark.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF() + testData.createOrReplaceTempView("inputTable") sql(s"CREATE TEMPORARY FUNCTION testUDFToIntIntMap " + s"AS '${classOf[UDFToIntIntMap].getName}'") @@ -228,11 +235,11 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("UDFListListInt") { - val testData = hiveContext.sparkContext.parallelize( + val testData = spark.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() - testData.registerTempTable("listListIntTable") + testData.createOrReplaceTempView("listListIntTable") sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") checkAnswer( @@ -244,10 +251,10 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("UDFListString") { - val testData = hiveContext.sparkContext.parallelize( + val testData = spark.sparkContext.parallelize( ListStringCaseClass(Seq("a", "b", "c")) :: ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() - testData.registerTempTable("listStringTable") + testData.createOrReplaceTempView("listStringTable") sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") checkAnswer( @@ -259,9 +266,9 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("UDFStringString") { - val testData = hiveContext.sparkContext.parallelize( + val testData = spark.sparkContext.parallelize( StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() - testData.registerTempTable("stringTable") + testData.createOrReplaceTempView("stringTable") sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( @@ -278,12 +285,12 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("UDFTwoListList") { - val testData = hiveContext.sparkContext.parallelize( + val testData = spark.sparkContext.parallelize( ListListIntCaseClass(Nil) :: ListListIntCaseClass(Seq((1, 2, 3))) :: ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() - testData.registerTempTable("TwoListTable") + testData.createOrReplaceTempView("TwoListTable") sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") checkAnswer( @@ -295,7 +302,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") { - Seq((1, 2)).toDF("a", "b").registerTempTable("testUDF") + Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") { // HiveSimpleUDF @@ -347,12 +354,12 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") } - sqlContext.dropTempTable("testUDF") + spark.catalog.dropTempView("testUDF") } test("Hive UDF in group by") { - withTempTable("tab1") { - Seq(Tuple1(1451400761)).toDF("test_date").registerTempTable("tab1") + withTempView("tab1") { + Seq(Tuple1(1451400761)).toDF("test_date").createOrReplaceTempView("tab1") sql(s"CREATE TEMPORARY FUNCTION testUDFToDate AS '${classOf[GenericUDFToDate].getName}'") val count = sql("select testUDFToDate(cast(test_date as timestamp))" + " from tab1 group by testUDFToDate(cast(test_date as timestamp))").count() @@ -435,10 +442,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } // Non-External parquet pointing to /tmp/... - - sql("CREATE TABLE parquet_tmp(c1 int, c2 int) " + - " STORED AS parquet " + - " AS SELECT 1, 2") + sql("CREATE TABLE parquet_tmp STORED AS parquet AS SELECT 1, 2") val answer4 = sql("SELECT input_file_name() as file FROM parquet_tmp").head().getString(0) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f20ab36efbf03..b421fae6dc57b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -19,12 +19,18 @@ package org.apache.spark.sql.hive.execution import java.sql.{Date, Timestamp} +import scala.sys.process.Process +import scala.util.Try + import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, + NoSuchPartitionException} +import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} @@ -59,7 +65,21 @@ case class Order( */ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext._ - import hiveContext.implicits._ + import spark.implicits._ + + test("script") { + if (testCommandAvailable("bash") && testCommandAvailable("echo | sed")) { + val df = Seq(("x1", "y1", "z1"), ("x2", "y2", "z2")).toDF("c1", "c2", "c3") + df.createOrReplaceTempView("script_table") + val query1 = sql( + """ + |SELECT col1 FROM (from(SELECT c1, c2, c3 FROM script_table) tempt_table + |REDUCE c1, c2, c3 USING 'bash src/test/resources/test_script.sh' AS + |(col1 STRING, col2 STRING)) script_test_table""".stripMargin) + checkAnswer(query1, Row("x1_y1") :: Row("x2_y2") :: Nil) + } + // else skip this test + } test("UDTF") { withUserDefinedFunction("udtf_count2" -> true) { @@ -103,14 +123,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") - df.registerTempTable("table1") + df.createOrReplaceTempView("table1") val query = sql("SELECT c1, v FROM table1 LATERAL VIEW stack(3, 1, c1 + 1, c1 + 2) d AS v") checkAnswer(query, Row(1, 1) :: Row(1, 2) :: Row(1, 3) :: Nil) } test("SPARK-13651: generator outputs shouldn't be resolved from its child's output") { - withTempTable("src") { - Seq(("id1", "value1")).toDF("key", "value").registerTempTable("src") + withTempView("src") { + Seq(("id1", "value1")).toDF("key", "value").createOrReplaceTempView("src") val query = sql("SELECT genoutput.* FROM src " + "LATERAL VIEW explode(map('key1', 100, 'key2', 200)) genoutput AS key, value") @@ -136,8 +156,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Order(1, "Atlas", "MTB", 434, "2015-01-07", "John D", "Pacifica", "CA", 20151), Order(11, "Swift", "YFlikr", 137, "2015-01-23", "John D", "Hayward", "CA", 20151)) - orders.toDF.registerTempTable("orders1") - orderUpdates.toDF.registerTempTable("orderupdates1") + orders.toDF.createOrReplaceTempView("orders1") + orderUpdates.toDF.createOrReplaceTempView("orderupdates1") sql( """CREATE TABLE orders( @@ -186,30 +206,50 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("show functions") { val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted - // The TestContext is shared by all the test cases, some functions may be registered before - // this, so we check that all the builtin functions are returned. val allFunctions = sql("SHOW functions").collect().map(r => r(0)) allBuiltinFunctions.foreach { f => assert(allFunctions.contains(f)) } - checkAnswer(sql("SHOW functions abs"), Row("abs")) - checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) - checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) - checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) - checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) - checkAnswer(sql("SHOW functions `~`"), Row("~")) - checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) - checkAnswer(sql("SHOW functions `weekofyea*`"), Row("weekofyear")) - // this probably will failed if we add more function with `sha` prefixing. - checkAnswer(sql("SHOW functions `sha*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) - // Test '|' for alternation. - checkAnswer( - sql("SHOW functions 'sha*|weekofyea*'"), - Row("sha") :: Row("sha1") :: Row("sha2") :: Row("weekofyear") :: Nil) + withTempDatabase { db => + def createFunction(names: Seq[String]): Unit = { + names.foreach { name => + sql( + s""" + |CREATE TEMPORARY FUNCTION $name + |AS '${classOf[PairUDF].getName}' + """.stripMargin) + } + } + def dropFunction(names: Seq[String]): Unit = { + names.foreach { name => + sql(s"DROP TEMPORARY FUNCTION $name") + } + } + createFunction(Seq("temp_abs", "temp_weekofyear", "temp_sha", "temp_sha1", "temp_sha2")) + + checkAnswer(sql("SHOW functions temp_abs"), Row("temp_abs")) + checkAnswer(sql("SHOW functions 'temp_abs'"), Row("temp_abs")) + checkAnswer(sql(s"SHOW functions $db.temp_abs"), Row("temp_abs")) + checkAnswer(sql(s"SHOW functions `$db`.`temp_abs`"), Row("temp_abs")) + checkAnswer(sql(s"SHOW functions `$db`.`temp_abs`"), Row("temp_abs")) + checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) + checkAnswer(sql("SHOW functions `temp_weekofyea*`"), Row("temp_weekofyear")) + + // this probably will failed if we add more function with `sha` prefixing. + checkAnswer( + sql("SHOW functions `temp_sha*`"), + List(Row("temp_sha"), Row("temp_sha1"), Row("temp_sha2"))) + + // Test '|' for alternation. + checkAnswer( + sql("SHOW functions 'temp_sha*|temp_weekofyea*'"), + List(Row("temp_sha"), Row("temp_sha1"), Row("temp_sha2"), Row("temp_weekofyear"))) + + dropFunction(Seq("temp_abs", "temp_weekofyear", "temp_sha", "temp_sha1", "temp_sha2")) + } } - test("describe functions") { - // The Spark SQL built-in functions + test("describe functions - built-in functions") { checkKeywordsExist(sql("describe function extended upper"), "Function: upper", "Class: org.apache.spark.sql.catalyst.expressions.Upper", @@ -253,9 +293,134 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { "When a = b, returns c; when a = d, return e; else return f") } + test("describe functions - user defined functions") { + withUserDefinedFunction("udtf_count" -> false) { + sql( + s""" + |CREATE FUNCTION udtf_count + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}' + """.stripMargin) + + checkKeywordsExist(sql("describe function udtf_count"), + "Function: default.udtf_count", + "Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2", + "Usage: N/A") + + checkAnswer( + sql("SELECT udtf_count(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + + checkKeywordsExist(sql("describe function udtf_count"), + "Function: default.udtf_count", + "Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2", + "Usage: N/A") + } + } + + test("describe functions - temporary user defined functions") { + withUserDefinedFunction("udtf_count_temp" -> true) { + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_count_temp + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}' + """.stripMargin) + + checkKeywordsExist(sql("describe function udtf_count_temp"), + "Function: udtf_count_temp", + "Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2", + "Usage: N/A") + + checkAnswer( + sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + + checkKeywordsExist(sql("describe function udtf_count_temp"), + "Function: udtf_count_temp", + "Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2", + "Usage: N/A") + } + } + + test("describe partition") { + withTable("partitioned_table") { + sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") + sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") + + checkKeywordsExist(sql("DESC partitioned_table PARTITION (c='Us', d=1)"), + "# Partition Information", + "# col_name") + + checkKeywordsExist(sql("DESC EXTENDED partitioned_table PARTITION (c='Us', d=1)"), + "# Partition Information", + "# col_name", + "Detailed Partition Information CatalogPartition(", + "Partition Values: [Us, 1]", + "Storage(Location:", + "Partition Parameters") + + checkKeywordsExist(sql("DESC FORMATTED partitioned_table PARTITION (c='Us', d=1)"), + "# Partition Information", + "# col_name", + "# Detailed Partition Information", + "Partition Value:", + "Database:", + "Table:", + "Location:", + "Partition Parameters:", + "# Storage Information") + } + } + + test("describe partition - error handling") { + withTable("partitioned_table", "datasource_table") { + sql("CREATE TABLE partitioned_table (a STRING, b INT) PARTITIONED BY (c STRING, d STRING)") + sql("ALTER TABLE partitioned_table ADD PARTITION (c='Us', d=1)") + + val m = intercept[NoSuchPartitionException] { + sql("DESC partitioned_table PARTITION (c='Us', d=2)") + }.getMessage() + assert(m.contains("Partition not found in table")) + + val m2 = intercept[AnalysisException] { + sql("DESC partitioned_table PARTITION (c='Us')") + }.getMessage() + assert(m2.contains("Partition spec is invalid")) + + val m3 = intercept[ParseException] { + sql("DESC partitioned_table PARTITION (c='Us', d)") + }.getMessage() + assert(m3.contains("PARTITION specification is incomplete: `d`")) + + spark + .range(1).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write + .partitionBy("d") + .saveAsTable("datasource_table") + val m4 = intercept[AnalysisException] { + sql("DESC datasource_table PARTITION (d=2)") + }.getMessage() + assert(m4.contains("DESC PARTITION is not allowed on a datasource table")) + + val m5 = intercept[AnalysisException] { + spark.range(10).select('id as 'a, 'id as 'b).createTempView("view1") + sql("DESC view1 PARTITION (c='Us', d=1)") + }.getMessage() + assert(m5.contains("DESC PARTITION is not allowed on a temporary view")) + + withView("permanent_view") { + val m = intercept[AnalysisException] { + sql("CREATE VIEW permanent_view AS SELECT * FROM partitioned_table") + sql("DESC permanent_view PARTITION (c='Us', d=1)") + }.getMessage() + assert(m.contains("DESC PARTITION is not allowed on a view")) + } + } + } + test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") - df.registerTempTable("table1") + df.createOrReplaceTempView("table1") val query = sql( """ @@ -279,7 +444,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("CTAS with WITH clause") { val df = Seq((1, 1)).toDF("c1", "c2") - df.registerTempTable("table1") + df.createOrReplaceTempView("table1") sql( """ @@ -296,7 +461,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("explode nested Field") { - Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.registerTempTable("nestedArray") + Seq(NestedArray1(NestedArray2(Seq(1, 2, 3)))).toDF.createOrReplaceTempView("nestedArray") checkAnswer( sql("SELECT ints FROM nestedArray LATERAL VIEW explode(a.b) a AS ints"), Row(1) :: Row(2) :: Row(3) :: Nil) @@ -326,78 +491,138 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ) } - test("CTAS without serde") { - def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { - val relation = EliminateSubqueryAliases( - sessionState.catalog.lookupRelation(TableIdentifier(tableName))) - relation match { - case LogicalRelation(r: HadoopFsRelation, _, _) => - if (!isDataSourceParquet) { - fail( - s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + + def checkRelation( + tableName: String, + isDataSourceParquet: Boolean, + format: String, + userSpecifiedLocation: Option[String] = None): Unit = { + val relation = EliminateSubqueryAliases( + sessionState.catalog.lookupRelation(TableIdentifier(tableName))) + val catalogTable = + sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) + relation match { + case LogicalRelation(r: HadoopFsRelation, _, _) => + if (!isDataSourceParquet) { + fail( + s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + s"${HadoopFsRelation.getClass.getCanonicalName}.") - } + } + userSpecifiedLocation match { + case Some(location) => + assert(r.options("path") === location) + case None => // OK. + } + assert( + catalogTable.properties(CreateDataSourceTableUtils.DATASOURCE_PROVIDER) === format) - case r: MetastoreRelation => - if (isDataSourceParquet) { - fail( - s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + + case r: MetastoreRelation => + if (isDataSourceParquet) { + fail( + s"${HadoopFsRelation.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") - } - } + } + userSpecifiedLocation match { + case Some(location) => + assert(r.catalogTable.storage.locationUri.get === location) + case None => // OK. + } + // Also make sure that the format is the desired format. + assert(catalogTable.storage.inputFormat.get.toLowerCase.contains(format)) } - val originalConf = sessionState.convertCTAS + // When a user-specified location is defined, the table type needs to be EXTERNAL. + val actualTableType = catalogTable.tableType + userSpecifiedLocation match { + case Some(location) => + assert(actualTableType === CatalogTableType.EXTERNAL) + case None => + assert(actualTableType === CatalogTableType.MANAGED) + } + } - setConf(HiveUtils.CONVERT_CTAS, true) + test("CTAS without serde without location") { + val originalConf = sessionState.conf.convertCTAS + setConf(SQLConf.CONVERT_CTAS, true) + + val defaultDataSource = sessionState.conf.defaultDataSourceName try { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - var message = intercept[AnalysisException] { + val message = intercept[AnalysisException] { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") }.getMessage assert(message.contains("already exists")) - checkRelation("ctas1", true) + checkRelation("ctas1", true, defaultDataSource) sql("DROP TABLE ctas1") // Specifying database name for query can be converted to data source write path // is not allowed right now. - message = intercept[AnalysisException] { - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert( - message.contains("Cannot specify database name in a CTAS statement"), - "When spark.sql.hive.convertCTAS is true, we should not allow " + - "database name specified.") + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true, defaultDataSource) + sql("DROP TABLE ctas1") sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) + checkRelation("ctas1", false, "text") sql("DROP TABLE ctas1") sql("CREATE TABLE ctas1 stored as sequencefile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true) + checkRelation("ctas1", false, "sequence") sql("DROP TABLE ctas1") sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) + checkRelation("ctas1", false, "rcfile") sql("DROP TABLE ctas1") sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) + checkRelation("ctas1", false, "orc") sql("DROP TABLE ctas1") sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false) + checkRelation("ctas1", false, "parquet") sql("DROP TABLE ctas1") } finally { - setConf(HiveUtils.CONVERT_CTAS, originalConf) + setConf(SQLConf.CONVERT_CTAS, originalConf) sql("DROP TABLE IF EXISTS ctas1") } } + test("CTAS without serde with location") { + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + withTempDir { dir => + val defaultDataSource = sessionState.conf.defaultDataSourceName + + val tempLocation = dir.getCanonicalPath + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c1")) + sql("DROP TABLE ctas1") + + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c2")) + sql("DROP TABLE ctas1") + + sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false, "text", Some(s"file:$tempLocation/c3")) + sql("DROP TABLE ctas1") + + sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false, "sequence", Some(s"file:$tempLocation/c4")) + sql("DROP TABLE ctas1") + + sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", false, "rcfile", Some(s"file:$tempLocation/c5")) + sql("DROP TABLE ctas1") + } + } + } + test("CTAS with serde") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect() sql( @@ -493,21 +718,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("specifying the column list for CTAS") { - Seq((1, "111111"), (2, "222222")).toDF("key", "value").registerTempTable("mytable1") - - sql("create table gen__tmp(a int, b string) as select key, value from mytable1") - checkAnswer( - sql("SELECT a, b from gen__tmp"), - sql("select key, value from mytable1").collect()) - sql("DROP TABLE gen__tmp") + withTempView("mytable1") { + Seq((1, "111111"), (2, "222222")).toDF("key", "value").createOrReplaceTempView("mytable1") + withTable("gen__tmp") { + sql("create table gen__tmp as select key as a, value as b from mytable1") + checkAnswer( + sql("SELECT a, b from gen__tmp"), + sql("select key, value from mytable1").collect()) + } - sql("create table gen__tmp(a double, b double) as select key, value from mytable1") - checkAnswer( - sql("SELECT a, b from gen__tmp"), - sql("select cast(key as double), cast(value as double) from mytable1").collect()) - sql("DROP TABLE gen__tmp") + withTable("gen__tmp") { + val e = intercept[AnalysisException] { + sql("create table gen__tmp(a int, b string) as select key, value from mytable1") + }.getMessage + assert(e.contains("Schema may not be specified in a Create Table As Select (CTAS)")) + } - sql("drop table mytable1") + withTable("gen__tmp") { + val e = intercept[AnalysisException] { + sql( + """ + |CREATE TABLE gen__tmp + |PARTITIONED BY (key string) + |AS SELECT key, value FROM mytable1 + """.stripMargin) + }.getMessage + assert(e.contains("A Create Table As Select (CTAS) statement is not allowed to " + + "create a partitioned table using Hive's file formats")) + } + } } test("command substitution") { @@ -548,7 +787,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("double nested data") { sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil) - .toDF().registerTempTable("nested") + .toDF().createOrReplaceTempView("nested") checkAnswer( sql("SELECT f1.f2.f3 FROM nested"), Row(1)) @@ -632,7 +871,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-4963 DataFrame sample on mutable row return wrong result") { sql("SELECT * FROM src WHERE key % 2 = 0") .sample(withReplacement = false, fraction = 0.3) - .registerTempTable("sampled") + .createOrReplaceTempView("sampled") (1 to 10).foreach { i => checkAnswer( sql("SELECT * FROM sampled WHERE key % 2 = 1"), @@ -640,7 +879,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("SPARK-4699 HiveContext should be case insensitive by default") { + test("SPARK-4699 SparkSession with Hive Support should be case insensitive by default") { checkAnswer( sql("SELECT KEY FROM Src ORDER BY value"), sql("SELECT key FROM src ORDER BY value").collect().toSeq) @@ -657,7 +896,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val rowRdd = sparkContext.parallelize(row :: Nil) - hiveContext.createDataFrame(rowRdd, schema).registerTempTable("testTable") + spark.createDataFrame(rowRdd, schema).createOrReplaceTempView("testTable") sql( """CREATE TABLE nullValuesInInnerComplexTypes @@ -683,14 +922,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-4296 Grouping field with Hive UDF as sub expression") { val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) - read.json(rdd).registerTempTable("data") + read.json(rdd).createOrReplaceTempView("data") checkAnswer( sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), Row("str-1", 1970)) dropTempTable("data") - read.json(rdd).registerTempTable("data") + read.json(rdd).createOrReplaceTempView("data") checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) dropTempTable("data") @@ -698,14 +937,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("resolve udtf in projection #1") { val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).registerTempTable("data") + read.json(rdd).createOrReplaceTempView("data") val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") } test("resolve udtf in projection #2") { val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).registerTempTable("data") + read.json(rdd).createOrReplaceTempView("data") checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) intercept[AnalysisException] { @@ -720,7 +959,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive test("TGF with non-TGF in projection") { val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) - read.json(rdd).registerTempTable("data") + read.json(rdd).createOrReplaceTempView("data") checkAnswer( sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), Row("1", "1", "1", "1") :: Nil) @@ -734,9 +973,9 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - read.json(rdd).registerTempTable("data") - val originalConf = sessionState.convertCTAS - setConf(HiveUtils.CONVERT_CTAS, false) + read.json(rdd).createOrReplaceTempView("data") + val originalConf = sessionState.conf.convertCTAS + setConf(SQLConf.CONVERT_CTAS, false) try { sql("CREATE TABLE explodeTest (key bigInt)") @@ -755,7 +994,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql("DROP TABLE explodeTest") dropTempTable("data") } finally { - setConf(HiveUtils.CONVERT_CTAS, originalConf) + setConf(SQLConf.CONVERT_CTAS, originalConf) } } @@ -774,7 +1013,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq.empty[(java.math.BigDecimal, java.math.BigDecimal)] .toDF("d1", "d2") .select($"d1".cast(DecimalType(10, 5)).as("d")) - .registerTempTable("dn") + .createOrReplaceTempView("dn") sql("select d from dn union all select d * 2 from dn") .queryExecution.analyzed @@ -782,27 +1021,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("Star Expansion - script transform") { val data = (1 to 100000).map { i => (i, i, i) } - data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + data.toDF("d1", "d2", "d3").createOrReplaceTempView("script_trans") assert(100000 === sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans").count()) } test("test script transform for stdout") { val data = (1 to 100000).map { i => (i, i, i) } - data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + data.toDF("d1", "d2", "d3").createOrReplaceTempView("script_trans") assert(100000 === sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans").count()) } test("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } - data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + data.toDF("d1", "d2", "d3").createOrReplaceTempView("script_trans") assert(0 === sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans").count()) } test("test script transform data type") { val data = (1 to 5).map { i => (i, i) } - data.toDF("key", "value").registerTempTable("test") + data.toDF("key", "value").createOrReplaceTempView("test") checkAnswer( sql("""FROM |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (`thing1` int, thing2 string)) t @@ -811,10 +1050,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("Sorting columns are not in Generate") { - withTempTable("data") { - sqlContext.range(1, 5) + withTempView("data") { + spark.range(1, 5) .select(array($"id", $"id" + 1).as("a"), $"id".as("b"), (lit(10) - $"id").as("c")) - .registerTempTable("data") + .createOrReplaceTempView("data") // case 1: missing sort columns are resolvable if join is true checkAnswer( @@ -838,7 +1077,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("test case key when") { - (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t") + (1 to 5).map(i => (i, i.toString)).toDF("k", "v").createOrReplaceTempView("t") checkAnswer( sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"), Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil) @@ -847,7 +1086,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-7269 Check analysis failed in case in-sensitive") { Seq(1, 2, 3).map { i => (i.toString, i.toString) - }.toDF("key", "value").registerTempTable("df_analysis") + }.toDF("key", "value").createOrReplaceTempView("df_analysis") sql("SELECT kEy from df_analysis group by key").collect() sql("SELECT kEy+3 from df_analysis group by key+3").collect() sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() @@ -866,29 +1105,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SELECT CAST('775983671874188101' as BIGINT)"), Row(775983671874188101L)) } - // `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest - test("udf_java_method") { - checkAnswer(sql( - """ - |SELECT java_method("java.lang.String", "valueOf", 1), - | java_method("java.lang.String", "isEmpty"), - | java_method("java.lang.Math", "max", 2, 3), - | java_method("java.lang.Math", "min", 2, 3), - | java_method("java.lang.Math", "round", 2.5D), - | java_method("java.lang.Math", "exp", 1.0D), - | java_method("java.lang.Math", "floor", 1.9D) - |FROM src tablesample (1 rows) - """.stripMargin), - Row( - "1", - "true", - java.lang.Math.max(2, 3).toString, - java.lang.Math.min(2, 3).toString, - java.lang.Math.round(2.5).toString, - java.lang.Math.exp(1.0).toString, - java.lang.Math.floor(1.9).toString)) - } - test("dynamic partition value test") { try { sql("set hive.exec.dynamic.partition.mode=nonstrict") @@ -981,7 +1197,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-8588 HiveTypeCoercion.inConversion fires too early") { val df = createDataFrame(Seq((1, "2014-01-01"), (2, "2015-01-01"), (3, "2016-01-01"))) - df.toDF("id", "datef").registerTempTable("test_SPARK8588") + df.toDF("id", "datef").createOrReplaceTempView("test_SPARK8588") checkAnswer( sql( """ @@ -996,7 +1212,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-9371: fix the support for special chars in column names for hive context") { read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) - .registerTempTable("t") + .createOrReplaceTempView("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } @@ -1033,7 +1249,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // We don't support creating a temporary table while specifying a database val message = intercept[AnalysisException] { - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE db.t |USING parquet @@ -1044,7 +1260,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { }.getMessage // If you use backticks to quote the name then it's OK. - sqlContext.sql( + spark.sql( s""" |CREATE TEMPORARY TABLE `db.t` |USING parquet @@ -1052,12 +1268,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { | path '$path' |) """.stripMargin) - checkAnswer(sqlContext.table("`db.t`"), df) + checkAnswer(spark.table("`db.t`"), df) } } test("SPARK-10593 same column names in lateral view") { - val df = sqlContext.sql( + val df = spark.sql( """ |select |insideLayer2.json as a2 @@ -1072,10 +1288,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ignore("SPARK-10310: " + "script transformation using default input/output SerDe and record reader/writer") { - sqlContext + spark .range(5) .selectExpr("id AS a", "id AS b") - .registerTempTable("test") + .createOrReplaceTempView("test") checkAnswer( sql( @@ -1090,10 +1306,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } ignore("SPARK-10310: script transformation using LazySimpleSerDe") { - sqlContext + spark .range(5) .selectExpr("id AS a", "id AS b") - .registerTempTable("test") + .createOrReplaceTempView("test") val df = sql( """FROM test @@ -1111,9 +1327,9 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-10741: Sort on Aggregate using parquet") { withTable("test10741") { - withTempTable("src") { - Seq("a" -> 5, "a" -> 9, "b" -> 6).toDF().registerTempTable("src") - sql("CREATE TABLE test10741(c1 STRING, c2 INT) STORED AS PARQUET AS SELECT * FROM src") + withTempView("src") { + Seq("a" -> 5, "a" -> 9, "b" -> 6).toDF("c1", "c2").createOrReplaceTempView("src") + sql("CREATE TABLE test10741 STORED AS PARQUET AS SELECT * FROM src") } checkAnswer(sql( @@ -1134,11 +1350,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("run sql directly on files") { - val df = sqlContext.range(100).toDF() + test("run sql directly on files - parquet") { + val df = spark.range(100).toDF() withTempPath(f => { df.write.parquet(f.getCanonicalPath) - checkAnswer(sql(s"select id from parquet.`${f.getCanonicalPath}`"), + // data source type is case insensitive + checkAnswer(sql(s"select id from Parquet.`${f.getCanonicalPath}`"), df) checkAnswer(sql(s"select id from `org.apache.spark.sql.parquet`.`${f.getCanonicalPath}`"), df) @@ -1147,6 +1364,49 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { }) } + test("run sql directly on files - orc") { + val df = spark.range(100).toDF() + withTempPath(f => { + df.write.orc(f.getCanonicalPath) + // data source type is case insensitive + checkAnswer(sql(s"select id from ORC.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select id from `org.apache.spark.sql.hive.orc`.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select a.id from orc.`${f.getCanonicalPath}` as a"), + df) + }) + } + + test("run sql directly on files - csv") { + val df = spark.range(100).toDF() + withTempPath(f => { + df.write.csv(f.getCanonicalPath) + // data source type is case insensitive + checkAnswer(sql(s"select cast(_c0 as int) id from CSV.`${f.getCanonicalPath}`"), + df) + checkAnswer( + sql(s"select cast(_c0 as int) id from `com.databricks.spark.csv`.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select cast(a._c0 as int) id from csv.`${f.getCanonicalPath}` as a"), + df) + }) + } + + test("run sql directly on files - json") { + val df = spark.range(100).toDF() + withTempPath(f => { + df.write.json(f.getCanonicalPath) + // data source type is case insensitive + checkAnswer(sql(s"select id from jsoN.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select id from `org.apache.spark.sql.json`.`${f.getCanonicalPath}`"), + df) + checkAnswer(sql(s"select a.id from json.`${f.getCanonicalPath}` as a"), + df) + }) + } + test("SPARK-8976 Wrong Result for Rollup #1") { checkAnswer(sql( "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"), @@ -1277,14 +1537,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Seq("3" -> "30").toDF("i", "j") .write.mode(SaveMode.Append).partitionBy("i").saveAsTable("tbl11453") checkAnswer( - sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + spark.read.table("tbl11453").select("i", "j").orderBy("i"), Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Nil) // make sure case sensitivity is correct. Seq("4" -> "40").toDF("i", "j") .write.mode(SaveMode.Append).partitionBy("I").saveAsTable("tbl11453") checkAnswer( - sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + spark.read.table("tbl11453").select("i", "j").orderBy("i"), Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil) } } @@ -1321,10 +1581,10 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("multi-insert with lateral view") { - withTempTable("t1") { - sqlContext.range(10) + withTempView("t1") { + spark.range(10) .select(array($"id", $"id" + 1).as("arr"), $"id") - .registerTempTable("source") + .createOrReplaceTempView("source") withTable("dest1", "dest2") { sql("CREATE TABLE dest1 (i INT)") sql("CREATE TABLE dest2 (i INT)") @@ -1340,61 +1600,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin) checkAnswer( - sqlContext.table("dest1"), + spark.table("dest1"), sql("SELECT id FROM source WHERE id > 3")) checkAnswer( - sqlContext.table("dest2"), + spark.table("dest2"), sql("SELECT col FROM source LATERAL VIEW EXPLODE(arr) exp AS col WHERE col > 3")) } } } - test( - "SPARK-14488 \"CREATE TEMPORARY TABLE ... USING ... AS SELECT ...\" " + - "shouldn't create persisted table" - ) { - withTempPath { dir => - withTempTable("t1", "t2") { - val path = dir.getCanonicalPath - val ds = sqlContext.range(10) - ds.registerTempTable("t1") - - sql( - s"""CREATE TEMPORARY TABLE t2 - |USING PARQUET - |OPTIONS (PATH '$path') - |AS SELECT * FROM t1 - """.stripMargin) - - checkAnswer( - sqlContext.tables().select('isTemporary).filter('tableName === "t2"), - Row(true) - ) - - checkAnswer(table("t2"), table("t1")) - } - } - } - - test( - "SPARK-14493 \"CREATE TEMPORARY TABLE ... USING ... AS SELECT ...\" " + - "shouldn always be used together with PATH data source option" - ) { - withTempTable("t") { - sqlContext.range(10).registerTempTable("t") - - val message = intercept[IllegalArgumentException] { - sql( - s"""CREATE TEMPORARY TABLE t1 - |USING PARQUET - |AS SELECT * FROM t - """.stripMargin) - }.getMessage - - assert(message == "'path' is not specified") - } - } - test("derived from Hive query file: drop_database_removes_partition_dirs.q") { // This test verifies that if a partition exists outside a table's current location when the // database is dropped the partition's location is dropped as well. @@ -1484,10 +1698,39 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { assert(fs.listStatus(new Path(path, "part=1")).nonEmpty) sql("drop table test_table") - assert( - !fs.exists(path), - "Once a managed table has been dropped, " + - "dirs of this table should also have been deleted.") + assert(fs.exists(path), "This is an external table, so the data should not have been dropped") + } + + test("select partitioned table") { + val table = "table_with_partition" + withTable(table) { + sql( + s""" + |CREATE TABLE $table(c1 string) + |PARTITIONED BY (p1 string,p2 string,p3 string,p4 string,p5 string) + """.stripMargin) + sql( + s""" + |INSERT OVERWRITE TABLE $table + |PARTITION (p1='a',p2='b',p3='c',p4='d',p5='e') + |SELECT 'blarr' + """.stripMargin) + + // project list is the same order of paritioning columns in table definition + checkAnswer( + sql(s"SELECT p1, p2, p3, p4, p5, c1 FROM $table"), + Row("a", "b", "c", "d", "e", "blarr") :: Nil) + + // project list does not have the same order of paritioning columns in table definition + checkAnswer( + sql(s"SELECT p2, p3, p4, p1, p5, c1 FROM $table"), + Row("b", "c", "d", "a", "e", "blarr") :: Nil) + + // project list contains partial partition columns in table definition + checkAnswer( + sql(s"SELECT p2, p1, p5, c1 FROM $table"), + Row("b", "a", "e", "blarr") :: Nil) + } } test("SPARK-14981: DESC not supported for sorting columns") { @@ -1502,7 +1745,59 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ) } - assert(cause.getMessage.contains("Only ASC ordering is supported for sorting columns")) + assert(cause.getMessage.contains("Column ordering must be ASC, was 'DESC'")) + } + } + + test("insert into datasource table") { + withTable("tbl") { + sql("CREATE TABLE tbl(i INT, j STRING) USING parquet") + Seq(1 -> "a").toDF("i", "j").write.mode("overwrite").insertInto("tbl") + checkAnswer(sql("SELECT * FROM tbl"), Row(1, "a")) } } + + test("spark-15557 promote string test") { + withTable("tbl") { + sql("CREATE TABLE tbl(c1 string, c2 string)") + sql("insert into tbl values ('3', '2.3')") + checkAnswer( + sql("select (cast (99 as decimal(19,6)) + cast('3' as decimal)) * cast('2.3' as decimal)"), + Row(204.0) + ) + checkAnswer( + sql("select (cast(99 as decimal(19,6)) + '3') *'2.3' from tbl"), + Row(234.6) + ) + checkAnswer( + sql("select (cast(99 as decimal(19,6)) + c1) * c2 from tbl"), + Row(234.6) + ) + } + } + + test("SPARK-17354: Partitioning by dates/timestamps works with Parquet vectorized reader") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { + sql( + """CREATE TABLE order(id INT) + |PARTITIONED BY (pd DATE, pt TIMESTAMP) + |STORED AS PARQUET + """.stripMargin) + + sql("set hive.exec.dynamic.partition.mode=nonstrict") + sql( + """INSERT INTO TABLE order PARTITION(pd, pt) + |SELECT 1 AS id, CAST('1990-02-24' AS DATE) AS pd, CAST('1990-02-24' AS TIMESTAMP) AS pt + """.stripMargin) + val actual = sql("SELECT * FROM order") + val expected = sql( + "SELECT 1 AS id, CAST('1990-02-24' AS DATE) AS pd, CAST('1990-02-24' AS TIMESTAMP) AS pt") + checkAnswer(actual, expected) + sql("DROP TABLE order") + } + } + + def testCommandAvailable(command: String): Boolean = { + Try(Process(command) !!).isSuccess + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala index 0d88b3b87f501..490fea27de2fc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -26,22 +28,82 @@ import org.apache.spark.sql.test.SQLTestUtils * A suite for testing view related functionality. */ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - import hiveContext.implicits._ + import spark.implicits._ override def beforeAll(): Unit = { // Create a simple table with two columns: id and id1 - sqlContext.range(1, 10).selectExpr("id", "id id1").write.format("json").saveAsTable("jt") + spark.range(1, 10).selectExpr("id", "id id1").write.format("json").saveAsTable("jt") } override def afterAll(): Unit = { - sqlContext.sql(s"DROP TABLE IF EXISTS jt") + spark.sql(s"DROP TABLE IF EXISTS jt") } - test("nested views") { - withView("jtv1", "jtv2") { - sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3").collect() - sql("CREATE VIEW jtv2 AS SELECT * FROM jtv1 WHERE id < 6").collect() + test("nested views (interleaved with temporary views)") { + withView("jtv1", "jtv2", "jtv3", "temp_jtv1", "temp_jtv2", "temp_jtv3") { + sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") + sql("CREATE VIEW jtv2 AS SELECT * FROM jtv1 WHERE id < 6") checkAnswer(sql("select count(*) FROM jtv2"), Row(2)) + + // Checks temporary views + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") + sql("CREATE TEMPORARY VIEW temp_jtv2 AS SELECT * FROM temp_jtv1 WHERE id < 6") + checkAnswer(sql("select count(*) FROM temp_jtv2"), Row(2)) + + // Checks interleaved temporary view and normal view + sql("CREATE TEMPORARY VIEW temp_jtv3 AS SELECT * FROM jt WHERE id > 3") + sql("CREATE VIEW jtv3 AS SELECT * FROM temp_jtv3 WHERE id < 6") + checkAnswer(sql("select count(*) FROM jtv3"), Row(2)) + } + } + + test("Issue exceptions for ALTER VIEW on the temporary view") { + val viewName = "testView" + withTempView(viewName) { + spark.range(10).createTempView(viewName) + assertNoSuchTable(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'an')") + assertNoSuchTable(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')") + } + } + + test("Issue exceptions for ALTER TABLE on the temporary view") { + val viewName = "testView" + withTempView(viewName) { + spark.range(10).createTempView(viewName) + assertNoSuchTable(s"ALTER TABLE $viewName SET SERDE 'whatever'") + assertNoSuchTable(s"ALTER TABLE $viewName PARTITION (a=1, b=2) SET SERDE 'whatever'") + assertNoSuchTable(s"ALTER TABLE $viewName SET SERDEPROPERTIES ('p' = 'an')") + assertNoSuchTable(s"ALTER TABLE $viewName SET LOCATION '/path/to/your/lovely/heart'") + assertNoSuchTable(s"ALTER TABLE $viewName PARTITION (a='4') SET LOCATION '/path/to/home'") + assertNoSuchTable(s"ALTER TABLE $viewName ADD IF NOT EXISTS PARTITION (a='4', b='8')") + assertNoSuchTable(s"ALTER TABLE $viewName DROP PARTITION (a='4', b='8')") + assertNoSuchTable(s"ALTER TABLE $viewName PARTITION (a='4') RENAME TO PARTITION (a='5')") + assertNoSuchTable(s"ALTER TABLE $viewName RECOVER PARTITIONS") + } + } + + test("Issue exceptions for other table DDL on the temporary view") { + val viewName = "testView" + withTempView(viewName) { + spark.range(10).createTempView(viewName) + + val e = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $viewName SELECT 1") + }.getMessage + assert(e.contains("Inserting into an RDD-based table is not allowed")) + + val testData = hiveContext.getHiveFile("data/files/employee.dat").getCanonicalPath + assertNoSuchTable(s"""LOAD DATA LOCAL INPATH "$testData" INTO TABLE $viewName""") + assertNoSuchTable(s"TRUNCATE TABLE $viewName") + assertNoSuchTable(s"SHOW CREATE TABLE $viewName") + assertNoSuchTable(s"SHOW PARTITIONS $viewName") + assertNoSuchTable(s"ANALYZE TABLE $viewName COMPUTE STATISTICS") + } + } + + private def assertNoSuchTable(query: String): Unit = { + intercept[NoSuchTableException] { + sql(query) } } @@ -57,6 +119,33 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("error handling: fail if the temp view name contains the database prefix") { + // Fully qualified table name like "database.table" is not allowed for temporary view + val e = intercept[AnalysisException] { + sql("CREATE OR REPLACE TEMPORARY VIEW default.myabcdview AS SELECT * FROM jt") + } + assert(e.message.contains("It is not allowed to add database prefix")) + } + + test("error handling: disallow IF NOT EXISTS for CREATE TEMPORARY VIEW") { + val e = intercept[AnalysisException] { + sql("CREATE TEMPORARY VIEW IF NOT EXISTS myabcdview AS SELECT * FROM jt") + } + assert(e.message.contains("It is not allowed to define a TEMPORARY view with IF NOT EXISTS")) + } + + test("error handling: fail if the temp view sql itself is invalid") { + // A table that does not exist for temporary view + intercept[AnalysisException] { + sql("CREATE OR REPLACE TEMPORARY VIEW myabcdview AS SELECT * FROM table_not_exist1345") + } + + // A column that does not exist, for temporary view + intercept[AnalysisException] { + sql("CREATE OR REPLACE TEMPORARY VIEW myabcdview AS SELECT random1234 FROM jt") + } + } + test("correctly parse CREATE VIEW statement") { withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { sql( @@ -69,22 +158,138 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("correctly parse CREATE TEMPORARY VIEW statement") { + withView("testView") { + sql( + """CREATE TEMPORARY VIEW + |testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla') + |TBLPROPERTIES ('a' = 'b') + |AS SELECT * FROM jt + |""".stripMargin) + checkAnswer(sql("SELECT c1, c2 FROM testView ORDER BY c1"), (1 to 9).map(i => Row(i, i))) + } + } + + test("should NOT allow CREATE TEMPORARY VIEW when TEMPORARY VIEW with same name exists") { + withView("testView") { + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + + val e = intercept[AnalysisException] { + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + } + + assert(e.message.contains("Temporary table") && e.message.contains("already exists")) + } + } + + test("should allow CREATE TEMPORARY VIEW when a permanent VIEW with same name exists") { + withView("testView", "default.testView") { + sql("CREATE VIEW testView AS SELECT id FROM jt") + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + } + } + + test("should allow CREATE permanent VIEW when a TEMPORARY VIEW with same name exists") { + withView("testView", "default.testView") { + sql("CREATE TEMPORARY VIEW testView AS SELECT id FROM jt") + sql("CREATE VIEW testView AS SELECT id FROM jt") + } + } + test("correctly handle CREATE VIEW IF NOT EXISTS") { withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("jt2") { - sql("CREATE VIEW testView AS SELECT id FROM jt") + withView("testView") { + sql("CREATE VIEW testView AS SELECT id FROM jt") - val df = (1 until 10).map(i => i -> i).toDF("i", "j") - df.write.format("json").saveAsTable("jt2") - sql("CREATE VIEW IF NOT EXISTS testView AS SELECT * FROM jt2") + val df = (1 until 10).map(i => i -> i).toDF("i", "j") + df.write.format("json").saveAsTable("jt2") + sql("CREATE VIEW IF NOT EXISTS testView AS SELECT * FROM jt2") + + // make sure our view doesn't change. + checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) + } + } + } + } - // make sure our view doesn't change. + test(s"correctly handle CREATE OR REPLACE TEMPORARY VIEW") { + withTable("jt2") { + withView("testView") { + sql("CREATE OR REPLACE TEMPORARY VIEW testView AS SELECT id FROM jt") checkAnswer(sql("SELECT * FROM testView ORDER BY id"), (1 to 9).map(i => Row(i))) - sql("DROP VIEW testView") + + sql("CREATE OR REPLACE TEMPORARY VIEW testView AS SELECT id AS i, id AS j FROM jt") + // make sure the view has been changed. + checkAnswer(sql("SELECT * FROM testView ORDER BY i"), (1 to 9).map(i => Row(i, i))) + } + } + } + + test("should not allow ALTER VIEW AS when the view does not exist") { + assertNoSuchTable("ALTER VIEW testView AS SELECT 1, 2") + assertNoSuchTable("ALTER VIEW default.testView AS SELECT 1, 2") + } + + test("ALTER VIEW AS should try to alter temp view first if view name has no database part") { + withView("test_view") { + withTempView("test_view") { + sql("CREATE VIEW test_view AS SELECT 1 AS a, 2 AS b") + sql("CREATE TEMP VIEW test_view AS SELECT 1 AS a, 2 AS b") + + sql("ALTER VIEW test_view AS SELECT 3 AS i, 4 AS j") + + // The temporary view should be updated. + checkAnswer(spark.table("test_view"), Row(3, 4)) + + // The permanent view should stay same. + checkAnswer(spark.table("default.test_view"), Row(1, 2)) } } } + test("ALTER VIEW AS should alter permanent view if view name has database part") { + withView("test_view") { + withTempView("test_view") { + sql("CREATE VIEW test_view AS SELECT 1 AS a, 2 AS b") + sql("CREATE TEMP VIEW test_view AS SELECT 1 AS a, 2 AS b") + + sql("ALTER VIEW default.test_view AS SELECT 3 AS i, 4 AS j") + + // The temporary view should stay same. + checkAnswer(spark.table("test_view"), Row(1, 2)) + + // The permanent view should be updated. + checkAnswer(spark.table("default.test_view"), Row(3, 4)) + } + } + } + + test("ALTER VIEW AS should keep the previous table properties, comment, create_time, etc.") { + withView("test_view") { + sql( + """ + |CREATE VIEW test_view + |COMMENT 'test' + |TBLPROPERTIES ('key' = 'a') + |AS SELECT 1 AS a, 2 AS b + """.stripMargin) + + val catalog = spark.sessionState.catalog + val viewMeta = catalog.getTableMetadata(TableIdentifier("test_view")) + assert(viewMeta.comment == Some("test")) + assert(viewMeta.properties("key") == "a") + + sql("ALTER VIEW test_view AS SELECT 3 AS i, 4 AS j") + val updatedViewMeta = catalog.getTableMetadata(TableIdentifier("test_view")) + assert(updatedViewMeta.comment == Some("test")) + assert(updatedViewMeta.properties("key") == "a") + assert(updatedViewMeta.createTime == viewMeta.createTime) + // The view should be updated. + checkAnswer(spark.table("test_view"), Row(3, 4)) + } + } + Seq(true, false).foreach { enabled => val prefix = (if (enabled) "With" else "Without") + " canonical native view: " test(s"$prefix correctly handle CREATE OR REPLACE VIEW") { @@ -105,7 +310,8 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val e = intercept[AnalysisException] { sql("CREATE OR REPLACE VIEW IF NOT EXISTS testView AS SELECT id FROM jt") } - assert(e.message.contains("not allowed to define a view")) + assert(e.message.contains( + "CREATE VIEW with both IF NOT EXISTS and REPLACE is not allowed")) } } } @@ -187,11 +393,11 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { withSQLConf( SQLConf.NATIVE_VIEW.key -> "true", SQLConf.CANONICAL_NATIVE_VIEW.key -> "true") { withTable("add_col") { - sqlContext.range(10).write.saveAsTable("add_col") + spark.range(10).write.saveAsTable("add_col") withView("v") { sql("CREATE VIEW v AS SELECT * FROM add_col") - sqlContext.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") - checkAnswer(sql("SELECT * FROM v"), sqlContext.range(10).toDF()) + spark.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") + checkAnswer(sql("SELECT * FROM v"), spark.range(10).toDF()) } } } @@ -201,8 +407,8 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { // make sure the new flag can handle some complex cases like join and schema change. withSQLConf(SQLConf.NATIVE_VIEW.key -> "true") { withTable("jt1", "jt2") { - sqlContext.range(1, 10).toDF("id1").write.format("json").saveAsTable("jt1") - sqlContext.range(1, 10).toDF("id2").write.format("json").saveAsTable("jt2") + spark.range(1, 10).toDF("id1").write.format("json").saveAsTable("jt1") + spark.range(1, 10).toDF("id2").write.format("json").saveAsTable("jt2") sql("CREATE VIEW testView AS SELECT * FROM jt1 JOIN jt2 ON id1 == id2") checkAnswer(sql("SELECT * FROM testView ORDER BY id1"), (1 to 9).map(i => Row(i, i))) @@ -215,4 +421,27 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SPARK-14933 - create view from hive parquet tabale") { + withTable("t_part") { + withView("v_part") { + spark.sql("create table t_part stored as parquet as select 1 as a, 2 as b") + spark.sql("create view v_part as select * from t_part") + checkAnswer( + sql("select * from t_part"), + sql("select * from v_part")) + } + } + } + + test("SPARK-14933 - create view from hive orc tabale") { + withTable("t_orc") { + withView("v_orc") { + spark.sql("create table t_orc stored as orc as select 1 as a, 2 as b") + spark.sql("create view v_orc as select * from t_orc") + checkAnswer( + sql("select * from t_orc"), + sql("select * from v_orc")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 19e8025d6b7c9..a8e81d7a3c42a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.scalatest.exceptions.TestFailedException -import org.apache.spark.TaskContext +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -29,7 +29,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { - import hiveContext.implicits._ + import spark.implicits._ private val noSerdeIOSchema = HiveScriptIOSchema( inputRowFormat = Seq.empty, @@ -109,6 +109,22 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton { } assert(e.getMessage().contains("intentional exception")) } + + test("SPARK-14400 script transformation should fail for bad script command") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + + val e = intercept[SparkException] { + val plan = + new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "some_non_existent_command", + output = Seq(AttributeReference("a", StringType)()), + child = rowsDf.queryExecution.sparkPlan, + ioschema = serdeIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + } + assert(e.getMessage.contains("Subprocess exited with status")) + } } private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala index c6b7eb63662c5..0ff3511c87a4f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -247,4 +247,16 @@ class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleto |from part """.stripMargin)) } + + test("SPARK-16646: LAST_VALUE(FALSE) OVER ()") { + checkAnswer(sql("SELECT LAST_VALUE(FALSE) OVER ()"), Row(false)) + checkAnswer(sql("SELECT LAST_VALUE(FALSE, FALSE) OVER ()"), Row(false)) + checkAnswer(sql("SELECT LAST_VALUE(TRUE, TRUE) OVER ()"), Row(true)) + } + + test("SPARK-16646: FIRST_VALUE(FALSE) OVER ()") { + checkAnswer(sql("SELECT FIRST_VALUE(FALSE) OVER ()"), Row(false)) + checkAnswer(sql("SELECT FIRST_VALUE(FALSE, FALSE) OVER ()"), Row(false)) + checkAnswer(sql("SELECT FIRST_VALUE(TRUE, TRUE) OVER ()"), Row(true)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala index ddabab3a14b51..8c027f9935f87 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.orc import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ @@ -54,7 +55,7 @@ class OrcFilterSuite extends QueryTest with OrcTest { DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) assert(selectedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(selectedFilters.toArray) + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") checker(maybeFilter.get) } @@ -78,10 +79,28 @@ class OrcFilterSuite extends QueryTest with OrcTest { checkFilterPredicate(df, predicate, checkLogicalOperator) } - test("filter pushdown - boolean") { - withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => - checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) - } + private def checkNoFilterPredicate + (predicate: Predicate) + (implicit df: DataFrame): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[HadoopFsRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.toArray) + assert(maybeFilter.isEmpty, s"Could generate filter predicate for $selectedFilters") } test("filter pushdown - integer") { @@ -189,16 +208,6 @@ class OrcFilterSuite extends QueryTest with OrcTest { } } - test("filter pushdown - binary") { - implicit class IntToBinary(int: Int) { - def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) - } - - withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => - checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) - } - } - test("filter pushdown - combinations with logical operators") { withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked @@ -238,4 +247,40 @@ class OrcFilterSuite extends QueryTest with OrcTest { ) } } + + test("no filter pushdown - non-supported types") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8) + } + // ArrayType + withOrcDataFrame((1 to 4).map(i => Tuple1(Array(i)))) { implicit df => + checkNoFilterPredicate('_1.isNull) + } + // DecimalType + withOrcDataFrame((1 to 4).map(i => Tuple1(BigDecimal.valueOf(i)))) { implicit df => + checkNoFilterPredicate('_1 <= BigDecimal.valueOf(4)) + } + // BinaryType + withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkNoFilterPredicate('_1 <=> 1.b) + } + // BooleanType + withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkNoFilterPredicate('_1 === true) + } + // TimestampType + val stringTimestamp = "2015-08-20 15:57:00" + withOrcDataFrame(Seq(Tuple1(Timestamp.valueOf(stringTimestamp)))) { implicit df => + checkNoFilterPredicate('_1 <=> Timestamp.valueOf(stringTimestamp)) + } + // DateType + val stringDate = "2015-01-01" + withOrcDataFrame(Seq(Tuple1(Date.valueOf(stringDate)))) { implicit df => + checkNoFilterPredicate('_1 === Date.valueOf(stringDate)) + } + // MapType + withOrcDataFrame((1 to 4).map(i => Tuple1(Map(i -> i)))) { implicit df => + checkNoFilterPredicate('_1.isNotNull) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index b97da1ffdc9f6..463c368fc42b1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.hive.orc import java.io.File import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.hive.ql.io.orc.{CompressionKind, OrcFile} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.Row @@ -31,7 +30,7 @@ import org.apache.spark.sql.types._ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { import testImplicits._ - override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + override val dataSourceName: String = classOf[OrcFileFormat].getCanonicalName // ORC does not play well with NullType and UDT. override protected def supportsDataType(dataType: DataType): Boolean = dataType match { @@ -60,7 +59,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - hiveContext.read.options(Map( + spark.read.options(Map( "path" -> file.getCanonicalPath, "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } @@ -75,11 +74,11 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.orc(path) checkAnswer( - sqlContext.read.orc(path).where("not (a = 2) or not(b in ('1'))"), + spark.read.orc(path).where("not (a = 2) or not(b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) checkAnswer( - sqlContext.read.orc(path).where("not (a = 2 and b in ('1'))"), + spark.read.orc(path).where("not (a = 2 and b in ('1'))"), (1 to 5).map(i => Row(i, (i % 2).toString))) } } @@ -94,18 +93,29 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { .orc(path) // Check if this is compressed as ZLIB. - val conf = sqlContext.sessionState.newHadoopConf() + val conf = spark.sessionState.newHadoopConf() val fs = FileSystem.getLocal(conf) val maybeOrcFile = new File(path).listFiles().find(_.getName.endsWith(".zlib.orc")) assert(maybeOrcFile.isDefined) - val orcFilePath = new Path(maybeOrcFile.get.toPath.toString) - val orcReader = OrcFile.createReader(orcFilePath, OrcFile.readerOptions(conf)) - assert(orcReader.getCompression == CompressionKind.ZLIB) + val orcFilePath = maybeOrcFile.get.toPath.toString + val expectedCompressionKind = + OrcFileOperator.getFileReader(orcFilePath).get.getCompression + assert("ZLIB" === expectedCompressionKind.name()) - val copyDf = sqlContext + val copyDf = spark .read .orc(path) checkAnswer(df, copyDf) } } + + test("Default compression codec is snappy for ORC compression") { + withTempPath { file => + spark.range(0, 10).write + .orc(file.getCanonicalPath) + val expectedCompressionKind = + OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + assert("SNAPPY" === expectedCompressionKind.name()) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 6161412a49775..d1ce3f1e2f058 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -37,8 +37,8 @@ case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: St // TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { - import hiveContext._ - import hiveContext.implicits._ + import spark._ + import spark.implicits._ val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultStrVal @@ -59,7 +59,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B } protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally hiveContext.dropTempTable(tableName) + try f finally spark.catalog.dropTempView(tableName) } protected def makePartitionDir( @@ -90,7 +90,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.orc(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).createOrReplaceTempView("t") withTempTable("t") { checkAnswer( @@ -137,7 +137,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.orc(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).createOrReplaceTempView("t") withTempTable("t") { checkAnswer( @@ -189,7 +189,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B read .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) .orc(base.getCanonicalPath) - .registerTempTable("t") + .createOrReplaceTempView("t") withTempTable("t") { checkAnswer( @@ -231,7 +231,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with TestHiveSingleton with B read .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) .orc(base.getCanonicalPath) - .registerTempTable("t") + .createOrReplaceTempView("t") withTempTable("t") { checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index fb678be234a27..8f8768217788e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive.orc -import java.io.File import java.nio.charset.StandardCharsets import org.scalatest.BeforeAndAfterAll @@ -25,7 +24,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.{HiveUtils, MetastoreRelation} import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf @@ -53,12 +52,6 @@ case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { - def getTempFilePath(prefix: String, suffix: String = ""): File = { - val tempFile = File.createTempFile(prefix, suffix) - tempFile.delete() - tempFile - } - test("Read/write All Types") { val data = (0 to 255).map { i => (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) @@ -66,7 +59,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - sqlContext.read.orc(file), + spark.read.orc(file), data.toDF().collect()) } } @@ -98,8 +91,8 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("Creating case class RDD table") { val data = (1 to 100).map(i => (i, s"val_$i")) - sparkContext.parallelize(data).toDF().registerTempTable("t") - withTempTable("t") { + sparkContext.parallelize(data).toDF().createOrReplaceTempView("t") + withTempView("t") { checkAnswer(sql("SELECT * FROM t"), data.toDF().collect()) } } @@ -153,11 +146,11 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("save and load case class RDD with `None`s as orc") { val data = ( - None: Option[Int], - None: Option[Long], - None: Option[Float], - None: Option[Double], - None: Option[Boolean] + Option.empty[Int], + Option.empty[Long], + Option.empty[Float], + Option.empty[Double], + Option.empty[Boolean] ) :: Nil withOrcFile(data) { file => @@ -167,10 +160,10 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - // Hive supports zlib, snappy and none for Hive 1.2.1. - test("Compression options for writing to an ORC file (SNAPPY, ZLIB and NONE)") { + test("SPARK-16610: Respect orc.compress option when compression is unset") { + // Respect `orc.compress`. withTempPath { file => - sqlContext.range(0, 10).write + spark.range(0, 10).write .option("orc.compress", "ZLIB") .orc(file.getCanonicalPath) val expectedCompressionKind = @@ -178,18 +171,41 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { assert("ZLIB" === expectedCompressionKind.name()) } + // `compression` overrides `orc.compress`. withTempPath { file => - sqlContext.range(0, 10).write + spark.range(0, 10).write + .option("compression", "ZLIB") .option("orc.compress", "SNAPPY") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + assert("ZLIB" === expectedCompressionKind.name()) + } + } + + // Hive supports zlib, snappy and none for Hive 1.2.1. + test("Compression options for writing to an ORC file (SNAPPY, ZLIB and NONE)") { + withTempPath { file => + spark.range(0, 10).write + .option("compression", "ZLIB") + .orc(file.getCanonicalPath) + val expectedCompressionKind = + OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression + assert("ZLIB" === expectedCompressionKind.name()) + } + + withTempPath { file => + spark.range(0, 10).write + .option("compression", "SNAPPY") + .orc(file.getCanonicalPath) + val expectedCompressionKind = + OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression assert("SNAPPY" === expectedCompressionKind.name()) } withTempPath { file => - sqlContext.range(0, 10).write - .option("orc.compress", "NONE") + spark.range(0, 10).write + .option("compression", "NONE") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -200,8 +216,8 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // Following codec is not supported in Hive 1.2.1, ignore it now ignore("LZO compression options for writing to an ORC file not supported in Hive 1.2.1") { withTempPath { file => - sqlContext.range(0, 10).write - .option("orc.compress", "LZO") + spark.range(0, 10).write + .option("compression", "LZO") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -223,7 +239,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") withOrcTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) @@ -233,7 +249,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + createDataFrame(data).toDF("c1", "c2").createOrReplaceTempView("tmp") withOrcTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") checkAnswer(table("t"), data.map(Row.fromTuple)) @@ -301,12 +317,12 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(0, 10).select('id as "Acol").write.format("orc").save(path) - sqlContext.read.format("orc").load(path).schema("Acol") + spark.range(0, 10).select('id as "Acol").write.format("orc").save(path) + spark.read.format("orc").load(path).schema("Acol") intercept[IllegalArgumentException] { - sqlContext.read.format("orc").load(path).schema("acol") + spark.read.format("orc").load(path).schema("acol") } - checkAnswer(sqlContext.read.format("orc").load(path).select("acol").sort("acol"), + checkAnswer(spark.read.format("orc").load(path).select("acol").sort("acol"), (0 until 10).map(Row(_))) } } @@ -316,38 +332,38 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val path = dir.getCanonicalPath withTable("empty_orc") { - withTempTable("empty", "single") { - sqlContext.sql( + withTempView("empty", "single") { + spark.sql( s"""CREATE TABLE empty_orc(key INT, value STRING) |STORED AS ORC |LOCATION '$path' """.stripMargin) val emptyDF = Seq.empty[(Int, String)].toDF("key", "value").coalesce(1) - emptyDF.registerTempTable("empty") + emptyDF.createOrReplaceTempView("empty") // This creates 1 empty ORC file with Hive ORC SerDe. We are using this trick because // Spark SQL ORC data source always avoids write empty ORC files. - sqlContext.sql( + spark.sql( s"""INSERT INTO TABLE empty_orc |SELECT key, value FROM empty """.stripMargin) val errorMessage = intercept[AnalysisException] { - sqlContext.read.orc(path) + spark.read.orc(path) }.getMessage assert(errorMessage.contains("Unable to infer schema for ORC")) val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) - singleRowDF.registerTempTable("single") + singleRowDF.createOrReplaceTempView("single") - sqlContext.sql( + spark.sql( s"""INSERT INTO TABLE empty_orc |SELECT key, value FROM single """.stripMargin) - val df = sqlContext.read.orc(path) + val df = spark.read.orc(path) assert(df.schema === singleRowDF.schema.asNullable) checkAnswer(df, singleRowDF) } @@ -373,7 +389,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { // It needs to repartition data so that we can have several ORC files // in order to skip stripes in ORC. createDataFrame(data).toDF("a", "b").repartition(10).write.orc(path) - val df = sqlContext.read.orc(path) + val df = spark.read.orc(path) def checkPredicate(pred: Column, answer: Seq[Row]): Unit = { val sourceDf = stripSparkFilter(df.where(pred)) @@ -407,40 +423,66 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("SPARK-14070 Use ORC data source for SQL queries on ORC tables") { - withTempPath { dir => - withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true", - HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { - val path = dir.getCanonicalPath - - withTable("dummy_orc") { - withTempTable("single") { - sqlContext.sql( - s"""CREATE TABLE dummy_orc(key INT, value STRING) - |STORED AS ORC - |LOCATION '$path' - """.stripMargin) - - val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) - singleRowDF.registerTempTable("single") - - sqlContext.sql( - s"""INSERT INTO TABLE dummy_orc - |SELECT key, value FROM single - """.stripMargin) - - val df = sqlContext.sql("SELECT * FROM dummy_orc WHERE key=0") - checkAnswer(df, singleRowDF) - - val queryExecution = df.queryExecution - queryExecution.analyzed.collectFirst { - case _: LogicalRelation => () - }.getOrElse { - fail(s"Expecting the query plan to have LogicalRelation, but got:\n$queryExecution") + test("Verify the ORC conversion parameter: CONVERT_METASTORE_ORC") { + withTempView("single") { + val singleRowDF = Seq((0, "foo")).toDF("key", "value") + singleRowDF.createOrReplaceTempView("single") + + Seq("true", "false").foreach { orcConversion => + withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> orcConversion) { + withTable("dummy_orc") { + withTempPath { dir => + val path = dir.getCanonicalPath + spark.sql( + s""" + |CREATE TABLE dummy_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '$path' + """.stripMargin) + + spark.sql( + s""" + |INSERT INTO TABLE dummy_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = spark.sql("SELECT * FROM dummy_orc WHERE key=0") + checkAnswer(df, singleRowDF) + + val queryExecution = df.queryExecution + if (orcConversion == "true") { + queryExecution.analyzed.collectFirst { + case _: LogicalRelation => () + }.getOrElse { + fail(s"Expecting the query plan to convert orc to data sources, " + + s"but got:\n$queryExecution") + } + } else { + queryExecution.analyzed.collectFirst { + case _: MetastoreRelation => () + }.getOrElse { + fail(s"Expecting no conversion from orc to data sources, " + + s"but got:\n$queryExecution") + } + } } } } } } } + + test("SPARK-14962 Produce correct results on array type with isnotnull") { + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + val data = (0 until 10).map(i => Tuple1(Array(i))) + withOrcFile(data) { file => + val actual = spark + .read + .orc(file) + .where("_1 is not null") + val expected = data.toDF() + checkAnswer(actual, expected) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index bdd3428a89742..0f37cd7bf3652 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -24,11 +24,13 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils case class OrcData(intField: Int, stringField: String) abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { - import hiveContext._ + import spark._ var orcTableDir: File = null var orcTableAsDir: File = null @@ -36,21 +38,17 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA override def beforeAll(): Unit = { super.beforeAll() - orcTableAsDir = File.createTempFile("orctests", "sparksql") - orcTableAsDir.delete() - orcTableAsDir.mkdir() + orcTableAsDir = Utils.createTempDir("orctests", "sparksql") // Hack: to prepare orc data files using hive external tables - orcTableDir = File.createTempFile("orctests", "sparksql") - orcTableDir.delete() - orcTableDir.mkdir() + orcTableDir = Utils.createTempDir("orctests", "sparksql") import org.apache.spark.sql.hive.test.TestHive.implicits._ sparkContext .makeRDD(1 to 10) .map(i => OrcData(i, s"part-$i")) .toDF() - .registerTempTable(s"orc_temp_table") + .createOrReplaceTempView(s"orc_temp_table") sql( s"""CREATE EXTERNAL TABLE normal_orc( @@ -67,15 +65,6 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA """.stripMargin) } - override def afterAll(): Unit = { - try { - orcTableDir.delete() - orcTableAsDir.delete() - } finally { - super.afterAll() - } - } - test("create temporary orc table") { checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) @@ -163,16 +152,16 @@ class OrcSourceSuite extends OrcSuite { override def beforeAll(): Unit = { super.beforeAll() - hiveContext.sql( - s"""CREATE TEMPORARY TABLE normal_orc_source + spark.sql( + s"""CREATE TEMPORARY VIEW normal_orc_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' |) """.stripMargin) - hiveContext.sql( - s"""CREATE TEMPORARY TABLE normal_orc_as_source + spark.sql( + s"""CREATE TEMPORARY VIEW normal_orc_as_source |USING org.apache.spark.sql.hive.orc |OPTIONS ( | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' @@ -182,12 +171,16 @@ class OrcSourceSuite extends OrcSuite { test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { // The `LessThan` should be converted while the `StringContains` shouldn't + val schema = new StructType( + Array( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = true))) assertResult( """leaf-0 = (LESS_THAN a 10) |expr = leaf-0 """.stripMargin.trim ) { - OrcFilters.createFilter(Array( + OrcFilters.createFilter(schema, Array( LessThan("a", 10), StringContains("b", "prefix") )).get.toString @@ -199,7 +192,7 @@ class OrcSourceSuite extends OrcSuite { |expr = leaf-0 """.stripMargin.trim ) { - OrcFilters.createFilter(Array( + OrcFilters.createFilter(schema, Array( LessThan("a", 10), Not(And( GreaterThan("a", 1), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 637c10611afc6..7226ed521ef32 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -49,7 +49,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(sqlContext.read.orc(path))) + withOrcFile(data)(path => f(spark.read.orc(path))) } /** @@ -61,8 +61,8 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => - sqlContext.registerDataFrameAsTable(df, tableName) - withTempTable(tableName)(f) + df.createOrReplaceTempView(tableName) + withTempView(tableName)(f) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 1c1f6d910d63d..e92bbdea75a7b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExecutedCommandExec -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoDataSourceCommand, InsertIntoHadoopFsRelationCommand, LogicalRelation} import org.apache.spark.sql.hive.execution.HiveTableScanExec import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -57,7 +57,7 @@ case class ParquetDataWithKeyAndComplexTypes( */ class ParquetMetastoreSuite extends ParquetPartitioningTest { import hiveContext._ - import hiveContext.implicits._ + import spark.implicits._ override def beforeAll(): Unit = { super.beforeAll() @@ -171,8 +171,9 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)") } - (1 to 10).map(i => (i, s"str$i")).toDF("a", "b").registerTempTable("jt") - (1 to 10).map(i => Tuple1(Seq(new Integer(i), null))).toDF("a").registerTempTable("jt_array") + (1 to 10).map(i => (i, s"str$i")).toDF("a", "b").createOrReplaceTempView("jt") + (1 to 10).map(i => Tuple1(Seq(new Integer(i), null))).toDF("a") + .createOrReplaceTempView("jt_array") setConf(HiveUtils.CONVERT_METASTORE_PARQUET, true) } @@ -307,10 +308,10 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.sparkPlan match { - case ExecutedCommandExec(_: InsertIntoHadoopFsRelation) => // OK + case ExecutedCommandExec(_: InsertIntoHadoopFsRelationCommand) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[HadoopFsRelation ].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan. " + + s"${classOf[InsertIntoDataSourceCommand].getCanonicalName} should have been SparkPlan. " + s"However, found a ${o.toString} ") } @@ -337,10 +338,10 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.sparkPlan match { - case ExecutedCommandExec(_: InsertIntoHadoopFsRelation) => // OK + case ExecutedCommandExec(_: InsertIntoHadoopFsRelationCommand) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[HadoopFsRelation ].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan." + + s"${classOf[InsertIntoDataSourceCommand].getCanonicalName} should have been SparkPlan." + s"However, found a ${o.toString} ") } @@ -388,17 +389,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { withTable("nonPartitioned") { sql( - s"""CREATE TABLE nonPartitioned ( - | key INT, - | value STRING - |) - |STORED AS PARQUET - """.stripMargin) + """ + |CREATE TABLE nonPartitioned ( + | key INT, + | value STRING + |) + |STORED AS PARQUET + """.stripMargin) // First lookup fills the cache - val r1 = collectHadoopFsRelation (table("nonPartitioned")) + val r1 = collectHadoopFsRelation(table("nonPartitioned")) // Second lookup should reuse the cache - val r2 = collectHadoopFsRelation (table("nonPartitioned")) + val r2 = collectHadoopFsRelation(table("nonPartitioned")) // They should be the same instance assert(r1 eq r2) } @@ -407,18 +409,42 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { withTable("partitioned") { sql( - s"""CREATE TABLE partitioned ( - | key INT, - | value STRING - |) - |PARTITIONED BY (part INT) - |STORED AS PARQUET - """.stripMargin) + """ + |CREATE TABLE partitioned ( + | key INT, + | value STRING + |) + |PARTITIONED BY (part INT) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectHadoopFsRelation(table("partitioned")) + // Second lookup should reuse the cache + val r2 = collectHadoopFsRelation(table("partitioned")) + // They should be the same instance + assert(r1 eq r2) + } + } + + test("SPARK-15968: nonempty partitioned metastore Parquet table lookup should use cached " + + "relation") { + withTable("partitioned") { + sql( + """ + |CREATE TABLE partitioned ( + | key INT, + | value STRING + |) + |PARTITIONED BY (part INT) + |STORED AS PARQUET + """.stripMargin) + sql("INSERT INTO TABLE partitioned PARTITION(part=0) SELECT 1 as key, 'one' as value") // First lookup fills the cache - val r1 = collectHadoopFsRelation (table("partitioned")) + val r1 = collectHadoopFsRelation(table("partitioned")) // Second lookup should reuse the cache - val r2 = collectHadoopFsRelation (table("partitioned")) + val r2 = collectHadoopFsRelation(table("partitioned")) // They should be the same instance assert(r1 eq r2) } @@ -461,7 +487,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. - sessionState.invalidateTable("test_insert_parquet") + sessionState.refreshTable("test_insert_parquet") assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) sql( """ @@ -474,7 +500,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { sql("select * from test_insert_parquet"), sql("select a, b from jt").collect()) // Invalidate the cache. - sessionState.invalidateTable("test_insert_parquet") + sessionState.refreshTable("test_insert_parquet") assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) // Create a partitioned table. @@ -524,11 +550,52 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { |select b, '2015-04-02', a FROM jt """.stripMargin).collect()) - sessionState.invalidateTable("test_parquet_partitioned_cache_test") + sessionState.refreshTable("test_parquet_partitioned_cache_test") assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null) dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test") } + + test("SPARK-15248: explicitly added partitions should be readable") { + withTable("test_added_partitions", "test_temp") { + withTempDir { src => + val partitionDir = new File(src, "partition").getCanonicalPath + sql( + """ + |CREATE TABLE test_added_partitions (a STRING) + |PARTITIONED BY (b INT) + |STORED AS PARQUET + """.stripMargin) + + // Temp view that is used to insert data into partitioned table + Seq("foo", "bar").toDF("a").createOrReplaceTempView("test_temp") + sql("INSERT INTO test_added_partitions PARTITION(b='0') SELECT a FROM test_temp") + + checkAnswer( + sql("SELECT * FROM test_added_partitions"), + Seq(("foo", 0), ("bar", 0)).toDF("a", "b")) + + // Create partition without data files and check whether it can be read + sql(s"ALTER TABLE test_added_partitions ADD PARTITION (b='1') LOCATION '$partitionDir'") + checkAnswer( + sql("SELECT * FROM test_added_partitions"), + Seq(("foo", 0), ("bar", 0)).toDF("a", "b")) + + // Add data files to partition directory and check whether they can be read + sql("INSERT INTO TABLE test_added_partitions PARTITION (b=1) select 'baz' as a") + checkAnswer( + sql("SELECT * FROM test_added_partitions"), + Seq(("foo", 0), ("bar", 0), ("baz", 1)).toDF("a", "b")) + } + } + } + + test("self-join") { + val table = spark.table("normal_parquet") + val selfJoin = table.as("t1").join(table.as("t2")) + checkAnswer(selfJoin, + sql("SELECT * FROM normal_parquet x JOIN normal_parquet y")) + } } /** @@ -536,7 +603,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { */ class ParquetSourceSuite extends ParquetPartitioningTest { import testImplicits._ - import hiveContext._ + import spark._ override def beforeAll(): Unit = { super.beforeAll() @@ -547,7 +614,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { "normal_parquet") sql( s""" - create temporary table partitioned_parquet + CREATE TEMPORARY VIEW partitioned_parquet USING org.apache.spark.sql.parquet OPTIONS ( path '${partitionedTableDir.getCanonicalPath}' @@ -555,7 +622,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { """) sql( s""" - create temporary table partitioned_parquet_with_key + CREATE TEMPORARY VIEW partitioned_parquet_with_key USING org.apache.spark.sql.parquet OPTIONS ( path '${partitionedTableDirWithKey.getCanonicalPath}' @@ -563,7 +630,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { """) sql( s""" - create temporary table normal_parquet + CREATE TEMPORARY VIEW normal_parquet USING org.apache.spark.sql.parquet OPTIONS ( path '${new File(partitionedTableDir, "p=1").getCanonicalPath}' @@ -571,7 +638,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { """) sql( s""" - CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes + CREATE TEMPORARY VIEW partitioned_parquet_with_key_and_complextypes USING org.apache.spark.sql.parquet OPTIONS ( path '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' @@ -579,7 +646,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest { """) sql( s""" - CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes + CREATE TEMPORARY VIEW partitioned_parquet_with_complextypes USING org.apache.spark.sql.parquet OPTIONS ( path '${partitionedTableDirWithComplexTypes.getCanonicalPath}' @@ -634,13 +701,53 @@ class ParquetSourceSuite extends ParquetPartitioningTest { """.stripMargin) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Row("1st", "2nd", Seq(Row("val_a", "val_b")))) } } } } + test("Verify the PARQUET conversion parameter: CONVERT_METASTORE_PARQUET") { + withTempView("single") { + val singleRowDF = Seq((0, "foo")).toDF("key", "value") + singleRowDF.createOrReplaceTempView("single") + + Seq("true", "false").foreach { parquetConversion => + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> parquetConversion) { + val tableName = "test_parquet_ctas" + withTable(tableName) { + sql( + s""" + |CREATE TABLE $tableName STORED AS PARQUET + |AS SELECT tmp.key, tmp.value FROM single tmp + """.stripMargin) + + val df = spark.sql(s"SELECT * FROM $tableName WHERE key=0") + checkAnswer(df, singleRowDF) + + val queryExecution = df.queryExecution + if (parquetConversion == "true") { + queryExecution.analyzed.collectFirst { + case _: LogicalRelation => + }.getOrElse { + fail(s"Expecting the query plan to convert parquet to data sources, " + + s"but got:\n$queryExecution") + } + } else { + queryExecution.analyzed.collectFirst { + case _: MetastoreRelation => + }.getOrElse { + fail(s"Expecting no conversion from parquet to data sources, " + + s"but got:\n$queryExecution") + } + } + } + } + } + } + } + test("values in arrays and maps stored in parquet are always nullable") { val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index d271e55467c6f..3554b0e101d64 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -51,7 +51,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet .saveAsTable("bucketed_table") for (i <- 0 until 5) { - val table = hiveContext.table("bucketed_table").filter($"i" === i) + val table = spark.table("bucketed_table").filter($"i" === i) val query = table.queryExecution val output = query.analyzed.output val rdd = query.toRdd @@ -80,7 +80,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - val bucketedDataFrame = hiveContext.table("bucketed_table").select("i", "j", "k") + val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column assert(bucketColumnNames.length == 1) @@ -234,11 +234,14 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet private def testBucketing( bucketSpecLeft: Option[BucketSpec], bucketSpecRight: Option[BucketSpec], - joinColumns: Seq[String], + joinType: String = "inner", + joinCondition: (DataFrame, DataFrame) => Column, shuffleLeft: Boolean, shuffleRight: Boolean): Unit = { withTable("bucketed_table1", "bucketed_table2") { - def withBucket(writer: DataFrameWriter, bucketSpec: Option[BucketSpec]): DataFrameWriter = { + def withBucket( + writer: DataFrameWriter[Row], + bucketSpec: Option[BucketSpec]): DataFrameWriter[Row] = { bucketSpec.map { spec => writer.bucketBy( spec.numBuckets, @@ -252,14 +255,14 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0", SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - val t1 = hiveContext.table("bucketed_table1") - val t2 = hiveContext.table("bucketed_table2") - val joined = t1.join(t2, joinCondition(t1, t2, joinColumns)) + val t1 = spark.table("bucketed_table1") + val t2 = spark.table("bucketed_table2") + val joined = t1.join(t2, joinCondition(t1, t2), joinType) // First check the result is corrected. checkAnswer( joined.sort("bucketed_table1.k", "bucketed_table2.k"), - df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k")) + df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k")) assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec]) val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec] @@ -274,54 +277,96 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } - private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = { + private def joinCondition(joinCols: Seq[String]) (left: DataFrame, right: DataFrame): Column = { joinCols.map(col => left(col) === right(col)).reduce(_ && _) } test("avoid shuffle when join 2 bucketed tables") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) - testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false + ) } // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 ignore("avoid shuffle when join keys are a super-set of bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) - testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false + ) } test("only shuffle one side when join bucketed table and non-bucketed table") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) - testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = None, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = true + ) } test("only shuffle one side when 2 bucketed tables have different bucket number") { val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil)) val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil)) - testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = true + ) } test("only shuffle one side when 2 bucketed tables have different bucket keys") { val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil)) val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil)) - testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i")), + shuffleLeft = false, + shuffleRight = true + ) } test("shuffle when join keys are not equal to bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) - testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("j")), + shuffleLeft = true, + shuffleRight = true + ) } test("shuffle when join 2 bucketed tables with bucketing disabled") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { - testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = true, + shuffleRight = true + ) } } test("avoid shuffle when grouping keys are equal to bucket keys") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i", "j").saveAsTable("bucketed_table") - val tbl = hiveContext.table("bucketed_table") + val tbl = spark.table("bucketed_table") val agged = tbl.groupBy("i", "j").agg(max("k")) checkAnswer( @@ -335,7 +380,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet test("avoid shuffle when grouping keys are a super-set of bucket keys") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") - val tbl = hiveContext.table("bucketed_table") + val tbl = spark.table("bucketed_table") val agged = tbl.groupBy("i", "j").agg(max("k")) checkAnswer( @@ -346,14 +391,32 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } + test("SPARK-17698 Join predicates should not contain filter clauses") { + val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i"))) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinType = "fullouter", + joinCondition = (left: DataFrame, right: DataFrame) => { + val joinPredicates = left("i") === right("i") + val filterLeft = left("i") === Literal("1") + val filterRight = right("i") === Literal("1") + joinPredicates && filterLeft && filterRight + }, + shuffleLeft = false, + shuffleRight = false + ) + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") - val tableDir = new File(hiveContext.sparkSession.warehousePath, "bucketed_table") + val tableDir = new File(hiveContext + .sparkSession.warehousePath, "bucketed_table") Utils.deleteRecursively(tableDir) df1.write.parquet(tableDir.getAbsolutePath) - val agged = hiveContext.table("bucketed_table").groupBy("i").count() + val agged = spark.table("bucketed_table").groupBy("i").count() val error = intercept[RuntimeException] { agged.count() } @@ -361,4 +424,15 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet assert(error.toString contains "Invalid bucket file") } } + + test("disable bucketing when the output doesn't contain all bucketing columns") { + withTable("bucketed_table") { + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") + + checkAnswer(hiveContext.table("bucketed_table").select("j"), df1.select("j")) + + checkAnswer(hiveContext.table("bucketed_table").groupBy("j").agg(max("k")), + df1.groupBy("j").agg(max("k"))) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index a3e7737a7c059..997445114ba58 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -59,17 +59,28 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle intercept[SparkException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt")) } - test("write bucketed data to non-hive-table or existing hive table") { + test("write bucketed data using save()") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") - intercept[IllegalArgumentException](df.write.bucketBy(2, "i").parquet("/tmp/path")) - intercept[IllegalArgumentException](df.write.bucketBy(2, "i").json("/tmp/path")) - intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").parquet("/tmp/path") + } + assert(e.getMessage == "'save' does not support bucketing right now;") + } + + test("write bucketed data using insertInto()") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + + val e = intercept[AnalysisException] { + df.write.bucketBy(2, "i").insertInto("tt") + } + assert(e.getMessage == "'insertInto' does not support bucketing right now;") } private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") def tableDir: File = { - val identifier = hiveContext.sessionState.sqlParser.parseTableIdentifier("bucketed_table") + val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table") new File(URI.create(hiveContext.sessionState.catalog.hiveDefaultTableFilePath(identifier))) } @@ -105,7 +116,7 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } // Read the bucket file into a dataframe, so that it's easier to test. - val readBack = sqlContext.read.format(source) + val readBack = spark.read.format(source) .load(bucketFile.getAbsolutePath) .select(columns: _*) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index 08e83b7f6905a..f9387fae4a4ca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -34,7 +34,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton // Here we coalesce partition number to 1 to ensure that only a single task is issued. This // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` // directory while committing/aborting the job. See SPARK-8513 for more details. - val df = sqlContext.range(0, 10).coalesce(1) + val df = spark.range(0, 10).coalesce(1) intercept[SparkException] { df.write.format(dataSourceName).save(file.getCanonicalPath) } @@ -49,7 +49,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton withTempPath { file => // fail the job in the middle of writing val divideByZero = udf((x: Int) => { x / (x - 1)}) - val df = sqlContext.range(0, 10).coalesce(1).select(divideByZero(col("id"))) + val df = spark.range(0, 10).coalesce(1).select(divideByZero(col("id"))) SimpleTextRelation.callbackCalled = false intercept[SparkException] { @@ -66,7 +66,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton SimpleTextRelation.failCommitter = false withTempPath { file => // fail the job in the middle of writing - val df = sqlContext.range(0, 10).coalesce(1).select(col("id").mod(2).as("key"), col("id")) + val df = spark.range(0, 10).coalesce(1).select(col("id").mod(2).as("key"), col("id")) SimpleTextRelation.callbackCalled = false SimpleTextRelation.failWriter = true diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 67b403a9bd3a6..97f2b23823d7a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import java.io.File + import scala.util.Random import org.apache.hadoop.fs.Path @@ -35,7 +37,7 @@ import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with TestHiveSingleton { - import sqlContext.implicits._ + import spark.implicits._ val dataSourceName: String @@ -89,8 +91,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes yield Row(s"val_$i", s"val_$i", s"val_$i", s"val_$i", 1, 1, 1, 1)) // Self-join - df.registerTempTable("t") - withTempTable("t") { + df.createOrReplaceTempView("t") + withTempView("t") { checkAnswer( sql( """SELECT l.a, r.b, l.p1, r.p2 @@ -141,8 +143,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .add("index", IntegerType, nullable = false) .add("col", dataType, nullable = true) val rdd = - sqlContext.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) - val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator()))) + val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) df.write .mode("overwrite") @@ -151,7 +153,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .options(extraOptions) .save(path) - val loadedDF = sqlContext + val loadedDF = spark .read .format(dataSourceName) .option("dataSchema", df.schema.json) @@ -172,7 +174,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("path", file.getCanonicalPath) .option("dataSchema", dataSchema.json) .load(), @@ -186,7 +188,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath).orderBy("a"), testDF.union(testDF).orderBy("a").collect()) @@ -206,7 +208,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.mode(SaveMode.Ignore).format(dataSourceName).save(file.getCanonicalPath) val path = new Path(file.getCanonicalPath) - val fs = path.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) assert(fs.listStatus(path).isEmpty) } } @@ -220,7 +222,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(file.getCanonicalPath) checkQueries( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath)) } @@ -241,7 +243,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -263,7 +265,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.union(partitionedTestDF).collect()) @@ -285,7 +287,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(file.getCanonicalPath) checkAnswer( - sqlContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -321,7 +323,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), testDF.collect()) + checkAnswer(spark.table("t"), testDF.collect()) } } @@ -330,14 +332,13 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), testDF.union(testDF).orderBy("a").collect()) + checkAnswer(spark.table("t"), testDF.union(testDF).orderBy("a").collect()) } } test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { + withTable("t") { + sql("CREATE TABLE t(i INT) USING parquet") intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") } @@ -345,11 +346,10 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } test("saveAsTable()/load() - non-partitioned table - Ignore") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { + withTable("t") { + sql("CREATE TABLE t(i INT) USING parquet") testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") - assert(sqlContext.table("t").collect().isEmpty) + assert(spark.table("t").collect().isEmpty) } } @@ -360,18 +360,18 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkQueries(sqlContext.table("t")) + checkQueries(spark.table("t")) } } test("saveAsTable()/load() - partitioned table - boolean type") { - sqlContext.range(2) + spark.range(2) .select('id, ('id % 2 === 0).as("b")) .write.partitionBy("b").saveAsTable("t") withTable("t") { checkAnswer( - sqlContext.table("t").sort('id), + spark.table("t").sort('id), Row(0, true) :: Row(1, false) :: Nil ) } @@ -393,7 +393,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) + checkAnswer(spark.table("t"), partitionedTestDF.collect()) } } @@ -413,7 +413,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.union(partitionedTestDF).collect()) + checkAnswer(spark.table("t"), partitionedTestDF.union(partitionedTestDF).collect()) } } @@ -433,7 +433,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) + checkAnswer(spark.table("t"), partitionedTestDF.collect()) } } @@ -457,9 +457,9 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } test("saveAsTable()/load() - partitioned table - ErrorIfExists") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") + Seq.empty[(Int, String)].toDF().createOrReplaceTempView("t") - withTempTable("t") { + withTempView("t") { intercept[AnalysisException] { partitionedTestDF.write .format(dataSourceName) @@ -472,9 +472,9 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes } test("saveAsTable()/load() - partitioned table - Ignore") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") + Seq.empty[(Int, String)].toDF().createOrReplaceTempView("t") - withTempTable("t") { + withTempView("t") { partitionedTestDF.write .format(dataSourceName) .mode(SaveMode.Ignore) @@ -482,44 +482,226 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .partitionBy("p1", "p2") .saveAsTable("t") - assert(sqlContext.table("t").collect().isEmpty) + assert(spark.table("t").collect().isEmpty) } } - test("Hadoop style globbing") { + test("load() - with directory of unpartitioned data in nested subdirs") { + withTempPath { dir => + val subdir = new File(dir, "subdir") + + val dataInDir = Seq(1, 2, 3).toDF("value") + val dataInSubdir = Seq(4, 5, 6).toDF("value") + + /* + + Directory structure to be generated + + dir + | + |___ [ files of dataInDir ] + | + |___ subsubdir + | + |___ [ files of dataInSubdir ] + */ + + // Generated dataInSubdir, not data in dir + dataInSubdir.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .save(subdir.getCanonicalPath) + + // Inferring schema should throw error as it should not find any file to infer + val e = intercept[Exception] { + spark.read.format(dataSourceName).load(dir.getCanonicalPath) + } + + e match { + case _: AnalysisException => + assert(e.getMessage.contains("infer")) + + case _: java.util.NoSuchElementException if e.getMessage.contains("dataSchema") => + // Ignore error, the source format requires schema to be provided by user + // This is needed for SimpleTextHadoopFsRelationSuite as SimpleTextSource needs schema + + case _ => + fail("Unexpected error trying to infer schema from empty dir", e) + } + + /** Test whether data is read with the given path matches the expected answer */ + def testWithPath(path: File, expectedAnswer: Seq[Row]): Unit = { + val df = spark.read + .format(dataSourceName) + .schema(dataInDir.schema) // avoid schema inference for any format + .load(path.getCanonicalPath) + checkAnswer(df, expectedAnswer) + } + + // Verify that reading by path 'dir/' gives empty results as there are no files in 'file' + // and it should not pick up files in 'dir/subdir' + require(subdir.exists) + require(subdir.listFiles().exists(!_.isDirectory)) + testWithPath(dir, Seq.empty) + + // Verify that if there is data in dir, then reading by path 'dir/' reads only dataInDir + dataInDir.write + .format(dataSourceName) + .mode(SaveMode.Append) // append to prevent subdir from being deleted + .save(dir.getCanonicalPath) + require(dir.listFiles().exists(!_.isDirectory)) + require(subdir.exists()) + require(subdir.listFiles().exists(!_.isDirectory)) + testWithPath(dir, dataInDir.collect()) + } + } + + test("Hadoop style globbing - unpartitioned data") { withTempPath { file => + + val dir = file.getCanonicalPath + val subdir = new File(dir, "subdir") + val subsubdir = new File(subdir, "subsubdir") + val anotherSubsubdir = + new File(new File(dir, "another-subdir"), "another-subsubdir") + + val dataInSubdir = Seq(1, 2, 3).toDF("value") + val dataInSubsubdir = Seq(4, 5, 6).toDF("value") + val dataInAnotherSubsubdir = Seq(7, 8, 9).toDF("value") + + dataInSubdir.write + .format (dataSourceName) + .mode (SaveMode.Overwrite) + .save (subdir.getCanonicalPath) + + dataInSubsubdir.write + .format (dataSourceName) + .mode (SaveMode.Overwrite) + .save (subsubdir.getCanonicalPath) + + dataInAnotherSubsubdir.write + .format (dataSourceName) + .mode (SaveMode.Overwrite) + .save (anotherSubsubdir.getCanonicalPath) + + require(subdir.exists) + require(subdir.listFiles().exists(!_.isDirectory)) + require(subsubdir.exists) + require(subsubdir.listFiles().exists(!_.isDirectory)) + require(anotherSubsubdir.exists) + require(anotherSubsubdir.listFiles().exists(!_.isDirectory)) + + /* + Directory structure generated + + dir + | + |___ subdir + | | + | |___ [ files of dataInSubdir ] + | | + | |___ subsubdir + | | + | |___ [ files of dataInSubsubdir ] + | + | + |___ anotherSubdir + | + |___ anotherSubsubdir + | + |___ [ files of dataInAnotherSubsubdir ] + */ + + val schema = dataInSubdir.schema + + /** Check whether data is read with the given path matches the expected answer */ + def check(path: String, expectedDf: DataFrame): Unit = { + val df = spark.read + .format(dataSourceName) + .schema(schema) // avoid schema inference for any format, expected to be same format + .load(path) + checkAnswer(df, expectedDf) + } + + check(s"$dir/*/", dataInSubdir) + check(s"$dir/sub*/*", dataInSubdir.union(dataInSubsubdir)) + check(s"$dir/another*/*", dataInAnotherSubsubdir) + check(s"$dir/*/another*", dataInAnotherSubsubdir) + check(s"$dir/*/*", dataInSubdir.union(dataInSubsubdir).union(dataInAnotherSubsubdir)) + } + } + + test("Hadoop style globbing - partitioned data with schema inference") { + + // Tests the following on partition data + // - partitions are not discovered with globbing and without base path set. + // - partitions are discovered with globbing and base path set, though more detailed + // tests for this is in ParquetPartitionDiscoverySuite + + withTempPath { path => + val dir = path.getCanonicalPath partitionedTestDF.write .format(dataSourceName) .mode(SaveMode.Overwrite) .partitionBy("p1", "p2") - .save(file.getCanonicalPath) + .save(dir) + + def check( + path: String, + expectedResult: Either[DataFrame, String], + basePath: Option[String] = None + ): Unit = { + try { + val reader = spark.read + basePath.foreach(reader.option("basePath", _)) + val testDf = reader + .format(dataSourceName) + .load(path) + assert(expectedResult.isLeft, s"Error was expected with $path but result found") + checkAnswer(testDf, expectedResult.left.get) + } catch { + case e: java.util.NoSuchElementException if e.getMessage.contains("dataSchema") => + // Ignore error, the source format requires schema to be provided by user + // This is needed for SimpleTextHadoopFsRelationSuite as SimpleTextSource needs schema + + case e: Throwable => + assert(expectedResult.isRight, s"Was not expecting error with $path: " + e) + assert( + e.getMessage.contains(expectedResult.right.get), + s"Did not find expected error message wiht $path") + } + } - val df = sqlContext.read - .format(dataSourceName) - .option("dataSchema", dataSchema.json) - .option("basePath", file.getCanonicalPath) - .load(s"${file.getCanonicalPath}/p1=*/p2=???") - - val expectedPaths = Set( - s"${file.getCanonicalFile}/p1=1/p2=foo", - s"${file.getCanonicalFile}/p1=2/p2=foo", - s"${file.getCanonicalFile}/p1=1/p2=bar", - s"${file.getCanonicalFile}/p1=2/p2=bar" - ).map { p => - val path = new Path(p) - val fs = path.getFileSystem(sqlContext.sessionState.newHadoopConf()) - path.makeQualified(fs.getUri, fs.getWorkingDirectory).toString + object Error { + def apply(msg: String): Either[DataFrame, String] = Right(msg) } - val actualPaths = df.queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: HadoopFsRelation, _, _) => - relation.location.paths.map(_.toString).toSet - }.getOrElse { - fail("Expect an FSBasedRelation, but none could be found") + object Result { + def apply(df: DataFrame): Either[DataFrame, String] = Left(df) } - assert(actualPaths === expectedPaths) - checkAnswer(df, partitionedTestDF.collect()) + // ---- Without base path set ---- + // Should find all the data with partitioning columns + check(s"$dir", Result(partitionedTestDF)) + + // Should fail as globbing finds dirs without files, only subdirs in them. + check(s"$dir/*/", Error("please set \"basePath\"")) + check(s"$dir/p1=*/", Error("please set \"basePath\"")) + + // Should not find partition columns as the globs resolve to p2 dirs + // with files in them + check(s"$dir/*/*", Result(partitionedTestDF.drop("p1", "p2"))) + check(s"$dir/p1=*/p2=foo", Result(partitionedTestDF.filter("p2 = 'foo'").drop("p1", "p2"))) + check(s"$dir/p1=1/p2=???", Result(partitionedTestDF.filter("p1 = 1").drop("p1", "p2"))) + + // Should find all data without the partitioning columns as the globs resolve to the files + check(s"$dir/*/*/*", Result(partitionedTestDF.drop("p1", "p2"))) + + // ---- With base path set ---- + val resultDf = partitionedTestDF.select("a", "b", "p1", "p2") + check(path = s"$dir/*", Result(resultDf), basePath = Some(dir)) + check(path = s"$dir/*/*", Result(resultDf), basePath = Some(dir)) + check(path = s"$dir/*/*/*", Result(resultDf), basePath = Some(dir)) } } @@ -538,7 +720,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes 'p3.cast(FloatType).as('pf1), 'f) - withTempTable("t") { + withTempView("t") { input .write .format(dataSourceName) @@ -555,7 +737,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes val realData = input.collect() - checkAnswer(sqlContext.table("t"), realData ++ realData) + checkAnswer(spark.table("t"), realData ++ realData) } } } @@ -570,7 +752,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") withTable("t") { - checkAnswer(sqlContext.table("t").select('b, 'c, 'a), df.select('b, 'c, 'a).collect()) + checkAnswer(spark.table("t").select('b, 'c, 'a), df.select('b, 'c, 'a).collect()) } } @@ -582,7 +764,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes test("SPARK-8406: Avoids name collision while writing files") { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext + spark .range(10000) .repartition(250) .write @@ -591,7 +773,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .save(path) assertResult(10000) { - sqlContext + spark .read .format(dataSourceName) .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) @@ -610,7 +792,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes classOf[AlwaysFailParquetOutputCommitter].getName ) - val df = sqlContext.range(1, 10).toDF("i") + val df = spark.range(1, 10).toDF("i") withTempPath { dir => df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) // Because there data already exists, @@ -618,7 +800,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes // with file format and AlwaysFailOutputCommitter will not be used. df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) checkAnswer( - sqlContext.read + spark.read .format(dataSourceName) .option("dataSchema", df.schema.json) .options(extraOptions) @@ -666,12 +848,12 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes ) withTempPath { dir => val path = "file://" + dir.getCanonicalPath - val df1 = sqlContext.range(4) + val df1 = spark.range(4) df1.coalesce(1).write.mode("overwrite").options(options).format(dataSourceName).save(path) df1.coalesce(1).write.mode("append").options(options).format(dataSourceName).save(path) def checkLocality(): Unit = { - val df2 = sqlContext.read + val df2 = spark.read .format(dataSourceName) .option("dataSchema", df1.schema.json) .options(options) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala index ef37787137d07..d79edee5b1a4c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/JsonHadoopFsRelationSuite.scala @@ -53,7 +53,7 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - hiveContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -71,14 +71,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { val data = Row(Seq(1L, 2L, 3L), Map("m1" -> Row(4L))) :: Row(Seq(5L, 6L, 7L), Map("m2" -> Row(10L))) :: Nil - val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) // Write the data out. df.write.format(dataSourceName).save(file.getCanonicalPath) // Read it back and check the result. checkAnswer( - hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + spark.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), df ) } @@ -96,14 +96,14 @@ class JsonHadoopFsRelationSuite extends HadoopFsRelationTest { Row(new BigDecimal("10.02")) :: Row(new BigDecimal("20000.99")) :: Row(new BigDecimal("10000")) :: Nil - val df = hiveContext.createDataFrame(sparkContext.parallelize(data), schema) + val df = spark.createDataFrame(sparkContext.parallelize(data), schema) // Write the data out. df.write.format(dataSourceName).save(file.getCanonicalPath) // Read it back and check the result. checkAnswer( - hiveContext.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), + spark.read.format(dataSourceName).schema(schema).load(file.getCanonicalPath), df ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala index 1d104889feffb..8aa018d0a9ee5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -21,6 +21,7 @@ import java.io.File import com.google.common.io.Files import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ @@ -58,7 +59,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - hiveContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -76,7 +77,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { .format("parquet") .save(s"${dir.getCanonicalPath}/_temporary") - checkAnswer(hiveContext.read.format("parquet").load(dir.getCanonicalPath), df.collect()) + checkAnswer(spark.read.format("parquet").load(dir.getCanonicalPath), df.collect()) } } @@ -104,7 +105,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // This shouldn't throw anything. df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(hiveContext.read.format("parquet").load(path), df) + checkAnswer(spark.read.format("parquet").load(path), df) } } @@ -114,7 +115,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // Parquet doesn't allow field names with spaces. Here we are intentionally making an // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger // the bug. Please refer to spark-8079 for more details. - hiveContext.range(1, 10) + spark.range(1, 10) .withColumnRenamed("id", "a b") .write .format("parquet") @@ -124,23 +125,25 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } test("SPARK-8604: Parquet data source should write summary file while doing appending") { - withTempPath { dir => - val path = dir.getCanonicalPath - val df = sqlContext.range(0, 5).toDF() - df.write.mode(SaveMode.Overwrite).parquet(path) + withSQLConf(ParquetOutputFormat.ENABLE_JOB_SUMMARY -> "true") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(0, 5).toDF() + df.write.mode(SaveMode.Overwrite).parquet(path) - val summaryPath = new Path(path, "_metadata") - val commonSummaryPath = new Path(path, "_common_metadata") + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") - val fs = summaryPath.getFileSystem(sqlContext.sessionState.newHadoopConf()) - fs.delete(summaryPath, true) - fs.delete(commonSummaryPath, true) + val fs = summaryPath.getFileSystem(spark.sessionState.newHadoopConf()) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) - df.write.mode(SaveMode.Append).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df.union(df)) + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(spark.read.parquet(path), df.union(df)) - assert(fs.exists(summaryPath)) - assert(fs.exists(commonSummaryPath)) + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) + } } } @@ -148,8 +151,8 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) - val df = sqlContext.read.parquet(path).filter('a === 0).select('b) + spark.range(2).select('id as 'a, 'id as 'b).write.partitionBy("b").parquet(path) + val df = spark.read.parquet(path).filter('a === 0).select('b) val physicalPlan = df.queryExecution.sparkPlan assert(physicalPlan.collect { case p: execution.ProjectExec => p }.length === 1) @@ -170,7 +173,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // The schema consists of the leading columns of the first part-file // in the lexicographic order. - assert(sqlContext.read.parquet(dir.getCanonicalPath).schema.map(_.name) + assert(spark.read.parquet(dir.getCanonicalPath).schema.map(_.name) === Seq("a", "b", "c", "d", "part")) } } @@ -188,8 +191,8 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { Row(5, 127.toByte), Row(6, -44.toByte), Row(7, 23.toByte), Row(8, -95.toByte), Row(9, 127.toByte), Row(10, 13.toByte)) - val rdd = sqlContext.sparkContext.parallelize(data) - val df = sqlContext.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + val rdd = spark.sparkContext.parallelize(data) + val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) df.write .mode("overwrite") @@ -197,7 +200,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { .option("dataSchema", df.schema.json) .save(path) - val loadedDF = sqlContext + val loadedDF = spark .read .format(dataSourceName) .option("dataSchema", df.schema.json) @@ -221,7 +224,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { val compressedFiles = new File(path).listFiles() assert(compressedFiles.exists(_.getName.endsWith(".gz.parquet"))) - val copyDf = sqlContext + val copyDf = spark .read .parquet(path) checkAnswer(df, copyDf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index 9ad0887609ed9..a47a2246ddc3c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -60,7 +60,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) checkQueries( - hiveContext.read.format(dataSourceName) + spark.read.format(dataSourceName) .option("dataSchema", dataSchemaWithPartition.json) .load(file.getCanonicalPath)) } @@ -69,7 +69,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat test("test hadoop conf option propagation") { withTempPath { file => // Test write side - val df = sqlContext.range(10).selectExpr("cast(id as string)") + val df = spark.range(10).selectExpr("cast(id as string)") df.write .option("some-random-write-option", "hahah-WRITE") .option("some-null-value-option", null) // test null robustness @@ -78,7 +78,7 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with Predicat assert(SimpleTextRelation.lastHadoopConf.get.get("some-random-write-option") == "hahah-WRITE") // Test read side - val df1 = sqlContext.read + val df1 = spark.read .option("some-random-read-option", "hahah-READ") .option("some-null-value-option", null) // test null robustness .option("dataSchema", df.schema.json) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 0fa18414154ce..67a58a3859b84 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -29,11 +29,12 @@ import org.apache.spark.sql.{sources, Row, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedPredicate, InterpretedProjection, JoinedRow, Literal} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration -class SimpleTextSource extends FileFormat with DataSourceRegister { +class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { override def shortName(): String = "test" override def inferSchema( @@ -144,7 +145,7 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") + val uniqueWriteJobId = configuration.get(DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) diff --git a/sql/hivecontext-compatibility/pom.xml b/sql/hivecontext-compatibility/pom.xml deleted file mode 100644 index ed9ef8e27919e..0000000000000 --- a/sql/hivecontext-compatibility/pom.xml +++ /dev/null @@ -1,57 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../../pom.xml - - - org.apache.spark - spark-hivecontext-compatibility_2.11 - jar - Spark Project HiveContext Compatibility - http://spark.apache.org/ - - hivecontext-compatibility - - - - - org.apache.spark - spark-hive_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/streaming/pom.xml b/streaming/pom.xml index 7d409c5d3b076..ab1f3ca190f7d 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml @@ -49,7 +49,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} @@ -77,6 +77,10 @@ org.eclipse.jetty jetty-servlet + + org.eclipse.jetty + jetty-servlets + @@ -93,6 +97,11 @@ selenium-java test + + org.seleniumhq.selenium + selenium-htmlunit-driver + test + org.mockito mockito-core diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java index 662889e779fb2..3c5cc7e2cae1b 100644 --- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java @@ -23,7 +23,7 @@ * This abstract class represents a handle that refers to a record written in a * {@link org.apache.spark.streaming.util.WriteAheadLog WriteAheadLog}. * It must contain all the information necessary for the record to be read and returned by - * an implemenation of the WriteAheadLog class. + * an implementation of the WriteAheadLog class. * * @see org.apache.spark.streaming.util.WriteAheadLog */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 7d8b8679c5944..0b11026863199 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -84,7 +84,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) assert(framework != null, "Checkpoint.framework is null") assert(graph != null, "Checkpoint.graph is null") assert(checkpointTime != null, "Checkpoint.checkpointTime is null") - logInfo("Checkpoint for time " + checkpointTime + " validated") + logInfo(s"Checkpoint for time $checkpointTime validated") } } @@ -103,7 +103,10 @@ object Checkpoint extends Logging { new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk") } - /** Get checkpoint files present in the give directory, ordered by oldest-first */ + /** + * @param checkpointDir checkpoint directory to read checkpoint files from + * @return checkpoint files from the `checkpointDir` checkpoint directory, ordered by oldest-first + */ def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { def sortFunc(path1: Path, path2: Path): Boolean = { @@ -121,11 +124,11 @@ object Checkpoint extends Logging { val filtered = paths.filter(p => REGEX.findFirstIn(p.toString).nonEmpty) filtered.sortWith(sortFunc) } else { - logWarning("Listing " + path + " returned null") + logWarning(s"Listing $path returned null") Seq.empty } } else { - logInfo("Checkpoint directory " + path + " does not exist") + logWarning(s"Checkpoint directory $path does not exist") Seq.empty } } @@ -151,7 +154,7 @@ object Checkpoint extends Logging { Utils.tryWithSafeFinally { // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version + // to find classes, which maybe the wrong class loader. Hence, an inherited version // of ObjectInputStream is used to explicitly use the current thread's default class // loader to find and load classes. This is a well know Java issue and has popped up // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) @@ -205,7 +208,7 @@ class CheckpointWriter( // time of a batch is greater than the batch interval, checkpointing for completing an old // batch may run after checkpointing of a new batch. If this happens, checkpoint of an old // batch actually has the latest information, so we want to recovery from it. Therefore, we - // also use the latest checkpoint time as the file name, so that we can recovery from the + // also use the latest checkpoint time as the file name, so that we can recover from the // latest checkpoint file. // // Note: there is only one thread writing the checkpoint files, so we don't need to worry @@ -216,8 +219,7 @@ class CheckpointWriter( while (attempts < MAX_ATTEMPTS && !stopped) { attempts += 1 try { - logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + checkpointFile - + "'") + logInfo(s"Saving checkpoint for time $checkpointTime to file '$checkpointFile'") // Write checkpoint to temp file if (fs.exists(tempFile)) { @@ -237,39 +239,38 @@ class CheckpointWriter( fs.delete(backupFile, true) // just in case it exists } if (!fs.rename(checkpointFile, backupFile)) { - logWarning("Could not rename " + checkpointFile + " to " + backupFile) + logWarning(s"Could not rename $checkpointFile to $backupFile") } } // Rename temp file to the final checkpoint file if (!fs.rename(tempFile, checkpointFile)) { - logWarning("Could not rename " + tempFile + " to " + checkpointFile) + logWarning(s"Could not rename $tempFile to $checkpointFile") } // Delete old checkpoint files val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach { file => - logInfo("Deleting " + file) + logInfo(s"Deleting $file") fs.delete(file, true) } } // All done, print success val finishTime = System.currentTimeMillis() - logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile + - "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms") + logInfo(s"Checkpoint for time $checkpointTime saved to file '$checkpointFile'" + + s", took ${bytes.length} bytes and ${finishTime - startTime} ms") jobGenerator.onCheckpointCompletion(checkpointTime, clearCheckpointDataLater) return } catch { case ioe: IOException => - logWarning("Error in attempt " + attempts + " of writing checkpoint to " - + checkpointFile, ioe) + val msg = s"Error in attempt $attempts of writing checkpoint to '$checkpointFile'" + logWarning(msg, ioe) fs = null } } - logWarning("Could not write checkpoint for time " + checkpointTime + " to file " - + checkpointFile + "'") + logWarning(s"Could not write checkpoint for time $checkpointTime to file '$checkpointFile'") } } @@ -278,7 +279,7 @@ class CheckpointWriter( val bytes = Checkpoint.serialize(checkpoint, conf) executor.execute(new CheckpointWriteHandler( checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) - logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") + logInfo(s"Submitted checkpoint of time ${checkpoint.checkpointTime} to writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) @@ -295,8 +296,8 @@ class CheckpointWriter( executor.shutdownNow() } val endTime = System.currentTimeMillis() - logInfo("CheckpointWriter executor terminated ? " + terminated + - ", waited for " + (endTime - startTime) + " ms.") + logInfo(s"CheckpointWriter executor terminated? $terminated," + + s" waited for ${endTime - startTime} ms.") stopped = true } } @@ -336,20 +337,20 @@ object CheckpointReader extends Logging { } // Try to read the checkpoint files in the order - logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) + logInfo(s"Checkpoint files found: ${checkpointFiles.mkString(",")}") var readError: Exception = null checkpointFiles.foreach { file => - logInfo("Attempting to load checkpoint from file " + file) + logInfo(s"Attempting to load checkpoint from file $file") try { val fis = fs.open(file) val cp = Checkpoint.deserialize(fis, conf) - logInfo("Checkpoint successfully loaded from file " + file) - logInfo("Checkpoint was generated at time " + cp.checkpointTime) + logInfo(s"Checkpoint successfully loaded from file $file") + logInfo(s"Checkpoint was generated at time ${cp.checkpointTime}") return Some(cp) } catch { case e: Exception => readError = e - logWarning("Error reading checkpoint from file " + file, e) + logWarning(s"Error reading checkpoint from file $file", e) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 42424d67d8838..3f560f889f055 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -120,7 +120,7 @@ sealed abstract class State[S] { def isTimingOut(): Boolean /** - * Get the state as an [[scala.Option]]. It will be `Some(state)` if it exists, otherwise `None`. + * Get the state as a [[scala.Option]]. It will be `Some(state)` if it exists, otherwise `None`. */ @inline final def getOption(): Option[S] = if (exists) Some(get()) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 928739a416f0f..6046426fdf8cb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.Queue import scala.reflect.ClassTag import scala.util.control.NonFatal -import org.apache.commons.lang.SerializationUtils +import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} @@ -322,7 +322,7 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream from network source hostname:port, where data is received + * Create an input stream from network source hostname:port, where data is received * as serialized blocks (serialized using the Spark's serializer) that can be directly * pushed into the block manager without deserializing them. This is the most efficient * way to receive data. @@ -341,7 +341,7 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream that monitors a Hadoop-compatible filesystem + * Create an input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. * Files must be written to the monitored directory by "moving" them from another * location within the same file system. File names starting with . are ignored. @@ -359,7 +359,7 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream that monitors a Hadoop-compatible filesystem + * Create an input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. * Files must be written to the monitored directory by "moving" them from another * location within the same file system. @@ -379,7 +379,7 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream that monitors a Hadoop-compatible filesystem + * Create an input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. * Files must be written to the monitored directory by "moving" them from another * location within the same file system. File names starting with . are ignored. @@ -403,7 +403,7 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream that monitors a Hadoop-compatible filesystem + * Create an input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as text files (using key as LongWritable, value * as Text and input format as TextInputFormat). Files must be written to the * monitored directory by "moving" them from another location within the same @@ -579,8 +579,7 @@ class StreamingContext private[streaming] ( sparkContext.setCallSite(startSite.get) sparkContext.clearJobGroup() sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") - savedProperties.set(SerializationUtils.clone( - sparkContext.localProperties.get()).asInstanceOf[Properties]) + savedProperties.set(SerializationUtils.clone(sparkContext.localProperties.get())) scheduler.start() } state = StreamingContextState.ACTIVE diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index 9697437dd2fe5..0b306a28d1a59 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -87,11 +87,11 @@ private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { // Gauge for last received batch, useful for monitoring the streaming job's running status, // displayed data -1 for any abnormal condition. registerGaugeWithOption("lastReceivedBatch_submissionTime", - _.lastCompletedBatch.map(_.submissionTime), -1L) + _.lastReceivedBatch.map(_.submissionTime), -1L) registerGaugeWithOption("lastReceivedBatch_processingStartTime", - _.lastCompletedBatch.flatMap(_.processingStartTime), -1L) + _.lastReceivedBatch.flatMap(_.processingStartTime), -1L) registerGaugeWithOption("lastReceivedBatch_processingEndTime", - _.lastCompletedBatch.flatMap(_.processingEndTime), -1L) + _.lastReceivedBatch.flatMap(_.processingEndTime), -1L) // Gauge for last received batch records. registerGauge("lastReceivedBatch_records", _.lastReceivedBatchRecords.values.sum, 0L) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 43632f37ccb16..a0a40fcee26d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -240,7 +240,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * This is more efficient than reduceByWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * @param reduceFunc associative and commutative reduce function - * @param invReduceFunc inverse reduce function + * @param invReduceFunc inverse reduce function; such that for all y, invertible x: + * `invReduceFunc(reduceFunc(x, y), x) = y` * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 2a80cf4466588..dec983165fb3b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -336,7 +336,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative and commutative reduce function - * @param invReduceFunc inverse function + * @param invReduceFunc inverse function; such that for all y, invertible x: + * `invReduceFunc(reduceFunc(x, y), x) = y` * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 7e78fa1d7e159..4c4376a089f59 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -349,7 +349,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { } /** - * Create an input stream from an queue of RDDs. In each batch, + * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * * NOTE: @@ -369,7 +369,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { } /** - * Create an input stream from an queue of RDDs. In each batch, + * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * * NOTE: @@ -393,7 +393,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { } /** - * Create an input stream from an queue of RDDs. In each batch, + * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * * NOTE: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index aeff4d7a98e7a..46bfc60856453 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -24,11 +24,14 @@ import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConverters._ import scala.language.existentials +import py4j.Py4JException + import org.apache.spark.SparkException import org.apache.spark.api.java._ +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Duration, Interval, Time} +import org.apache.spark.streaming.{Duration, Interval, StreamingContext, Time} import org.apache.spark.streaming.api.java._ import org.apache.spark.streaming.dstream._ import org.apache.spark.util.Utils @@ -157,7 +160,7 @@ private[python] object PythonTransformFunctionSerializer { /** * Helper functions, which are called from Python via Py4J. */ -private[python] object PythonDStream { +private[streaming] object PythonDStream { /** * can not access PythonTransformFunctionSerializer.register() via Py4j @@ -184,6 +187,32 @@ private[python] object PythonDStream { rdds.asScala.foreach(queue.add) queue } + + /** + * Stop [[StreamingContext]] if the Python process crashes (E.g., OOM) in case the user cannot + * stop it in the Python side. + */ + def stopStreamingContextIfPythonProcessIsDead(e: Throwable): Unit = { + // These two special messages are from: + // scalastyle:off + // https://github.com/bartdag/py4j/blob/5cbb15a21f857e8cf334ce5f675f5543472f72eb/py4j-java/src/main/java/py4j/CallbackClient.java#L218 + // https://github.com/bartdag/py4j/blob/5cbb15a21f857e8cf334ce5f675f5543472f72eb/py4j-java/src/main/java/py4j/CallbackClient.java#L340 + // scalastyle:on + if (e.isInstanceOf[Py4JException] && + ("Cannot obtain a new communication channel" == e.getMessage || + "Error while obtaining a new communication channel" == e.getMessage)) { + // Start a new thread to stop StreamingContext to avoid deadlock. + new Thread("Stop-StreamingContext") with Logging { + setDaemon(true) + + override def run(): Unit = { + logError( + "Cannot connect to Python process. It's probably dead. Stopping StreamingContext.", e) + StreamingContext.getActive().foreach(_.stop(stopSparkContext = false)) + } + }.start() + } + } } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 583f5a48d1a6c..fa15a0bf65ab9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -52,7 +52,7 @@ import org.apache.spark.util.{CallSite, Utils} * `join`. These operations are automatically available on any DStream of pairs * (e.g., DStream[(Int, Int)] through implicit conversions. * - * DStreams internally is characterized by a few basic properties: + * A DStream internally is characterized by a few basic properties: * - A list of other DStreams that the DStream depends on * - A time interval at which the DStream generates an RDD * - A function that is used to generate an RDD after each time interval @@ -157,7 +157,7 @@ abstract class DStream[T: ClassTag] ( def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { throw new UnsupportedOperationException( - "Cannot change storage level of an DStream after streaming context has started") + "Cannot change storage level of a DStream after streaming context has started") } this.storageLevel = level this @@ -176,7 +176,7 @@ abstract class DStream[T: ClassTag] ( def checkpoint(interval: Duration): DStream[T] = { if (isInitialized) { throw new UnsupportedOperationException( - "Cannot change checkpoint interval of an DStream after streaming context has started") + "Cannot change checkpoint interval of a DStream after streaming context has started") } persist() checkpointDuration = interval @@ -793,7 +793,8 @@ abstract class DStream[T: ClassTag] ( * This is more efficient than reduceByWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * @param reduceFunc associative and commutative reduce function - * @param invReduceFunc inverse reduce function + * @param invReduceFunc inverse reduce function; such that for all y, invertible x: + * `invReduceFunc(reduceFunc(x, y), x) = y` * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 36f50e04db422..ed9305875cb77 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -195,10 +195,16 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( ) logDebug(s"Getting new files for time $currentTime, " + s"ignoring files older than $modTimeIgnoreThreshold") - val filter = new PathFilter { + + val newFileFilter = new PathFilter { def accept(path: Path): Boolean = isNewFile(path, currentTime, modTimeIgnoreThreshold) } - val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) + val directoryFilter = new PathFilter { + override def accept(path: Path): Boolean = fs.getFileStatus(path).isDirectory + } + val directories = fs.globStatus(directoryPath, directoryFilter).map(_.getPath) + val newFiles = directories.flatMap(dir => + fs.listStatus(dir, newFileFilter).map(_.getPath.toString)) val timeTaken = clock.getTimeMillis() - lastNewFileFindingTime logInfo("Finding new files took " + timeTaken + " ms") logDebug("# cached file times = " + fileToModTime.size) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index b6394e36b5152..2f2a6d13dd79b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -290,7 +290,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative and commutative reduce function - * @param invReduceFunc inverse reduce function + * @param invReduceFunc inverse reduce function; such that for all y, invertible x: + * `invReduceFunc(reduceFunc(x, y), x) = y` * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala index 47eb9b806fa7d..0dde120927576 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala @@ -29,7 +29,7 @@ class TransformedDStream[U: ClassTag] ( transformFunc: (Seq[RDD[_]], Time) => RDD[U] ) extends DStream[U](parents.head.ssc) { - require(parents.length > 0, "List of DStreams to transform is empty") + require(parents.nonEmpty, "List of DStreams to transform is empty") require(parents.map(_.ssc).distinct.size == 1, "Some of the DStreams have different contexts") require(parents.map(_.slideDuration).distinct.size == 1, "Some of the DStreams have different slide durations") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 53fccd8d5e6ed..0b2ec298132ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -120,7 +120,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( val blockId = partition.blockId def getBlockFromBlockManager(): Option[Iterator[T]] = { - blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]]) + blockManager.get[T](blockId).map(_.data.asInstanceOf[Iterator[T]]) } def getBlockFromWriteAheadLog(): Iterator[T] = { @@ -163,7 +163,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( dataRead.rewind() } serializerManager - .dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream()) + .dataDeserializeStream( + blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 4592e015ed9a0..90309c0145ae1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -86,13 +86,13 @@ private[streaming] class BlockGenerator( /** * The BlockGenerator can be in 5 possible states, in the order as follows. * - * - Initialized: Nothing has been started + * - Initialized: Nothing has been started. * - Active: start() has been called, and it is generating blocks on added data. * - StoppedAddingData: stop() has been called, the adding of data has been stopped, * but blocks are still being generated and pushed. * - StoppedGeneratingBlocks: Generating of blocks has been stopped, but * they are still being pushed. - * - StoppedAll: Everything has stopped, and the BlockGenerator object can be GCed. + * - StoppedAll: Everything has been stopped, and the BlockGenerator object can be GCed. */ private object GeneratorState extends Enumeration { type GeneratorState = Value @@ -148,7 +148,7 @@ private[streaming] class BlockGenerator( blockIntervalTimer.stop(interruptTimer = false) synchronized { state = StoppedGeneratingBlocks } - // Wait for the queue to drain and mark generated as stopped + // Wait for the queue to drain and mark state as StoppedAll logInfo("Waiting for block pushing thread to terminate") blockPushingThread.join() synchronized { state = StoppedAll } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala index 47968afef2dbf..8c3a7977beae3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlock.scala @@ -31,5 +31,5 @@ private[streaming] case class ArrayBufferBlock(arrayBuffer: ArrayBuffer[_]) exte /** class representing a block received as an Iterator */ private[streaming] case class IteratorBlock(iterator: Iterator[_]) extends ReceivedBlock -/** class representing a block received as an ByteBuffer */ +/** class representing a block received as a ByteBuffer */ private[streaming] case class ByteBufferBlock(byteBuffer: ByteBuffer) extends ReceivedBlock diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index c4bc5cf3f6a58..80c07958b41f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -170,7 +170,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( */ def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { - var numRecords = None: Option[Long] + var numRecords = Option.empty[Long] // Serialize the block so that it can be inserted into both val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5157ca62dc449..d91a64df321a6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -32,7 +32,7 @@ import org.apache.spark.storage.StorageLevel * should define the setup steps necessary to start receiving data, * and `onStop()` should define the cleanup steps necessary to stop receiving data. * Exceptions while receiving can be handled either by restarting the receiver with `restart(...)` - * or stopped completely by `stop(...)` or + * or stopped completely by `stop(...)`. * * A custom receiver in Scala would look like this. * @@ -45,7 +45,7 @@ import org.apache.spark.storage.StorageLevel * // Call store(...) in those threads to store received data into Spark's memory. * * // Call stop(...), restart(...) or reportError(...) on any thread based on how - * // different errors needs to be handled. + * // different errors need to be handled. * * // See corresponding method documentation for more details * } @@ -71,7 +71,7 @@ import org.apache.spark.storage.StorageLevel * // Call store(...) in those threads to store received data into Spark's memory. * * // Call stop(...), restart(...) or reportError(...) on any thread based on how - * // different errors needs to be handled. + * // different errors need to be handled. * * // See corresponding method documentation for more details * } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 42fc84c19b971..faf6db82d5b18 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -79,7 +79,7 @@ private[streaming] abstract class ReceiverSupervisor( optionalBlockId: Option[StreamBlockId] ): Unit - /** Store a iterator of received data as a data block into Spark's memory. */ + /** Store an iterator of received data as a data block into Spark's memory. */ def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 4fb0f8caacbb6..5ba09a54af18d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -129,7 +129,7 @@ private[streaming] class ReceiverSupervisorImpl( pushAndReportBlock(ArrayBufferBlock(arrayBuffer), metadataOption, blockIdOption) } - /** Store a iterator of received data as a data block into Spark's memory. */ + /** Store an iterator of received data as a data block into Spark's memory. */ def pushIterator( iterator: Iterator[_], metadataOption: Option[Any], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala index f7b6584893c6e..fb5587edeccee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala @@ -38,7 +38,7 @@ import org.apache.spark.util.{Clock, Utils} * - Periodically take the average batch completion times and compare with the batch interval * - If (avg. proc. time / batch interval) >= scaling up ratio, then request more executors. * The number of executors requested is based on the ratio = (avg. proc. time / batch interval). - * - If (avg. proc. time / batch interval) <= scaling down ratio, then try to kill a executor that + * - If (avg. proc. time / batch interval) <= scaling down ratio, then try to kill an executor that * is not running a receiver. * * This features should ideally be used in conjunction with backpressure, as backpressure ensures diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index 4f124a1356b5a..8e1a090618433 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -67,7 +67,7 @@ private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging if (inputInfos.contains(inputInfo.inputStreamId)) { throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId} for batch" + - s"$batchTime is already added into InputInfoTracker, this is a illegal state") + s"$batchTime is already added into InputInfoTracker, this is an illegal state") } inputInfos += ((inputInfo.inputStreamId, inputInfo)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8f9421fc098ba..4489a5334d17e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -19,10 +19,10 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} -import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} +import org.apache.spark.streaming.api.python.PythonDStream import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils} @@ -239,7 +239,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Restarted JobGenerator at " + restartTime) } - /** Generate jobs and perform checkpoint for the given `time`. */ + /** Generate jobs and perform checkpointing for the given `time`. */ private def generateJobs(time: Time) { // Checkpoint all RDDs marked for checkpointing to ensure their lineages are // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). @@ -253,6 +253,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToInputInfos)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) + PythonDStream.stopStreamingContextIfPythonProcessIsDead(e) } eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index ac18f73ea86aa..f5ba5edad9e25 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,17 +17,17 @@ package org.apache.spark.streaming.scheduler -import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ import scala.util.Failure -import org.apache.commons.lang.SerializationUtils +import org.apache.commons.lang3.SerializationUtils import org.apache.spark.internal.Logging import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming._ +import org.apache.spark.streaming.api.python.PythonDStream import org.apache.spark.streaming.ui.UIUtils import org.apache.spark.util.{EventLoop, ThreadUtils} @@ -211,6 +211,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private def handleError(msg: String, e: Throwable) { logError(msg, e) ssc.waiter.notifyError(e) + PythonDStream.stopStreamingContextIfPythonProcessIsDead(e) } private class JobHandler(job: Job) extends Runnable with Logging { @@ -219,8 +220,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { def run() { val oldProps = ssc.sparkContext.getLocalProperties try { - ssc.sparkContext.setLocalProperties( - SerializationUtils.clone(ssc.savedProperties.get()).asInstanceOf[Properties]) + ssc.sparkContext.setLocalProperties(SerializationUtils.clone(ssc.savedProperties.get())) val formattedTime = UIUtils.formatBatchTime( job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala index 391a461f08125..4105171a3db24 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -31,7 +31,7 @@ import org.apache.spark.streaming.receiver.Receiver * all receivers at the same time. ReceiverTracker will call `scheduleReceivers` at this phase. * It will try to schedule receivers such that they are evenly distributed. ReceiverTracker * should update its `receiverTrackingInfoMap` according to the results of `scheduleReceivers`. - * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to an location list + * `ReceiverTrackingInfo.scheduledLocations` for each receiver should be set to a location list * that contains the scheduled locations. Then when a receiver is starting, it will send a * register request and `ReceiverTracker.registerReceiver` will be called. In * `ReceiverTracker.registerReceiver`, if a receiver's scheduled locations is set, it should diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 9aa2f0bbb9952..b9d898a72362e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.HashMap -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext import scala.language.existentials import scala.util.{Failure, Success} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 1ef26d2f865da..60122b48135f7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -86,7 +86,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { /** * Generate a row for a Spark Job. Because duplicated output op infos needs to be collapsed into - * one cell, we use "rowspan" for the first row of a output op. + * one cell, we use "rowspan" for the first row of an output op. */ private def generateNormalJobRow( outputOpData: OutputOperationUIData, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index c086df47d9835..61f852a0d31a7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -259,7 +259,7 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) // We use an Iterable rather than explicitly converting to a seq so that updates // will propagate val outputOpIdToSparkJobIds: Iterable[OutputOpIdAndSparkJobId] = - Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime).asScala) + Option(batchTimeToOutputOpIdSparkJobIdPair.get(batchTime)).map(_.asScala) .getOrElse(Seq.empty) _batchUIData.outputOpIdSparkJobIdPairs = outputOpIdToSparkJobIds } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index b97e24f28bfc6..46cd3092e9061 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -396,11 +396,11 @@ private[ui] class StreamingPage(parent: StreamingTab) .map(_.ceil.toLong) .getOrElse(0L) - val content = listener.receivedRecordRateWithBatchTime.toList.sortBy(_._1).map { + val content: Seq[Node] = listener.receivedRecordRateWithBatchTime.toList.sortBy(_._1).flatMap { case (streamId, recordRates) => generateInputDStreamRow( jsCollector, streamId, recordRates, minX, maxX, minY, maxYCalculated) - }.foldLeft[Seq[Node]](Nil)(_ ++ _) + } // scalastyle:off diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 01f0c4de9e3c9..3d54abd903b6d 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -36,7 +36,6 @@ import com.google.common.io.Files; import com.google.common.collect.Sets; -import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; @@ -46,6 +45,7 @@ import org.apache.spark.api.java.function.*; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; +import org.apache.spark.util.LongAccumulator; import org.apache.spark.util.Utils; // The test suite itself is Serializable so that anonymous Function implementations can be @@ -794,8 +794,8 @@ public Iterator call(String x) { @SuppressWarnings("unchecked") @Test public void testForeachRDD() { - final Accumulator accumRdd = ssc.sparkContext().accumulator(0); - final Accumulator accumEle = ssc.sparkContext().accumulator(0); + final LongAccumulator accumRdd = ssc.sparkContext().sc().longAccumulator(); + final LongAccumulator accumEle = ssc.sparkContext().sc().longAccumulator(); List> inputData = Arrays.asList( Arrays.asList(1,1,1), Arrays.asList(1,1,1)); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index bdbac64b9bc79..bd8f9950bf1c7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -71,7 +71,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => /** * Tests a streaming operation under checkpointing, by restarting the operation * from checkpoint file and verifying whether the final output is correct. - * The output is assumed to have come from a reliable queue which an replay + * The output is assumed to have come from a reliable queue which a replay * data as required. * * NOTE: This takes into consideration that the last batch processed before diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index a2653000af557..00d506c2f18bb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -140,10 +140,10 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("binary records stream") { - val testDir: File = null + var testDir: File = null try { val batchDuration = Seconds(2) - val testDir = Utils.createTempDir() + testDir = Utils.createTempDir() // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") Files.write("0\n", existingFile, StandardCharsets.UTF_8) @@ -198,6 +198,68 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { testFileStream(newFilesOnly = false) } + test("file input stream - wildcard") { + var testDir: File = null + try { + val batchDuration = Seconds(2) + testDir = Utils.createTempDir() + val testSubDir1 = Utils.createDirectory(testDir.toString, "tmp1") + val testSubDir2 = Utils.createDirectory(testDir.toString, "tmp2") + + // Create a file that exists before the StreamingContext is created: + val existingFile = new File(testDir, "0") + Files.write("0\n", existingFile, StandardCharsets.UTF_8) + assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) + + val pathWithWildCard = testDir.toString + "/*/" + + // Set up the streaming context and input streams + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.setTime(existingFile.lastModified + batchDuration.milliseconds) + val batchCounter = new BatchCounter(ssc) + // monitor "testDir/*/" + val fileStream = ssc.fileStream[LongWritable, Text, TextInputFormat]( + pathWithWildCard).map(_._2.toString) + val outputQueue = new ConcurrentLinkedQueue[Seq[String]] + val outputStream = new TestOutputStream(fileStream, outputQueue) + outputStream.register() + ssc.start() + + // Advance the clock so that the files are created after StreamingContext starts, but + // not enough to trigger a batch + clock.advance(batchDuration.milliseconds / 2) + + def createFileAndAdvenceTime(data: Int, dir: File): Unit = { + val file = new File(testSubDir1, data.toString) + Files.write(data + "\n", file, StandardCharsets.UTF_8) + assert(file.setLastModified(clock.getTimeMillis())) + assert(file.lastModified === clock.getTimeMillis()) + logInfo("Created file " + file) + // Advance the clock after creating the file to avoid a race when + // setting its modification time + clock.advance(batchDuration.milliseconds) + eventually(eventuallyTimeout) { + assert(batchCounter.getNumCompletedBatches === data) + } + } + // Over time, create files in the temp directory 1 + val input1 = Seq(1, 2, 3, 4, 5) + input1.foreach(i => createFileAndAdvenceTime(i, testSubDir1)) + + // Over time, create files in the temp directory 1 + val input2 = Seq(6, 7, 8, 9, 10) + input2.foreach(i => createFileAndAdvenceTime(i, testSubDir2)) + + // Verify that all the files have been read + val expectedOutput = (input1 ++ input2).map(_.toString).toSet + assert(outputQueue.asScala.flatten.toSet === expectedOutput) + } + } finally { + if (testDir != null) Utils.deleteRecursively(testDir) + } + } + test("multi-thread receiver") { // set up the test receiver val numThreads = 10 @@ -363,10 +425,10 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } def testFileStream(newFilesOnly: Boolean) { - val testDir: File = null + var testDir: File = null try { val batchDuration = Seconds(2) - val testDir = Utils.createTempDir() + testDir = Utils.createTempDir() // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") Files.write("0\n", existingFile, StandardCharsets.UTF_8) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 4be4882938df5..7e665454a5400 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -23,12 +23,14 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps +import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.Logging import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.netty.NettyBlockTransferService @@ -46,6 +48,7 @@ class ReceivedBlockHandlerSuite extends SparkFunSuite with BeforeAndAfter with Matchers + with LocalSparkContext with Logging { import WriteAheadLogBasedBlockHandler._ @@ -57,7 +60,8 @@ class ReceivedBlockHandlerSuite val hadoopConf = new Configuration() val streamId = 1 val securityMgr = new SecurityManager(conf) - val mapOutputTracker = new MapOutputTrackerMaster(conf) + val broadcastManager = new BroadcastManager(true, conf, securityMgr) + val mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) val shuffleManager = new SortShuffleManager(conf) val serializer = new KryoSerializer(conf) var serializerManager = new SerializerManager(serializer, conf) @@ -75,8 +79,10 @@ class ReceivedBlockHandlerSuite rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) conf.set("spark.driver.port", rpcEnv.address.port.toString) + sc = new SparkContext("local", "test", conf) blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", - new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) + new BlockManagerMasterEndpoint(rpcEnv, true, conf, + new LiveListenerBus(sc))), conf, true) storageLevel = StorageLevel.MEMORY_ONLY_SER blockManager = createBlockManager(blockManagerSize, conf) @@ -158,7 +164,7 @@ class ReceivedBlockHandlerSuite val bytes = reader.read(fileSegment) reader.close() serializerManager.dataDeserializeStream( - generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList + generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList } loggedData shouldEqual data } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 917232c9cdd63..1b1e21f6e5bab 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -215,7 +215,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def getCurrentLogFiles(logDirectory: File): Seq[String] = { try { if (logDirectory.exists()) { - logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString } + logDirectory.listFiles().filter { _.getName.startsWith("log") }.map { _.toString } } else { Seq.empty } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 806e181f61980..f1482e5c06cdc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -819,10 +819,13 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo ssc.checkpoint(checkpointDirectory) ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } ssc.start() - eventually(timeout(10000 millis)) { - assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + try { + eventually(timeout(30000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + } finally { + ssc.stop() } - ssc.stop() checkpointDirectory } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala index a1af95be81c8e..1a0460cd669af 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -119,7 +119,7 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10) - // prepare a series of batch updates, one every 20ms with an decreasing number of processed + // prepare a series of batch updates, one every 20ms with a decreasing number of processed // elements in each batch, but constant processing time, and no accumulated error. Even though // the integral part is non-zero, the estimated rate should follow only the proportional term, // asking for less and less elements diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 26b757cc2d535..46ab3ac8de3d4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -68,6 +68,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) listener.retainedCompletedBatches should be (Nil) + listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoSubmitted))) listener.lastCompletedBatch should be (None) listener.numUnprocessedBatches should be (1) listener.numTotalCompletedBatches should be (0) @@ -81,6 +82,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) listener.retainedCompletedBatches should be (Nil) + listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoStarted))) listener.lastCompletedBatch should be (None) listener.numUnprocessedBatches should be (1) listener.numTotalCompletedBatches should be (0) @@ -123,6 +125,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) listener.retainedCompletedBatches should be (List(BatchUIData(batchInfoCompleted))) + listener.lastReceivedBatch should be (Some(BatchUIData(batchInfoCompleted))) listener.lastCompletedBatch should be (Some(BatchUIData(batchInfoCompleted))) listener.numUnprocessedBatches should be (0) listener.numTotalCompletedBatches should be (1) diff --git a/tools/pom.xml b/tools/pom.xml index 9bb20e1381067..4c2dab471407e 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml diff --git a/yarn/pom.xml b/yarn/pom.xml index 328bb6678db99..2dc61d002f5ac 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.0.0-SNAPSHOT + 2.0.3-SNAPSHOT ../pom.xml @@ -30,6 +30,7 @@ Spark Project YARN yarn + 1.9 @@ -53,7 +54,7 @@ org.apache.spark - spark-test-tags_${scala.binary.version} + spark-tags_${scala.binary.version} org.apache.hadoop @@ -101,6 +102,10 @@ org.eclipse.jetty jetty-servlet + + org.eclipse.jetty + jetty-servlets + com.sun.jersey jersey-core test + ${jersey-1.version} com.sun.jersey jersey-json test + ${jersey-1.version} com.sun.jersey jersey-server test + ${jersey-1.version} + + + com.sun.jersey.contribs + jersey-guice + test + ${jersey-1.version}